llm.py改造 fastapi

This commit is contained in:
2025-10-28 11:49:07 +08:00
parent 45f69ab3d5
commit cbe0477ba1
3 changed files with 148 additions and 63 deletions

View File

@@ -15,22 +15,26 @@
#
import logging
import json
from flask import request
from flask_login import login_required, current_user
from typing import Optional
from fastapi import APIRouter, Depends, Query
from fastapi.responses import JSONResponse
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService
from api import settings
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result
from api.db import StatusEnum, LLMType
from api.db.db_models import TenantLLM
from api.utils.api_utils import get_json_result
from api.utils.base64_image import test_image
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
from api.models.llm_models import SetApiKeyRequest, AddLLMRequest, DeleteLLMRequest, DeleteFactoryRequest
from api.utils.api_utils import get_current_user
# 创建 FastAPI 路由器
router = APIRouter()
@manager.route('/factories', methods=['GET']) # noqa: F821
@login_required
def factories():
@router.get('/factories')
async def factories(current_user = Depends(get_current_user)):
try:
fac = LLMFactoriesService.get_all()
fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]]
@@ -50,21 +54,18 @@ def factories():
return server_error_response(e)
@manager.route('/set_api_key', methods=['POST']) # noqa: F821
@login_required
@validate_request("llm_factory", "api_key")
def set_api_key():
req = request.json
@router.post('/set_api_key')
async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_current_user)):
# test if api key works
chat_passed, embd_passed, rerank_passed = False, False, False
factory = req["llm_factory"]
factory = request.llm_factory
extra = {"provider": factory}
msg = ""
for llm in LLMService.query(fid=factory):
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
mdl = EmbeddingModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
request.api_key, llm.llm_name, base_url=request.base_url)
try:
arr, tc = mdl.encode(["Test if the api key is available"])
if len(arr[0]) == 0:
@@ -75,7 +76,7 @@ def set_api_key():
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
request.api_key, llm.llm_name, base_url=request.base_url, **extra)
try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
{"temperature": 0.9, 'max_tokens': 50})
@@ -88,7 +89,7 @@ def set_api_key():
elif not rerank_passed and llm.model_type == LLMType.RERANK:
assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet."
mdl = RerankModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
request.api_key, llm.llm_name, base_url=request.base_url)
try:
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
if len(arr) == 0 or tc == 0:
@@ -106,12 +107,9 @@ def set_api_key():
return get_data_error_result(message=msg)
llm_config = {
"api_key": req["api_key"],
"api_base": req.get("base_url", "")
"api_key": request.api_key,
"api_base": request.base_url or ""
}
for n in ["model_type", "llm_name"]:
if n in req:
llm_config[n] = req[n]
for llm in LLMService.query(fid=factory):
llm_config["max_tokens"]=llm.max_tokens
@@ -133,18 +131,15 @@ def set_api_key():
return get_json_result(data=True)
@manager.route('/add_llm', methods=['POST']) # noqa: F821
@login_required
@validate_request("llm_factory")
def add_llm():
req = request.json
factory = req["llm_factory"]
api_key = req.get("api_key", "x")
llm_name = req.get("llm_name")
@router.post('/add_llm')
async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_user)):
factory = request.llm_factory
api_key = request.api_key or "x"
llm_name = request.llm_name
def apikey_json(keys):
nonlocal req
return json.dumps({k: req.get(k, "") for k in keys})
nonlocal request
return json.dumps({k: getattr(request, k, "") for k in keys})
if factory == "VolcEngine":
# For VolcEngine, due to its special authentication method
@@ -152,12 +147,21 @@ def add_llm():
api_key = apikey_json(["ark_api_key", "endpoint_id"])
elif factory == "Tencent Hunyuan":
req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"])
return set_api_key()
# Create a temporary request object for set_api_key
temp_request = SetApiKeyRequest(
llm_factory=factory,
api_key=apikey_json(["hunyuan_sid", "hunyuan_sk"]),
base_url=request.api_base
)
return await set_api_key(temp_request, current_user)
elif factory == "Tencent Cloud":
req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"])
return set_api_key()
temp_request = SetApiKeyRequest(
llm_factory=factory,
api_key=apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"]),
base_url=request.api_base
)
return await set_api_key(temp_request, current_user)
elif factory == "Bedrock":
# For Bedrock, due to its special authentication method
@@ -177,9 +181,9 @@ def add_llm():
llm_name += "___VLLM"
elif factory == "XunFei Spark":
if req["model_type"] == "chat":
api_key = req.get("spark_api_password", "")
elif req["model_type"] == "tts":
if request.model_type == "chat":
api_key = request.spark_api_password or ""
elif request.model_type == "tts":
api_key = apikey_json(["spark_app_id", "spark_api_secret", "spark_api_key"])
elif factory == "BaiduYiyan":
@@ -197,11 +201,11 @@ def add_llm():
llm = {
"tenant_id": current_user.id,
"llm_factory": factory,
"model_type": req["model_type"],
"model_type": request.model_type,
"llm_name": llm_name,
"api_base": req.get("api_base", ""),
"api_base": request.api_base or "",
"api_key": api_key,
"max_tokens": req.get("max_tokens")
"max_tokens": request.max_tokens
}
msg = ""
@@ -290,33 +294,27 @@ def add_llm():
return get_json_result(data=True)
@manager.route('/delete_llm', methods=['POST']) # noqa: F821
@login_required
@validate_request("llm_factory", "llm_name")
def delete_llm():
req = request.json
@router.post('/delete_llm')
async def delete_llm(request: DeleteLLMRequest, current_user = Depends(get_current_user)):
TenantLLMService.filter_delete(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"],
TenantLLM.llm_name == req["llm_name"]])
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == request.llm_factory,
TenantLLM.llm_name == request.llm_name])
return get_json_result(data=True)
@manager.route('/delete_factory', methods=['POST']) # noqa: F821
@login_required
@validate_request("llm_factory")
def delete_factory():
req = request.json
@router.post('/delete_factory')
async def delete_factory(request: DeleteFactoryRequest, current_user = Depends(get_current_user)):
TenantLLMService.filter_delete(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == request.llm_factory])
return get_json_result(data=True)
@manager.route('/my_llms', methods=['GET']) # noqa: F821
@login_required
def my_llms():
@router.get('/my_llms')
async def my_llms(
include_details: bool = Query(False, description="是否包含详细信息"),
current_user = Depends(get_current_user)
):
try:
include_details = request.args.get('include_details', 'false').lower() == 'true'
if include_details:
res = {}
objs = TenantLLMService.query(tenant_id=current_user.id)
@@ -362,12 +360,13 @@ def my_llms():
return server_error_response(e)
@manager.route('/list', methods=['GET']) # noqa: F821
@login_required
def list_app():
@router.get('/list')
async def list_app(
model_type: Optional[str] = Query(None, description="模型类型"),
current_user = Depends(get_current_user)
):
self_deployed = ["Youdao", "FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"]
weighted = ["Youdao", "FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
model_type = request.args.get("model_type")
try:
objs = TenantLLMService.query(tenant_id=current_user.id)
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])