Fix SSE route dependency and align architecture docs
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""LLM工厂 - 统一创建和管理LLM客户端"""
|
||||
"""Provide service-layer logic for llm factory."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from loguru import logger
|
||||
@@ -7,16 +7,18 @@ from functools import lru_cache
|
||||
from .base_client import BaseLLMClient, LLMConfig, LLMProvider, LLMResponse
|
||||
from .deepseek_client import DeepSeekClient
|
||||
from .qwen_client import QwenClient, QwenVLClient
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
|
||||
|
||||
# 默认模型映射
|
||||
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
DEFAULT_MODELS = {
|
||||
LLMProvider.DEEPSEEK: "deepseek-v4-flash",
|
||||
LLMProvider.QWEN: "qwen3.5-flash",
|
||||
LLMProvider.QWEN_VL: "qwen3-vl-plus"
|
||||
}
|
||||
|
||||
# API基础URL(使用统一代理服务)
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
DEFAULT_BASE_URLS = {
|
||||
LLMProvider.DEEPSEEK: "http://6.86.80.4:30080/v1",
|
||||
LLMProvider.QWEN: "http://6.86.80.4:30080/v1",
|
||||
@@ -25,31 +27,13 @@ DEFAULT_BASE_URLS = {
|
||||
|
||||
|
||||
class LLMFactory:
|
||||
"""
|
||||
LLM客户端工厂(支持全局缓存)
|
||||
"""Represent the L L M Factory type."""
|
||||
|
||||
支持的提供商和模型:
|
||||
- DeepSeek: deepseek-chat (DeepSeek-V3), deepseek-coder
|
||||
- Qwen: qwen-turbo, qwen-plus, qwen-max, qwen-long
|
||||
- QwenVL: qwen-vl-plus, qwen-vl-max (多模态)
|
||||
|
||||
使用示例:
|
||||
factory = LLMFactory()
|
||||
|
||||
# 使用默认配置
|
||||
client = factory.create("deepseek")
|
||||
|
||||
# 自定义配置
|
||||
client = factory.create("qwen", model="qwen-max", temperature=0.5)
|
||||
|
||||
# 调用LLM
|
||||
response = client.complete("你好,介绍一下自己")
|
||||
"""
|
||||
|
||||
# 全局客户端缓存(类级别,跨实例共享)
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
_global_instances: Dict[str, BaseLLMClient] = {}
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the L L M Factory instance."""
|
||||
self._config_cache: Dict[str, Any] = {}
|
||||
|
||||
def create(
|
||||
@@ -62,24 +46,10 @@ class LLMFactory:
|
||||
temperature: float = 0.7,
|
||||
**kwargs
|
||||
) -> BaseLLMClient:
|
||||
"""
|
||||
创建LLM客户端
|
||||
|
||||
Args:
|
||||
provider: 提供商名称 ("deepseek", "qwen", "qwen_vl")
|
||||
api_key: API密钥(如未提供,从环境变量获取)
|
||||
model: 模型名称(如未提供,使用默认模型)
|
||||
base_url: API基础URL
|
||||
max_tokens: 最大输出token数
|
||||
temperature: 温度参数
|
||||
**kwargs: 其他配置参数
|
||||
|
||||
Returns:
|
||||
BaseLLMClient: LLM客户端实例
|
||||
"""
|
||||
"""Handle create for the L L M Factory instance."""
|
||||
provider_enum = self._parse_provider(provider)
|
||||
|
||||
# 获取配置
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
api_key = api_key or self._get_api_key(provider_enum)
|
||||
model = model or DEFAULT_MODELS.get(provider_enum)
|
||||
base_url = base_url or DEFAULT_BASE_URLS.get(provider_enum)
|
||||
@@ -87,7 +57,7 @@ class LLMFactory:
|
||||
if not api_key:
|
||||
raise ValueError(f"缺少API密钥,请设置环境变量或传入api_key参数")
|
||||
|
||||
# 检查全局缓存
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
cache_key = f"{provider}_{model}"
|
||||
if cache_key in LLMFactory._global_instances:
|
||||
logger.debug(f"使用缓存的LLM客户端: {cache_key}")
|
||||
@@ -103,17 +73,17 @@ class LLMFactory:
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# 创建客户端
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
client = self._create_client(config)
|
||||
|
||||
# 缓存到全局实例
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
LLMFactory._global_instances[cache_key] = client
|
||||
|
||||
logger.info(f"LLM客户端创建成功并缓存: {provider} - {model}")
|
||||
return client
|
||||
|
||||
def _parse_provider(self, provider: str) -> LLMProvider:
|
||||
"""解析提供商名称"""
|
||||
"""Handle parse provider for this module for the L L M Factory instance."""
|
||||
provider_map = {
|
||||
"deepseek": LLMProvider.DEEPSEEK,
|
||||
"deepseek-v3": LLMProvider.DEEPSEEK,
|
||||
@@ -137,7 +107,7 @@ class LLMFactory:
|
||||
return provider_map[provider_lower]
|
||||
|
||||
def _get_api_key(self, provider: LLMProvider) -> Optional[str]:
|
||||
"""从环境变量获取API密钥"""
|
||||
"""Handle get api key for this module for the L L M Factory instance."""
|
||||
import os
|
||||
|
||||
key_map = {
|
||||
@@ -154,7 +124,7 @@ class LLMFactory:
|
||||
return None
|
||||
|
||||
def _create_client(self, config: LLMConfig) -> BaseLLMClient:
|
||||
"""创建具体客户端"""
|
||||
"""Handle create client for this module for the L L M Factory instance."""
|
||||
client_map = {
|
||||
LLMProvider.DEEPSEEK: DeepSeekClient,
|
||||
LLMProvider.QWEN: QwenClient,
|
||||
@@ -168,14 +138,14 @@ class LLMFactory:
|
||||
return client_class(config)
|
||||
|
||||
def get_cached(self, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]:
|
||||
"""获取缓存的客户端"""
|
||||
"""Return cached for the L L M Factory instance."""
|
||||
provider_enum = self._parse_provider(provider)
|
||||
model = model or DEFAULT_MODELS.get(provider_enum)
|
||||
cache_key = f"{provider}_{model}"
|
||||
return LLMFactory._global_instances.get(cache_key)
|
||||
|
||||
def list_available_providers(self) -> Dict[str, list]:
|
||||
"""列出可用的提供商和模型"""
|
||||
"""List available providers for the L L M Factory instance."""
|
||||
return {
|
||||
"deepseek": DeepSeekClient.SUPPORTED_MODELS,
|
||||
"qwen": QwenClient.SUPPORTED_MODELS,
|
||||
@@ -184,12 +154,7 @@ class LLMFactory:
|
||||
|
||||
@classmethod
|
||||
def preload_clients(cls, providers: list = None):
|
||||
"""
|
||||
预加载LLM客户端(应用启动时调用)
|
||||
|
||||
Args:
|
||||
providers: 要预加载的提供商列表,默认加载qwen和deepseek
|
||||
"""
|
||||
"""Handle preload clients for the L L M Factory instance."""
|
||||
if providers is None:
|
||||
providers = ["qwen", "deepseek"]
|
||||
|
||||
@@ -203,9 +168,9 @@ class LLMFactory:
|
||||
|
||||
@classmethod
|
||||
def get_global_client(cls, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]:
|
||||
"""获取全局缓存的客户端"""
|
||||
"""Return global client for the L L M Factory instance."""
|
||||
provider_lower = provider.lower()
|
||||
# 处理模型名作为provider的情况(如 qwen3.5-flash)
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
if provider_lower.startswith("qwen"):
|
||||
provider_lower = "qwen"
|
||||
model = model or DEFAULT_MODELS.get(LLMProvider.QWEN if provider_lower == "qwen" else LLMProvider.DEEPSEEK)
|
||||
@@ -214,7 +179,7 @@ class LLMFactory:
|
||||
|
||||
@classmethod
|
||||
def cleanup(cls):
|
||||
"""清理所有缓存的客户端"""
|
||||
"""Handle cleanup for the L L M Factory instance."""
|
||||
for cache_key, client in cls._global_instances.items():
|
||||
try:
|
||||
client.close()
|
||||
@@ -227,7 +192,7 @@ class LLMFactory:
|
||||
|
||||
@lru_cache
|
||||
def get_llm_factory() -> LLMFactory:
|
||||
"""获取LLM工厂实例(缓存)"""
|
||||
"""Return llm factory."""
|
||||
return LLMFactory()
|
||||
|
||||
|
||||
@@ -236,20 +201,10 @@ def get_llm_client(
|
||||
model: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> BaseLLMClient:
|
||||
"""
|
||||
便捷函数:获取LLM客户端(优先使用缓存)
|
||||
|
||||
Args:
|
||||
provider: 提供商名称
|
||||
model: 模型名称
|
||||
**kwargs: 其他配置
|
||||
|
||||
Returns:
|
||||
BaseLLMClient: LLM客户端实例
|
||||
"""
|
||||
"""Return llm client."""
|
||||
factory = get_llm_factory()
|
||||
|
||||
# 先尝试获取缓存的实例
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
cached = factory.get_cached(provider, model)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
Reference in New Issue
Block a user