Files
AIRegulation-DocAnalysis/backend/app/services/llm/base_client.py

116 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LLM客户端基类 - 统一接口定义"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any
from enum import Enum
class LLMProvider(Enum):
"""LLM提供商"""
DEEPSEEK = "deepseek"
QWEN = "qwen"
QWEN_VL = "qwen_vl"
@dataclass
class LLMResponse:
"""LLM响应结果"""
content: str
model: str
usage: Dict[str, int] = field(default_factory=dict)
finish_reason: str = "stop"
latency_ms: int = 0
error: Optional[str] = None
@property
def is_success(self) -> bool:
return self.error is None
@dataclass
class LLMConfig:
"""LLM配置"""
provider: LLMProvider
model: str
api_key: str
base_url: str
max_tokens: int = 4096
temperature: float = 0.7
top_p: float = 0.9
timeout: int = 300 # 默认超时300秒摘要/Skills生成可能需要较长时间
class BaseLLMClient(ABC):
"""LLM客户端基类"""
def __init__(self, config: LLMConfig):
self.config = config
self._client = None
@abstractmethod
def _init_client(self):
"""初始化客户端"""
pass
@abstractmethod
def chat(
self,
messages: List[Dict[str, str]],
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
**kwargs
) -> LLMResponse:
"""
对话补全
Args:
messages: 对话消息列表 [{"role": "user/assistant/system", "content": "..."}]
max_tokens: 最大输出token数
temperature: 温度参数
**kwargs: 其他参数
Returns:
LLMResponse: 响应结果
"""
pass
def complete(
self,
prompt: str,
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
**kwargs
) -> LLMResponse:
"""
单轮补全(便捷方法)
Args:
prompt: 用户输入
system_prompt: 系统提示词
max_tokens: 最大输出token数
temperature: 温度参数
Returns:
LLMResponse: 响应结果
"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
return self.chat(messages, max_tokens, temperature, **kwargs)
@abstractmethod
def get_available_models(self) -> List[str]:
"""获取可用模型列表"""
pass
def estimate_tokens(self, text: str) -> int:
"""估算文本token数粗略估计"""
# 中文字符约1.5 token英文约0.25 token
chinese_chars = sum(1 for c in text if '' <= c <= '鿿')
other_chars = len(text) - chinese_chars
return int(chinese_chars * 1.5 + other_chars * 0.25)