Fix SSE route dependency and align architecture docs

This commit is contained in:
ash66
2026-05-18 16:32:42 +08:00
parent 86b9ac806a
commit 3f69cad404
149 changed files with 4786 additions and 5957 deletions

View File

@@ -1,6 +1,8 @@
"""Agent服务模块"""
"""Initialize the app.services.agent package."""
from .qa_agent import QAAgent, ask_compliance_question
from .session_manager import SessionManager, ChatSession
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = ["QAAgent", "ask_compliance_question", "SessionManager", "ChatSession"]

View File

@@ -1,21 +1,19 @@
"""RAG问答Agent - 合规智能问答核心实现"""
"""Provide service-layer logic for qa agent."""
from __future__ import annotations
import time
from typing import List, Dict, Optional, Any, Generator
from dataclasses import dataclass, field
from loguru import logger
from typing import Dict, Generator, List, Optional
from app.services.llm import get_llm_client, BaseLLMClient, LLMResponse
from app.services.llm.llm_factory import LLMFactory
from app.services.rag.retriever import Retriever, RetrievedDocument
from app.services.rag.context_builder import ContextBuilder, RAGContext
from app.services.rag.prompt_templates import get_prompt_template, PromptTemplate
from app.config.settings import settings
from app.shared.bootstrap import get_agent_conversation_service
# Keep service responsibilities explicit so downstream behavior stays predictable.
@dataclass
class AgentResponse:
"""Agent响应结果"""
"""Represent the Agent Response type."""
answer: str
sources: List[Dict] = field(default_factory=list)
model: str = ""
@@ -27,385 +25,73 @@ class AgentResponse:
@property
def is_success(self) -> bool:
"""Return whether success for the Agent Response instance."""
return self.error is None
@dataclass
class AgentConfig:
"""Agent配置"""
llm_provider: str = "deepseek"
llm_model: str = "deepseek-v4-flash"
top_k: int = 5
min_score: float = 0.3
max_context_tokens: int = 2000
temperature: float = 0.7
"""Define configuration for agent config."""
llm_provider: str = settings.llm_provider
llm_model: str = settings.llm_model
top_k: int = settings.rag_top_k
min_score: float = 0.0
max_context_tokens: int = settings.rag_max_context_tokens
temperature: float = settings.llm_temperature
prompt_template: str = "compliance_qa"
include_metadata: bool = True
class QAAgent:
"""
合规问答Agent
核心流程:
1. 接收用户问题
2. Milvus混合检索相关法规条款
3. 构建RAG上下文
4. 调用LLM生成回答
5. 返回答案和引用来源
使用示例:
agent = QAAgent()
response = agent.ask("机动车安全技术检验有哪些要求?")
print(response.answer)
for source in response.sources:
print(f"引用: {source['doc_name']} - {source['clause_number']}")
"""
"""Represent the Q A Agent type."""
def __init__(self, config: Optional[AgentConfig] = None):
"""
初始化问答Agent
Args:
config: Agent配置可选使用默认配置
"""
self.config = config or AgentConfig(
llm_provider=settings.llm_provider,
llm_model=settings.llm_model,
top_k=settings.rag_top_k,
max_context_tokens=settings.rag_max_context_tokens
)
# 初始化组件(延迟加载)
self.llm: Optional[BaseLLMClient] = None
self.retriever: Optional[Retriever] = None
self.context_builder: Optional[ContextBuilder] = None
logger.info(f"问答Agent初始化: provider={self.config.llm_provider}, model={self.config.llm_model}")
def _init_llm(self):
"""延迟初始化LLM客户端优先使用全局缓存"""
if self.llm is None:
# 尝试先获取全局缓存的客户端
cached = LLMFactory.get_global_client(self.config.llm_provider, self.config.llm_model)
if cached:
self.llm = cached
logger.debug(f"使用全局缓存的LLM客户端: {self.config.llm_provider} - {self.config.llm_model}")
else:
logger.info("创建新的LLM客户端...")
self.llm = get_llm_client(
provider=self.config.llm_provider,
model=self.config.llm_model,
temperature=self.config.temperature
)
def _init_retriever(self):
"""延迟初始化检索器"""
if self.retriever is None:
logger.info("初始化检索器...")
self.retriever = Retriever(
top_k=self.config.top_k,
min_score=self.config.min_score
)
def _init_context_builder(self):
"""延迟初始化上下文构建器"""
if self.context_builder is None:
logger.info("初始化上下文构建器...")
self.context_builder = ContextBuilder(
max_context_tokens=self.config.max_context_tokens,
include_metadata=self.config.include_metadata
)
"""Initialize the Q A Agent instance."""
self.config = config or AgentConfig()
def ask(
self,
query: str,
filters: Optional[str] = None,
prompt_template: Optional[str] = None
prompt_template: Optional[str] = None,
) -> AgentResponse:
"""
回答用户问题
Args:
query: 用户问题
filters: 检索过滤条件(如 "regulation_type=='车辆安全'"
prompt_template: Prompt模板名称可选覆盖默认配置
Returns:
AgentResponse: 包含答案和引用来源的响应对象
"""
start_time = time.time()
logger.info(f"收到问题: {query}")
try:
# Step 1: 检索相关法规
self._init_retriever()
documents = self.retriever.retrieve(query, filters)
retrieved_count = len(documents)
if retrieved_count == 0:
return AgentResponse(
answer="抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问,或提供更具体的法规名称。",
retrieved_count=0,
error="no_retrieved_documents"
)
# Step 2: 构建RAG上下文
self._init_context_builder()
template_name = prompt_template or self.config.prompt_template
template = get_prompt_template(template_name)
context = self.context_builder.build(
query=query,
documents=documents,
system_prompt=template.system_prompt
)
# Step 3: 构建LLM输入消息
messages = self._build_messages(template, context)
# Step 4: 调用LLM生成回答
self._init_llm()
llm_response = self.llm.chat(
messages=messages,
temperature=self.config.temperature
)
if not llm_response.is_success:
return AgentResponse(
answer="",
retrieved_count=retrieved_count,
error=llm_response.error
)
latency_ms = int((time.time() - start_time) * 1000)
# Step 5: 返回结果
logger.success(f"问答完成: {latency_ms}ms, {retrieved_count}条引用")
return AgentResponse(
answer=llm_response.content,
sources=context.sources,
model=llm_response.model,
latency_ms=latency_ms,
retrieved_count=retrieved_count,
context_tokens=context.total_tokens,
truncated=context.truncated
)
except Exception as e:
logger.error(f"问答失败: {e}")
return AgentResponse(
answer="",
error=str(e)
)
def ask_with_context(
self,
query: str,
documents: List[RetrievedDocument],
prompt_template: Optional[str] = None
) -> AgentResponse:
"""
使用提供的文档回答问题(不执行检索)
Args:
query: 用户问题
documents: 已检索的文档列表
prompt_template: Prompt模板名称
Returns:
AgentResponse: 响应结果
"""
start_time = time.time()
try:
self._init_context_builder()
self._init_llm()
template_name = prompt_template or self.config.prompt_template
template = get_prompt_template(template_name)
context = self.context_builder.build(
query=query,
documents=documents,
system_prompt=template.system_prompt
)
messages = self._build_messages(template, context)
llm_response = self.llm.chat(messages)
latency_ms = int((time.time() - start_time) * 1000)
return AgentResponse(
answer=llm_response.content,
sources=context.sources,
model=llm_response.model,
latency_ms=latency_ms,
retrieved_count=len(documents),
context_tokens=context.total_tokens,
truncated=context.truncated
)
except Exception as e:
logger.error(f"问答失败: {e}")
return AgentResponse(answer="", error=str(e))
def _build_messages(
self,
template: PromptTemplate,
context: RAGContext
) -> List[Dict[str, str]]:
"""构建LLM输入消息"""
user_content = template.user_template.format(
context=context.context_text,
query=context.user_query
"""Handle ask for the Q A Agent instance."""
_, result = get_agent_conversation_service().ask(
query=query,
filters=filters,
provider=self.config.llm_provider,
model=self.config.llm_model,
top_k=self.config.top_k,
prompt_template=prompt_template or self.config.prompt_template,
)
return AgentResponse(
answer=result.answer,
sources=[source.__dict__ for source in result.sources],
model=result.model,
latency_ms=result.latency_ms,
retrieved_count=result.retrieved_count,
context_tokens=result.context_tokens,
truncated=result.truncated,
error=result.error,
)
return [
{"role": "system", "content": template.system_prompt},
{"role": "user", "content": user_content}
]
def ask_stream(
self,
query: str,
filters: Optional[str] = None,
prompt_template: Optional[str] = None
) -> Generator[Dict[str, Any], None, None]:
"""
流式回答用户问题SSE模式
返回事件类型:
- {"event": "status", "data": "正在检索..."} - 状态更新
- {"event": "sources", "data": [...]} - 引用来源
- {"event": "content", "data": "文本片段"} - 回答内容
- {"event": "done", "data": {"latency_ms": ..., "model": ...}} - 完成
Args:
query: 用户问题
filters: 检索过滤条件
prompt_template: Prompt模板名称
Yields:
Dict: SSE事件数据
"""
start_time = time.time()
logger.info(f"收到流式问题: {query}")
try:
# Step 1: 检索相关法规
yield {"event": "status", "data": "正在检索相关法规..."}
self._init_retriever()
documents = self.retriever.retrieve(query, filters)
retrieved_count = len(documents)
if retrieved_count == 0:
yield {"event": "status", "data": "未找到相关法规"}
yield {"event": "content", "data": "抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问。"}
yield {"event": "done", "data": {"latency_ms": 0, "retrieved_count": 0}}
return
# Step 2: 发送检索结果
yield {"event": "status", "data": f"找到{retrieved_count}条相关法规,正在生成回答..."}
sources = [
{
"doc_name": doc.doc_name,
"doc_id": doc.doc_id,
"clause_number": doc.clause_number,
"score": doc.score
}
for doc in documents[:5] # 只返回前5条引用
]
yield {"event": "sources", "data": sources}
# Step 3: 构建RAG上下文
self._init_context_builder()
template_name = prompt_template or self.config.prompt_template
template = get_prompt_template(template_name)
context = self.context_builder.build(
query=query,
documents=documents,
system_prompt=template.system_prompt
)
# Step 4: 构建LLM输入消息
messages = self._build_messages(template, context)
# Step 5: 流式调用LLM生成回答
self._init_llm()
full_answer = ""
# 检查LLM是否支持流式输出
if hasattr(self.llm, 'stream_chat'):
yield {"event": "status", "data": "思考中..."}
for chunk in self.llm.stream_chat(
messages=messages,
temperature=self.config.temperature
):
full_answer += chunk
yield {"event": "content", "data": chunk}
else:
# 如果不支持流式,回退到普通调用
yield {"event": "status", "data": "生成回答中..."}
llm_response = self.llm.chat(
messages=messages,
temperature=self.config.temperature
)
if llm_response.is_success:
full_answer = llm_response.content
yield {"event": "content", "data": full_answer}
# Step 6: 发送完成事件
latency_ms = int((time.time() - start_time) * 1000)
logger.success(f"流式问答完成: {latency_ms}ms, {retrieved_count}条引用")
yield {
"event": "done",
"data": {
"latency_ms": latency_ms,
"model": self.config.llm_model,
"retrieved_count": retrieved_count,
"context_tokens": context.total_tokens
}
}
except Exception as e:
logger.error(f"流式问答失败: {e}")
yield {"event": "error", "data": str(e)}
def ask_stream(self, query: str, filters: Optional[str] = None) -> Generator[dict, None, None]:
"""Handle ask stream for the Q A Agent instance."""
_, stream = get_agent_conversation_service().stream_chat(
query=query,
filters=filters,
provider=self.config.llm_provider,
model=self.config.llm_model,
top_k=self.config.top_k,
prompt_template=self.config.prompt_template,
)
for event in stream:
yield event
def close(self):
"""关闭Agent资源不关闭LLM客户端因为它全局缓存"""
if self.retriever:
self.retriever.close()
logger.info("问答Agent已关闭")
"""Release the resources held by this component."""
return None
def ask_compliance_question(
query: str,
provider: str = "deepseek",
model: str = "deepseek-v4-flash",
top_k: int = 10
) -> AgentResponse:
"""
便捷函数:问答合规问题
Args:
query: 用户问题
provider: LLM提供商
model: LLM模型
top_k: 检索数量
Returns:
AgentResponse: 响应结果
"""
config = AgentConfig(
llm_provider=provider,
llm_model=model,
top_k=top_k
)
agent = QAAgent(config)
response = agent.ask(query)
agent.close()
return response
def ask_compliance_question(query: str, top_k: int = 5) -> AgentResponse:
"""Handle ask compliance question."""
return QAAgent(AgentConfig(top_k=top_k)).ask(query)

View File

@@ -1,4 +1,4 @@
"""多轮对话会话管理"""
"""Provide service-layer logic for session manager."""
import time
import uuid
@@ -9,7 +9,7 @@ from loguru import logger
@dataclass
class ChatMessage:
"""对话消息"""
"""Represent the Chat Message type."""
role: str # "user" / "assistant" / "system"
content: str
timestamp: int
@@ -19,7 +19,7 @@ class ChatMessage:
@dataclass
class ChatSession:
"""对话会话"""
"""Represent the Chat Session type."""
session_id: str
messages: List[ChatMessage] = field(default_factory=list)
created_at: int = field(default_factory=lambda: int(time.time()))
@@ -27,7 +27,7 @@ class ChatSession:
metadata: Dict = field(default_factory=dict)
def add_user_message(self, content: str) -> ChatMessage:
"""添加用户消息"""
"""Handle add user message for the Chat Session instance."""
message = ChatMessage(
role="user",
content=content,
@@ -42,7 +42,7 @@ class ChatSession:
content: str,
sources: List[Dict] = None
) -> ChatMessage:
"""添加助手消息"""
"""Handle add assistant message for the Chat Session instance."""
message = ChatMessage(
role="assistant",
content=content,
@@ -54,9 +54,9 @@ class ChatSession:
return message
def get_history(self, max_turns: int = 5) -> List[Dict[str, str]]:
"""获取历史对话用于LLM上下文"""
"""Return history for the Chat Session instance."""
history = []
# 获取最近N轮对话每轮包含user + assistant
# Keep service responsibilities explicit so downstream behavior stays predictable.
recent_messages = self.messages[-(max_turns * 2):]
for msg in recent_messages:
@@ -68,81 +68,47 @@ class ChatSession:
return history
def clear_history(self):
"""清空对话历史"""
"""Handle clear history for the Chat Session instance."""
self.messages = []
self.updated_at = int(time.time())
logger.info(f"会话历史已清空: {self.session_id}")
@property
def message_count(self) -> int:
"""消息数量"""
"""Handle message count for the Chat Session instance."""
return len(self.messages)
@property
def is_empty(self) -> bool:
"""是否为空会话"""
"""Return whether empty for the Chat Session instance."""
return len(self.messages) == 0
class SessionManager:
"""
会话管理器
功能:
- 创建/获取/删除会话
- 会话超时清理
- 会话历史记录管理
使用示例:
manager = SessionManager()
# 创建会话
session = manager.create_session()
# 添加消息
session.add_user_message("什么是机动车安全技术检验?")
session.add_assistant_message("根据GB 7258...", sources=[...])
# 获取历史用于LLM多轮对话
history = session.get_history(max_turns=3)
"""
"""Represent the Session Manager type."""
def __init__(
self,
max_sessions: int = 100,
session_timeout_minutes: int = 30
):
"""
初始化会话管理器
Args:
max_sessions: 最大会话数量
session_timeout_minutes: 会话超时时间(分钟)
"""
"""Initialize the Session Manager instance."""
self.max_sessions = max_sessions
self.session_timeout = session_timeout_minutes * 60
# 会话存储(内存)
# Keep service responsibilities explicit so downstream behavior stays predictable.
self._sessions: Dict[str, ChatSession] = {}
logger.info(f"会话管理器初始化: max_sessions={max_sessions}, timeout={session_timeout_minutes}min")
def create_session(self, metadata: Dict = None) -> ChatSession:
"""
创建新会话
Args:
metadata: 会话元数据(可选)
Returns:
ChatSession: 新创建的会话
"""
# 检查会话数量限制
"""Create session for the Session Manager instance."""
# Keep service responsibilities explicit so downstream behavior stays predictable.
if len(self._sessions) >= self.max_sessions:
# 清理过期会话
# Keep service responsibilities explicit so downstream behavior stays predictable.
self._cleanup_expired_sessions()
# 如果仍然超出限制,删除最老的会话
# Keep service responsibilities explicit so downstream behavior stays predictable.
if len(self._sessions) >= self.max_sessions:
oldest_id = min(
self._sessions.keys(),
@@ -163,19 +129,11 @@ class SessionManager:
return session
def get_session(self, session_id: str) -> Optional[ChatSession]:
"""
获取会话
Args:
session_id: 会话ID
Returns:
ChatSession: 会话对象如不存在返回None
"""
"""Return session for the Session Manager instance."""
session = self._sessions.get(session_id)
if session:
# 检查是否过期
# Keep service responsibilities explicit so downstream behavior stays predictable.
if self._is_session_expired(session):
self.delete_session(session_id)
logger.info(f"会话已过期,已删除: {session_id}")
@@ -184,15 +142,7 @@ class SessionManager:
return session
def delete_session(self, session_id: str) -> bool:
"""
删除会话
Args:
session_id: 会话ID
Returns:
bool: 是否成功删除
"""
"""Delete session for the Session Manager instance."""
if session_id in self._sessions:
del self._sessions[session_id]
logger.info(f"删除会话: {session_id}")
@@ -200,12 +150,7 @@ class SessionManager:
return False
def list_sessions(self) -> List[Dict]:
"""
列出所有会话
Returns:
List[Dict]: 会话列表摘要
"""
"""List sessions for the Session Manager instance."""
return [
{
"session_id": s.session_id,
@@ -217,12 +162,12 @@ class SessionManager:
]
def _is_session_expired(self, session: ChatSession) -> bool:
"""检查会话是否过期"""
"""Handle is session expired for this module for the Session Manager instance."""
current_time = int(time.time())
return (current_time - session.updated_at) > self.session_timeout
def _cleanup_expired_sessions(self) -> int:
"""清理过期会话"""
"""Handle cleanup expired sessions for this module for the Session Manager instance."""
expired_ids = [
sid for sid, session in self._sessions.items()
if self._is_session_expired(session)
@@ -237,10 +182,10 @@ class SessionManager:
return len(expired_ids)
def get_session_count(self) -> int:
"""获取当前会话数量"""
"""Return session count for the Session Manager instance."""
return len(self._sessions)
def clear_all_sessions(self):
"""清空所有会话"""
"""Handle clear all sessions for the Session Manager instance."""
self._sessions.clear()
logger.info("所有会话已清空")