# src/services/agent/qa_agent.py """RAG问答Agent - 合规智能问答核心实现""" import time from typing import List, Dict, Optional, Any, Generator from dataclasses import dataclass, field from loguru import logger from app.services.llm import get_llm_client, BaseLLMClient, LLMResponse from app.services.llm.llm_factory import LLMFactory from app.services.rag.retriever import Retriever, RetrievedDocument from app.services.rag.context_builder import ContextBuilder, RAGContext from app.services.rag.prompt_templates import get_prompt_template, PromptTemplate from app.config.settings import settings @dataclass class AgentResponse: """Agent响应结果""" answer: str sources: List[Dict] = field(default_factory=list) model: str = "" latency_ms: int = 0 retrieved_count: int = 0 context_tokens: int = 0 truncated: bool = False error: Optional[str] = None @property def is_success(self) -> bool: return self.error is None @dataclass class AgentConfig: """Agent配置""" llm_provider: str = "deepseek" llm_model: str = "deepseek-v4-flash" top_k: int = 5 min_score: float = 0.3 max_context_tokens: int = 2000 temperature: float = 0.7 prompt_template: str = "compliance_qa" include_metadata: bool = True class QAAgent: """ 合规问答Agent 核心流程: 1. 接收用户问题 2. Milvus混合检索相关法规条款 3. 构建RAG上下文 4. 调用LLM生成回答 5. 返回答案和引用来源 使用示例: agent = QAAgent() response = agent.ask("机动车安全技术检验有哪些要求?") print(response.answer) for source in response.sources: print(f"引用: {source['doc_name']} - {source['clause_number']}") """ def __init__(self, config: Optional[AgentConfig] = None): """ 初始化问答Agent Args: config: Agent配置(可选,使用默认配置) """ self.config = config or AgentConfig( llm_provider=settings.llm_provider, llm_model=settings.llm_model, top_k=settings.rag_top_k, max_context_tokens=settings.rag_max_context_tokens ) # 初始化组件(延迟加载) self.llm: Optional[BaseLLMClient] = None self.retriever: Optional[Retriever] = None self.context_builder: Optional[ContextBuilder] = None logger.info(f"问答Agent初始化: provider={self.config.llm_provider}, model={self.config.llm_model}") def _init_llm(self): """延迟初始化LLM客户端(优先使用全局缓存)""" if self.llm is None: # 尝试先获取全局缓存的客户端 cached = LLMFactory.get_global_client(self.config.llm_provider, self.config.llm_model) if cached: self.llm = cached logger.debug(f"使用全局缓存的LLM客户端: {self.config.llm_provider} - {self.config.llm_model}") else: logger.info("创建新的LLM客户端...") self.llm = get_llm_client( provider=self.config.llm_provider, model=self.config.llm_model, temperature=self.config.temperature ) def _init_retriever(self): """延迟初始化检索器""" if self.retriever is None: logger.info("初始化检索器...") self.retriever = Retriever( top_k=self.config.top_k, min_score=self.config.min_score ) def _init_context_builder(self): """延迟初始化上下文构建器""" if self.context_builder is None: logger.info("初始化上下文构建器...") self.context_builder = ContextBuilder( max_context_tokens=self.config.max_context_tokens, include_metadata=self.config.include_metadata ) def ask( self, query: str, filters: Optional[str] = None, prompt_template: Optional[str] = None ) -> AgentResponse: """ 回答用户问题 Args: query: 用户问题 filters: 检索过滤条件(如 "regulation_type=='车辆安全'") prompt_template: Prompt模板名称(可选,覆盖默认配置) Returns: AgentResponse: 包含答案和引用来源的响应对象 """ start_time = time.time() logger.info(f"收到问题: {query}") try: # Step 1: 检索相关法规 self._init_retriever() documents = self.retriever.retrieve(query, filters) retrieved_count = len(documents) if retrieved_count == 0: return AgentResponse( answer="抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问,或提供更具体的法规名称。", retrieved_count=0, error="no_retrieved_documents" ) # Step 2: 构建RAG上下文 self._init_context_builder() template_name = prompt_template or self.config.prompt_template template = get_prompt_template(template_name) context = self.context_builder.build( query=query, documents=documents, system_prompt=template.system_prompt ) # Step 3: 构建LLM输入消息 messages = self._build_messages(template, context) # Step 4: 调用LLM生成回答 self._init_llm() llm_response = self.llm.chat( messages=messages, temperature=self.config.temperature ) if not llm_response.is_success: return AgentResponse( answer="", retrieved_count=retrieved_count, error=llm_response.error ) latency_ms = int((time.time() - start_time) * 1000) # Step 5: 返回结果 logger.success(f"问答完成: {latency_ms}ms, {retrieved_count}条引用") return AgentResponse( answer=llm_response.content, sources=context.sources, model=llm_response.model, latency_ms=latency_ms, retrieved_count=retrieved_count, context_tokens=context.total_tokens, truncated=context.truncated ) except Exception as e: logger.error(f"问答失败: {e}") return AgentResponse( answer="", error=str(e) ) def ask_with_context( self, query: str, documents: List[RetrievedDocument], prompt_template: Optional[str] = None ) -> AgentResponse: """ 使用提供的文档回答问题(不执行检索) Args: query: 用户问题 documents: 已检索的文档列表 prompt_template: Prompt模板名称 Returns: AgentResponse: 响应结果 """ start_time = time.time() try: self._init_context_builder() self._init_llm() template_name = prompt_template or self.config.prompt_template template = get_prompt_template(template_name) context = self.context_builder.build( query=query, documents=documents, system_prompt=template.system_prompt ) messages = self._build_messages(template, context) llm_response = self.llm.chat(messages) latency_ms = int((time.time() - start_time) * 1000) return AgentResponse( answer=llm_response.content, sources=context.sources, model=llm_response.model, latency_ms=latency_ms, retrieved_count=len(documents), context_tokens=context.total_tokens, truncated=context.truncated ) except Exception as e: logger.error(f"问答失败: {e}") return AgentResponse(answer="", error=str(e)) def _build_messages( self, template: PromptTemplate, context: RAGContext ) -> List[Dict[str, str]]: """构建LLM输入消息""" user_content = template.user_template.format( context=context.context_text, query=context.user_query ) return [ {"role": "system", "content": template.system_prompt}, {"role": "user", "content": user_content} ] def ask_stream( self, query: str, filters: Optional[str] = None, prompt_template: Optional[str] = None ) -> Generator[Dict[str, Any], None, None]: """ 流式回答用户问题(SSE模式) 返回事件类型: - {"event": "status", "data": "正在检索..."} - 状态更新 - {"event": "sources", "data": [...]} - 引用来源 - {"event": "content", "data": "文本片段"} - 回答内容 - {"event": "done", "data": {"latency_ms": ..., "model": ...}} - 完成 Args: query: 用户问题 filters: 检索过滤条件 prompt_template: Prompt模板名称 Yields: Dict: SSE事件数据 """ start_time = time.time() logger.info(f"收到流式问题: {query}") try: # Step 1: 检索相关法规 yield {"event": "status", "data": "正在检索相关法规..."} self._init_retriever() documents = self.retriever.retrieve(query, filters) retrieved_count = len(documents) if retrieved_count == 0: yield {"event": "status", "data": "未找到相关法规"} yield {"event": "content", "data": "抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问。"} yield {"event": "done", "data": {"latency_ms": 0, "retrieved_count": 0}} return # Step 2: 发送检索结果 yield {"event": "status", "data": f"找到{retrieved_count}条相关法规,正在生成回答..."} sources = [ { "doc_name": doc.doc_name, "doc_id": doc.doc_id, "clause_number": doc.clause_number, "score": doc.score } for doc in documents[:5] # 只返回前5条引用 ] yield {"event": "sources", "data": sources} # Step 3: 构建RAG上下文 self._init_context_builder() template_name = prompt_template or self.config.prompt_template template = get_prompt_template(template_name) context = self.context_builder.build( query=query, documents=documents, system_prompt=template.system_prompt ) # Step 4: 构建LLM输入消息 messages = self._build_messages(template, context) # Step 5: 流式调用LLM生成回答 self._init_llm() full_answer = "" # 检查LLM是否支持流式输出 if hasattr(self.llm, 'stream_chat'): yield {"event": "status", "data": "思考中..."} for chunk in self.llm.stream_chat( messages=messages, temperature=self.config.temperature ): full_answer += chunk yield {"event": "content", "data": chunk} else: # 如果不支持流式,回退到普通调用 yield {"event": "status", "data": "生成回答中..."} llm_response = self.llm.chat( messages=messages, temperature=self.config.temperature ) if llm_response.is_success: full_answer = llm_response.content yield {"event": "content", "data": full_answer} # Step 6: 发送完成事件 latency_ms = int((time.time() - start_time) * 1000) logger.success(f"流式问答完成: {latency_ms}ms, {retrieved_count}条引用") yield { "event": "done", "data": { "latency_ms": latency_ms, "model": self.config.llm_model, "retrieved_count": retrieved_count, "context_tokens": context.total_tokens } } except Exception as e: logger.error(f"流式问答失败: {e}") yield {"event": "error", "data": str(e)} def close(self): """关闭Agent资源(不关闭LLM客户端,因为它全局缓存)""" if self.retriever: self.retriever.close() logger.info("问答Agent已关闭") def ask_compliance_question( query: str, provider: str = "deepseek", model: str = "deepseek-v4-flash", top_k: int = 10 ) -> AgentResponse: """ 便捷函数:问答合规问题 Args: query: 用户问题 provider: LLM提供商 model: LLM模型 top_k: 检索数量 Returns: AgentResponse: 响应结果 """ config = AgentConfig( llm_provider=provider, llm_model=model, top_k=top_k ) agent = QAAgent(config) response = agent.ask(query) agent.close() return response