update
This commit is contained in:
116
backend/app/services/llm/base_client.py
Normal file
116
backend/app/services/llm/base_client.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# src/services/llm/base_client.py
|
||||
"""LLM客户端基类 - 统一接口定义"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Optional, Any
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class LLMProvider(Enum):
|
||||
"""LLM提供商"""
|
||||
DEEPSEEK = "deepseek"
|
||||
QWEN = "qwen"
|
||||
QWEN_VL = "qwen_vl"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""LLM响应结果"""
|
||||
content: str
|
||||
model: str
|
||||
usage: Dict[str, int] = field(default_factory=dict)
|
||||
finish_reason: str = "stop"
|
||||
latency_ms: int = 0
|
||||
error: Optional[str] = None
|
||||
|
||||
@property
|
||||
def is_success(self) -> bool:
|
||||
return self.error is None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
"""LLM配置"""
|
||||
provider: LLMProvider
|
||||
model: str
|
||||
api_key: str
|
||||
base_url: str
|
||||
max_tokens: int = 4096
|
||||
temperature: float = 0.7
|
||||
top_p: float = 0.9
|
||||
timeout: int = 300 # 默认超时300秒(摘要/Skills生成可能需要较长时间)
|
||||
|
||||
|
||||
class BaseLLMClient(ABC):
|
||||
"""LLM客户端基类"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self._client = None
|
||||
|
||||
@abstractmethod
|
||||
def _init_client(self):
|
||||
"""初始化客户端"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
对话补全
|
||||
|
||||
Args:
|
||||
messages: 对话消息列表 [{"role": "user/assistant/system", "content": "..."}]
|
||||
max_tokens: 最大输出token数
|
||||
temperature: 温度参数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
LLMResponse: 响应结果
|
||||
"""
|
||||
pass
|
||||
|
||||
def complete(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
单轮补全(便捷方法)
|
||||
|
||||
Args:
|
||||
prompt: 用户输入
|
||||
system_prompt: 系统提示词
|
||||
max_tokens: 最大输出token数
|
||||
temperature: 温度参数
|
||||
|
||||
Returns:
|
||||
LLMResponse: 响应结果
|
||||
"""
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
return self.chat(messages, max_tokens, temperature, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""获取可用模型列表"""
|
||||
pass
|
||||
|
||||
def estimate_tokens(self, text: str) -> int:
|
||||
"""估算文本token数(粗略估计)"""
|
||||
# 中文字符约1.5 token,英文约0.25 token
|
||||
chinese_chars = sum(1 for c in text if '一' <= c <= '鿿')
|
||||
other_chars = len(text) - chinese_chars
|
||||
return int(chinese_chars * 1.5 + other_chars * 0.25)
|
||||
Reference in New Issue
Block a user