Files
AIRegulation-DocAnalysis/backend/app/services/rag/retriever.py
2026-05-14 15:07:34 +08:00

194 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# src/services/rag/retriever.py
"""RAG检索服务 - 封装Milvus检索"""
from typing import List, Dict, Optional, Any
from dataclasses import dataclass, field
from loguru import logger
from app.services.embedding.bge_m3_embedder import BGEM3Embedder
from app.services.storage.milvus_client import MilvusClient, SearchResult
from app.config.settings import settings
@dataclass
class RetrievedDocument:
"""检索到的文档"""
content: str
doc_id: str # 文档ID用于下载
doc_name: str
section_title: str
clause_number: str
page_number: int
score: float
metadata: Dict[str, Any] = field(default_factory=dict)
class Retriever:
"""
RAG检索器
功能:
- 向量检索Dense + Sparse混合
- 重排序(可选)
- 过滤和筛选
"""
def __init__(
self,
top_k: int = None,
rerank: bool = False,
min_score: float = 0.3
):
"""
初始化检索器
Args:
top_k: 检索召回数量
rerank: 是否启用重排序
min_score: 最低相关性分数阈值
"""
self.top_k = top_k or settings.rag_top_k
self.rerank = rerank
self.min_score = min_score
# 嵌入模型(延迟加载)
self.embedder: Optional[BGEM3Embedder] = None
# Milvus客户端延迟连接
self.milvus: Optional[MilvusClient] = None
logger.info(f"检索器初始化: top_k={self.top_k}, rerank={self.rerank}")
def _init_embedder(self):
"""延迟初始化嵌入模型"""
if self.embedder is None:
logger.info("加载嵌入模型...")
self.embedder = BGEM3Embedder(model_name=settings.embedding_model)
def _init_milvus(self):
"""延迟初始化Milvus"""
if self.milvus is None:
logger.info("连接Milvus...")
self.milvus = MilvusClient()
self.milvus.connect()
self.milvus.create_collection(recreate=False)
self.milvus.load_collection()
def retrieve(
self,
query: str,
filters: Optional[str] = None,
top_k: Optional[int] = None
) -> List[RetrievedDocument]:
"""
检索相关文档
Args:
query: 查询文本
filters: 过滤条件(如 "regulation_type=='车辆安全'"
top_k: 返回数量(可选,覆盖默认值)
Returns:
List[RetrievedDocument]: 检索结果列表
"""
logger.info(f"执行检索: {query}")
# 初始化组件
self._init_embedder()
self._init_milvus()
# 生成查询向量
query_embedding = self.embedder.embed_single(query)
# 执行混合检索
results = self.milvus.hybrid_search(
query_dense=query_embedding['dense'].tolist(),
query_sparse=query_embedding['sparse'],
top_k=top_k or self.top_k,
filters=filters
)
# 转换为RetrievedDocument格式
documents = []
for r in results:
if r.score >= self.min_score:
doc = RetrievedDocument(
content=r.content,
doc_id=r.metadata.get("doc_id", ""),
doc_name=r.metadata.get("doc_name", ""),
section_title=r.metadata.get("section_title", ""),
clause_number=r.metadata.get("clause_number", ""),
page_number=r.metadata.get("page_number", 0),
score=r.score,
metadata=r.metadata
)
documents.append(doc)
logger.success(f"检索完成,返回{len(documents)}条结果(阈值过滤后)")
return documents
def retrieve_with_scores(
self,
query: str,
filters: Optional[str] = None
) -> List[Dict]:
"""
检索并返回完整结果(包含分数)
Args:
query: 查询文本
filters: 过滤条件
Returns:
List[Dict]: 包含分数的检索结果
"""
documents = self.retrieve(query, filters)
return [
{
"content": doc.content,
"doc_id": doc.doc_id,
"doc_name": doc.doc_name,
"section_title": doc.section_title,
"clause_number": doc.clause_number,
"page_number": doc.page_number,
"score": doc.score
}
for doc in documents
]
def search_by_doc_name(
self,
query: str,
doc_name: str
) -> List[RetrievedDocument]:
"""按文档名称过滤检索"""
filters = f'doc_name=="{doc_name}"'
return self.retrieve(query, filters)
def search_by_regulation_type(
self,
query: str,
regulation_type: str
) -> List[RetrievedDocument]:
"""按法规类型过滤检索"""
filters = f'regulation_type=="{regulation_type}"'
return self.retrieve(query, filters)
def close(self):
"""关闭连接"""
if self.milvus:
self.milvus.disconnect()
logger.info("检索器已关闭")
def retrieve_regulations(
query: str,
top_k: int = 10,
filters: Optional[str] = None
) -> List[RetrievedDocument]:
"""便捷函数:检索法规"""
retriever = Retriever(top_k=top_k)
results = retriever.retrieve(query, filters)
retriever.close()
return results