"""RAG检索服务 - 封装Milvus检索""" from typing import List, Dict, Optional, Any from dataclasses import dataclass, field from loguru import logger from app.services.embedding.bge_m3_embedder import BGEM3Embedder from app.services.storage.milvus_client import MilvusClient, SearchResult from app.config.settings import settings @dataclass class RetrievedDocument: """检索到的文档""" content: str doc_id: str # 文档ID,用于下载 doc_name: str section_title: str clause_number: str page_number: int score: float metadata: Dict[str, Any] = field(default_factory=dict) class Retriever: """ RAG检索器 功能: - 向量检索(Dense + Sparse混合) - 重排序(可选) - 过滤和筛选 """ def __init__( self, top_k: int = None, rerank: bool = False, min_score: float = 0.3 ): """ 初始化检索器 Args: top_k: 检索召回数量 rerank: 是否启用重排序 min_score: 最低相关性分数阈值 """ self.top_k = top_k or settings.rag_top_k self.rerank = rerank self.min_score = min_score # 嵌入模型(延迟加载) self.embedder: Optional[BGEM3Embedder] = None # Milvus客户端(延迟连接) self.milvus: Optional[MilvusClient] = None logger.info(f"检索器初始化: top_k={self.top_k}, rerank={self.rerank}") def _init_embedder(self): """延迟初始化嵌入模型""" if self.embedder is None: logger.info("加载嵌入模型...") self.embedder = BGEM3Embedder(model_name=settings.embedding_model) def _init_milvus(self): """延迟初始化Milvus""" if self.milvus is None: logger.info("连接Milvus...") self.milvus = MilvusClient() self.milvus.connect() self.milvus.create_collection(recreate=False) self.milvus.load_collection() def retrieve( self, query: str, filters: Optional[str] = None, top_k: Optional[int] = None ) -> List[RetrievedDocument]: """ 检索相关文档 Args: query: 查询文本 filters: 过滤条件(如 "regulation_type=='车辆安全'") top_k: 返回数量(可选,覆盖默认值) Returns: List[RetrievedDocument]: 检索结果列表 """ logger.info(f"执行检索: {query}") # 初始化组件 self._init_embedder() self._init_milvus() # 生成查询向量 query_embedding = self.embedder.embed_single(query) # 执行混合检索 results = self.milvus.hybrid_search( query_dense=query_embedding['dense'].tolist(), query_sparse=query_embedding['sparse'], top_k=top_k or self.top_k, filters=filters ) # 转换为RetrievedDocument格式 documents = [] for r in results: if r.score >= self.min_score: doc = RetrievedDocument( content=r.content, doc_id=r.metadata.get("doc_id", ""), doc_name=r.metadata.get("doc_name", ""), section_title=r.metadata.get("section_title", ""), clause_number=r.metadata.get("clause_number", ""), page_number=r.metadata.get("page_number", 0), score=r.score, metadata=r.metadata ) documents.append(doc) logger.success(f"检索完成,返回{len(documents)}条结果(阈值过滤后)") return documents def retrieve_with_scores( self, query: str, filters: Optional[str] = None ) -> List[Dict]: """ 检索并返回完整结果(包含分数) Args: query: 查询文本 filters: 过滤条件 Returns: List[Dict]: 包含分数的检索结果 """ documents = self.retrieve(query, filters) return [ { "content": doc.content, "doc_id": doc.doc_id, "doc_name": doc.doc_name, "section_title": doc.section_title, "clause_number": doc.clause_number, "page_number": doc.page_number, "score": doc.score } for doc in documents ] def search_by_doc_name( self, query: str, doc_name: str ) -> List[RetrievedDocument]: """按文档名称过滤检索""" filters = f'doc_name=="{doc_name}"' return self.retrieve(query, filters) def search_by_regulation_type( self, query: str, regulation_type: str ) -> List[RetrievedDocument]: """按法规类型过滤检索""" filters = f'regulation_type=="{regulation_type}"' return self.retrieve(query, filters) def close(self): """关闭连接""" if self.milvus: self.milvus.disconnect() logger.info("检索器已关闭") def retrieve_regulations( query: str, top_k: int = 10, filters: Optional[str] = None ) -> List[RetrievedDocument]: """便捷函数:检索法规""" retriever = Retriever(top_k=top_k) results = retriever.retrieve(query, filters) retriever.close() return results