From cbe0477ba12c0f462bb4c12c317be8a036957db0 Mon Sep 17 00:00:00 2001 From: dangzerong <429714019@qq.com> Date: Tue, 28 Oct 2025 11:49:07 +0800 Subject: [PATCH] =?UTF-8?q?llm.py=E6=94=B9=E9=80=A0=20fastapi?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/apps/__init___fastapi.py | 2 + api/apps/llm_app.py | 125 +++++++++++++++++------------------ api/models/llm_models.py | 84 +++++++++++++++++++++++ 3 files changed, 148 insertions(+), 63 deletions(-) create mode 100644 api/models/llm_models.py diff --git a/api/apps/__init___fastapi.py b/api/apps/__init___fastapi.py index a26b9f2..6794166 100644 --- a/api/apps/__init___fastapi.py +++ b/api/apps/__init___fastapi.py @@ -128,6 +128,7 @@ def setup_routes(app: FastAPI): from api.apps.file2document_app import router as file2document_router from api.apps.mcp_server_app import router as mcp_router from api.apps.tenant_app import router as tenant_router + from api.apps.llm_app import router as llm_router app.include_router(user_router, prefix=f"/{API_VERSION}/user", tags=["User"]) app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KB"]) @@ -136,6 +137,7 @@ def setup_routes(app: FastAPI): app.include_router(file2document_router, prefix=f"/{API_VERSION}/file2document", tags=["File2Document"]) app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp", tags=["MCP"]) app.include_router(tenant_router, prefix=f"/{API_VERSION}/tenant", tags=["Tenant"]) + app.include_router(llm_router, prefix=f"/{API_VERSION}/llm", tags=["LLM"]) def get_current_user_from_token(authorization: str): """从token获取当前用户""" diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 1ef1c3e..a332345 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -15,22 +15,26 @@ # import logging import json -from flask import request -from flask_login import login_required, current_user +from typing import Optional +from fastapi import APIRouter, Depends, Query +from fastapi.responses import JSONResponse from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService from api.db.services.llm_service import LLMService from api import settings -from api.utils.api_utils import server_error_response, get_data_error_result, validate_request +from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result from api.db import StatusEnum, LLMType from api.db.db_models import TenantLLM -from api.utils.api_utils import get_json_result from api.utils.base64_image import test_image from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel +from api.models.llm_models import SetApiKeyRequest, AddLLMRequest, DeleteLLMRequest, DeleteFactoryRequest +from api.utils.api_utils import get_current_user + +# 创建 FastAPI 路由器 +router = APIRouter() -@manager.route('/factories', methods=['GET']) # noqa: F821 -@login_required -def factories(): +@router.get('/factories') +async def factories(current_user = Depends(get_current_user)): try: fac = LLMFactoriesService.get_all() fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]] @@ -50,21 +54,18 @@ def factories(): return server_error_response(e) -@manager.route('/set_api_key', methods=['POST']) # noqa: F821 -@login_required -@validate_request("llm_factory", "api_key") -def set_api_key(): - req = request.json +@router.post('/set_api_key') +async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_current_user)): # test if api key works chat_passed, embd_passed, rerank_passed = False, False, False - factory = req["llm_factory"] + factory = request.llm_factory extra = {"provider": factory} msg = "" for llm in LLMService.query(fid=factory): if not embd_passed and llm.model_type == LLMType.EMBEDDING.value: assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet." mdl = EmbeddingModel[factory]( - req["api_key"], llm.llm_name, base_url=req.get("base_url")) + request.api_key, llm.llm_name, base_url=request.base_url) try: arr, tc = mdl.encode(["Test if the api key is available"]) if len(arr[0]) == 0: @@ -75,7 +76,7 @@ def set_api_key(): elif not chat_passed and llm.model_type == LLMType.CHAT.value: assert factory in ChatModel, f"Chat model from {factory} is not supported yet." mdl = ChatModel[factory]( - req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra) + request.api_key, llm.llm_name, base_url=request.base_url, **extra) try: m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, 'max_tokens': 50}) @@ -88,7 +89,7 @@ def set_api_key(): elif not rerank_passed and llm.model_type == LLMType.RERANK: assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet." mdl = RerankModel[factory]( - req["api_key"], llm.llm_name, base_url=req.get("base_url")) + request.api_key, llm.llm_name, base_url=request.base_url) try: arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"]) if len(arr) == 0 or tc == 0: @@ -106,12 +107,9 @@ def set_api_key(): return get_data_error_result(message=msg) llm_config = { - "api_key": req["api_key"], - "api_base": req.get("base_url", "") + "api_key": request.api_key, + "api_base": request.base_url or "" } - for n in ["model_type", "llm_name"]: - if n in req: - llm_config[n] = req[n] for llm in LLMService.query(fid=factory): llm_config["max_tokens"]=llm.max_tokens @@ -133,18 +131,15 @@ def set_api_key(): return get_json_result(data=True) -@manager.route('/add_llm', methods=['POST']) # noqa: F821 -@login_required -@validate_request("llm_factory") -def add_llm(): - req = request.json - factory = req["llm_factory"] - api_key = req.get("api_key", "x") - llm_name = req.get("llm_name") +@router.post('/add_llm') +async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_user)): + factory = request.llm_factory + api_key = request.api_key or "x" + llm_name = request.llm_name def apikey_json(keys): - nonlocal req - return json.dumps({k: req.get(k, "") for k in keys}) + nonlocal request + return json.dumps({k: getattr(request, k, "") for k in keys}) if factory == "VolcEngine": # For VolcEngine, due to its special authentication method @@ -152,12 +147,21 @@ def add_llm(): api_key = apikey_json(["ark_api_key", "endpoint_id"]) elif factory == "Tencent Hunyuan": - req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"]) - return set_api_key() + # Create a temporary request object for set_api_key + temp_request = SetApiKeyRequest( + llm_factory=factory, + api_key=apikey_json(["hunyuan_sid", "hunyuan_sk"]), + base_url=request.api_base + ) + return await set_api_key(temp_request, current_user) elif factory == "Tencent Cloud": - req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"]) - return set_api_key() + temp_request = SetApiKeyRequest( + llm_factory=factory, + api_key=apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"]), + base_url=request.api_base + ) + return await set_api_key(temp_request, current_user) elif factory == "Bedrock": # For Bedrock, due to its special authentication method @@ -177,9 +181,9 @@ def add_llm(): llm_name += "___VLLM" elif factory == "XunFei Spark": - if req["model_type"] == "chat": - api_key = req.get("spark_api_password", "") - elif req["model_type"] == "tts": + if request.model_type == "chat": + api_key = request.spark_api_password or "" + elif request.model_type == "tts": api_key = apikey_json(["spark_app_id", "spark_api_secret", "spark_api_key"]) elif factory == "BaiduYiyan": @@ -197,11 +201,11 @@ def add_llm(): llm = { "tenant_id": current_user.id, "llm_factory": factory, - "model_type": req["model_type"], + "model_type": request.model_type, "llm_name": llm_name, - "api_base": req.get("api_base", ""), + "api_base": request.api_base or "", "api_key": api_key, - "max_tokens": req.get("max_tokens") + "max_tokens": request.max_tokens } msg = "" @@ -290,33 +294,27 @@ def add_llm(): return get_json_result(data=True) -@manager.route('/delete_llm', methods=['POST']) # noqa: F821 -@login_required -@validate_request("llm_factory", "llm_name") -def delete_llm(): - req = request.json +@router.post('/delete_llm') +async def delete_llm(request: DeleteLLMRequest, current_user = Depends(get_current_user)): TenantLLMService.filter_delete( - [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], - TenantLLM.llm_name == req["llm_name"]]) + [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == request.llm_factory, + TenantLLM.llm_name == request.llm_name]) return get_json_result(data=True) -@manager.route('/delete_factory', methods=['POST']) # noqa: F821 -@login_required -@validate_request("llm_factory") -def delete_factory(): - req = request.json +@router.post('/delete_factory') +async def delete_factory(request: DeleteFactoryRequest, current_user = Depends(get_current_user)): TenantLLMService.filter_delete( - [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]]) + [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == request.llm_factory]) return get_json_result(data=True) -@manager.route('/my_llms', methods=['GET']) # noqa: F821 -@login_required -def my_llms(): +@router.get('/my_llms') +async def my_llms( + include_details: bool = Query(False, description="是否包含详细信息"), + current_user = Depends(get_current_user) +): try: - include_details = request.args.get('include_details', 'false').lower() == 'true' - if include_details: res = {} objs = TenantLLMService.query(tenant_id=current_user.id) @@ -362,12 +360,13 @@ def my_llms(): return server_error_response(e) -@manager.route('/list', methods=['GET']) # noqa: F821 -@login_required -def list_app(): +@router.get('/list') +async def list_app( + model_type: Optional[str] = Query(None, description="模型类型"), + current_user = Depends(get_current_user) +): self_deployed = ["Youdao", "FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"] weighted = ["Youdao", "FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else [] - model_type = request.args.get("model_type") try: objs = TenantLLMService.query(tenant_id=current_user.id) facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key]) diff --git a/api/models/llm_models.py b/api/models/llm_models.py new file mode 100644 index 0000000..bac70d0 --- /dev/null +++ b/api/models/llm_models.py @@ -0,0 +1,84 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Optional, Dict, Any +from pydantic import BaseModel, Field + + +class SetApiKeyRequest(BaseModel): + """设置API密钥请求模型""" + llm_factory: str = Field(..., description="LLM工厂名称") + api_key: str = Field(..., description="API密钥") + base_url: Optional[str] = Field(None, description="基础URL") + + +class AddLLMRequest(BaseModel): + """添加LLM请求模型""" + llm_factory: str = Field(..., description="LLM工厂名称") + api_key: Optional[str] = Field("x", description="API密钥") + llm_name: Optional[str] = Field(None, description="LLM名称") + model_type: Optional[str] = Field(None, description="模型类型") + api_base: Optional[str] = Field(None, description="API基础URL") + max_tokens: Optional[int] = Field(None, description="最大token数") + + # VolcEngine specific fields + ark_api_key: Optional[str] = Field(None, description="VolcEngine ARK API密钥") + endpoint_id: Optional[str] = Field(None, description="VolcEngine端点ID") + + # Tencent Hunyuan specific fields + hunyuan_sid: Optional[str] = Field(None, description="腾讯混元SID") + hunyuan_sk: Optional[str] = Field(None, description="腾讯混元SK") + + # Tencent Cloud specific fields + tencent_cloud_sid: Optional[str] = Field(None, description="腾讯云SID") + tencent_cloud_sk: Optional[str] = Field(None, description="腾讯云SK") + + # Bedrock specific fields + bedrock_ak: Optional[str] = Field(None, description="Bedrock访问密钥") + bedrock_sk: Optional[str] = Field(None, description="Bedrock秘密密钥") + bedrock_region: Optional[str] = Field(None, description="Bedrock区域") + + # XunFei Spark specific fields + spark_api_password: Optional[str] = Field(None, description="讯飞Spark API密码") + spark_app_id: Optional[str] = Field(None, description="讯飞Spark应用ID") + spark_api_secret: Optional[str] = Field(None, description="讯飞Spark API密钥") + spark_api_key: Optional[str] = Field(None, description="讯飞Spark API密钥") + + # BaiduYiyan specific fields + yiyan_ak: Optional[str] = Field(None, description="百度文心一言AK") + yiyan_sk: Optional[str] = Field(None, description="百度文心一言SK") + + # Fish Audio specific fields + fish_audio_ak: Optional[str] = Field(None, description="Fish Audio AK") + fish_audio_refid: Optional[str] = Field(None, description="Fish Audio参考ID") + + # Google Cloud specific fields + google_project_id: Optional[str] = Field(None, description="Google Cloud项目ID") + google_region: Optional[str] = Field(None, description="Google Cloud区域") + google_service_account_key: Optional[str] = Field(None, description="Google Cloud服务账户密钥") + + # Azure OpenAI specific fields + api_version: Optional[str] = Field(None, description="Azure OpenAI API版本") + + +class DeleteLLMRequest(BaseModel): + """删除LLM请求模型""" + llm_factory: str = Field(..., description="LLM工厂名称") + llm_name: str = Field(..., description="LLM名称") + + +class DeleteFactoryRequest(BaseModel): + """删除工厂请求模型""" + llm_factory: str = Field(..., description="LLM工厂名称")