update
This commit is contained in:
15
backend/app/services/llm/__init__.py
Normal file
15
backend/app/services/llm/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# src/services/llm/__init__.py
|
||||
"""LLM服务模块"""
|
||||
|
||||
from .llm_factory import LLMFactory, get_llm_client
|
||||
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
|
||||
from .deepseek_client import DeepSeekClient
|
||||
from .qwen_client import QwenClient, QwenVLClient
|
||||
from .document_summarizer import DocumentSummarizer, summarize_document, DocumentSummary
|
||||
|
||||
__all__ = [
|
||||
"LLMFactory", "get_llm_client",
|
||||
"BaseLLMClient", "LLMResponse", "LLMConfig", "LLMProvider",
|
||||
"DeepSeekClient", "QwenClient", "QwenVLClient",
|
||||
"DocumentSummarizer", "summarize_document", "DocumentSummary"
|
||||
]
|
||||
116
backend/app/services/llm/base_client.py
Normal file
116
backend/app/services/llm/base_client.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# src/services/llm/base_client.py
|
||||
"""LLM客户端基类 - 统一接口定义"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Optional, Any
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class LLMProvider(Enum):
|
||||
"""LLM提供商"""
|
||||
DEEPSEEK = "deepseek"
|
||||
QWEN = "qwen"
|
||||
QWEN_VL = "qwen_vl"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""LLM响应结果"""
|
||||
content: str
|
||||
model: str
|
||||
usage: Dict[str, int] = field(default_factory=dict)
|
||||
finish_reason: str = "stop"
|
||||
latency_ms: int = 0
|
||||
error: Optional[str] = None
|
||||
|
||||
@property
|
||||
def is_success(self) -> bool:
|
||||
return self.error is None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
"""LLM配置"""
|
||||
provider: LLMProvider
|
||||
model: str
|
||||
api_key: str
|
||||
base_url: str
|
||||
max_tokens: int = 4096
|
||||
temperature: float = 0.7
|
||||
top_p: float = 0.9
|
||||
timeout: int = 300 # 默认超时300秒(摘要/Skills生成可能需要较长时间)
|
||||
|
||||
|
||||
class BaseLLMClient(ABC):
|
||||
"""LLM客户端基类"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self._client = None
|
||||
|
||||
@abstractmethod
|
||||
def _init_client(self):
|
||||
"""初始化客户端"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
对话补全
|
||||
|
||||
Args:
|
||||
messages: 对话消息列表 [{"role": "user/assistant/system", "content": "..."}]
|
||||
max_tokens: 最大输出token数
|
||||
temperature: 温度参数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
LLMResponse: 响应结果
|
||||
"""
|
||||
pass
|
||||
|
||||
def complete(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
单轮补全(便捷方法)
|
||||
|
||||
Args:
|
||||
prompt: 用户输入
|
||||
system_prompt: 系统提示词
|
||||
max_tokens: 最大输出token数
|
||||
temperature: 温度参数
|
||||
|
||||
Returns:
|
||||
LLMResponse: 响应结果
|
||||
"""
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
return self.chat(messages, max_tokens, temperature, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""获取可用模型列表"""
|
||||
pass
|
||||
|
||||
def estimate_tokens(self, text: str) -> int:
|
||||
"""估算文本token数(粗略估计)"""
|
||||
# 中文字符约1.5 token,英文约0.25 token
|
||||
chinese_chars = sum(1 for c in text if '一' <= c <= '鿿')
|
||||
other_chars = len(text) - chinese_chars
|
||||
return int(chinese_chars * 1.5 + other_chars * 0.25)
|
||||
130
backend/app/services/llm/deepseek_client.py
Normal file
130
backend/app/services/llm/deepseek_client.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# src/services/llm/deepseek_client.py
|
||||
"""DeepSeek LLM客户端 - OpenAI兼容API"""
|
||||
|
||||
import time
|
||||
from typing import List, Dict, Optional
|
||||
from loguru import logger
|
||||
import httpx
|
||||
|
||||
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
|
||||
|
||||
|
||||
class DeepSeekClient(BaseLLMClient):
|
||||
"""
|
||||
DeepSeek API客户端(OpenAI兼容格式)
|
||||
|
||||
支持模型:
|
||||
- deepseek-chat
|
||||
- deepseek-coder
|
||||
- deepseek-reasoner
|
||||
- deepseek-v3
|
||||
- deepseek-v3.2
|
||||
- deepseek-v4-flash
|
||||
"""
|
||||
|
||||
SUPPORTED_MODELS = [
|
||||
"deepseek-chat",
|
||||
"deepseek-coder",
|
||||
"deepseek-reasoner",
|
||||
"deepseek-v3",
|
||||
"deepseek-v3.2",
|
||||
"deepseek-v4-flash"
|
||||
]
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
if config.provider != LLMProvider.DEEPSEEK:
|
||||
raise ValueError(f"配置provider应为DEEPSEEK,实际为{config.provider}")
|
||||
super().__init__(config)
|
||||
self._init_client()
|
||||
|
||||
def _init_client(self):
|
||||
"""初始化HTTP客户端"""
|
||||
self._client = httpx.Client(
|
||||
base_url=self.config.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
logger.info(f"DeepSeek客户端初始化完成: {self.config.base_url} - {self.config.model}")
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
"""对话补全"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens or self.config.max_tokens,
|
||||
"temperature": temperature or self.config.temperature,
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"stream": False
|
||||
}
|
||||
|
||||
response = self._client.post("/chat/completions", json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
choices = data.get("choices", [{}])
|
||||
message = choices[0].get("message", {})
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", self.config.model),
|
||||
usage=data.get("usage", {}),
|
||||
finish_reason=choices[0].get("finish_reason", "stop"),
|
||||
latency_ms=latency_ms
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"DeepSeek API错误: {e.response.status_code} - {e.response.text}")
|
||||
return LLMResponse(
|
||||
content="",
|
||||
model=self.config.model,
|
||||
error=f"API错误: {e.response.status_code} - {e.response.text[:200]}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"DeepSeek调用失败: {e}")
|
||||
return LLMResponse(
|
||||
content="",
|
||||
model=self.config.model,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""获取可用模型列表"""
|
||||
return self.SUPPORTED_MODELS
|
||||
|
||||
def close(self):
|
||||
"""关闭客户端"""
|
||||
if self._client:
|
||||
self._client.close()
|
||||
|
||||
|
||||
def create_deepseek_client(
|
||||
api_key: str,
|
||||
model: str = "deepseek-v4-flash",
|
||||
base_url: str = "http://6.86.80.4:30080/v1",
|
||||
**kwargs
|
||||
) -> DeepSeekClient:
|
||||
"""便捷函数:创建DeepSeek客户端"""
|
||||
config = LLMConfig(
|
||||
provider=LLMProvider.DEEPSEEK,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
**kwargs
|
||||
)
|
||||
return DeepSeekClient(config)
|
||||
231
backend/app/services/llm/document_summarizer.py
Normal file
231
backend/app/services/llm/document_summarizer.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# src/services/llm/document_summarizer.py
|
||||
"""文档摘要生成服务 - LLM生成法规文档摘要"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
|
||||
from app.services.llm import get_llm_client, BaseLLMClient
|
||||
from app.services.rag.prompt_templates import get_prompt_template
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentSummary:
|
||||
"""文档摘要结果"""
|
||||
doc_name: str
|
||||
summary: str
|
||||
applicable_scope: str
|
||||
key_clauses: list
|
||||
key_terms: list
|
||||
compliance_points: list
|
||||
model: str
|
||||
latency_ms: int
|
||||
error: Optional[str] = None
|
||||
|
||||
@property
|
||||
def is_success(self) -> bool:
|
||||
return self.error is None
|
||||
|
||||
|
||||
class DocumentSummarizer:
|
||||
"""
|
||||
文档摘要生成器
|
||||
|
||||
功能:
|
||||
- 生成法规文档的核心要点摘要
|
||||
- 提取适用范围
|
||||
- 突出关键条款
|
||||
- 列出合规要点
|
||||
|
||||
使用示例:
|
||||
summarizer = DocumentSummarizer()
|
||||
result = summarizer.summarize("GB 7258-2017", markdown_content)
|
||||
print(result.summary)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str = None,
|
||||
model: str = None,
|
||||
max_tokens: int = None
|
||||
):
|
||||
"""
|
||||
初始化摘要生成器
|
||||
|
||||
Args:
|
||||
provider: LLM提供商
|
||||
model: LLM模型名称
|
||||
max_tokens: 最大输出token数
|
||||
"""
|
||||
self.provider = provider or settings.llm_provider
|
||||
self.model = model or settings.llm_model
|
||||
self.max_tokens = max_tokens or settings.rag_summary_max_tokens
|
||||
|
||||
# LLM客户端(延迟加载)
|
||||
self.llm: Optional[BaseLLMClient] = None
|
||||
|
||||
logger.info(f"摘要生成器初始化: provider={self.provider}, model={self.model}")
|
||||
|
||||
def _init_llm(self):
|
||||
"""延迟初始化LLM"""
|
||||
if self.llm is None:
|
||||
self.llm = get_llm_client(
|
||||
provider=self.provider,
|
||||
model=self.model
|
||||
)
|
||||
|
||||
def summarize(
|
||||
self,
|
||||
doc_name: str,
|
||||
content: str,
|
||||
regulation_type: str = "",
|
||||
max_tokens: Optional[int] = None
|
||||
) -> DocumentSummary:
|
||||
"""
|
||||
生成文档摘要
|
||||
|
||||
Args:
|
||||
doc_name: 文档名称
|
||||
content: 文档内容(Markdown格式)
|
||||
regulation_type: 法规类型
|
||||
max_tokens: 最大输出token数
|
||||
|
||||
Returns:
|
||||
DocumentSummary: 摘要结果
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(f"生成文档摘要: {doc_name}")
|
||||
|
||||
try:
|
||||
self._init_llm()
|
||||
|
||||
# 使用摘要模板
|
||||
template = get_prompt_template("document_summary")
|
||||
|
||||
# 构建用户消息
|
||||
user_content = template.user_template.format(
|
||||
doc_name=doc_name,
|
||||
content=content[:8000] # 截取前8000字符(避免超出token限制)
|
||||
)
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm.chat(
|
||||
messages=[
|
||||
{"role": "system", "content": template.system_prompt},
|
||||
{"role": "user", "content": user_content}
|
||||
],
|
||||
max_tokens=max_tokens or self.max_tokens,
|
||||
temperature=0.3 # 低温度保证摘要准确性
|
||||
)
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
if not response.is_success:
|
||||
return DocumentSummary(
|
||||
doc_name=doc_name,
|
||||
summary="",
|
||||
applicable_scope="",
|
||||
key_clauses=[],
|
||||
key_terms=[],
|
||||
compliance_points=[],
|
||||
model=self.model,
|
||||
latency_ms=latency_ms,
|
||||
error=response.error
|
||||
)
|
||||
|
||||
# 解析摘要结构
|
||||
summary_data = self._parse_summary(response.content)
|
||||
|
||||
logger.success(f"摘要生成完成: {doc_name}, {latency_ms}ms")
|
||||
|
||||
return DocumentSummary(
|
||||
doc_name=doc_name,
|
||||
summary=summary_data.get("summary", response.content),
|
||||
applicable_scope=summary_data.get("applicable_scope", ""),
|
||||
key_clauses=summary_data.get("key_clauses", []),
|
||||
key_terms=summary_data.get("key_terms", []),
|
||||
compliance_points=summary_data.get("compliance_points", []),
|
||||
model=response.model,
|
||||
latency_ms=latency_ms
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"摘要生成失败: {e}")
|
||||
return DocumentSummary(
|
||||
doc_name=doc_name,
|
||||
summary="",
|
||||
applicable_scope="",
|
||||
key_clauses=[],
|
||||
key_terms=[],
|
||||
compliance_points=[],
|
||||
model=self.model,
|
||||
latency_ms=0,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def _parse_summary(self, content: str) -> Dict:
|
||||
"""解析摘要内容(提取结构化信息)"""
|
||||
result = {
|
||||
"summary": content,
|
||||
"applicable_scope": "",
|
||||
"key_clauses": [],
|
||||
"key_terms": [],
|
||||
"compliance_points": []
|
||||
}
|
||||
|
||||
# 简单解析(提取关键信息)
|
||||
lines = content.split("\n")
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
|
||||
# 提取适用范围
|
||||
if "适用范围" in line or "适用对象" in line:
|
||||
result["applicable_scope"] = line.split(":")[-1].strip() if ":" in line else line.split(":")[-1].strip()
|
||||
|
||||
# 提取关键条款
|
||||
if line.startswith("- 【条款") or line.startswith("【条款"):
|
||||
result["key_clauses"].append(line)
|
||||
|
||||
# 提取关键术语
|
||||
if "关键术语" in line or "术语定义" in line:
|
||||
# 继续读取后续几行
|
||||
pass
|
||||
|
||||
# 提取合规要点
|
||||
if "合规要点" in line or "必须满足" in line:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
def batch_summarize(
|
||||
self,
|
||||
documents: list
|
||||
) -> list:
|
||||
"""
|
||||
批量生成摘要
|
||||
|
||||
Args:
|
||||
documents: 文档列表 [{"doc_name": str, "content": str}, ...]
|
||||
|
||||
Returns:
|
||||
list: 摘要结果列表
|
||||
"""
|
||||
results = []
|
||||
for doc in documents:
|
||||
result = self.summarize(doc["doc_name"], doc["content"])
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
||||
def summarize_document(
|
||||
doc_name: str,
|
||||
content: str,
|
||||
**kwargs
|
||||
) -> DocumentSummary:
|
||||
"""便捷函数:生成文档摘要"""
|
||||
summarizer = DocumentSummarizer(**kwargs)
|
||||
return summarizer.summarize(doc_name, content)
|
||||
258
backend/app/services/llm/llm_factory.py
Normal file
258
backend/app/services/llm/llm_factory.py
Normal file
@@ -0,0 +1,258 @@
|
||||
# src/services/llm/llm_factory.py
|
||||
"""LLM工厂 - 统一创建和管理LLM客户端"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from loguru import logger
|
||||
from functools import lru_cache
|
||||
|
||||
from .base_client import BaseLLMClient, LLMConfig, LLMProvider, LLMResponse
|
||||
from .deepseek_client import DeepSeekClient
|
||||
from .qwen_client import QwenClient, QwenVLClient
|
||||
|
||||
|
||||
# 默认模型映射
|
||||
DEFAULT_MODELS = {
|
||||
LLMProvider.DEEPSEEK: "deepseek-v4-flash",
|
||||
LLMProvider.QWEN: "qwen3.5-flash",
|
||||
LLMProvider.QWEN_VL: "qwen3-vl-plus"
|
||||
}
|
||||
|
||||
# API基础URL(使用统一代理服务)
|
||||
DEFAULT_BASE_URLS = {
|
||||
LLMProvider.DEEPSEEK: "http://6.86.80.4:30080/v1",
|
||||
LLMProvider.QWEN: "http://6.86.80.4:30080/v1",
|
||||
LLMProvider.QWEN_VL: "http://6.86.80.4:30080/v1"
|
||||
}
|
||||
|
||||
|
||||
class LLMFactory:
|
||||
"""
|
||||
LLM客户端工厂(支持全局缓存)
|
||||
|
||||
支持的提供商和模型:
|
||||
- DeepSeek: deepseek-chat (DeepSeek-V3), deepseek-coder
|
||||
- Qwen: qwen-turbo, qwen-plus, qwen-max, qwen-long
|
||||
- QwenVL: qwen-vl-plus, qwen-vl-max (多模态)
|
||||
|
||||
使用示例:
|
||||
factory = LLMFactory()
|
||||
|
||||
# 使用默认配置
|
||||
client = factory.create("deepseek")
|
||||
|
||||
# 自定义配置
|
||||
client = factory.create("qwen", model="qwen-max", temperature=0.5)
|
||||
|
||||
# 调用LLM
|
||||
response = client.complete("你好,介绍一下自己")
|
||||
"""
|
||||
|
||||
# 全局客户端缓存(类级别,跨实例共享)
|
||||
_global_instances: Dict[str, BaseLLMClient] = {}
|
||||
|
||||
def __init__(self):
|
||||
self._config_cache: Dict[str, Any] = {}
|
||||
|
||||
def create(
|
||||
self,
|
||||
provider: str,
|
||||
api_key: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
**kwargs
|
||||
) -> BaseLLMClient:
|
||||
"""
|
||||
创建LLM客户端
|
||||
|
||||
Args:
|
||||
provider: 提供商名称 ("deepseek", "qwen", "qwen_vl")
|
||||
api_key: API密钥(如未提供,从环境变量获取)
|
||||
model: 模型名称(如未提供,使用默认模型)
|
||||
base_url: API基础URL
|
||||
max_tokens: 最大输出token数
|
||||
temperature: 温度参数
|
||||
**kwargs: 其他配置参数
|
||||
|
||||
Returns:
|
||||
BaseLLMClient: LLM客户端实例
|
||||
"""
|
||||
provider_enum = self._parse_provider(provider)
|
||||
|
||||
# 获取配置
|
||||
api_key = api_key or self._get_api_key(provider_enum)
|
||||
model = model or DEFAULT_MODELS.get(provider_enum)
|
||||
base_url = base_url or DEFAULT_BASE_URLS.get(provider_enum)
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(f"缺少API密钥,请设置环境变量或传入api_key参数")
|
||||
|
||||
# 检查全局缓存
|
||||
cache_key = f"{provider}_{model}"
|
||||
if cache_key in LLMFactory._global_instances:
|
||||
logger.debug(f"使用缓存的LLM客户端: {cache_key}")
|
||||
return LLMFactory._global_instances[cache_key]
|
||||
|
||||
config = LLMConfig(
|
||||
provider=provider_enum,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# 创建客户端
|
||||
client = self._create_client(config)
|
||||
|
||||
# 缓存到全局实例
|
||||
LLMFactory._global_instances[cache_key] = client
|
||||
|
||||
logger.info(f"LLM客户端创建成功并缓存: {provider} - {model}")
|
||||
return client
|
||||
|
||||
def _parse_provider(self, provider: str) -> LLMProvider:
|
||||
"""解析提供商名称"""
|
||||
provider_map = {
|
||||
"deepseek": LLMProvider.DEEPSEEK,
|
||||
"deepseek-v3": LLMProvider.DEEPSEEK,
|
||||
"deepseek_chat": LLMProvider.DEEPSEEK,
|
||||
"qwen": LLMProvider.QWEN,
|
||||
"qwen-turbo": LLMProvider.QWEN,
|
||||
"qwen-plus": LLMProvider.QWEN,
|
||||
"qwen-max": LLMProvider.QWEN,
|
||||
"qwen3.5-flash": LLMProvider.QWEN,
|
||||
"qwen3.5-plus": LLMProvider.QWEN,
|
||||
"qwen_vl": LLMProvider.QWEN_VL,
|
||||
"qwen-vl": LLMProvider.QWEN_VL,
|
||||
"qwen-vl-plus": LLMProvider.QWEN_VL,
|
||||
"qwen-vl-max": LLMProvider.QWEN_VL
|
||||
}
|
||||
|
||||
provider_lower = provider.lower()
|
||||
if provider_lower not in provider_map:
|
||||
raise ValueError(f"不支持的提供商: {provider},支持的: {list(provider_map.keys())}")
|
||||
|
||||
return provider_map[provider_lower]
|
||||
|
||||
def _get_api_key(self, provider: LLMProvider) -> Optional[str]:
|
||||
"""从环境变量获取API密钥"""
|
||||
import os
|
||||
|
||||
key_map = {
|
||||
LLMProvider.DEEPSEEK: ["DEEPSEEK_API_KEY", "OPENAI_API_KEY"],
|
||||
LLMProvider.QWEN: ["QWEN_API_KEY", "DASHSCOPE_API_KEY", "ALIBABA_API_KEY"],
|
||||
LLMProvider.QWEN_VL: ["QWEN_API_KEY", "DASHSCOPE_API_KEY", "ALIBABA_API_KEY"]
|
||||
}
|
||||
|
||||
for key_name in key_map.get(provider, []):
|
||||
api_key = os.getenv(key_name)
|
||||
if api_key:
|
||||
return api_key
|
||||
|
||||
return None
|
||||
|
||||
def _create_client(self, config: LLMConfig) -> BaseLLMClient:
|
||||
"""创建具体客户端"""
|
||||
client_map = {
|
||||
LLMProvider.DEEPSEEK: DeepSeekClient,
|
||||
LLMProvider.QWEN: QwenClient,
|
||||
LLMProvider.QWEN_VL: QwenVLClient
|
||||
}
|
||||
|
||||
client_class = client_map.get(config.provider)
|
||||
if not client_class:
|
||||
raise ValueError(f"不支持的提供商: {config.provider}")
|
||||
|
||||
return client_class(config)
|
||||
|
||||
def get_cached(self, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]:
|
||||
"""获取缓存的客户端"""
|
||||
provider_enum = self._parse_provider(provider)
|
||||
model = model or DEFAULT_MODELS.get(provider_enum)
|
||||
cache_key = f"{provider}_{model}"
|
||||
return LLMFactory._global_instances.get(cache_key)
|
||||
|
||||
def list_available_providers(self) -> Dict[str, list]:
|
||||
"""列出可用的提供商和模型"""
|
||||
return {
|
||||
"deepseek": DeepSeekClient.SUPPORTED_MODELS,
|
||||
"qwen": QwenClient.SUPPORTED_MODELS,
|
||||
"qwen_vl": QwenVLClient.SUPPORTED_MODELS
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def preload_clients(cls, providers: list = None):
|
||||
"""
|
||||
预加载LLM客户端(应用启动时调用)
|
||||
|
||||
Args:
|
||||
providers: 要预加载的提供商列表,默认加载qwen和deepseek
|
||||
"""
|
||||
if providers is None:
|
||||
providers = ["qwen", "deepseek"]
|
||||
|
||||
factory = cls()
|
||||
for provider in providers:
|
||||
try:
|
||||
client = factory.create(provider)
|
||||
logger.success(f"预加载LLM客户端成功: {provider}")
|
||||
except Exception as e:
|
||||
logger.warning(f"预加载LLM客户端失败: {provider} - {e}")
|
||||
|
||||
@classmethod
|
||||
def get_global_client(cls, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]:
|
||||
"""获取全局缓存的客户端"""
|
||||
provider_lower = provider.lower()
|
||||
# 处理模型名作为provider的情况(如 qwen3.5-flash)
|
||||
if provider_lower.startswith("qwen"):
|
||||
provider_lower = "qwen"
|
||||
model = model or DEFAULT_MODELS.get(LLMProvider.QWEN if provider_lower == "qwen" else LLMProvider.DEEPSEEK)
|
||||
cache_key = f"{provider_lower}_{model}"
|
||||
return cls._global_instances.get(cache_key)
|
||||
|
||||
@classmethod
|
||||
def cleanup(cls):
|
||||
"""清理所有缓存的客户端"""
|
||||
for cache_key, client in cls._global_instances.items():
|
||||
try:
|
||||
client.close()
|
||||
logger.debug(f"关闭LLM客户端: {cache_key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭LLM客户端失败: {cache_key} - {e}")
|
||||
cls._global_instances.clear()
|
||||
logger.info("所有LLM客户端已清理")
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_llm_factory() -> LLMFactory:
|
||||
"""获取LLM工厂实例(缓存)"""
|
||||
return LLMFactory()
|
||||
|
||||
|
||||
def get_llm_client(
|
||||
provider: str = "qwen",
|
||||
model: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> BaseLLMClient:
|
||||
"""
|
||||
便捷函数:获取LLM客户端(优先使用缓存)
|
||||
|
||||
Args:
|
||||
provider: 提供商名称
|
||||
model: 模型名称
|
||||
**kwargs: 其他配置
|
||||
|
||||
Returns:
|
||||
BaseLLMClient: LLM客户端实例
|
||||
"""
|
||||
factory = get_llm_factory()
|
||||
|
||||
# 先尝试获取缓存的实例
|
||||
cached = factory.get_cached(provider, model)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
return factory.create(provider, model=model, **kwargs)
|
||||
392
backend/app/services/llm/qwen_client.py
Normal file
392
backend/app/services/llm/qwen_client.py
Normal file
@@ -0,0 +1,392 @@
|
||||
# src/services/llm/qwen_client.py
|
||||
"""Qwen LLM客户端 - 支持OpenAI兼容API格式"""
|
||||
|
||||
import time
|
||||
import json
|
||||
from typing import List, Dict, Optional, Generator, AsyncGenerator
|
||||
from loguru import logger
|
||||
import httpx
|
||||
|
||||
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
|
||||
|
||||
|
||||
class QwenClient(BaseLLMClient):
|
||||
"""
|
||||
Qwen API客户端(OpenAI兼容格式)
|
||||
|
||||
支持通过new-api等代理服务调用:
|
||||
- qwen-turbo
|
||||
- qwen-plus
|
||||
- qwen-max
|
||||
- qwen3.5-flash (推荐:快速响应)
|
||||
- qwen3.5-plus
|
||||
- qwen-long
|
||||
- qwen2.5系列
|
||||
"""
|
||||
|
||||
SUPPORTED_MODELS = [
|
||||
"qwen-turbo",
|
||||
"qwen-plus",
|
||||
"qwen-max",
|
||||
"qwen-max-longcontext",
|
||||
"qwen-long",
|
||||
"qwen3.5-flash",
|
||||
"qwen3.5-plus",
|
||||
"qwen3-plus",
|
||||
"qwen2.5-72b-instruct",
|
||||
"qwen2.5-32b-instruct",
|
||||
"qwen2.5-14b-instruct",
|
||||
"qwen2.5-7b-instruct"
|
||||
]
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
if config.provider not in [LLMProvider.QWEN, LLMProvider.QWEN_VL]:
|
||||
raise ValueError(f"配置provider应为Qwen,实际为{config.provider}")
|
||||
super().__init__(config)
|
||||
self._init_client()
|
||||
|
||||
def _init_client(self):
|
||||
"""初始化HTTP客户端"""
|
||||
# OpenAI兼容API格式
|
||||
self._client = httpx.Client(
|
||||
base_url=self.config.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
logger.info(f"Qwen客户端初始化完成: {self.config.base_url} - {self.config.model}")
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
"""对话补全(OpenAI兼容格式)"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# OpenAI兼容格式的请求体
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens or self.config.max_tokens,
|
||||
"temperature": temperature or self.config.temperature,
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"stream": False
|
||||
}
|
||||
|
||||
# OpenAI兼容接口路径
|
||||
response = self._client.post("/chat/completions", json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# OpenAI兼容格式的响应解析
|
||||
choices = data.get("choices", [{}])
|
||||
message = choices[0].get("message", {})
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", self.config.model),
|
||||
usage=data.get("usage", {}),
|
||||
finish_reason=choices[0].get("finish_reason", "stop"),
|
||||
latency_ms=latency_ms
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Qwen API错误: {e.response.status_code} - {e.response.text}")
|
||||
return LLMResponse(
|
||||
content="",
|
||||
model=self.config.model,
|
||||
error=f"API错误: {e.response.status_code} - {e.response.text[:200]}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Qwen调用失败: {e}")
|
||||
return LLMResponse(
|
||||
content="",
|
||||
model=self.config.model,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
流式对话补全(SSE格式)
|
||||
|
||||
Yields:
|
||||
str: 每次返回一个文本片段
|
||||
|
||||
使用示例:
|
||||
for chunk in client.stream_chat(messages):
|
||||
print(chunk, end="", flush=True)
|
||||
"""
|
||||
try:
|
||||
# OpenAI兼容格式的请求体,启用流式输出
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens or self.config.max_tokens,
|
||||
"temperature": temperature or self.config.temperature,
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"stream": True # 启用流式输出
|
||||
}
|
||||
|
||||
# 使用stream模式发送请求
|
||||
with self._client.stream("POST", "/chat/completions", json=payload) as response:
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line = line.strip()
|
||||
# SSE格式: data: {...}
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:] # 移除 "data: " 前缀
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
continue # 跳过空的choices
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Qwen流式API错误: {e.response.status_code}")
|
||||
yield f"[ERROR: API返回错误 {e.response.status_code}]"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Qwen流式调用失败: {e}")
|
||||
yield f"[ERROR: {str(e)}]"
|
||||
|
||||
async def async_stream_chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
异步流式对话补全(用于FastAPI SSE响应)
|
||||
|
||||
Yields:
|
||||
str: 每次返回一个文本片段
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# 使用同步流式方法,包装为异步
|
||||
for chunk in self.stream_chat(messages, max_tokens, temperature, **kwargs):
|
||||
yield chunk
|
||||
# 给async循环一个小延迟,让其他任务有机会执行
|
||||
await asyncio.sleep(0)
|
||||
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""获取可用模型列表"""
|
||||
return self.SUPPORTED_MODELS
|
||||
|
||||
def close(self):
|
||||
"""关闭客户端"""
|
||||
if self._client:
|
||||
self._client.close()
|
||||
|
||||
|
||||
class QwenVLClient(BaseLLMClient):
|
||||
"""
|
||||
Qwen VL多模态客户端(OpenAI兼容格式)
|
||||
|
||||
支持模型:
|
||||
- qwen-vl-plus
|
||||
- qwen-vl-max
|
||||
- qwen3-vl-plus
|
||||
- qwen2-vl-7b-instruct
|
||||
- qwen2-vl-72b-instruct
|
||||
"""
|
||||
|
||||
SUPPORTED_MODELS = [
|
||||
"qwen-vl-plus",
|
||||
"qwen-vl-max",
|
||||
"qwen3-vl-plus",
|
||||
"qwen2-vl-7b-instruct",
|
||||
"qwen2-vl-72b-instruct"
|
||||
]
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
if config.provider != LLMProvider.QWEN_VL:
|
||||
raise ValueError(f"配置provider应为QWEN_VL,实际为{config.provider}")
|
||||
super().__init__(config)
|
||||
self._init_client()
|
||||
|
||||
def _init_client(self):
|
||||
"""初始化HTTP客户端"""
|
||||
self._client = httpx.Client(
|
||||
base_url=self.config.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
logger.info(f"QwenVL客户端初始化完成: {self.config.base_url} - {self.config.model}")
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
"""多模态对话补全(OpenAI兼容格式)
|
||||
|
||||
支持图片输入,消息格式:
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}},
|
||||
{"type": "text", "text": "描述这张图片"}
|
||||
]
|
||||
}
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# OpenAI兼容格式的请求体
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens or self.config.max_tokens,
|
||||
"temperature": temperature or self.config.temperature,
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"stream": False
|
||||
}
|
||||
|
||||
response = self._client.post("/chat/completions", json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
choices = data.get("choices", [{}])
|
||||
message = choices[0].get("message", {})
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", self.config.model),
|
||||
usage=data.get("usage", {}),
|
||||
finish_reason=choices[0].get("finish_reason", "stop"),
|
||||
latency_ms=latency_ms
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"QwenVL API错误: {e.response.status_code} - {e.response.text}")
|
||||
return LLMResponse(
|
||||
content="",
|
||||
model=self.config.model,
|
||||
error=f"API错误: {e.response.status_code} - {e.response.text[:200]}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"QwenVL调用失败: {e}")
|
||||
return LLMResponse(
|
||||
content="",
|
||||
model=self.config.model,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
"""流式多模态对话补全"""
|
||||
try:
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens or self.config.max_tokens,
|
||||
"temperature": temperature or self.config.temperature,
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"stream": True
|
||||
}
|
||||
|
||||
with self._client.stream("POST", "/chat/completions", json=payload) as response:
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line = line.strip()
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
continue # 跳过空的choices
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"QwenVL流式调用失败: {e}")
|
||||
yield f"[ERROR: {str(e)}]"
|
||||
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""获取可用模型列表"""
|
||||
return self.SUPPORTED_MODELS
|
||||
|
||||
def close(self):
|
||||
"""关闭客户端"""
|
||||
if self._client:
|
||||
self._client.close()
|
||||
|
||||
|
||||
def create_qwen_client(
|
||||
api_key: str,
|
||||
model: str = "qwen3.5-flash",
|
||||
base_url: str = "http://6.86.80.4:30080/v1",
|
||||
**kwargs
|
||||
) -> QwenClient:
|
||||
"""便捷函数:创建Qwen客户端"""
|
||||
config = LLMConfig(
|
||||
provider=LLMProvider.QWEN,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
**kwargs
|
||||
)
|
||||
return QwenClient(config)
|
||||
|
||||
|
||||
def create_qwen_vl_client(
|
||||
api_key: str,
|
||||
model: str = "qwen3-vl-plus",
|
||||
base_url: str = "http://6.86.80.4:30080/v1",
|
||||
**kwargs
|
||||
) -> QwenVLClient:
|
||||
"""便捷函数:创建QwenVL客户端"""
|
||||
config = LLMConfig(
|
||||
provider=LLMProvider.QWEN_VL,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
**kwargs
|
||||
)
|
||||
return QwenVLClient(config)
|
||||
Reference in New Issue
Block a user