llm.py改造 fastapi

This commit is contained in:
2025-10-28 11:49:07 +08:00
parent 45f69ab3d5
commit cbe0477ba1
3 changed files with 148 additions and 63 deletions

View File

@@ -128,6 +128,7 @@ def setup_routes(app: FastAPI):
from api.apps.file2document_app import router as file2document_router from api.apps.file2document_app import router as file2document_router
from api.apps.mcp_server_app import router as mcp_router from api.apps.mcp_server_app import router as mcp_router
from api.apps.tenant_app import router as tenant_router from api.apps.tenant_app import router as tenant_router
from api.apps.llm_app import router as llm_router
app.include_router(user_router, prefix=f"/{API_VERSION}/user", tags=["User"]) app.include_router(user_router, prefix=f"/{API_VERSION}/user", tags=["User"])
app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KB"]) app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KB"])
@@ -136,6 +137,7 @@ def setup_routes(app: FastAPI):
app.include_router(file2document_router, prefix=f"/{API_VERSION}/file2document", tags=["File2Document"]) app.include_router(file2document_router, prefix=f"/{API_VERSION}/file2document", tags=["File2Document"])
app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp", tags=["MCP"]) app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp", tags=["MCP"])
app.include_router(tenant_router, prefix=f"/{API_VERSION}/tenant", tags=["Tenant"]) app.include_router(tenant_router, prefix=f"/{API_VERSION}/tenant", tags=["Tenant"])
app.include_router(llm_router, prefix=f"/{API_VERSION}/llm", tags=["LLM"])
def get_current_user_from_token(authorization: str): def get_current_user_from_token(authorization: str):
"""从token获取当前用户""" """从token获取当前用户"""

View File

