Files
AIRegulation-DocAnalysis/backend/app/services/llm/qwen_client.py
2026-05-14 15:07:34 +08:00

393 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# src/services/llm/qwen_client.py
"""Qwen LLM客户端 - 支持OpenAI兼容API格式"""
import time
import json
from typing import List, Dict, Optional, Generator, AsyncGenerator
from loguru import logger
import httpx
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
class QwenClient(BaseLLMClient):
"""
Qwen API客户端OpenAI兼容格式
支持通过new-api等代理服务调用
- qwen-turbo
- qwen-plus
- qwen-max
- qwen3.5-flash (推荐:快速响应)
- qwen3.5-plus
- qwen-long
- qwen2.5系列
"""
SUPPORTED_MODELS = [
"qwen-turbo",
"qwen-plus",
"qwen-max",
"qwen-max-longcontext",
"qwen-long",
"qwen3.5-flash",
"qwen3.5-plus",
"qwen3-plus",
"qwen2.5-72b-instruct",
"qwen2.5-32b-instruct",
"qwen2.5-14b-instruct",
"qwen2.5-7b-instruct"
]
def __init__(self, config: LLMConfig):
if config.provider not in [LLMProvider.QWEN, LLMProvider.QWEN_VL]:
raise ValueError(f"配置provider应为Qwen实际为{config.provider}")
super().__init__(config)
self._init_client()
def _init_client(self):
"""初始化HTTP客户端"""
# OpenAI兼容API格式
self._client = httpx.Client(
base_url=self.config.base_url,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json"
},
timeout=self.config.timeout
)
logger.info(f"Qwen客户端初始化完成: {self.config.base_url} - {self.config.model}")
def chat(
self,
messages: List[Dict[str, str]],
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
**kwargs
) -> LLMResponse:
"""对话补全OpenAI兼容格式"""
start_time = time.time()
try:
# OpenAI兼容格式的请求体
payload = {
"model": self.config.model,
"messages": messages,
"max_tokens": max_tokens or self.config.max_tokens,
"temperature": temperature or self.config.temperature,
"top_p": kwargs.get("top_p", self.config.top_p),
"stream": False
}
# OpenAI兼容接口路径
response = self._client.post("/chat/completions", json=payload)
response.raise_for_status()
data = response.json()
latency_ms = int((time.time() - start_time) * 1000)
# OpenAI兼容格式的响应解析
choices = data.get("choices", [{}])
message = choices[0].get("message", {})
return LLMResponse(
content=message.get("content", ""),
model=data.get("model", self.config.model),
usage=data.get("usage", {}),
finish_reason=choices[0].get("finish_reason", "stop"),
latency_ms=latency_ms
)
except httpx.HTTPStatusError as e:
logger.error(f"Qwen API错误: {e.response.status_code} - {e.response.text}")
return LLMResponse(
content="",
model=self.config.model,
error=f"API错误: {e.response.status_code} - {e.response.text[:200]}"
)
except Exception as e:
logger.error(f"Qwen调用失败: {e}")
return LLMResponse(
content="",
model=self.config.model,
error=str(e)
)
def stream_chat(
self,
messages: List[Dict[str, str]],
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
**kwargs
) -> Generator[str, None, None]:
"""
流式对话补全SSE格式
Yields:
str: 每次返回一个文本片段
使用示例:
for chunk in client.stream_chat(messages):
print(chunk, end="", flush=True)
"""
try:
# OpenAI兼容格式的请求体启用流式输出
payload = {
"model": self.config.model,
"messages": messages,
"max_tokens": max_tokens or self.config.max_tokens,
"temperature": temperature or self.config.temperature,
"top_p": kwargs.get("top_p", self.config.top_p),
"stream": True # 启用流式输出
}
# 使用stream模式发送请求
with self._client.stream("POST", "/chat/completions", json=payload) as response:
for line in response.iter_lines():
if line:
line = line.strip()
# SSE格式: data: {...}
if line.startswith("data: "):
data_str = line[6:] # 移除 "data: " 前缀
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if not choices:
continue # 跳过空的choices
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue
except httpx.HTTPStatusError as e:
logger.error(f"Qwen流式API错误: {e.response.status_code}")
yield f"[ERROR: API返回错误 {e.response.status_code}]"
except Exception as e:
logger.error(f"Qwen流式调用失败: {e}")
yield f"[ERROR: {str(e)}]"
async def async_stream_chat(
self,
messages: List[Dict[str, str]],
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
**kwargs
) -> AsyncGenerator[str, None]:
"""
异步流式对话补全用于FastAPI SSE响应
Yields:
str: 每次返回一个文本片段
"""
import asyncio
# 使用同步流式方法,包装为异步
for chunk in self.stream_chat(messages, max_tokens, temperature, **kwargs):
yield chunk
# 给async循环一个小延迟让其他任务有机会执行
await asyncio.sleep(0)
def get_available_models(self) -> List[str]:
"""获取可用模型列表"""
return self.SUPPORTED_MODELS
def close(self):
"""关闭客户端"""
if self._client:
self._client.close()
class QwenVLClient(BaseLLMClient):
"""
Qwen VL多模态客户端OpenAI兼容格式
支持模型:
- qwen-vl-plus
- qwen-vl-max
- qwen3-vl-plus
- qwen2-vl-7b-instruct
- qwen2-vl-72b-instruct
"""
SUPPORTED_MODELS = [
"qwen-vl-plus",
"qwen-vl-max",
"qwen3-vl-plus",
"qwen2-vl-7b-instruct",
"qwen2-vl-72b-instruct"
]
def __init__(self, config: LLMConfig):
if config.provider != LLMProvider.QWEN_VL:
raise ValueError(f"配置provider应为QWEN_VL实际为{config.provider}")
super().__init__(config)
self._init_client()
def _init_client(self):
"""初始化HTTP客户端"""
self._client = httpx.Client(
base_url=self.config.base_url,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json"
},
timeout=self.config.timeout
)
logger.info(f"QwenVL客户端初始化完成: {self.config.base_url} - {self.config.model}")
def chat(
self,
messages: List[Dict[str, str]],
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
**kwargs
) -> LLMResponse:
"""多模态对话补全OpenAI兼容格式
支持图片输入,消息格式:
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}},
{"type": "text", "text": "描述这张图片"}
]
}
"""
start_time = time.time()
try:
# OpenAI兼容格式的请求体
payload = {
"model": self.config.model,
"messages": messages,
"max_tokens": max_tokens or self.config.max_tokens,
"temperature": temperature or self.config.temperature,
"top_p": kwargs.get("top_p", self.config.top_p),
"stream": False
}
response = self._client.post("/chat/completions", json=payload)
response.raise_for_status()
data = response.json()
latency_ms = int((time.time() - start_time) * 1000)
choices = data.get("choices", [{}])
message = choices[0].get("message", {})
return LLMResponse(
content=message.get("content", ""),
model=data.get("model", self.config.model),
usage=data.get("usage", {}),
finish_reason=choices[0].get("finish_reason", "stop"),
latency_ms=latency_ms
)
except httpx.HTTPStatusError as e:
logger.error(f"QwenVL API错误: {e.response.status_code} - {e.response.text}")
return LLMResponse(
content="",
model=self.config.model,
error=f"API错误: {e.response.status_code} - {e.response.text[:200]}"
)
except Exception as e:
logger.error(f"QwenVL调用失败: {e}")
return LLMResponse(
content="",
model=self.config.model,
error=str(e)
)
def stream_chat(
self,
messages: List[Dict[str, str]],
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
**kwargs
) -> Generator[str, None, None]:
"""流式多模态对话补全"""
try:
payload = {
"model": self.config.model,
"messages": messages,
"max_tokens": max_tokens or self.config.max_tokens,
"temperature": temperature or self.config.temperature,
"top_p": kwargs.get("top_p", self.config.top_p),
"stream": True
}
with self._client.stream("POST", "/chat/completions", json=payload) as response:
for line in response.iter_lines():
if line:
line = line.strip()
if line.startswith("data: "):
data_str = line[6:]
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if not choices:
continue # 跳过空的choices
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"QwenVL流式调用失败: {e}")
yield f"[ERROR: {str(e)}]"
def get_available_models(self) -> List[str]:
"""获取可用模型列表"""
return self.SUPPORTED_MODELS
def close(self):
"""关闭客户端"""
if self._client:
self._client.close()
def create_qwen_client(
api_key: str,
model: str = "qwen3.5-flash",
base_url: str = "http://6.86.80.4:30080/v1",
**kwargs
) -> QwenClient:
"""便捷函数创建Qwen客户端"""
config = LLMConfig(
provider=LLMProvider.QWEN,
model=model,
api_key=api_key,
base_url=base_url,
**kwargs
)
return QwenClient(config)
def create_qwen_vl_client(
api_key: str,
model: str = "qwen3-vl-plus",
base_url: str = "http://6.86.80.4:30080/v1",
**kwargs
) -> QwenVLClient:
"""便捷函数创建QwenVL客户端"""
config = LLMConfig(
provider=LLMProvider.QWEN_VL,
model=model,
api_key=api_key,
base_url=base_url,
**kwargs
)
return QwenVLClient(config)