This commit is contained in:
2026-05-14 15:07:34 +08:00
parent c2a398930d
commit 10d04c4083
179 changed files with 24073 additions and 1243 deletions

View File

@@ -0,0 +1,7 @@
# src/services/agent/__init__.py
"""Agent服务模块"""
from .qa_agent import QAAgent, ask_compliance_question
from .session_manager import SessionManager, ChatSession
__all__ = ["QAAgent", "ask_compliance_question", "SessionManager", "ChatSession"]

View File

@@ -0,0 +1,412 @@
# src/services/agent/qa_agent.py
"""RAG问答Agent - 合规智能问答核心实现"""
import time
from typing import List, Dict, Optional, Any, Generator
from dataclasses import dataclass, field
from loguru import logger
from app.services.llm import get_llm_client, BaseLLMClient, LLMResponse
from app.services.llm.llm_factory import LLMFactory
from app.services.rag.retriever import Retriever, RetrievedDocument
from app.services.rag.context_builder import ContextBuilder, RAGContext
from app.services.rag.prompt_templates import get_prompt_template, PromptTemplate
from app.config.settings import settings
@dataclass
class AgentResponse:
"""Agent响应结果"""
answer: str
sources: List[Dict] = field(default_factory=list)
model: str = ""
latency_ms: int = 0
retrieved_count: int = 0
context_tokens: int = 0
truncated: bool = False
error: Optional[str] = None
@property
def is_success(self) -> bool:
return self.error is None
@dataclass
class AgentConfig:
"""Agent配置"""
llm_provider: str = "deepseek"
llm_model: str = "deepseek-v4-flash"
top_k: int = 5
min_score: float = 0.3
max_context_tokens: int = 2000
temperature: float = 0.7
prompt_template: str = "compliance_qa"
include_metadata: bool = True
class QAAgent:
"""
合规问答Agent
核心流程:
1. 接收用户问题
2. Milvus混合检索相关法规条款
3. 构建RAG上下文
4. 调用LLM生成回答
5. 返回答案和引用来源
使用示例:
agent = QAAgent()
response = agent.ask("机动车安全技术检验有哪些要求?")
print(response.answer)
for source in response.sources:
print(f"引用: {source['doc_name']} - {source['clause_number']}")
"""
def __init__(self, config: Optional[AgentConfig] = None):
"""
初始化问答Agent
Args:
config: Agent配置可选使用默认配置
"""
self.config = config or AgentConfig(
llm_provider=settings.llm_provider,
llm_model=settings.llm_model,
top_k=settings.rag_top_k,
max_context_tokens=settings.rag_max_context_tokens
)
# 初始化组件(延迟加载)
self.llm: Optional[BaseLLMClient] = None
self.retriever: Optional[Retriever] = None
self.context_builder: Optional[ContextBuilder] = None
logger.info(f"问答Agent初始化: provider={self.config.llm_provider}, model={self.config.llm_model}")
def _init_llm(self):
"""延迟初始化LLM客户端优先使用全局缓存"""
if self.llm is None:
# 尝试先获取全局缓存的客户端
cached = LLMFactory.get_global_client(self.config.llm_provider, self.config.llm_model)
if cached:
self.llm = cached
logger.debug(f"使用全局缓存的LLM客户端: {self.config.llm_provider} - {self.config.llm_model}")
else:
logger.info("创建新的LLM客户端...")
self.llm = get_llm_client(
provider=self.config.llm_provider,
model=self.config.llm_model,
temperature=self.config.temperature
)
def _init_retriever(self):
"""延迟初始化检索器"""
if self.retriever is None:
logger.info("初始化检索器...")
self.retriever = Retriever(
top_k=self.config.top_k,
min_score=self.config.min_score
)
def _init_context_builder(self):
"""延迟初始化上下文构建器"""
if self.context_builder is None:
logger.info("初始化上下文构建器...")
self.context_builder = ContextBuilder(
max_context_tokens=self.config.max_context_tokens,
include_metadata=self.config.include_metadata
)
def ask(
self,
query: str,
filters: Optional[str] = None,
prompt_template: Optional[str] = None
) -> AgentResponse:
"""
回答用户问题
Args:
query: 用户问题
filters: 检索过滤条件(如 "regulation_type=='车辆安全'"
prompt_template: Prompt模板名称可选覆盖默认配置
Returns:
AgentResponse: 包含答案和引用来源的响应对象
"""
start_time = time.time()
logger.info(f"收到问题: {query}")
try:
# Step 1: 检索相关法规
self._init_retriever()
documents = self.retriever.retrieve(query, filters)
retrieved_count = len(documents)
if retrieved_count == 0:
return AgentResponse(
answer="抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问,或提供更具体的法规名称。",
retrieved_count=0,
error="no_retrieved_documents"
)
# Step 2: 构建RAG上下文
self._init_context_builder()
template_name = prompt_template or self.config.prompt_template
template = get_prompt_template(template_name)
context = self.context_builder.build(
query=query,
documents=documents,
system_prompt=template.system_prompt
)
# Step 3: 构建LLM输入消息
messages = self._build_messages(template, context)
# Step 4: 调用LLM生成回答
self._init_llm()
llm_response = self.llm.chat(
messages=messages,
temperature=self.config.temperature
)
if not llm_response.is_success:
return AgentResponse(
answer="",
retrieved_count=retrieved_count,
error=llm_response.error
)
latency_ms = int((time.time() - start_time) * 1000)
# Step 5: 返回结果
logger.success(f"问答完成: {latency_ms}ms, {retrieved_count}条引用")
return AgentResponse(
answer=llm_response.content,
sources=context.sources,
model=llm_response.model,
latency_ms=latency_ms,
retrieved_count=retrieved_count,
context_tokens=context.total_tokens,
truncated=context.truncated
)
except Exception as e:
logger.error(f"问答失败: {e}")
return AgentResponse(
answer="",
error=str(e)
)
def ask_with_context(
self,
query: str,
documents: List[RetrievedDocument],
prompt_template: Optional[str] = None
) -> AgentResponse:
"""
使用提供的文档回答问题(不执行检索)
Args:
query: 用户问题
documents: 已检索的文档列表
prompt_template: Prompt模板名称
Returns:
AgentResponse: 响应结果
"""
start_time = time.time()
try:
self._init_context_builder()
self._init_llm()
template_name = prompt_template or self.config.prompt_template
template = get_prompt_template(template_name)
context = self.context_builder.build(
query=query,
documents=documents,
system_prompt=template.system_prompt
)
messages = self._build_messages(template, context)
llm_response = self.llm.chat(messages)
latency_ms = int((time.time() - start_time) * 1000)
return AgentResponse(
answer=llm_response.content,
sources=context.sources,
model=llm_response.model,
latency_ms=latency_ms,
retrieved_count=len(documents),
context_tokens=context.total_tokens,
truncated=context.truncated
)
except Exception as e:
logger.error(f"问答失败: {e}")
return AgentResponse(answer="", error=str(e))
def _build_messages(
self,
template: PromptTemplate,
context: RAGContext
) -> List[Dict[str, str]]:
"""构建LLM输入消息"""
user_content = template.user_template.format(
context=context.context_text,
query=context.user_query
)
return [
{"role": "system", "content": template.system_prompt},
{"role": "user", "content": user_content}
]
def ask_stream(
self,
query: str,
filters: Optional[str] = None,
prompt_template: Optional[str] = None
) -> Generator[Dict[str, Any], None, None]:
"""
流式回答用户问题SSE模式
返回事件类型:
- {"event": "status", "data": "正在检索..."} - 状态更新
- {"event": "sources", "data": [...]} - 引用来源
- {"event": "content", "data": "文本片段"} - 回答内容
- {"event": "done", "data": {"latency_ms": ..., "model": ...}} - 完成
Args:
query: 用户问题
filters: 检索过滤条件
prompt_template: Prompt模板名称
Yields:
Dict: SSE事件数据
"""
start_time = time.time()
logger.info(f"收到流式问题: {query}")
try:
# Step 1: 检索相关法规
yield {"event": "status", "data": "正在检索相关法规..."}
self._init_retriever()
documents = self.retriever.retrieve(query, filters)
retrieved_count = len(documents)
if retrieved_count == 0:
yield {"event": "status", "data": "未找到相关法规"}
yield {"event": "content", "data": "抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问。"}
yield {"event": "done", "data": {"latency_ms": 0, "retrieved_count": 0}}
return
# Step 2: 发送检索结果
yield {"event": "status", "data": f"找到{retrieved_count}条相关法规,正在生成回答..."}
sources = [
{
"doc_name": doc.doc_name,
"doc_id": doc.doc_id,
"clause_number": doc.clause_number,
"score": doc.score
}
for doc in documents[:5] # 只返回前5条引用
]
yield {"event": "sources", "data": sources}
# Step 3: 构建RAG上下文
self._init_context_builder()
template_name = prompt_template or self.config.prompt_template
template = get_prompt_template(template_name)
context = self.context_builder.build(
query=query,
documents=documents,
system_prompt=template.system_prompt
)
# Step 4: 构建LLM输入消息
messages = self._build_messages(template, context)
# Step 5: 流式调用LLM生成回答
self._init_llm()
full_answer = ""
# 检查LLM是否支持流式输出
if hasattr(self.llm, 'stream_chat'):
yield {"event": "status", "data": "思考中..."}
for chunk in self.llm.stream_chat(
messages=messages,
temperature=self.config.temperature
):
full_answer += chunk
yield {"event": "content", "data": chunk}
else:
# 如果不支持流式,回退到普通调用
yield {"event": "status", "data": "生成回答中..."}
llm_response = self.llm.chat(
messages=messages,
temperature=self.config.temperature
)
if llm_response.is_success:
full_answer = llm_response.content
yield {"event": "content", "data": full_answer}
# Step 6: 发送完成事件
latency_ms = int((time.time() - start_time) * 1000)
logger.success(f"流式问答完成: {latency_ms}ms, {retrieved_count}条引用")
yield {
"event": "done",
"data": {
"latency_ms": latency_ms,
"model": self.config.llm_model,
"retrieved_count": retrieved_count,
"context_tokens": context.total_tokens
}
}
except Exception as e:
logger.error(f"流式问答失败: {e}")
yield {"event": "error", "data": str(e)}
def close(self):
"""关闭Agent资源不关闭LLM客户端因为它全局缓存"""
if self.retriever:
self.retriever.close()
logger.info("问答Agent已关闭")
def ask_compliance_question(
query: str,
provider: str = "deepseek",
model: str = "deepseek-v4-flash",
top_k: int = 10
) -> AgentResponse:
"""
便捷函数:问答合规问题
Args:
query: 用户问题
provider: LLM提供商
model: LLM模型
top_k: 检索数量
Returns:
AgentResponse: 响应结果
"""
config = AgentConfig(
llm_provider=provider,
llm_model=model,
top_k=top_k
)
agent = QAAgent(config)
response = agent.ask(query)
agent.close()
return response