@@ -15,22 +15,26 @@
# #
import logging import logging
import json import json
from flask import request from typing import Optional
from flask_login import login_required, current_user from fastapi import APIRouter, Depends, Query
from fastapi.responses import JSONResponse
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService from api.db.services.llm_service import LLMService
from api import settings from api import settings
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result
from api.db import StatusEnum, LLMType from api.db import StatusEnum, LLMType
from api.db.db_models import TenantLLM from api.db.db_models import TenantLLM
from api.utils.api_utils import get_json_result
from api.utils.base64_image import test_image from api.utils.base64_image import test_image
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
from api.models.llm_models import SetApiKeyRequest, AddLLMRequest, DeleteLLMRequest, DeleteFactoryRequest
from api.utils.api_utils import get_current_user
# 创建 FastAPI 路由器
router = APIRouter()
@manager.route('/factories', methods=['GET']) # noqa: F821 @router.get('/factories')
@login_required async def factories(current_user = Depends(get_current_user)):
def factories():
try: try:
fac = LLMFactoriesService.get_all() fac = LLMFactoriesService.get_all()
fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]] fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]]
@@ -50,21 +54,18 @@ def factories():
return server_error_response(e) return server_error_response(e)
@manager.route('/set_api_key', methods=['POST']) # noqa: F821 @router.post('/set_api_key')
@login_required async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_current_user)):
@validate_request("llm_factory", "api_key")
def set_api_key():
req = request.json
# test if api key works # test if api key works
chat_passed, embd_passed, rerank_passed = False, False, False chat_passed, embd_passed, rerank_passed = False, False, False
factory = req["llm_factory"] factory = request.llm_factory
extra = {"provider": factory} extra = {"provider": factory}
msg = "" msg = ""
for llm in LLMService.query(fid=factory): for llm in LLMService.query(fid=factory):
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value: if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet." assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
mdl = EmbeddingModel[factory]( mdl = EmbeddingModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url")) request.api_key, llm.llm_name, base_url=request.base_url)
try: try:
arr, tc = mdl.encode(["Test if the api key is available"]) arr, tc = mdl.encode(["Test if the api key is available"])
if len(arr[0]) == 0: if len(arr[0]) == 0:
@@ -75,7 +76,7 @@ def set_api_key():
elif not chat_passed and llm.model_type == LLMType.CHAT.value: elif not chat_passed and llm.model_type == LLMType.CHAT.value:
assert factory in ChatModel, f"Chat model from {factory} is not supported yet." assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory]( mdl = ChatModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra) request.api_key, llm.llm_name, base_url=request.base_url, **extra)
try: try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
{"temperature": 0.9, 'max_tokens': 50}) {"temperature": 0.9, 'max_tokens': 50})
@@ -88,7 +89,7 @@ def set_api_key():
elif not rerank_passed and llm.model_type == LLMType.RERANK: elif not rerank_passed and llm.model_type == LLMType.RERANK:
assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet." assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet."
mdl = RerankModel[factory]( mdl = RerankModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url")) request.api_key, llm.llm_name, base_url=request.base_url)
try: try:
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"]) arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
if len(arr) == 0 or tc == 0: if len(arr) == 0 or tc == 0:
@@ -106,12 +107,9 @@ def set_api_key():
return get_data_error_result(message=msg) return get_data_error_result(message=msg)
llm_config = { llm_config = {
"api_key": req["api_key"], "api_key": request.api_key,
"api_base": req.get("base_url", "") "api_base": request.base_url or ""
} }
for n in ["model_type", "llm_name"]:
if n in req:
llm_config[n] = req[n]
for llm in LLMService.query(fid=factory): for llm in LLMService.query(fid=factory):
llm_config["max_tokens"]=llm.max_tokens llm_config["max_tokens"]=llm.max_tokens
@@ -133,18 +131,15 @@ def set_api_key():
return get_json_result(data=True) return get_json_result(data=True)
@manager.route('/add_llm', methods=['POST']) # noqa: F821 @router.post('/add_llm')
@login_required async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_user)):
@validate_request("llm_factory") factory = request.llm_factory
def add_llm(): api_key = request.api_key or "x"
req = request.json llm_name = request.llm_name
factory = req["llm_factory"]
api_key = req.get("api_key", "x")
llm_name = req.get("llm_name")
def apikey_json(keys): def apikey_json(keys):
nonlocal req nonlocal request
return json.dumps({k: req.get(k, "") for k in keys}) return json.dumps({k: getattr(request, k, "") for k in keys})
if factory == "VolcEngine": if factory == "VolcEngine":
# For VolcEngine, due to its special authentication method # For VolcEngine, due to its special authentication method
@@ -152,12 +147,21 @@ def add_llm():
api_key = apikey_json(["ark_api_key", "endpoint_id"]) api_key = apikey_json(["ark_api_key", "endpoint_id"])
elif factory == "Tencent Hunyuan": elif factory == "Tencent Hunyuan":
req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"]) # Create a temporary request object for set_api_key
return set_api_key() temp_request = SetApiKeyRequest(
llm_factory=factory,
api_key=apikey_json(["hunyuan_sid", "hunyuan_sk"]),
base_url=request.api_base
)
return await set_api_key(temp_request, current_user)
elif factory == "Tencent Cloud": elif factory == "Tencent Cloud":
req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"]) temp_request = SetApiKeyRequest(
return set_api_key() llm_factory=factory,
api_key=apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"]),
base_url=request.api_base
)
return await set_api_key(temp_request, current_user)
elif factory == "Bedrock": elif factory == "Bedrock":
# For Bedrock, due to its special authentication method # For Bedrock, due to its special authentication method
@@ -177,9 +181,9 @@ def add_llm():
llm_name += "___VLLM" llm_name += "___VLLM"
elif factory == "XunFei Spark": elif factory == "XunFei Spark":
if req["model_type"] == "chat": if request.model_type == "chat":
api_key = req.get("spark_api_password", "") api_key = request.spark_api_password or ""
elif req["model_type"] == "tts": elif request.model_type == "tts":
api_key = apikey_json(["spark_app_id", "spark_api_secret", "spark_api_key"]) api_key = apikey_json(["spark_app_id", "spark_api_secret", "spark_api_key"])
elif factory == "BaiduYiyan": elif factory == "BaiduYiyan":
@@ -197,11 +201,11 @@ def add_llm():
llm = { llm = {
"tenant_id": current_user.id, "tenant_id": current_user.id,
"llm_factory": factory, "llm_factory": factory,
"model_type": req["model_type"], "model_type": request.model_type,
"llm_name": llm_name, "llm_name": llm_name,
"api_base": req.get("api_base", ""), "api_base": request.api_base or "",
"api_key": api_key, "api_key": api_key,
"max_tokens": req.get("max_tokens") "max_tokens": request.max_tokens
} }
msg = "" msg = ""
@@ -290,33 +294,27 @@ def add_llm():
return get_json_result(data=True) return get_json_result(data=True)
@manager.route('/delete_llm', methods=['POST']) # noqa: F821 @router.post('/delete_llm')
@login_required async def delete_llm(request: DeleteLLMRequest, current_user = Depends(get_current_user)):
@validate_request("llm_factory", "llm_name")
def delete_llm():
req = request.json
TenantLLMService.filter_delete( TenantLLMService.filter_delete(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == request.llm_factory,
TenantLLM.llm_name == req["llm_name"]]) TenantLLM.llm_name == request.llm_name])
return get_json_result(data=True) return get_json_result(data=True)
@manager.route('/delete_factory', methods=['POST']) # noqa: F821 @router.post('/delete_factory')
@login_required async def delete_factory(request: DeleteFactoryRequest, current_user = Depends(get_current_user)):
@validate_request("llm_factory")
def delete_factory():
req = request.json
TenantLLMService.filter_delete( TenantLLMService.filter_delete(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]]) [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == request.llm_factory])
return get_json_result(data=True) return get_json_result(data=True)
@manager.route('/my_llms', methods=['GET']) # noqa: F821 @router.get('/my_llms')
@login_required async def my_llms(
def my_llms(): include_details: bool = Query(False, description="是否包含详细信息"),
current_user = Depends(get_current_user)
):
try: try:
include_details = request.args.get('include_details', 'false').lower() == 'true'
if include_details: if include_details:
res = {} res = {}
objs = TenantLLMService.query(tenant_id=current_user.id) objs = TenantLLMService.query(tenant_id=current_user.id)
@@ -362,12 +360,13 @@ def my_llms():
return server_error_response(e) return server_error_response(e)
@manager.route('/list', methods=['GET']) # noqa: F821 @router.get('/list')
@login_required async def list_app(
def list_app(): model_type: Optional[str] = Query(None, description="模型类型"),
current_user = Depends(get_current_user)
):
self_deployed = ["Youdao", "FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"] self_deployed = ["Youdao", "FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"]
weighted = ["Youdao", "FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else [] weighted = ["Youdao", "FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
model_type = request.args.get("model_type")
try: try:
objs = TenantLLMService.query(tenant_id=current_user.id) objs = TenantLLMService.query(tenant_id=current_user.id)
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key]) facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])

