Files
AIRegulation-DocAnalysis/backend/app/services/llm/llm_factory.py

259 lines
8.3 KiB
Python
Raw Normal View History

2026-05-14 15:07:34 +08:00
# src/services/llm/llm_factory.py
"""LLM工厂 - 统一创建和管理LLM客户端"""
from typing import Optional, Dict, Any
from loguru import logger
from functools import lru_cache
from .base_client import BaseLLMClient, LLMConfig, LLMProvider, LLMResponse
from .deepseek_client import DeepSeekClient
from .qwen_client import QwenClient, QwenVLClient
# 默认模型映射
DEFAULT_MODELS = {
LLMProvider.DEEPSEEK: "deepseek-v4-flash",
LLMProvider.QWEN: "qwen3.5-flash",
LLMProvider.QWEN_VL: "qwen3-vl-plus"
}
# API基础URL使用统一代理服务
DEFAULT_BASE_URLS = {
LLMProvider.DEEPSEEK: "http://6.86.80.4:30080/v1",
LLMProvider.QWEN: "http://6.86.80.4:30080/v1",
LLMProvider.QWEN_VL: "http://6.86.80.4:30080/v1"
}
class LLMFactory:
"""
LLM客户端工厂支持全局缓存
支持的提供商和模型
- DeepSeek: deepseek-chat (DeepSeek-V3), deepseek-coder
- Qwen: qwen-turbo, qwen-plus, qwen-max, qwen-long
- QwenVL: qwen-vl-plus, qwen-vl-max (多模态)
使用示例
factory = LLMFactory()
# 使用默认配置
client = factory.create("deepseek")
# 自定义配置
client = factory.create("qwen", model="qwen-max", temperature=0.5)
# 调用LLM
response = client.complete("你好,介绍一下自己")
"""
# 全局客户端缓存(类级别,跨实例共享)
_global_instances: Dict[str, BaseLLMClient] = {}
def __init__(self):
self._config_cache: Dict[str, Any] = {}
def create(
self,
provider: str,
api_key: Optional[str] = None,
model: Optional[str] = None,
base_url: Optional[str] = None,
max_tokens: int = 4096,
temperature: float = 0.7,
**kwargs
) -> BaseLLMClient:
"""
创建LLM客户端
Args:
provider: 提供商名称 ("deepseek", "qwen", "qwen_vl")
api_key: API密钥如未提供从环境变量获取
model: 模型名称如未提供使用默认模型
base_url: API基础URL
max_tokens: 最大输出token数
temperature: 温度参数
**kwargs: 其他配置参数
Returns:
BaseLLMClient: LLM客户端实例
"""
provider_enum = self._parse_provider(provider)
# 获取配置
api_key = api_key or self._get_api_key(provider_enum)
model = model or DEFAULT_MODELS.get(provider_enum)
base_url = base_url or DEFAULT_BASE_URLS.get(provider_enum)
if not api_key:
raise ValueError(f"缺少API密钥请设置环境变量或传入api_key参数")
# 检查全局缓存
cache_key = f"{provider}_{model}"
if cache_key in LLMFactory._global_instances:
logger.debug(f"使用缓存的LLM客户端: {cache_key}")
return LLMFactory._global_instances[cache_key]
config = LLMConfig(
provider=provider_enum,
model=model,
api_key=api_key,
base_url=base_url,
max_tokens=max_tokens,
temperature=temperature,
**kwargs
)
# 创建客户端
client = self._create_client(config)
# 缓存到全局实例
LLMFactory._global_instances[cache_key] = client
logger.info(f"LLM客户端创建成功并缓存: {provider} - {model}")
return client
def _parse_provider(self, provider: str) -> LLMProvider:
"""解析提供商名称"""
provider_map = {
"deepseek": LLMProvider.DEEPSEEK,
"deepseek-v3": LLMProvider.DEEPSEEK,
"deepseek_chat": LLMProvider.DEEPSEEK,
"qwen": LLMProvider.QWEN,
"qwen-turbo": LLMProvider.QWEN,
"qwen-plus": LLMProvider.QWEN,
"qwen-max": LLMProvider.QWEN,
"qwen3.5-flash": LLMProvider.QWEN,
"qwen3.5-plus": LLMProvider.QWEN,
"qwen_vl": LLMProvider.QWEN_VL,
"qwen-vl": LLMProvider.QWEN_VL,
"qwen-vl-plus": LLMProvider.QWEN_VL,
"qwen-vl-max": LLMProvider.QWEN_VL
}
provider_lower = provider.lower()
if provider_lower not in provider_map:
raise ValueError(f"不支持的提供商: {provider},支持的: {list(provider_map.keys())}")
return provider_map[provider_lower]
def _get_api_key(self, provider: LLMProvider) -> Optional[str]:
"""从环境变量获取API密钥"""
import os
key_map = {
LLMProvider.DEEPSEEK: ["DEEPSEEK_API_KEY", "OPENAI_API_KEY"],
LLMProvider.QWEN: ["QWEN_API_KEY", "DASHSCOPE_API_KEY", "ALIBABA_API_KEY"],
LLMProvider.QWEN_VL: ["QWEN_API_KEY", "DASHSCOPE_API_KEY", "ALIBABA_API_KEY"]
}
for key_name in key_map.get(provider, []):
api_key = os.getenv(key_name)
if api_key:
return api_key
return None
def _create_client(self, config: LLMConfig) -> BaseLLMClient:
"""创建具体客户端"""
client_map = {
LLMProvider.DEEPSEEK: DeepSeekClient,
LLMProvider.QWEN: QwenClient,
LLMProvider.QWEN_VL: QwenVLClient
}
client_class = client_map.get(config.provider)
if not client_class:
raise ValueError(f"不支持的提供商: {config.provider}")
return client_class(config)
def get_cached(self, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]:
"""获取缓存的客户端"""
provider_enum = self._parse_provider(provider)
model = model or DEFAULT_MODELS.get(provider_enum)
cache_key = f"{provider}_{model}"
return LLMFactory._global_instances.get(cache_key)
def list_available_providers(self) -> Dict[str, list]:
"""列出可用的提供商和模型"""
return {
"deepseek": DeepSeekClient.SUPPORTED_MODELS,
"qwen": QwenClient.SUPPORTED_MODELS,
"qwen_vl": QwenVLClient.SUPPORTED_MODELS
}
@classmethod
def preload_clients(cls, providers: list = None):
"""
预加载LLM客户端应用启动时调用
Args:
providers: 要预加载的提供商列表默认加载qwen和deepseek
"""
if providers is None:
providers = ["qwen", "deepseek"]
factory = cls()
for provider in providers:
try:
client = factory.create(provider)
logger.success(f"预加载LLM客户端成功: {provider}")
except Exception as e:
logger.warning(f"预加载LLM客户端失败: {provider} - {e}")
@classmethod
def get_global_client(cls, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]:
"""获取全局缓存的客户端"""
provider_lower = provider.lower()
# 处理模型名作为provider的情况如 qwen3.5-flash
if provider_lower.startswith("qwen"):
provider_lower = "qwen"
model = model or DEFAULT_MODELS.get(LLMProvider.QWEN if provider_lower == "qwen" else LLMProvider.DEEPSEEK)
cache_key = f"{provider_lower}_{model}"
return cls._global_instances.get(cache_key)
@classmethod
def cleanup(cls):
"""清理所有缓存的客户端"""
for cache_key, client in cls._global_instances.items():
try:
client.close()
logger.debug(f"关闭LLM客户端: {cache_key}")
except Exception as e:
logger.warning(f"关闭LLM客户端失败: {cache_key} - {e}")
cls._global_instances.clear()
logger.info("所有LLM客户端已清理")
@lru_cache
def get_llm_factory() -> LLMFactory:
"""获取LLM工厂实例缓存"""
return LLMFactory()
def get_llm_client(
provider: str = "qwen",
model: Optional[str] = None,
**kwargs
) -> BaseLLMClient:
"""
便捷函数获取LLM客户端优先使用缓存
Args:
provider: 提供商名称
model: 模型名称
**kwargs: 其他配置
Returns:
BaseLLMClient: LLM客户端实例
"""
factory = get_llm_factory()
# 先尝试获取缓存的实例
cached = factory.get_cached(provider, model)
if cached:
return cached
return factory.create(provider, model=model, **kwargs)