Fix SSE route dependency and align architecture docs
This commit is contained in:
@@ -1,21 +1,19 @@
|
||||
"""RAG问答Agent - 合规智能问答核心实现"""
|
||||
"""Provide service-layer logic for qa agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import List, Dict, Optional, Any, Generator
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
from typing import Dict, Generator, List, Optional
|
||||
|
||||
from app.services.llm import get_llm_client, BaseLLMClient, LLMResponse
|
||||
from app.services.llm.llm_factory import LLMFactory
|
||||
from app.services.rag.retriever import Retriever, RetrievedDocument
|
||||
from app.services.rag.context_builder import ContextBuilder, RAGContext
|
||||
from app.services.rag.prompt_templates import get_prompt_template, PromptTemplate
|
||||
from app.config.settings import settings
|
||||
from app.shared.bootstrap import get_agent_conversation_service
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResponse:
|
||||
"""Agent响应结果"""
|
||||
"""Represent the Agent Response type."""
|
||||
answer: str
|
||||
sources: List[Dict] = field(default_factory=list)
|
||||
model: str = ""
|
||||
@@ -27,385 +25,73 @@ class AgentResponse:
|
||||
|
||||
@property
|
||||
def is_success(self) -> bool:
|
||||
"""Return whether success for the Agent Response instance."""
|
||||
return self.error is None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
"""Agent配置"""
|
||||
llm_provider: str = "deepseek"
|
||||
llm_model: str = "deepseek-v4-flash"
|
||||
top_k: int = 5
|
||||
min_score: float = 0.3
|
||||
max_context_tokens: int = 2000
|
||||
temperature: float = 0.7
|
||||
"""Define configuration for agent config."""
|
||||
llm_provider: str = settings.llm_provider
|
||||
llm_model: str = settings.llm_model
|
||||
top_k: int = settings.rag_top_k
|
||||
min_score: float = 0.0
|
||||
max_context_tokens: int = settings.rag_max_context_tokens
|
||||
temperature: float = settings.llm_temperature
|
||||
prompt_template: str = "compliance_qa"
|
||||
include_metadata: bool = True
|
||||
|
||||
|
||||
class QAAgent:
|
||||
"""
|
||||
合规问答Agent
|
||||
|
||||
核心流程:
|
||||
1. 接收用户问题
|
||||
2. Milvus混合检索相关法规条款
|
||||
3. 构建RAG上下文
|
||||
4. 调用LLM生成回答
|
||||
5. 返回答案和引用来源
|
||||
|
||||
使用示例:
|
||||
agent = QAAgent()
|
||||
response = agent.ask("机动车安全技术检验有哪些要求?")
|
||||
print(response.answer)
|
||||
for source in response.sources:
|
||||
print(f"引用: {source['doc_name']} - {source['clause_number']}")
|
||||
"""
|
||||
|
||||
"""Represent the Q A Agent type."""
|
||||
def __init__(self, config: Optional[AgentConfig] = None):
|
||||
"""
|
||||
初始化问答Agent
|
||||
|
||||
Args:
|
||||
config: Agent配置(可选,使用默认配置)
|
||||
"""
|
||||
self.config = config or AgentConfig(
|
||||
llm_provider=settings.llm_provider,
|
||||
llm_model=settings.llm_model,
|
||||
top_k=settings.rag_top_k,
|
||||
max_context_tokens=settings.rag_max_context_tokens
|
||||
)
|
||||
|
||||
# 初始化组件(延迟加载)
|
||||
self.llm: Optional[BaseLLMClient] = None
|
||||
self.retriever: Optional[Retriever] = None
|
||||
self.context_builder: Optional[ContextBuilder] = None
|
||||
|
||||
logger.info(f"问答Agent初始化: provider={self.config.llm_provider}, model={self.config.llm_model}")
|
||||
|
||||
def _init_llm(self):
|
||||
"""延迟初始化LLM客户端(优先使用全局缓存)"""
|
||||
if self.llm is None:
|
||||
# 尝试先获取全局缓存的客户端
|
||||
cached = LLMFactory.get_global_client(self.config.llm_provider, self.config.llm_model)
|
||||
if cached:
|
||||
self.llm = cached
|
||||
logger.debug(f"使用全局缓存的LLM客户端: {self.config.llm_provider} - {self.config.llm_model}")
|
||||
else:
|
||||
logger.info("创建新的LLM客户端...")
|
||||
self.llm = get_llm_client(
|
||||
provider=self.config.llm_provider,
|
||||
model=self.config.llm_model,
|
||||
temperature=self.config.temperature
|
||||
)
|
||||
|
||||
def _init_retriever(self):
|
||||
"""延迟初始化检索器"""
|
||||
if self.retriever is None:
|
||||
logger.info("初始化检索器...")
|
||||
self.retriever = Retriever(
|
||||
top_k=self.config.top_k,
|
||||
min_score=self.config.min_score
|
||||
)
|
||||
|
||||
def _init_context_builder(self):
|
||||
"""延迟初始化上下文构建器"""
|
||||
if self.context_builder is None:
|
||||
logger.info("初始化上下文构建器...")
|
||||
self.context_builder = ContextBuilder(
|
||||
max_context_tokens=self.config.max_context_tokens,
|
||||
include_metadata=self.config.include_metadata
|
||||
)
|
||||
"""Initialize the Q A Agent instance."""
|
||||
self.config = config or AgentConfig()
|
||||
|
||||
def ask(
|
||||
self,
|
||||
query: str,
|
||||
filters: Optional[str] = None,
|
||||
prompt_template: Optional[str] = None
|
||||
prompt_template: Optional[str] = None,
|
||||
) -> AgentResponse:
|
||||
"""
|
||||
回答用户问题
|
||||
|
||||
Args:
|
||||
query: 用户问题
|
||||
filters: 检索过滤条件(如 "regulation_type=='车辆安全'")
|
||||
prompt_template: Prompt模板名称(可选,覆盖默认配置)
|
||||
|
||||
Returns:
|
||||
AgentResponse: 包含答案和引用来源的响应对象
|
||||
"""
|
||||
start_time = time.time()
|
||||
logger.info(f"收到问题: {query}")
|
||||
|
||||
try:
|
||||
# Step 1: 检索相关法规
|
||||
self._init_retriever()
|
||||
documents = self.retriever.retrieve(query, filters)
|
||||
retrieved_count = len(documents)
|
||||
|
||||
if retrieved_count == 0:
|
||||
return AgentResponse(
|
||||
answer="抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问,或提供更具体的法规名称。",
|
||||
retrieved_count=0,
|
||||
error="no_retrieved_documents"
|
||||
)
|
||||
|
||||
# Step 2: 构建RAG上下文
|
||||
self._init_context_builder()
|
||||
template_name = prompt_template or self.config.prompt_template
|
||||
template = get_prompt_template(template_name)
|
||||
|
||||
context = self.context_builder.build(
|
||||
query=query,
|
||||
documents=documents,
|
||||
system_prompt=template.system_prompt
|
||||
)
|
||||
|
||||
# Step 3: 构建LLM输入消息
|
||||
messages = self._build_messages(template, context)
|
||||
|
||||
# Step 4: 调用LLM生成回答
|
||||
self._init_llm()
|
||||
llm_response = self.llm.chat(
|
||||
messages=messages,
|
||||
temperature=self.config.temperature
|
||||
)
|
||||
|
||||
if not llm_response.is_success:
|
||||
return AgentResponse(
|
||||
answer="",
|
||||
retrieved_count=retrieved_count,
|
||||
error=llm_response.error
|
||||
)
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Step 5: 返回结果
|
||||
logger.success(f"问答完成: {latency_ms}ms, {retrieved_count}条引用")
|
||||
|
||||
return AgentResponse(
|
||||
answer=llm_response.content,
|
||||
sources=context.sources,
|
||||
model=llm_response.model,
|
||||
latency_ms=latency_ms,
|
||||
retrieved_count=retrieved_count,
|
||||
context_tokens=context.total_tokens,
|
||||
truncated=context.truncated
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"问答失败: {e}")
|
||||
return AgentResponse(
|
||||
answer="",
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def ask_with_context(
|
||||
self,
|
||||
query: str,
|
||||
documents: List[RetrievedDocument],
|
||||
prompt_template: Optional[str] = None
|
||||
) -> AgentResponse:
|
||||
"""
|
||||
使用提供的文档回答问题(不执行检索)
|
||||
|
||||
Args:
|
||||
query: 用户问题
|
||||
documents: 已检索的文档列表
|
||||
prompt_template: Prompt模板名称
|
||||
|
||||
Returns:
|
||||
AgentResponse: 响应结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
self._init_context_builder()
|
||||
self._init_llm()
|
||||
|
||||
template_name = prompt_template or self.config.prompt_template
|
||||
template = get_prompt_template(template_name)
|
||||
|
||||
context = self.context_builder.build(
|
||||
query=query,
|
||||
documents=documents,
|
||||
system_prompt=template.system_prompt
|
||||
)
|
||||
|
||||
messages = self._build_messages(template, context)
|
||||
|
||||
llm_response = self.llm.chat(messages)
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return AgentResponse(
|
||||
answer=llm_response.content,
|
||||
sources=context.sources,
|
||||
model=llm_response.model,
|
||||
latency_ms=latency_ms,
|
||||
retrieved_count=len(documents),
|
||||
context_tokens=context.total_tokens,
|
||||
truncated=context.truncated
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"问答失败: {e}")
|
||||
return AgentResponse(answer="", error=str(e))
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
template: PromptTemplate,
|
||||
context: RAGContext
|
||||
) -> List[Dict[str, str]]:
|
||||
"""构建LLM输入消息"""
|
||||
user_content = template.user_template.format(
|
||||
context=context.context_text,
|
||||
query=context.user_query
|
||||
"""Handle ask for the Q A Agent instance."""
|
||||
_, result = get_agent_conversation_service().ask(
|
||||
query=query,
|
||||
filters=filters,
|
||||
provider=self.config.llm_provider,
|
||||
model=self.config.llm_model,
|
||||
top_k=self.config.top_k,
|
||||
prompt_template=prompt_template or self.config.prompt_template,
|
||||
)
|
||||
return AgentResponse(
|
||||
answer=result.answer,
|
||||
sources=[source.__dict__ for source in result.sources],
|
||||
model=result.model,
|
||||
latency_ms=result.latency_ms,
|
||||
retrieved_count=result.retrieved_count,
|
||||
context_tokens=result.context_tokens,
|
||||
truncated=result.truncated,
|
||||
error=result.error,
|
||||
)
|
||||
|
||||
return [
|
||||
{"role": "system", "content": template.system_prompt},
|
||||
{"role": "user", "content": user_content}
|
||||
]
|
||||
|
||||
def ask_stream(
|
||||
self,
|
||||
query: str,
|
||||
filters: Optional[str] = None,
|
||||
prompt_template: Optional[str] = None
|
||||
) -> Generator[Dict[str, Any], None, None]:
|
||||
"""
|
||||
流式回答用户问题(SSE模式)
|
||||
|
||||
返回事件类型:
|
||||
- {"event": "status", "data": "正在检索..."} - 状态更新
|
||||
- {"event": "sources", "data": [...]} - 引用来源
|
||||
- {"event": "content", "data": "文本片段"} - 回答内容
|
||||
- {"event": "done", "data": {"latency_ms": ..., "model": ...}} - 完成
|
||||
|
||||
Args:
|
||||
query: 用户问题
|
||||
filters: 检索过滤条件
|
||||
prompt_template: Prompt模板名称
|
||||
|
||||
Yields:
|
||||
Dict: SSE事件数据
|
||||
"""
|
||||
start_time = time.time()
|
||||
logger.info(f"收到流式问题: {query}")
|
||||
|
||||
try:
|
||||
# Step 1: 检索相关法规
|
||||
yield {"event": "status", "data": "正在检索相关法规..."}
|
||||
self._init_retriever()
|
||||
documents = self.retriever.retrieve(query, filters)
|
||||
retrieved_count = len(documents)
|
||||
|
||||
if retrieved_count == 0:
|
||||
yield {"event": "status", "data": "未找到相关法规"}
|
||||
yield {"event": "content", "data": "抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问。"}
|
||||
yield {"event": "done", "data": {"latency_ms": 0, "retrieved_count": 0}}
|
||||
return
|
||||
|
||||
# Step 2: 发送检索结果
|
||||
yield {"event": "status", "data": f"找到{retrieved_count}条相关法规,正在生成回答..."}
|
||||
sources = [
|
||||
{
|
||||
"doc_name": doc.doc_name,
|
||||
"doc_id": doc.doc_id,
|
||||
"clause_number": doc.clause_number,
|
||||
"score": doc.score
|
||||
}
|
||||
for doc in documents[:5] # 只返回前5条引用
|
||||
]
|
||||
yield {"event": "sources", "data": sources}
|
||||
|
||||
# Step 3: 构建RAG上下文
|
||||
self._init_context_builder()
|
||||
template_name = prompt_template or self.config.prompt_template
|
||||
template = get_prompt_template(template_name)
|
||||
context = self.context_builder.build(
|
||||
query=query,
|
||||
documents=documents,
|
||||
system_prompt=template.system_prompt
|
||||
)
|
||||
|
||||
# Step 4: 构建LLM输入消息
|
||||
messages = self._build_messages(template, context)
|
||||
|
||||
# Step 5: 流式调用LLM生成回答
|
||||
self._init_llm()
|
||||
full_answer = ""
|
||||
|
||||
# 检查LLM是否支持流式输出
|
||||
if hasattr(self.llm, 'stream_chat'):
|
||||
yield {"event": "status", "data": "思考中..."}
|
||||
for chunk in self.llm.stream_chat(
|
||||
messages=messages,
|
||||
temperature=self.config.temperature
|
||||
):
|
||||
full_answer += chunk
|
||||
yield {"event": "content", "data": chunk}
|
||||
else:
|
||||
# 如果不支持流式,回退到普通调用
|
||||
yield {"event": "status", "data": "生成回答中..."}
|
||||
llm_response = self.llm.chat(
|
||||
messages=messages,
|
||||
temperature=self.config.temperature
|
||||
)
|
||||
if llm_response.is_success:
|
||||
full_answer = llm_response.content
|
||||
yield {"event": "content", "data": full_answer}
|
||||
|
||||
# Step 6: 发送完成事件
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
logger.success(f"流式问答完成: {latency_ms}ms, {retrieved_count}条引用")
|
||||
|
||||
yield {
|
||||
"event": "done",
|
||||
"data": {
|
||||
"latency_ms": latency_ms,
|
||||
"model": self.config.llm_model,
|
||||
"retrieved_count": retrieved_count,
|
||||
"context_tokens": context.total_tokens
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式问答失败: {e}")
|
||||
yield {"event": "error", "data": str(e)}
|
||||
def ask_stream(self, query: str, filters: Optional[str] = None) -> Generator[dict, None, None]:
|
||||
"""Handle ask stream for the Q A Agent instance."""
|
||||
_, stream = get_agent_conversation_service().stream_chat(
|
||||
query=query,
|
||||
filters=filters,
|
||||
provider=self.config.llm_provider,
|
||||
model=self.config.llm_model,
|
||||
top_k=self.config.top_k,
|
||||
prompt_template=self.config.prompt_template,
|
||||
)
|
||||
for event in stream:
|
||||
yield event
|
||||
|
||||
def close(self):
|
||||
"""关闭Agent资源(不关闭LLM客户端,因为它全局缓存)"""
|
||||
if self.retriever:
|
||||
self.retriever.close()
|
||||
logger.info("问答Agent已关闭")
|
||||
"""Release the resources held by this component."""
|
||||
return None
|
||||
|
||||
|
||||
def ask_compliance_question(
|
||||
query: str,
|
||||
provider: str = "deepseek",
|
||||
model: str = "deepseek-v4-flash",
|
||||
top_k: int = 10
|
||||
) -> AgentResponse:
|
||||
"""
|
||||
便捷函数:问答合规问题
|
||||
|
||||
Args:
|
||||
query: 用户问题
|
||||
provider: LLM提供商
|
||||
model: LLM模型
|
||||
top_k: 检索数量
|
||||
|
||||
Returns:
|
||||
AgentResponse: 响应结果
|
||||
"""
|
||||
config = AgentConfig(
|
||||
llm_provider=provider,
|
||||
llm_model=model,
|
||||
top_k=top_k
|
||||
)
|
||||
agent = QAAgent(config)
|
||||
response = agent.ask(query)
|
||||
agent.close()
|
||||
return response
|
||||
def ask_compliance_question(query: str, top_k: int = 5) -> AgentResponse:
|
||||
"""Handle ask compliance question."""
|
||||
return QAAgent(AgentConfig(top_k=top_k)).ask(query)
|
||||
|
||||
Reference in New Issue
Block a user