llm.py改造 fastapi
This commit is contained in:
@@ -128,6 +128,7 @@ def setup_routes(app: FastAPI):
|
||||
from api.apps.file2document_app import router as file2document_router
|
||||
from api.apps.mcp_server_app import router as mcp_router
|
||||
from api.apps.tenant_app import router as tenant_router
|
||||
from api.apps.llm_app import router as llm_router
|
||||
|
||||
app.include_router(user_router, prefix=f"/{API_VERSION}/user", tags=["User"])
|
||||
app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KB"])
|
||||
@@ -136,6 +137,7 @@ def setup_routes(app: FastAPI):
|
||||
app.include_router(file2document_router, prefix=f"/{API_VERSION}/file2document", tags=["File2Document"])
|
||||
app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp", tags=["MCP"])
|
||||
app.include_router(tenant_router, prefix=f"/{API_VERSION}/tenant", tags=["Tenant"])
|
||||
app.include_router(llm_router, prefix=f"/{API_VERSION}/llm", tags=["LLM"])
|
||||
|
||||
def get_current_user_from_token(authorization: str):
|
||||
"""从token获取当前用户"""
|
||||
|
||||
@@ -15,22 +15,26 @@
|
||||
#
|
||||
import logging
|
||||
import json
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService
|
||||
from api import settings
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result
|
||||
from api.db import StatusEnum, LLMType
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.utils.base64_image import test_image
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
|
||||
from api.models.llm_models import SetApiKeyRequest, AddLLMRequest, DeleteLLMRequest, DeleteFactoryRequest
|
||||
from api.utils.api_utils import get_current_user
|
||||
|
||||
# 创建 FastAPI 路由器
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@manager.route('/factories', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def factories():
|
||||
@router.get('/factories')
|
||||
async def factories(current_user = Depends(get_current_user)):
|
||||
try:
|
||||
fac = LLMFactoriesService.get_all()
|
||||
fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]]
|
||||
@@ -50,21 +54,18 @@ def factories():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/set_api_key', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "api_key")
|
||||
def set_api_key():
|
||||
req = request.json
|
||||
@router.post('/set_api_key')
|
||||
async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_current_user)):
|
||||
# test if api key works
|
||||
chat_passed, embd_passed, rerank_passed = False, False, False
|
||||
factory = req["llm_factory"]
|
||||
factory = request.llm_factory
|
||||
extra = {"provider": factory}
|
||||
msg = ""
|
||||
for llm in LLMService.query(fid=factory):
|
||||
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
|
||||
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
||||
mdl = EmbeddingModel[factory](
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||
request.api_key, llm.llm_name, base_url=request.base_url)
|
||||
try:
|
||||
arr, tc = mdl.encode(["Test if the api key is available"])
|
||||
if len(arr[0]) == 0:
|
||||
@@ -75,7 +76,7 @@ def set_api_key():
|
||||
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
|
||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||
mdl = ChatModel[factory](
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
|
||||
request.api_key, llm.llm_name, base_url=request.base_url, **extra)
|
||||
try:
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
|
||||
{"temperature": 0.9, 'max_tokens': 50})
|
||||
@@ -88,7 +89,7 @@ def set_api_key():
|
||||
elif not rerank_passed and llm.model_type == LLMType.RERANK:
|
||||
assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet."
|
||||
mdl = RerankModel[factory](
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||
request.api_key, llm.llm_name, base_url=request.base_url)
|
||||
try:
|
||||
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
|
||||
if len(arr) == 0 or tc == 0:
|
||||
@@ -106,12 +107,9 @@ def set_api_key():
|
||||
return get_data_error_result(message=msg)
|
||||
|
||||
llm_config = {
|
||||
"api_key": req["api_key"],
|
||||
"api_base": req.get("base_url", "")
|
||||
"api_key": request.api_key,
|
||||
"api_base": request.base_url or ""
|
||||
}
|
||||
for n in ["model_type", "llm_name"]:
|
||||
if n in req:
|
||||
llm_config[n] = req[n]
|
||||
|
||||
for llm in LLMService.query(fid=factory):
|
||||
llm_config["max_tokens"]=llm.max_tokens
|
||||
@@ -133,18 +131,15 @@ def set_api_key():
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/add_llm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory")
|
||||
def add_llm():
|
||||
req = request.json
|
||||
factory = req["llm_factory"]
|
||||
api_key = req.get("api_key", "x")
|
||||
llm_name = req.get("llm_name")
|
||||
@router.post('/add_llm')
|
||||
async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_user)):
|
||||
factory = request.llm_factory
|
||||
api_key = request.api_key or "x"
|
||||
llm_name = request.llm_name
|
||||
|
||||
def apikey_json(keys):
|
||||
nonlocal req
|
||||
return json.dumps({k: req.get(k, "") for k in keys})
|
||||
nonlocal request
|
||||
return json.dumps({k: getattr(request, k, "") for k in keys})
|
||||
|
||||
if factory == "VolcEngine":
|
||||
# For VolcEngine, due to its special authentication method
|
||||
@@ -152,12 +147,21 @@ def add_llm():
|
||||
api_key = apikey_json(["ark_api_key", "endpoint_id"])
|
||||
|
||||
elif factory == "Tencent Hunyuan":
|
||||
req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"])
|
||||
return set_api_key()
|
||||
# Create a temporary request object for set_api_key
|
||||
temp_request = SetApiKeyRequest(
|
||||
llm_factory=factory,
|
||||
api_key=apikey_json(["hunyuan_sid", "hunyuan_sk"]),
|
||||
base_url=request.api_base
|
||||
)
|
||||
return await set_api_key(temp_request, current_user)
|
||||
|
||||
elif factory == "Tencent Cloud":
|
||||
req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"])
|
||||
return set_api_key()
|
||||
temp_request = SetApiKeyRequest(
|
||||
llm_factory=factory,
|
||||
api_key=apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"]),
|
||||
base_url=request.api_base
|
||||
)
|
||||
return await set_api_key(temp_request, current_user)
|
||||
|
||||
elif factory == "Bedrock":
|
||||
# For Bedrock, due to its special authentication method
|
||||
@@ -177,9 +181,9 @@ def add_llm():
|
||||
llm_name += "___VLLM"
|
||||
|
||||
elif factory == "XunFei Spark":
|
||||
if req["model_type"] == "chat":
|
||||
api_key = req.get("spark_api_password", "")
|
||||
elif req["model_type"] == "tts":
|
||||
if request.model_type == "chat":
|
||||
api_key = request.spark_api_password or ""
|
||||
elif request.model_type == "tts":
|
||||
api_key = apikey_json(["spark_app_id", "spark_api_secret", "spark_api_key"])
|
||||
|
||||
elif factory == "BaiduYiyan":
|
||||
@@ -197,11 +201,11 @@ def add_llm():
|
||||
llm = {
|
||||
"tenant_id": current_user.id,
|
||||
"llm_factory": factory,
|
||||
"model_type": req["model_type"],
|
||||
"model_type": request.model_type,
|
||||
"llm_name": llm_name,
|
||||
"api_base": req.get("api_base", ""),
|
||||
"api_base": request.api_base or "",
|
||||
"api_key": api_key,
|
||||
"max_tokens": req.get("max_tokens")
|
||||
"max_tokens": request.max_tokens
|
||||
}
|
||||
|
||||
msg = ""
|
||||
@@ -290,33 +294,27 @@ def add_llm():
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/delete_llm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "llm_name")
|
||||
def delete_llm():
|
||||
req = request.json
|
||||
@router.post('/delete_llm')
|
||||
async def delete_llm(request: DeleteLLMRequest, current_user = Depends(get_current_user)):
|
||||
TenantLLMService.filter_delete(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"],
|
||||
TenantLLM.llm_name == req["llm_name"]])
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == request.llm_factory,
|
||||
TenantLLM.llm_name == request.llm_name])
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/delete_factory', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory")
|
||||
def delete_factory():
|
||||
req = request.json
|
||||
@router.post('/delete_factory')
|
||||
async def delete_factory(request: DeleteFactoryRequest, current_user = Depends(get_current_user)):
|
||||
TenantLLMService.filter_delete(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == request.llm_factory])
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/my_llms', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def my_llms():
|
||||
@router.get('/my_llms')
|
||||
async def my_llms(
|
||||
include_details: bool = Query(False, description="是否包含详细信息"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
try:
|
||||
include_details = request.args.get('include_details', 'false').lower() == 'true'
|
||||
|
||||
if include_details:
|
||||
res = {}
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
@@ -362,12 +360,13 @@ def my_llms():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/list', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def list_app():
|
||||
@router.get('/list')
|
||||
async def list_app(
|
||||
model_type: Optional[str] = Query(None, description="模型类型"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
self_deployed = ["Youdao", "FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"]
|
||||
weighted = ["Youdao", "FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
|
||||
model_type = request.args.get("model_type")
|
||||
try:
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
|
||||
|
||||
Reference in New Issue
Block a user