193 lines
5.4 KiB
Python
193 lines
5.4 KiB
Python
|
|
"""RAG检索服务 - 封装Milvus检索"""
|
|||
|
|
|
|||
|
|
from typing import List, Dict, Optional, Any
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from loguru import logger
|
|||
|
|
|
|||
|
|
from app.services.embedding.bge_m3_embedder import BGEM3Embedder
|
|||
|
|
from app.services.storage.milvus_client import MilvusClient, SearchResult
|
|||
|
|
from app.config.settings import settings
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class RetrievedDocument:
|
|||
|
|
"""检索到的文档"""
|
|||
|
|
content: str
|
|||
|
|
doc_id: str # 文档ID,用于下载
|
|||
|
|
doc_name: str
|
|||
|
|
section_title: str
|
|||
|
|
clause_number: str
|
|||
|
|
page_number: int
|
|||
|
|
score: float
|
|||
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Retriever:
|
|||
|
|
"""
|
|||
|
|
RAG检索器
|
|||
|
|
|
|||
|
|
功能:
|
|||
|
|
- 向量检索(Dense + Sparse混合)
|
|||
|
|
- 重排序(可选)
|
|||
|
|
- 过滤和筛选
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
top_k: int = None,
|
|||
|
|
rerank: bool = False,
|
|||
|
|
min_score: float = 0.3
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
初始化检索器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
top_k: 检索召回数量
|
|||
|
|
rerank: 是否启用重排序
|
|||
|
|
min_score: 最低相关性分数阈值
|
|||
|
|
"""
|
|||
|
|
self.top_k = top_k or settings.rag_top_k
|
|||
|
|
self.rerank = rerank
|
|||
|
|
self.min_score = min_score
|
|||
|
|
|
|||
|
|
# 嵌入模型(延迟加载)
|
|||
|
|
self.embedder: Optional[BGEM3Embedder] = None
|
|||
|
|
|
|||
|
|
# Milvus客户端(延迟连接)
|
|||
|
|
self.milvus: Optional[MilvusClient] = None
|
|||
|
|
|
|||
|
|
logger.info(f"检索器初始化: top_k={self.top_k}, rerank={self.rerank}")
|
|||
|
|
|
|||
|
|
def _init_embedder(self):
|
|||
|
|
"""延迟初始化嵌入模型"""
|
|||
|
|
if self.embedder is None:
|
|||
|
|
logger.info("加载嵌入模型...")
|
|||
|
|
self.embedder = BGEM3Embedder(model_name=settings.embedding_model)
|
|||
|
|
|
|||
|
|
def _init_milvus(self):
|
|||
|
|
"""延迟初始化Milvus"""
|
|||
|
|
if self.milvus is None:
|
|||
|
|
logger.info("连接Milvus...")
|
|||
|
|
self.milvus = MilvusClient()
|
|||
|
|
self.milvus.connect()
|
|||
|
|
self.milvus.create_collection(recreate=False)
|
|||
|
|
self.milvus.load_collection()
|
|||
|
|
|
|||
|
|
def retrieve(
|
|||
|
|
self,
|
|||
|
|
query: str,
|
|||
|
|
filters: Optional[str] = None,
|
|||
|
|
top_k: Optional[int] = None
|
|||
|
|
) -> List[RetrievedDocument]:
|
|||
|
|
"""
|
|||
|
|
检索相关文档
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
query: 查询文本
|
|||
|
|
filters: 过滤条件(如 "regulation_type=='车辆安全'")
|
|||
|
|
top_k: 返回数量(可选,覆盖默认值)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
List[RetrievedDocument]: 检索结果列表
|
|||
|
|
"""
|
|||
|
|
logger.info(f"执行检索: {query}")
|
|||
|
|
|
|||
|
|
# 初始化组件
|
|||
|
|
self._init_embedder()
|
|||
|
|
self._init_milvus()
|
|||
|
|
|
|||
|
|
# 生成查询向量
|
|||
|
|
query_embedding = self.embedder.embed_single(query)
|
|||
|
|
|
|||
|
|
# 执行混合检索
|
|||
|
|
results = self.milvus.hybrid_search(
|
|||
|
|
query_dense=query_embedding['dense'].tolist(),
|
|||
|
|
query_sparse=query_embedding['sparse'],
|
|||
|
|
top_k=top_k or self.top_k,
|
|||
|
|
filters=filters
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 转换为RetrievedDocument格式
|
|||
|
|
documents = []
|
|||
|
|
for r in results:
|
|||
|
|
if r.score >= self.min_score:
|
|||
|
|
doc = RetrievedDocument(
|
|||
|
|
content=r.content,
|
|||
|
|
doc_id=r.metadata.get("doc_id", ""),
|
|||
|
|
doc_name=r.metadata.get("doc_name", ""),
|
|||
|
|
section_title=r.metadata.get("section_title", ""),
|
|||
|
|
clause_number=r.metadata.get("clause_number", ""),
|
|||
|
|
page_number=r.metadata.get("page_number", 0),
|
|||
|
|
score=r.score,
|
|||
|
|
metadata=r.metadata
|
|||
|
|
)
|
|||
|
|
documents.append(doc)
|
|||
|
|
|
|||
|
|
logger.success(f"检索完成,返回{len(documents)}条结果(阈值过滤后)")
|
|||
|
|
return documents
|
|||
|
|
|
|||
|
|
def retrieve_with_scores(
|
|||
|
|
self,
|
|||
|
|
query: str,
|
|||
|
|
filters: Optional[str] = None
|
|||
|
|
) -> List[Dict]:
|
|||
|
|
"""
|
|||
|
|
检索并返回完整结果(包含分数)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
query: 查询文本
|
|||
|
|
filters: 过滤条件
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
List[Dict]: 包含分数的检索结果
|
|||
|
|
"""
|
|||
|
|
documents = self.retrieve(query, filters)
|
|||
|
|
return [
|
|||
|
|
{
|
|||
|
|
"content": doc.content,
|
|||
|
|
"doc_id": doc.doc_id,
|
|||
|
|
"doc_name": doc.doc_name,
|
|||
|
|
"section_title": doc.section_title,
|
|||
|
|
"clause_number": doc.clause_number,
|
|||
|
|
"page_number": doc.page_number,
|
|||
|
|
"score": doc.score
|
|||
|
|
}
|
|||
|
|
for doc in documents
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
def search_by_doc_name(
|
|||
|
|
self,
|
|||
|
|
query: str,
|
|||
|
|
doc_name: str
|
|||
|
|
) -> List[RetrievedDocument]:
|
|||
|
|
"""按文档名称过滤检索"""
|
|||
|
|
filters = f'doc_name=="{doc_name}"'
|
|||
|
|
return self.retrieve(query, filters)
|
|||
|
|
|
|||
|
|
def search_by_regulation_type(
|
|||
|
|
self,
|
|||
|
|
query: str,
|
|||
|
|
regulation_type: str
|
|||
|
|
) -> List[RetrievedDocument]:
|
|||
|
|
"""按法规类型过滤检索"""
|
|||
|
|
filters = f'regulation_type=="{regulation_type}"'
|
|||
|
|
return self.retrieve(query, filters)
|
|||
|
|
|
|||
|
|
def close(self):
|
|||
|
|
"""关闭连接"""
|
|||
|
|
if self.milvus:
|
|||
|
|
self.milvus.disconnect()
|
|||
|
|
logger.info("检索器已关闭")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def retrieve_regulations(
|
|||
|
|
query: str,
|
|||
|
|
top_k: int = 10,
|
|||
|
|
filters: Optional[str] = None
|
|||
|
|
) -> List[RetrievedDocument]:
|
|||
|
|
"""便捷函数:检索法规"""
|
|||
|
|
retriever = Retriever(top_k=top_k)
|
|||
|
|
results = retriever.retrieve(query, filters)
|
|||
|
|
retriever.close()
|
|||
|
|
return results
|