View File

@@ -0,0 +1,247 @@
# src/services/agent/session_manager.py
"""多轮对话会话管理"""
import time
import uuid
from typing import Dict, List, Optional
from dataclasses import dataclass, field
from loguru import logger
@dataclass
class ChatMessage:
"""对话消息"""
role: str # "user" / "assistant" / "system"
content: str
timestamp: int
sources: List[Dict] = field(default_factory=list)
metadata: Dict = field(default_factory=dict)
@dataclass
class ChatSession:
"""对话会话"""
session_id: str
messages: List[ChatMessage] = field(default_factory=list)
created_at: int = field(default_factory=lambda: int(time.time()))
updated_at: int = field(default_factory=lambda: int(time.time()))
metadata: Dict = field(default_factory=dict)
def add_user_message(self, content: str) -> ChatMessage:
"""添加用户消息"""
message = ChatMessage(
role="user",
content=content,
timestamp=int(time.time())
)
self.messages.append(message)
self.updated_at = int(time.time())
return message
def add_assistant_message(
self,
content: str,
sources: List[Dict] = None
) -> ChatMessage:
"""添加助手消息"""
message = ChatMessage(
role="assistant",
content=content,
timestamp=int(time.time()),
sources=sources or []
)
self.messages.append(message)
self.updated_at = int(time.time())
return message
def get_history(self, max_turns: int = 5) -> List[Dict[str, str]]:
"""获取历史对话用于LLM上下文"""
history = []
# 获取最近N轮对话每轮包含user + assistant
recent_messages = self.messages[-(max_turns * 2):]
for msg in recent_messages:
history.append({
"role": msg.role,
"content": msg.content
})
return history
def clear_history(self):
"""清空对话历史"""
self.messages = []
self.updated_at = int(time.time())
logger.info(f"会话历史已清空: {self.session_id}")
@property
def message_count(self) -> int:
"""消息数量"""
return len(self.messages)
@property
def is_empty(self) -> bool:
"""是否为空会话"""
return len(self.messages) == 0
class SessionManager:
"""
会话管理器
功能:
- 创建/获取/删除会话
- 会话超时清理
- 会话历史记录管理
使用示例:
manager = SessionManager()
# 创建会话
session = manager.create_session()
# 添加消息
session.add_user_message("什么是机动车安全技术检验?")
session.add_assistant_message("根据GB 7258...", sources=[...])
# 获取历史用于LLM多轮对话
history = session.get_history(max_turns=3)
"""
def __init__(
self,
max_sessions: int = 100,
session_timeout_minutes: int = 30
):
"""
初始化会话管理器
Args:
max_sessions: 最大会话数量
session_timeout_minutes: 会话超时时间(分钟)
"""
self.max_sessions = max_sessions
self.session_timeout = session_timeout_minutes * 60
# 会话存储(内存)
self._sessions: Dict[str, ChatSession] = {}
logger.info(f"会话管理器初始化: max_sessions={max_sessions}, timeout={session_timeout_minutes}min")
def create_session(self, metadata: Dict = None) -> ChatSession:
"""
创建新会话
Args:
metadata: 会话元数据(可选)
Returns:
ChatSession: 新创建的会话
"""
# 检查会话数量限制
if len(self._sessions) >= self.max_sessions:
# 清理过期会话
self._cleanup_expired_sessions()
# 如果仍然超出限制,删除最老的会话
if len(self._sessions) >= self.max_sessions:
oldest_id = min(
self._sessions.keys(),
key=lambda x: self._sessions[x].created_at
)
self.delete_session(oldest_id)
logger.warning(f"删除最老会话以腾出空间: {oldest_id}")
session_id = str(uuid.uuid4())[:8]
session = ChatSession(
session_id=session_id,
metadata=metadata or {}
)
self._sessions[session_id] = session
logger.info(f"创建新会话: {session_id}")
return session
def get_session(self, session_id: str) -> Optional[ChatSession]:
"""
获取会话
Args:
session_id: 会话ID
Returns:
ChatSession: 会话对象如不存在返回None
"""
session = self._sessions.get(session_id)
if session:
# 检查是否过期
if self._is_session_expired(session):
self.delete_session(session_id)
logger.info(f"会话已过期,已删除: {session_id}")
return None
return session
def delete_session(self, session_id: str) -> bool:
"""
删除会话
Args:
session_id: 会话ID
Returns:
bool: 是否成功删除
"""
if session_id in self._sessions:
del self._sessions[session_id]
logger.info(f"删除会话: {session_id}")
return True
return False
def list_sessions(self) -> List[Dict]:
"""
列出所有会话
Returns:
List[Dict]: 会话列表摘要
"""
return [
{
"session_id": s.session_id,
"message_count": s.message_count,
"created_at": s.created_at,
"updated_at": s.updated_at
}
for s in self._sessions.values()
]
def _is_session_expired(self, session: ChatSession) -> bool:
"""检查会话是否过期"""
current_time = int(time.time())
return (current_time - session.updated_at) > self.session_timeout
def _cleanup_expired_sessions(self) -> int:
"""清理过期会话"""
expired_ids = [
sid for sid, session in self._sessions.items()
if self._is_session_expired(session)
]
for sid in expired_ids:
self.delete_session(sid)
if expired_ids:
logger.info(f"清理过期会话: {len(expired_ids)}")
return len(expired_ids)
def get_session_count(self) -> int:
"""获取当前会话数量"""
return len(self._sessions)
def clear_all_sessions(self):
"""清空所有会话"""
self._sessions.clear()
logger.info("所有会话已清空")