84
api/models/llm_models.py Normal file
View File

@@ -0,0 +1,84 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, Dict, Any
from pydantic import BaseModel, Field
class SetApiKeyRequest(BaseModel):
"""设置API密钥请求模型"""
llm_factory: str = Field(..., description="LLM工厂名称")
api_key: str = Field(..., description="API密钥")
base_url: Optional[str] = Field(None, description="基础URL")
class AddLLMRequest(BaseModel):
"""添加LLM请求模型"""
llm_factory: str = Field(..., description="LLM工厂名称")
api_key: Optional[str] = Field("x", description="API密钥")
llm_name: Optional[str] = Field(None, description="LLM名称")
model_type: Optional[str] = Field(None, description="模型类型")
api_base: Optional[str] = Field(None, description="API基础URL")
max_tokens: Optional[int] = Field(None, description="最大token数")
# VolcEngine specific fields
ark_api_key: Optional[str] = Field(None, description="VolcEngine ARK API密钥")
endpoint_id: Optional[str] = Field(None, description="VolcEngine端点ID")
# Tencent Hunyuan specific fields
hunyuan_sid: Optional[str] = Field(None, description="腾讯混元SID")
hunyuan_sk: Optional[str] = Field(None, description="腾讯混元SK")
# Tencent Cloud specific fields
tencent_cloud_sid: Optional[str] = Field(None, description="腾讯云SID")
tencent_cloud_sk: Optional[str] = Field(None, description="腾讯云SK")
# Bedrock specific fields
bedrock_ak: Optional[str] = Field(None, description="Bedrock访问密钥")
bedrock_sk: Optional[str] = Field(None, description="Bedrock秘密密钥")
bedrock_region: Optional[str] = Field(None, description="Bedrock区域")
# XunFei Spark specific fields
spark_api_password: Optional[str] = Field(None, description="讯飞Spark API密码")
spark_app_id: Optional[str] = Field(None, description="讯飞Spark应用ID")
spark_api_secret: Optional[str] = Field(None, description="讯飞Spark API密钥")
spark_api_key: Optional[str] = Field(None, description="讯飞Spark API密钥")
# BaiduYiyan specific fields
yiyan_ak: Optional[str] = Field(None, description="百度文心一言AK")
yiyan_sk: Optional[str] = Field(None, description="百度文心一言SK")
# Fish Audio specific fields
fish_audio_ak: Optional[str] = Field(None, description="Fish Audio AK")
fish_audio_refid: Optional[str] = Field(None, description="Fish Audio参考ID")
# Google Cloud specific fields
google_project_id: Optional[str] = Field(None, description="Google Cloud项目ID")
google_region: Optional[str] = Field(None, description="Google Cloud区域")
google_service_account_key: Optional[str] = Field(None, description="Google Cloud服务账户密钥")
# Azure OpenAI specific fields
api_version: Optional[str] = Field(None, description="Azure OpenAI API版本")
class DeleteLLMRequest(BaseModel):
"""删除LLM请求模型"""
llm_factory: str = Field(..., description="LLM工厂名称")
llm_name: str = Field(..., description="LLM名称")
class DeleteFactoryRequest(BaseModel):
"""删除工厂请求模型"""
llm_factory: str = Field(..., description="LLM工厂名称")