413 lines
14 KiB
Python
413 lines
14 KiB
Python
|
|
# src/services/agent/qa_agent.py
|
|||
|
|
"""RAG问答Agent - 合规智能问答核心实现"""
|
|||
|
|
|
|||
|
|
import time
|
|||
|
|
from typing import List, Dict, Optional, Any, Generator
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from loguru import logger
|
|||
|
|
|
|||
|
|
from app.services.llm import get_llm_client, BaseLLMClient, LLMResponse
|
|||
|
|
from app.services.llm.llm_factory import LLMFactory
|
|||
|
|
from app.services.rag.retriever import Retriever, RetrievedDocument
|
|||
|
|
from app.services.rag.context_builder import ContextBuilder, RAGContext
|
|||
|
|
from app.services.rag.prompt_templates import get_prompt_template, PromptTemplate
|
|||
|
|
from app.config.settings import settings
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class AgentResponse:
|
|||
|
|
"""Agent响应结果"""
|
|||
|
|
answer: str
|
|||
|
|
sources: List[Dict] = field(default_factory=list)
|
|||
|
|
model: str = ""
|
|||
|
|
latency_ms: int = 0
|
|||
|
|
retrieved_count: int = 0
|
|||
|
|
context_tokens: int = 0
|
|||
|
|
truncated: bool = False
|
|||
|
|
error: Optional[str] = None
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def is_success(self) -> bool:
|
|||
|
|
return self.error is None
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class AgentConfig:
|
|||
|
|
"""Agent配置"""
|
|||
|
|
llm_provider: str = "deepseek"
|
|||
|
|
llm_model: str = "deepseek-v4-flash"
|
|||
|
|
top_k: int = 5
|
|||
|
|
min_score: float = 0.3
|
|||
|
|
max_context_tokens: int = 2000
|
|||
|
|
temperature: float = 0.7
|
|||
|
|
prompt_template: str = "compliance_qa"
|
|||
|
|
include_metadata: bool = True
|
|||
|
|
|
|||
|
|
|
|||
|
|
class QAAgent:
|
|||
|
|
"""
|
|||
|
|
合规问答Agent
|
|||
|
|
|
|||
|
|
核心流程:
|
|||
|
|
1. 接收用户问题
|
|||
|
|
2. Milvus混合检索相关法规条款
|
|||
|
|
3. 构建RAG上下文
|
|||
|
|
4. 调用LLM生成回答
|
|||
|
|
5. 返回答案和引用来源
|
|||
|
|
|
|||
|
|
使用示例:
|
|||
|
|
agent = QAAgent()
|
|||
|
|
response = agent.ask("机动车安全技术检验有哪些要求?")
|
|||
|
|
print(response.answer)
|
|||
|
|
for source in response.sources:
|
|||
|
|
print(f"引用: {source['doc_name']} - {source['clause_number']}")
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, config: Optional[AgentConfig] = None):
|
|||
|
|
"""
|
|||
|
|
初始化问答Agent
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
config: Agent配置(可选,使用默认配置)
|
|||
|
|
"""
|
|||
|
|
self.config = config or AgentConfig(
|
|||
|
|
llm_provider=settings.llm_provider,
|
|||
|
|
llm_model=settings.llm_model,
|
|||
|
|
top_k=settings.rag_top_k,
|
|||
|
|
max_context_tokens=settings.rag_max_context_tokens
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 初始化组件(延迟加载)
|
|||
|
|
self.llm: Optional[BaseLLMClient] = None
|
|||
|
|
self.retriever: Optional[Retriever] = None
|
|||
|
|
self.context_builder: Optional[ContextBuilder] = None
|
|||
|
|
|
|||
|
|
logger.info(f"问答Agent初始化: provider={self.config.llm_provider}, model={self.config.llm_model}")
|
|||
|
|
|
|||
|
|
def _init_llm(self):
|
|||
|
|
"""延迟初始化LLM客户端(优先使用全局缓存)"""
|
|||
|
|
if self.llm is None:
|
|||
|
|
# 尝试先获取全局缓存的客户端
|
|||
|
|
cached = LLMFactory.get_global_client(self.config.llm_provider, self.config.llm_model)
|
|||
|
|
if cached:
|
|||
|
|
self.llm = cached
|
|||
|
|
logger.debug(f"使用全局缓存的LLM客户端: {self.config.llm_provider} - {self.config.llm_model}")
|
|||
|
|
else:
|
|||
|
|
logger.info("创建新的LLM客户端...")
|
|||
|
|
self.llm = get_llm_client(
|
|||
|
|
provider=self.config.llm_provider,
|
|||
|
|
model=self.config.llm_model,
|
|||
|
|
temperature=self.config.temperature
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _init_retriever(self):
|
|||
|
|
"""延迟初始化检索器"""
|
|||
|
|
if self.retriever is None:
|
|||
|
|
logger.info("初始化检索器...")
|
|||
|
|
self.retriever = Retriever(
|
|||
|
|
top_k=self.config.top_k,
|
|||
|
|
min_score=self.config.min_score
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _init_context_builder(self):
|
|||
|
|
"""延迟初始化上下文构建器"""
|
|||
|
|
if self.context_builder is None:
|
|||
|
|
logger.info("初始化上下文构建器...")
|
|||
|
|
self.context_builder = ContextBuilder(
|
|||
|
|
max_context_tokens=self.config.max_context_tokens,
|
|||
|
|
include_metadata=self.config.include_metadata
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def ask(
|
|||
|
|
self,
|
|||
|
|
query: str,
|
|||
|
|
filters: Optional[str] = None,
|
|||
|
|
prompt_template: Optional[str] = None
|
|||
|
|
) -> AgentResponse:
|
|||
|
|
"""
|
|||
|
|
回答用户问题
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
query: 用户问题
|
|||
|
|
filters: 检索过滤条件(如 "regulation_type=='车辆安全'")
|
|||
|
|
prompt_template: Prompt模板名称(可选,覆盖默认配置)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
AgentResponse: 包含答案和引用来源的响应对象
|
|||
|
|
"""
|
|||
|
|
start_time = time.time()
|
|||
|
|
logger.info(f"收到问题: {query}")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# Step 1: 检索相关法规
|
|||
|
|
self._init_retriever()
|
|||
|
|
documents = self.retriever.retrieve(query, filters)
|
|||
|
|
retrieved_count = len(documents)
|
|||
|
|
|
|||
|
|
if retrieved_count == 0:
|
|||
|
|
return AgentResponse(
|
|||
|
|
answer="抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问,或提供更具体的法规名称。",
|
|||
|
|
retrieved_count=0,
|
|||
|
|
error="no_retrieved_documents"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Step 2: 构建RAG上下文
|
|||
|
|
self._init_context_builder()
|
|||
|
|
template_name = prompt_template or self.config.prompt_template
|
|||
|
|
template = get_prompt_template(template_name)
|
|||
|
|
|
|||
|
|
context = self.context_builder.build(
|
|||
|
|
query=query,
|
|||
|
|
documents=documents,
|
|||
|
|
system_prompt=template.system_prompt
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Step 3: 构建LLM输入消息
|
|||
|
|
messages = self._build_messages(template, context)
|
|||
|
|
|
|||
|
|
# Step 4: 调用LLM生成回答
|
|||
|
|
self._init_llm()
|
|||
|
|
llm_response = self.llm.chat(
|
|||
|
|
messages=messages,
|
|||
|
|
temperature=self.config.temperature
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if not llm_response.is_success:
|
|||
|
|
return AgentResponse(
|
|||
|
|
answer="",
|
|||
|
|
retrieved_count=retrieved_count,
|
|||
|
|
error=llm_response.error
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
latency_ms = int((time.time() - start_time) * 1000)
|
|||
|
|
|
|||
|
|
# Step 5: 返回结果
|
|||
|
|
logger.success(f"问答完成: {latency_ms}ms, {retrieved_count}条引用")
|
|||
|
|
|
|||
|
|
return AgentResponse(
|
|||
|
|
answer=llm_response.content,
|
|||
|
|
sources=context.sources,
|
|||
|
|
model=llm_response.model,
|
|||
|
|
latency_ms=latency_ms,
|
|||
|
|
retrieved_count=retrieved_count,
|
|||
|
|
context_tokens=context.total_tokens,
|
|||
|
|
truncated=context.truncated
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"问答失败: {e}")
|
|||
|
|
return AgentResponse(
|
|||
|
|
answer="",
|
|||
|
|
error=str(e)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def ask_with_context(
|
|||
|
|
self,
|
|||
|
|
query: str,
|
|||
|
|
documents: List[RetrievedDocument],
|
|||
|
|
prompt_template: Optional[str] = None
|
|||
|
|
) -> AgentResponse:
|
|||
|
|
"""
|
|||
|
|
使用提供的文档回答问题(不执行检索)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
query: 用户问题
|
|||
|
|
documents: 已检索的文档列表
|
|||
|
|
prompt_template: Prompt模板名称
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
AgentResponse: 响应结果
|
|||
|
|
"""
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
self._init_context_builder()
|
|||
|
|
self._init_llm()
|
|||
|
|
|
|||
|
|
template_name = prompt_template or self.config.prompt_template
|
|||
|
|
template = get_prompt_template(template_name)
|
|||
|
|
|
|||
|
|
context = self.context_builder.build(
|
|||
|
|
query=query,
|
|||
|
|
documents=documents,
|
|||
|
|
system_prompt=template.system_prompt
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
messages = self._build_messages(template, context)
|
|||
|
|
|
|||
|
|
llm_response = self.llm.chat(messages)
|
|||
|
|
|
|||
|
|
latency_ms = int((time.time() - start_time) * 1000)
|
|||
|
|
|
|||
|
|
return AgentResponse(
|
|||
|
|
answer=llm_response.content,
|
|||
|
|
sources=context.sources,
|
|||
|
|
model=llm_response.model,
|
|||
|
|
latency_ms=latency_ms,
|
|||
|
|
retrieved_count=len(documents),
|
|||
|
|
context_tokens=context.total_tokens,
|
|||
|
|
truncated=context.truncated
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"问答失败: {e}")
|
|||
|
|
return AgentResponse(answer="", error=str(e))
|
|||
|
|
|
|||
|
|
def _build_messages(
|
|||
|
|
self,
|
|||
|
|
template: PromptTemplate,
|
|||
|
|
context: RAGContext
|
|||
|
|
) -> List[Dict[str, str]]:
|
|||
|
|
"""构建LLM输入消息"""
|
|||
|
|
user_content = template.user_template.format(
|
|||
|
|
context=context.context_text,
|
|||
|
|
query=context.user_query
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return [
|
|||
|
|
{"role": "system", "content": template.system_prompt},
|
|||
|
|
{"role": "user", "content": user_content}
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
def ask_stream(
|
|||
|
|
self,
|
|||
|
|
query: str,
|
|||
|
|
filters: Optional[str] = None,
|
|||
|
|
prompt_template: Optional[str] = None
|
|||
|
|
) -> Generator[Dict[str, Any], None, None]:
|
|||
|
|
"""
|
|||
|
|
流式回答用户问题(SSE模式)
|
|||
|
|
|
|||
|
|
返回事件类型:
|
|||
|
|
- {"event": "status", "data": "正在检索..."} - 状态更新
|
|||
|
|
- {"event": "sources", "data": [...]} - 引用来源
|
|||
|
|
- {"event": "content", "data": "文本片段"} - 回答内容
|
|||
|
|
- {"event": "done", "data": {"latency_ms": ..., "model": ...}} - 完成
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
query: 用户问题
|
|||
|
|
filters: 检索过滤条件
|
|||
|
|
prompt_template: Prompt模板名称
|
|||
|
|
|
|||
|
|
Yields:
|
|||
|
|
Dict: SSE事件数据
|
|||
|
|
"""
|
|||
|
|
start_time = time.time()
|
|||
|
|
logger.info(f"收到流式问题: {query}")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# Step 1: 检索相关法规
|
|||
|
|
yield {"event": "status", "data": "正在检索相关法规..."}
|
|||
|
|
self._init_retriever()
|
|||
|
|
documents = self.retriever.retrieve(query, filters)
|
|||
|
|
retrieved_count = len(documents)
|
|||
|
|
|
|||
|
|
if retrieved_count == 0:
|
|||
|
|
yield {"event": "status", "data": "未找到相关法规"}
|
|||
|
|
yield {"event": "content", "data": "抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问。"}
|
|||
|
|
yield {"event": "done", "data": {"latency_ms": 0, "retrieved_count": 0}}
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# Step 2: 发送检索结果
|
|||
|
|
yield {"event": "status", "data": f"找到{retrieved_count}条相关法规,正在生成回答..."}
|
|||
|
|
sources = [
|
|||
|
|
{
|
|||
|
|
"doc_name": doc.doc_name,
|
|||
|
|
"doc_id": doc.doc_id,
|
|||
|
|
"clause_number": doc.clause_number,
|
|||
|
|
"score": doc.score
|
|||
|
|
}
|
|||
|
|
for doc in documents[:5] # 只返回前5条引用
|
|||
|
|
]
|
|||
|
|
yield {"event": "sources", "data": sources}
|
|||
|
|
|
|||
|
|
# Step 3: 构建RAG上下文
|
|||
|
|
self._init_context_builder()
|
|||
|
|
template_name = prompt_template or self.config.prompt_template
|
|||
|
|
template = get_prompt_template(template_name)
|
|||
|
|
context = self.context_builder.build(
|
|||
|
|
query=query,
|
|||
|
|
documents=documents,
|
|||
|
|
system_prompt=template.system_prompt
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Step 4: 构建LLM输入消息
|
|||
|
|
messages = self._build_messages(template, context)
|
|||
|
|
|
|||
|
|
# Step 5: 流式调用LLM生成回答
|
|||
|
|
self._init_llm()
|
|||
|
|
full_answer = ""
|
|||
|
|
|
|||
|
|
# 检查LLM是否支持流式输出
|
|||
|
|
if hasattr(self.llm, 'stream_chat'):
|
|||
|
|
yield {"event": "status", "data": "思考中..."}
|
|||
|
|
for chunk in self.llm.stream_chat(
|
|||
|
|
messages=messages,
|
|||
|
|
temperature=self.config.temperature
|
|||
|
|
):
|
|||
|
|
full_answer += chunk
|
|||
|
|
yield {"event": "content", "data": chunk}
|
|||
|
|
else:
|
|||
|
|
# 如果不支持流式,回退到普通调用
|
|||
|
|
yield {"event": "status", "data": "生成回答中..."}
|
|||
|
|
llm_response = self.llm.chat(
|
|||
|
|
messages=messages,
|
|||
|
|
temperature=self.config.temperature
|
|||
|
|
)
|
|||
|
|
if llm_response.is_success:
|
|||
|
|
full_answer = llm_response.content
|
|||
|
|
yield {"event": "content", "data": full_answer}
|
|||
|
|
|
|||
|
|
# Step 6: 发送完成事件
|
|||
|
|
latency_ms = int((time.time() - start_time) * 1000)
|
|||
|
|
logger.success(f"流式问答完成: {latency_ms}ms, {retrieved_count}条引用")
|
|||
|
|
|
|||
|
|
yield {
|
|||
|
|
"event": "done",
|
|||
|
|
"data": {
|
|||
|
|
"latency_ms": latency_ms,
|
|||
|
|
"model": self.config.llm_model,
|
|||
|
|
"retrieved_count": retrieved_count,
|
|||
|
|
"context_tokens": context.total_tokens
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"流式问答失败: {e}")
|
|||
|
|
yield {"event": "error", "data": str(e)}
|
|||
|
|
|
|||
|
|
def close(self):
|
|||
|
|
"""关闭Agent资源(不关闭LLM客户端,因为它全局缓存)"""
|
|||
|
|
if self.retriever:
|
|||
|
|
self.retriever.close()
|
|||
|
|
logger.info("问答Agent已关闭")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def ask_compliance_question(
|
|||
|
|
query: str,
|
|||
|
|
provider: str = "deepseek",
|
|||
|
|
model: str = "deepseek-v4-flash",
|
|||
|
|
top_k: int = 10
|
|||
|
|
) -> AgentResponse:
|
|||
|
|
"""
|
|||
|
|
便捷函数:问答合规问题
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
query: 用户问题
|
|||
|
|
provider: LLM提供商
|
|||
|
|
model: LLM模型
|
|||
|
|
top_k: 检索数量
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
AgentResponse: 响应结果
|
|||
|
|
"""
|
|||
|
|
config = AgentConfig(
|
|||
|
|
llm_provider=provider,
|
|||
|
|
llm_model=model,
|
|||
|
|
top_k=top_k
|
|||
|
|
)
|
|||
|
|
agent = QAAgent(config)
|
|||
|
|
response = agent.ask(query)
|
|||
|
|
agent.close()
|
|||
|
|
return response
|