v0.21.1-fastapi

This commit is contained in:
2025-11-04 16:06:36 +08:00
parent 3e58c3d0e9
commit d57b5d76ae
218 changed files with 19617 additions and 72339 deletions

View File

@@ -15,26 +15,36 @@
#
import logging
import json
from typing import Optional
from fastapi import APIRouter, Depends, Query
from fastapi.responses import JSONResponse
from api.apps.models.auth_dependencies import get_current_user
from api.apps.models.llm_models import (
SetApiKeyRequest,
AddLLMRequest,
DeleteLLMRequest,
DeleteFactoryRequest,
MyLLMsQuery,
ListLLMsQuery,
)
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService
from api import settings
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result
from api.utils.api_utils import server_error_response, get_data_error_result
from api.db import StatusEnum, LLMType
from api.db.db_models import TenantLLM
from api.utils.api_utils import get_json_result
from api.utils.base64_image import test_image
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
from api.models.llm_models import SetApiKeyRequest, AddLLMRequest, DeleteLLMRequest, DeleteFactoryRequest
from api.utils.api_utils import get_current_user
# 创建 FastAPI 路由器
# 创建路由器
router = APIRouter()
@router.get('/factories')
async def factories(current_user = Depends(get_current_user)):
async def factories(
current_user = Depends(get_current_user)
):
"""获取 LLM 工厂列表"""
try:
fac = LLMFactoriesService.get_all()
fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]]
@@ -55,17 +65,22 @@ async def factories(current_user = Depends(get_current_user)):
@router.post('/set_api_key')
async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_current_user)):
async def set_api_key(
request: SetApiKeyRequest,
current_user = Depends(get_current_user)
):
"""设置 API Key"""
req = request.model_dump(exclude_unset=True)
# test if api key works
chat_passed, embd_passed, rerank_passed = False, False, False
factory = request.llm_factory
factory = req["llm_factory"]
extra = {"provider": factory}
msg = ""
for llm in LLMService.query(fid=factory):
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
mdl = EmbeddingModel[factory](
request.api_key, llm.llm_name, base_url=request.base_url)
req["api_key"], llm.llm_name, base_url=req.get("base_url", ""))
try:
arr, tc = mdl.encode(["Test if the api key is available"])
if len(arr[0]) == 0:
@@ -76,7 +91,7 @@ async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_curr
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory](
request.api_key, llm.llm_name, base_url=request.base_url, **extra)
req["api_key"], llm.llm_name, base_url=req.get("base_url", ""), **extra)
try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
{"temperature": 0.9, 'max_tokens': 50})
@@ -89,7 +104,7 @@ async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_curr
elif not rerank_passed and llm.model_type == LLMType.RERANK:
assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet."
mdl = RerankModel[factory](
request.api_key, llm.llm_name, base_url=request.base_url)
req["api_key"], llm.llm_name, base_url=req.get("base_url", ""))
try:
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
if len(arr) == 0 or tc == 0:
@@ -107,9 +122,12 @@ async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_curr
return get_data_error_result(message=msg)
llm_config = {
"api_key": request.api_key,
"api_base": request.base_url or ""
"api_key": req["api_key"],
"api_base": req.get("base_url", "")
}
for n in ["model_type", "llm_name"]:
if n in req:
llm_config[n] = req[n]
for llm in LLMService.query(fid=factory):
llm_config["max_tokens"]=llm.max_tokens
@@ -132,14 +150,19 @@ async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_curr
@router.post('/add_llm')
async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_user)):
factory = request.llm_factory
api_key = request.api_key or "x"
llm_name = request.llm_name
async def add_llm(
request: AddLLMRequest,
current_user = Depends(get_current_user)
):
"""添加 LLM"""
req = request.model_dump(exclude_unset=True)
factory = req["llm_factory"]
api_key = req.get("api_key", "x")
llm_name = req.get("llm_name")
def apikey_json(keys):
nonlocal request
return json.dumps({k: getattr(request, k, "") for k in keys})
nonlocal req
return json.dumps({k: req.get(k, "") for k in keys})
if factory == "VolcEngine":
# For VolcEngine, due to its special authentication method
@@ -147,21 +170,28 @@ async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_use
api_key = apikey_json(["ark_api_key", "endpoint_id"])
elif factory == "Tencent Hunyuan":
# Create a temporary request object for set_api_key
temp_request = SetApiKeyRequest(
llm_factory=factory,
api_key=apikey_json(["hunyuan_sid", "hunyuan_sk"]),
base_url=request.api_base
req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"])
# 创建 SetApiKeyRequest 并调用 set_api_key 逻辑
set_api_key_req = SetApiKeyRequest(
llm_factory=req["llm_factory"],
api_key=req["api_key"],
base_url=req.get("api_base", req.get("base_url", "")),
model_type=req.get("model_type"),
llm_name=req.get("llm_name")
)
return await set_api_key(temp_request, current_user)
return await set_api_key(set_api_key_req, current_user)
elif factory == "Tencent Cloud":
temp_request = SetApiKeyRequest(
llm_factory=factory,
api_key=apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"]),
base_url=request.api_base
req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"])
# 创建 SetApiKeyRequest 并调用 set_api_key 逻辑
set_api_key_req = SetApiKeyRequest(
llm_factory=req["llm_factory"],
api_key=req["api_key"],
base_url=req.get("api_base", req.get("base_url", "")),
model_type=req.get("model_type"),
llm_name=req.get("llm_name")
)
return await set_api_key(temp_request, current_user)
return await set_api_key(set_api_key_req, current_user)
elif factory == "Bedrock":
# For Bedrock, due to its special authentication method
@@ -181,9 +211,9 @@ async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_use
llm_name += "___VLLM"
elif factory == "XunFei Spark":
if request.model_type == "chat":
api_key = request.spark_api_password or ""
elif request.model_type == "tts":
if req["model_type"] == "chat":
api_key = req.get("spark_api_password", "")
elif req["model_type"] == "tts":
api_key = apikey_json(["spark_app_id", "spark_api_secret", "spark_api_key"])
elif factory == "BaiduYiyan":
@@ -198,14 +228,17 @@ async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_use
elif factory == "Azure-OpenAI":
api_key = apikey_json(["api_key", "api_version"])
elif factory == "OpenRouter":
api_key = apikey_json(["api_key", "provider_order"])
llm = {
"tenant_id": current_user.id,
"llm_factory": factory,
"model_type": request.model_type,
"model_type": req["model_type"],
"llm_name": llm_name,
"api_base": request.api_base or "",
"api_base": req.get("api_base", ""),
"api_key": api_key,
"max_tokens": request.max_tokens
"max_tokens": req.get("max_tokens")
}
msg = ""
@@ -295,7 +328,11 @@ async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_use
@router.post('/delete_llm')
async def delete_llm(request: DeleteLLMRequest, current_user = Depends(get_current_user)):
async def delete_llm(
request: DeleteLLMRequest,
current_user = Depends(get_current_user)
):
"""删除 LLM"""
TenantLLMService.filter_delete(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == request.llm_factory,
TenantLLM.llm_name == request.llm_name])
@@ -303,7 +340,11 @@ async def delete_llm(request: DeleteLLMRequest, current_user = Depends(get_curre
@router.post('/delete_factory')
async def delete_factory(request: DeleteFactoryRequest, current_user = Depends(get_current_user)):
async def delete_factory(
request: DeleteFactoryRequest,
current_user = Depends(get_current_user)
):
"""删除工厂"""
TenantLLMService.filter_delete(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == request.llm_factory])
return get_json_result(data=True)
@@ -311,10 +352,13 @@ async def delete_factory(request: DeleteFactoryRequest, current_user = Depends(g
@router.get('/my_llms')
async def my_llms(
include_details: bool = Query(False, description="是否包含详细信息"),
query: MyLLMsQuery = Depends(),
current_user = Depends(get_current_user)
):
"""获取我的 LLMs"""
try:
include_details = query.include_details.lower() == 'true'
if include_details:
res = {}
objs = TenantLLMService.query(tenant_id=current_user.id)
@@ -362,11 +406,13 @@ async def my_llms(
@router.get('/list')
async def list_app(
model_type: Optional[str] = Query(None, description="模型类型"),
query: ListLLMsQuery = Depends(),
current_user = Depends(get_current_user)
):
"""列出 LLMs"""
self_deployed = ["Youdao", "FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"]
weighted = ["Youdao", "FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
model_type = query.model_type
try:
objs = TenantLLMService.query(tenant_id=current_user.id)
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])