# src/services/llm/llm_factory.py """LLM工厂 - 统一创建和管理LLM客户端""" from typing import Optional, Dict, Any from loguru import logger from functools import lru_cache from .base_client import BaseLLMClient, LLMConfig, LLMProvider, LLMResponse from .deepseek_client import DeepSeekClient from .qwen_client import QwenClient, QwenVLClient # 默认模型映射 DEFAULT_MODELS = { LLMProvider.DEEPSEEK: "deepseek-v4-flash", LLMProvider.QWEN: "qwen3.5-flash", LLMProvider.QWEN_VL: "qwen3-vl-plus" } # API基础URL(使用统一代理服务) DEFAULT_BASE_URLS = { LLMProvider.DEEPSEEK: "http://6.86.80.4:30080/v1", LLMProvider.QWEN: "http://6.86.80.4:30080/v1", LLMProvider.QWEN_VL: "http://6.86.80.4:30080/v1" } class LLMFactory: """ LLM客户端工厂(支持全局缓存) 支持的提供商和模型: - DeepSeek: deepseek-chat (DeepSeek-V3), deepseek-coder - Qwen: qwen-turbo, qwen-plus, qwen-max, qwen-long - QwenVL: qwen-vl-plus, qwen-vl-max (多模态) 使用示例: factory = LLMFactory() # 使用默认配置 client = factory.create("deepseek") # 自定义配置 client = factory.create("qwen", model="qwen-max", temperature=0.5) # 调用LLM response = client.complete("你好,介绍一下自己") """ # 全局客户端缓存(类级别,跨实例共享) _global_instances: Dict[str, BaseLLMClient] = {} def __init__(self): self._config_cache: Dict[str, Any] = {} def create( self, provider: str, api_key: Optional[str] = None, model: Optional[str] = None, base_url: Optional[str] = None, max_tokens: int = 4096, temperature: float = 0.7, **kwargs ) -> BaseLLMClient: """ 创建LLM客户端 Args: provider: 提供商名称 ("deepseek", "qwen", "qwen_vl") api_key: API密钥(如未提供,从环境变量获取) model: 模型名称(如未提供,使用默认模型) base_url: API基础URL max_tokens: 最大输出token数 temperature: 温度参数 **kwargs: 其他配置参数 Returns: BaseLLMClient: LLM客户端实例 """ provider_enum = self._parse_provider(provider) # 获取配置 api_key = api_key or self._get_api_key(provider_enum) model = model or DEFAULT_MODELS.get(provider_enum) base_url = base_url or DEFAULT_BASE_URLS.get(provider_enum) if not api_key: raise ValueError(f"缺少API密钥,请设置环境变量或传入api_key参数") # 检查全局缓存 cache_key = f"{provider}_{model}" if cache_key in LLMFactory._global_instances: logger.debug(f"使用缓存的LLM客户端: {cache_key}") return LLMFactory._global_instances[cache_key] config = LLMConfig( provider=provider_enum, model=model, api_key=api_key, base_url=base_url, max_tokens=max_tokens, temperature=temperature, **kwargs ) # 创建客户端 client = self._create_client(config) # 缓存到全局实例 LLMFactory._global_instances[cache_key] = client logger.info(f"LLM客户端创建成功并缓存: {provider} - {model}") return client def _parse_provider(self, provider: str) -> LLMProvider: """解析提供商名称""" provider_map = { "deepseek": LLMProvider.DEEPSEEK, "deepseek-v3": LLMProvider.DEEPSEEK, "deepseek_chat": LLMProvider.DEEPSEEK, "qwen": LLMProvider.QWEN, "qwen-turbo": LLMProvider.QWEN, "qwen-plus": LLMProvider.QWEN, "qwen-max": LLMProvider.QWEN, "qwen3.5-flash": LLMProvider.QWEN, "qwen3.5-plus": LLMProvider.QWEN, "qwen_vl": LLMProvider.QWEN_VL, "qwen-vl": LLMProvider.QWEN_VL, "qwen-vl-plus": LLMProvider.QWEN_VL, "qwen-vl-max": LLMProvider.QWEN_VL } provider_lower = provider.lower() if provider_lower not in provider_map: raise ValueError(f"不支持的提供商: {provider},支持的: {list(provider_map.keys())}") return provider_map[provider_lower] def _get_api_key(self, provider: LLMProvider) -> Optional[str]: """从环境变量获取API密钥""" import os key_map = { LLMProvider.DEEPSEEK: ["DEEPSEEK_API_KEY", "OPENAI_API_KEY"], LLMProvider.QWEN: ["QWEN_API_KEY", "DASHSCOPE_API_KEY", "ALIBABA_API_KEY"], LLMProvider.QWEN_VL: ["QWEN_API_KEY", "DASHSCOPE_API_KEY", "ALIBABA_API_KEY"] } for key_name in key_map.get(provider, []): api_key = os.getenv(key_name) if api_key: return api_key return None def _create_client(self, config: LLMConfig) -> BaseLLMClient: """创建具体客户端""" client_map = { LLMProvider.DEEPSEEK: DeepSeekClient, LLMProvider.QWEN: QwenClient, LLMProvider.QWEN_VL: QwenVLClient } client_class = client_map.get(config.provider) if not client_class: raise ValueError(f"不支持的提供商: {config.provider}") return client_class(config) def get_cached(self, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]: """获取缓存的客户端""" provider_enum = self._parse_provider(provider) model = model or DEFAULT_MODELS.get(provider_enum) cache_key = f"{provider}_{model}" return LLMFactory._global_instances.get(cache_key) def list_available_providers(self) -> Dict[str, list]: """列出可用的提供商和模型""" return { "deepseek": DeepSeekClient.SUPPORTED_MODELS, "qwen": QwenClient.SUPPORTED_MODELS, "qwen_vl": QwenVLClient.SUPPORTED_MODELS } @classmethod def preload_clients(cls, providers: list = None): """ 预加载LLM客户端(应用启动时调用) Args: providers: 要预加载的提供商列表,默认加载qwen和deepseek """ if providers is None: providers = ["qwen", "deepseek"] factory = cls() for provider in providers: try: client = factory.create(provider) logger.success(f"预加载LLM客户端成功: {provider}") except Exception as e: logger.warning(f"预加载LLM客户端失败: {provider} - {e}") @classmethod def get_global_client(cls, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]: """获取全局缓存的客户端""" provider_lower = provider.lower() # 处理模型名作为provider的情况(如 qwen3.5-flash) if provider_lower.startswith("qwen"): provider_lower = "qwen" model = model or DEFAULT_MODELS.get(LLMProvider.QWEN if provider_lower == "qwen" else LLMProvider.DEEPSEEK) cache_key = f"{provider_lower}_{model}" return cls._global_instances.get(cache_key) @classmethod def cleanup(cls): """清理所有缓存的客户端""" for cache_key, client in cls._global_instances.items(): try: client.close() logger.debug(f"关闭LLM客户端: {cache_key}") except Exception as e: logger.warning(f"关闭LLM客户端失败: {cache_key} - {e}") cls._global_instances.clear() logger.info("所有LLM客户端已清理") @lru_cache def get_llm_factory() -> LLMFactory: """获取LLM工厂实例(缓存)""" return LLMFactory() def get_llm_client( provider: str = "qwen", model: Optional[str] = None, **kwargs ) -> BaseLLMClient: """ 便捷函数:获取LLM客户端(优先使用缓存) Args: provider: 提供商名称 model: 模型名称 **kwargs: 其他配置 Returns: BaseLLMClient: LLM客户端实例 """ factory = get_llm_factory() # 先尝试获取缓存的实例 cached = factory.get_cached(provider, model) if cached: return cached return factory.create(provider, model=model, **kwargs)