v0.21.1-fastapi
This commit is contained in:
@@ -15,26 +15,36 @@
|
||||
#
|
||||
import logging
|
||||
import json
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from api.apps.models.auth_dependencies import get_current_user
|
||||
from api.apps.models.llm_models import (
|
||||
SetApiKeyRequest,
|
||||
AddLLMRequest,
|
||||
DeleteLLMRequest,
|
||||
DeleteFactoryRequest,
|
||||
MyLLMsQuery,
|
||||
ListLLMsQuery,
|
||||
)
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService
|
||||
from api import settings
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result
|
||||
from api.db import StatusEnum, LLMType
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.utils.base64_image import test_image
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
|
||||
from api.models.llm_models import SetApiKeyRequest, AddLLMRequest, DeleteLLMRequest, DeleteFactoryRequest
|
||||
from api.utils.api_utils import get_current_user
|
||||
|
||||
# 创建 FastAPI 路由器
|
||||
# 创建路由器
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get('/factories')
|
||||
async def factories(current_user = Depends(get_current_user)):
|
||||
async def factories(
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取 LLM 工厂列表"""
|
||||
try:
|
||||
fac = LLMFactoriesService.get_all()
|
||||
fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]]
|
||||
@@ -55,17 +65,22 @@ async def factories(current_user = Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.post('/set_api_key')
|
||||
async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_current_user)):
|
||||
async def set_api_key(
|
||||
request: SetApiKeyRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""设置 API Key"""
|
||||
req = request.model_dump(exclude_unset=True)
|
||||
# test if api key works
|
||||
chat_passed, embd_passed, rerank_passed = False, False, False
|
||||
factory = request.llm_factory
|
||||
factory = req["llm_factory"]
|
||||
extra = {"provider": factory}
|
||||
msg = ""
|
||||
for llm in LLMService.query(fid=factory):
|
||||
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
|
||||
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
||||
mdl = EmbeddingModel[factory](
|
||||
request.api_key, llm.llm_name, base_url=request.base_url)
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url", ""))
|
||||
try:
|
||||
arr, tc = mdl.encode(["Test if the api key is available"])
|
||||
if len(arr[0]) == 0:
|
||||
@@ -76,7 +91,7 @@ async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_curr
|
||||
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
|
||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||
mdl = ChatModel[factory](
|
||||
request.api_key, llm.llm_name, base_url=request.base_url, **extra)
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url", ""), **extra)
|
||||
try:
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
|
||||
{"temperature": 0.9, 'max_tokens': 50})
|
||||
@@ -89,7 +104,7 @@ async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_curr
|
||||
elif not rerank_passed and llm.model_type == LLMType.RERANK:
|
||||
assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet."
|
||||
mdl = RerankModel[factory](
|
||||
request.api_key, llm.llm_name, base_url=request.base_url)
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url", ""))
|
||||
try:
|
||||
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
|
||||
if len(arr) == 0 or tc == 0:
|
||||
@@ -107,9 +122,12 @@ async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_curr
|
||||
return get_data_error_result(message=msg)
|
||||
|
||||
llm_config = {
|
||||
"api_key": request.api_key,
|
||||
"api_base": request.base_url or ""
|
||||
"api_key": req["api_key"],
|
||||
"api_base": req.get("base_url", "")
|
||||
}
|
||||
for n in ["model_type", "llm_name"]:
|
||||
if n in req:
|
||||
llm_config[n] = req[n]
|
||||
|
||||
for llm in LLMService.query(fid=factory):
|
||||
llm_config["max_tokens"]=llm.max_tokens
|
||||
@@ -132,14 +150,19 @@ async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_curr
|
||||
|
||||
|
||||
@router.post('/add_llm')
|
||||
async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_user)):
|
||||
factory = request.llm_factory
|
||||
api_key = request.api_key or "x"
|
||||
llm_name = request.llm_name
|
||||
async def add_llm(
|
||||
request: AddLLMRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""添加 LLM"""
|
||||
req = request.model_dump(exclude_unset=True)
|
||||
factory = req["llm_factory"]
|
||||
api_key = req.get("api_key", "x")
|
||||
llm_name = req.get("llm_name")
|
||||
|
||||
def apikey_json(keys):
|
||||
nonlocal request
|
||||
return json.dumps({k: getattr(request, k, "") for k in keys})
|
||||
nonlocal req
|
||||
return json.dumps({k: req.get(k, "") for k in keys})
|
||||
|
||||
if factory == "VolcEngine":
|
||||
# For VolcEngine, due to its special authentication method
|
||||
@@ -147,21 +170,28 @@ async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_use
|
||||
api_key = apikey_json(["ark_api_key", "endpoint_id"])
|
||||
|
||||
elif factory == "Tencent Hunyuan":
|
||||
# Create a temporary request object for set_api_key
|
||||
temp_request = SetApiKeyRequest(
|
||||
llm_factory=factory,
|
||||
api_key=apikey_json(["hunyuan_sid", "hunyuan_sk"]),
|
||||
base_url=request.api_base
|
||||
req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"])
|
||||
# 创建 SetApiKeyRequest 并调用 set_api_key 逻辑
|
||||
set_api_key_req = SetApiKeyRequest(
|
||||
llm_factory=req["llm_factory"],
|
||||
api_key=req["api_key"],
|
||||
base_url=req.get("api_base", req.get("base_url", "")),
|
||||
model_type=req.get("model_type"),
|
||||
llm_name=req.get("llm_name")
|
||||
)
|
||||
return await set_api_key(temp_request, current_user)
|
||||
return await set_api_key(set_api_key_req, current_user)
|
||||
|
||||
elif factory == "Tencent Cloud":
|
||||
temp_request = SetApiKeyRequest(
|
||||
llm_factory=factory,
|
||||
api_key=apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"]),
|
||||
base_url=request.api_base
|
||||
req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"])
|
||||
# 创建 SetApiKeyRequest 并调用 set_api_key 逻辑
|
||||
set_api_key_req = SetApiKeyRequest(
|
||||
llm_factory=req["llm_factory"],
|
||||
api_key=req["api_key"],
|
||||
base_url=req.get("api_base", req.get("base_url", "")),
|
||||
model_type=req.get("model_type"),
|
||||
llm_name=req.get("llm_name")
|
||||
)
|
||||
return await set_api_key(temp_request, current_user)
|
||||
return await set_api_key(set_api_key_req, current_user)
|
||||
|
||||
elif factory == "Bedrock":
|
||||
# For Bedrock, due to its special authentication method
|
||||
@@ -181,9 +211,9 @@ async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_use
|
||||
llm_name += "___VLLM"
|
||||
|
||||
elif factory == "XunFei Spark":
|
||||
if request.model_type == "chat":
|
||||
api_key = request.spark_api_password or ""
|
||||
elif request.model_type == "tts":
|
||||
if req["model_type"] == "chat":
|
||||
api_key = req.get("spark_api_password", "")
|
||||
elif req["model_type"] == "tts":
|
||||
api_key = apikey_json(["spark_app_id", "spark_api_secret", "spark_api_key"])
|
||||
|
||||
elif factory == "BaiduYiyan":
|
||||
@@ -198,14 +228,17 @@ async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_use
|
||||
elif factory == "Azure-OpenAI":
|
||||
api_key = apikey_json(["api_key", "api_version"])
|
||||
|
||||
elif factory == "OpenRouter":
|
||||
api_key = apikey_json(["api_key", "provider_order"])
|
||||
|
||||
llm = {
|
||||
"tenant_id": current_user.id,
|
||||
"llm_factory": factory,
|
||||
"model_type": request.model_type,
|
||||
"model_type": req["model_type"],
|
||||
"llm_name": llm_name,
|
||||
"api_base": request.api_base or "",
|
||||
"api_base": req.get("api_base", ""),
|
||||
"api_key": api_key,
|
||||
"max_tokens": request.max_tokens
|
||||
"max_tokens": req.get("max_tokens")
|
||||
}
|
||||
|
||||
msg = ""
|
||||
@@ -295,7 +328,11 @@ async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_use
|
||||
|
||||
|
||||
@router.post('/delete_llm')
|
||||
async def delete_llm(request: DeleteLLMRequest, current_user = Depends(get_current_user)):
|
||||
async def delete_llm(
|
||||
request: DeleteLLMRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""删除 LLM"""
|
||||
TenantLLMService.filter_delete(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == request.llm_factory,
|
||||
TenantLLM.llm_name == request.llm_name])
|
||||
@@ -303,7 +340,11 @@ async def delete_llm(request: DeleteLLMRequest, current_user = Depends(get_curre
|
||||
|
||||
|
||||
@router.post('/delete_factory')
|
||||
async def delete_factory(request: DeleteFactoryRequest, current_user = Depends(get_current_user)):
|
||||
async def delete_factory(
|
||||
request: DeleteFactoryRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""删除工厂"""
|
||||
TenantLLMService.filter_delete(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == request.llm_factory])
|
||||
return get_json_result(data=True)
|
||||
@@ -311,10 +352,13 @@ async def delete_factory(request: DeleteFactoryRequest, current_user = Depends(g
|
||||
|
||||
@router.get('/my_llms')
|
||||
async def my_llms(
|
||||
include_details: bool = Query(False, description="是否包含详细信息"),
|
||||
query: MyLLMsQuery = Depends(),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取我的 LLMs"""
|
||||
try:
|
||||
include_details = query.include_details.lower() == 'true'
|
||||
|
||||
if include_details:
|
||||
res = {}
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
@@ -362,11 +406,13 @@ async def my_llms(
|
||||
|
||||
@router.get('/list')
|
||||
async def list_app(
|
||||
model_type: Optional[str] = Query(None, description="模型类型"),
|
||||
query: ListLLMsQuery = Depends(),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""列出 LLMs"""
|
||||
self_deployed = ["Youdao", "FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"]
|
||||
weighted = ["Youdao", "FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
|
||||
model_type = query.model_type
|
||||
try:
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
|
||||
|
||||
Reference in New Issue
Block a user