Files
AIRegulation-DocAnalysis/backend/app/services/llm/llm_factory.py

213 lines
7.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Provide service-layer logic for llm factory."""
from typing import Optional, Dict, Any
from loguru import logger
from functools import lru_cache
from .base_client import BaseLLMClient, LLMConfig, LLMProvider, LLMResponse
from .deepseek_client import DeepSeekClient
from .qwen_client import QwenClient, QwenVLClient
# Keep provider-specific behavior explicit so debugging stays straightforward.
# Keep provider-specific behavior explicit so debugging stays straightforward.
DEFAULT_MODELS = {
LLMProvider.DEEPSEEK: "deepseek-v4-flash",
LLMProvider.QWEN: "qwen3.5-flash",
LLMProvider.QWEN_VL: "qwen3-vl-plus"
}
# Keep provider-specific behavior explicit so debugging stays straightforward.
DEFAULT_BASE_URLS = {
LLMProvider.DEEPSEEK: "http://6.86.80.4:30080/v1",
LLMProvider.QWEN: "http://6.86.80.4:30080/v1",
LLMProvider.QWEN_VL: "http://6.86.80.4:30080/v1"
}
class LLMFactory:
"""Represent the L L M Factory type."""
# Keep provider-specific behavior explicit so debugging stays straightforward.
_global_instances: Dict[str, BaseLLMClient] = {}
def __init__(self):
"""Initialize the L L M Factory instance."""
self._config_cache: Dict[str, Any] = {}
def create(
self,
provider: str,
api_key: Optional[str] = None,
model: Optional[str] = None,
base_url: Optional[str] = None,
max_tokens: int = 4096,
temperature: float = 0.7,
**kwargs
) -> BaseLLMClient:
"""Handle create for the L L M Factory instance."""
provider_enum = self._parse_provider(provider)
# Keep provider-specific behavior explicit so debugging stays straightforward.
api_key = api_key or self._get_api_key(provider_enum)
model = model or DEFAULT_MODELS.get(provider_enum)
base_url = base_url or DEFAULT_BASE_URLS.get(provider_enum)
if not api_key:
raise ValueError(f"缺少API密钥请设置环境变量或传入api_key参数")
# Keep provider-specific behavior explicit so debugging stays straightforward.
cache_key = f"{provider}_{model}"
if cache_key in LLMFactory._global_instances:
logger.debug(f"使用缓存的LLM客户端: {cache_key}")
return LLMFactory._global_instances[cache_key]
config = LLMConfig(
provider=provider_enum,
model=model,
api_key=api_key,
base_url=base_url,
max_tokens=max_tokens,
temperature=temperature,
**kwargs
)
# Keep provider-specific behavior explicit so debugging stays straightforward.
client = self._create_client(config)
# Keep provider-specific behavior explicit so debugging stays straightforward.
LLMFactory._global_instances[cache_key] = client
logger.info(f"LLM客户端创建成功并缓存: {provider} - {model}")
return client
def _parse_provider(self, provider: str) -> LLMProvider:
"""Handle parse provider for this module for the L L M Factory instance."""
provider_map = {
"deepseek": LLMProvider.DEEPSEEK,
"deepseek-v3": LLMProvider.DEEPSEEK,
"deepseek_chat": LLMProvider.DEEPSEEK,
"qwen": LLMProvider.QWEN,
"qwen-turbo": LLMProvider.QWEN,
"qwen-plus": LLMProvider.QWEN,
"qwen-max": LLMProvider.QWEN,
"qwen3.5-flash": LLMProvider.QWEN,
"qwen3.5-plus": LLMProvider.QWEN,
"qwen_vl": LLMProvider.QWEN_VL,
"qwen-vl": LLMProvider.QWEN_VL,
"qwen-vl-plus": LLMProvider.QWEN_VL,
"qwen-vl-max": LLMProvider.QWEN_VL
}
provider_lower = provider.lower()
if provider_lower not in provider_map:
raise ValueError(f"不支持的提供商: {provider},支持的: {list(provider_map.keys())}")
return provider_map[provider_lower]
def _get_api_key(self, provider: LLMProvider) -> Optional[str]:
"""Handle get api key for this module for the L L M Factory instance."""
import os
key_map = {
LLMProvider.DEEPSEEK: ["DEEPSEEK_API_KEY", "OPENAI_API_KEY"],
LLMProvider.QWEN: ["QWEN_API_KEY", "DASHSCOPE_API_KEY", "ALIBABA_API_KEY"],
LLMProvider.QWEN_VL: ["QWEN_API_KEY", "DASHSCOPE_API_KEY", "ALIBABA_API_KEY"]
}
for key_name in key_map.get(provider, []):
api_key = os.getenv(key_name)
if api_key:
return api_key
return None
def _create_client(self, config: LLMConfig) -> BaseLLMClient:
"""Handle create client for this module for the L L M Factory instance."""
client_map = {
LLMProvider.DEEPSEEK: DeepSeekClient,
LLMProvider.QWEN: QwenClient,
LLMProvider.QWEN_VL: QwenVLClient
}
client_class = client_map.get(config.provider)
if not client_class:
raise ValueError(f"不支持的提供商: {config.provider}")
return client_class(config)
def get_cached(self, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]:
"""Return cached for the L L M Factory instance."""
provider_enum = self._parse_provider(provider)
model = model or DEFAULT_MODELS.get(provider_enum)
cache_key = f"{provider}_{model}"
return LLMFactory._global_instances.get(cache_key)
def list_available_providers(self) -> Dict[str, list]:
"""List available providers for the L L M Factory instance."""
return {
"deepseek": DeepSeekClient.SUPPORTED_MODELS,
"qwen": QwenClient.SUPPORTED_MODELS,
"qwen_vl": QwenVLClient.SUPPORTED_MODELS
}
@classmethod
def preload_clients(cls, providers: list = None):
"""Handle preload clients for the L L M Factory instance."""
if providers is None:
providers = ["qwen", "deepseek"]
factory = cls()
for provider in providers:
try:
client = factory.create(provider)
logger.success(f"预加载LLM客户端成功: {provider}")
except Exception as e:
logger.warning(f"预加载LLM客户端失败: {provider} - {e}")
@classmethod
def get_global_client(cls, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]:
"""Return global client for the L L M Factory instance."""
provider_lower = provider.lower()
# Keep provider-specific behavior explicit so debugging stays straightforward.
if provider_lower.startswith("qwen"):
provider_lower = "qwen"
model = model or DEFAULT_MODELS.get(LLMProvider.QWEN if provider_lower == "qwen" else LLMProvider.DEEPSEEK)
cache_key = f"{provider_lower}_{model}"
return cls._global_instances.get(cache_key)
@classmethod
def cleanup(cls):
"""Handle cleanup for the L L M Factory instance."""
for cache_key, client in cls._global_instances.items():
try:
client.close()
logger.debug(f"关闭LLM客户端: {cache_key}")
except Exception as e:
logger.warning(f"关闭LLM客户端失败: {cache_key} - {e}")
cls._global_instances.clear()
logger.info("所有LLM客户端已清理")
@lru_cache
def get_llm_factory() -> LLMFactory:
"""Return llm factory."""
return LLMFactory()
def get_llm_client(
provider: str = "qwen",
model: Optional[str] = None,
**kwargs
) -> BaseLLMClient:
"""Return llm client."""
factory = get_llm_factory()
# Keep provider-specific behavior explicit so debugging stays straightforward.
cached = factory.get_cached(provider, model)
if cached:
return cached
return factory.create(provider, model=model, **kwargs)