update
This commit is contained in:
12
backend/app/services/rag/__init__.py
Normal file
12
backend/app/services/rag/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# src/services/rag/__init__.py
|
||||
"""RAG服务模块"""
|
||||
|
||||
from .retriever import Retriever, retrieve_regulations
|
||||
from .context_builder import ContextBuilder, build_rag_context
|
||||
from .prompt_templates import PromptTemplates, get_prompt_template
|
||||
|
||||
__all__ = [
|
||||
"Retriever", "retrieve_regulations",
|
||||
"ContextBuilder", "build_rag_context",
|
||||
"PromptTemplates", "get_prompt_template"
|
||||
]
|
||||
230
backend/app/services/rag/context_builder.py
Normal file
230
backend/app/services/rag/context_builder.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# src/services/rag/context_builder.py
|
||||
"""RAG上下文构建服务 - 构建LLM输入上下文"""
|
||||
|
||||
from typing import List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
|
||||
from .retriever import RetrievedDocument
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class RAGContext:
|
||||
"""RAG构建的上下文"""
|
||||
system_prompt: str
|
||||
context_text: str
|
||||
user_query: str
|
||||
total_tokens: int
|
||||
sources: List[Dict]
|
||||
truncated: bool = False
|
||||
|
||||
|
||||
class ContextBuilder:
|
||||
"""
|
||||
RAG上下文构建器
|
||||
|
||||
功能:
|
||||
- 格式化检索结果为上下文文本
|
||||
- 控制上下文长度(token限制)
|
||||
- 构建完整的LLM输入格式
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_context_tokens: int = None,
|
||||
include_metadata: bool = True,
|
||||
citation_format: str = "【条款{clause}】"
|
||||
):
|
||||
"""
|
||||
初始化上下文构建器
|
||||
|
||||
Args:
|
||||
max_context_tokens: 最大上下文token数
|
||||
include_metadata: 是否包含元数据(文档名、条款号等)
|
||||
citation_format: 引用格式模板
|
||||
"""
|
||||
self.max_context_tokens = max_context_tokens or settings.rag_max_context_tokens
|
||||
self.include_metadata = include_metadata
|
||||
self.citation_format = citation_format
|
||||
|
||||
logger.info(f"上下文构建器初始化: max_tokens={self.max_context_tokens}")
|
||||
|
||||
def build(
|
||||
self,
|
||||
query: str,
|
||||
documents: List[RetrievedDocument],
|
||||
system_prompt: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> RAGContext:
|
||||
"""
|
||||
构建RAG上下文
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
documents: 检索到的文档列表
|
||||
system_prompt: 系统提示词(可选)
|
||||
max_tokens: 最大token数(可选,覆盖默认值)
|
||||
|
||||
Returns:
|
||||
RAGContext: 构建的上下文对象
|
||||
"""
|
||||
max_tokens = max_tokens or self.max_context_tokens
|
||||
|
||||
# 格式化文档内容
|
||||
context_text, sources, truncated = self._format_documents(
|
||||
documents,
|
||||
max_tokens
|
||||
)
|
||||
|
||||
# 构建系统提示词
|
||||
system_prompt = system_prompt or self._default_system_prompt()
|
||||
|
||||
# 估算总token数
|
||||
total_tokens = self._estimate_tokens(system_prompt + context_text + query)
|
||||
|
||||
logger.info(f"上下文构建完成: {len(documents)}条文档, {total_tokens}tokens, truncated={truncated}")
|
||||
|
||||
return RAGContext(
|
||||
system_prompt=system_prompt,
|
||||
context_text=context_text,
|
||||
user_query=query,
|
||||
total_tokens=total_tokens,
|
||||
sources=sources,
|
||||
truncated=truncated
|
||||
)
|
||||
|
||||
def _format_documents(
|
||||
self,
|
||||
documents: List[RetrievedDocument],
|
||||
max_tokens: int
|
||||
) -> tuple:
|
||||
"""
|
||||
格式化文档内容
|
||||
|
||||
Args:
|
||||
documents: 文档列表
|
||||
max_tokens: 最大token数
|
||||
|
||||
Returns:
|
||||
(context_text, sources, truncated)
|
||||
"""
|
||||
context_parts = []
|
||||
sources = []
|
||||
current_tokens = 0
|
||||
truncated = False
|
||||
|
||||
for i, doc in enumerate(documents):
|
||||
# 格式化单个文档
|
||||
formatted = self._format_single_doc(doc, i + 1)
|
||||
|
||||
# 估算token数
|
||||
doc_tokens = self._estimate_tokens(formatted)
|
||||
|
||||
# 检查是否超出限制
|
||||
if current_tokens + doc_tokens > max_tokens:
|
||||
truncated = True
|
||||
logger.warning(f"上下文截断: 已达到{max_tokens}tokens限制")
|
||||
break
|
||||
|
||||
context_parts.append(formatted)
|
||||
current_tokens += doc_tokens
|
||||
|
||||
# 记录来源
|
||||
sources.append({
|
||||
"index": i + 1,
|
||||
"doc_id": doc.doc_id,
|
||||
"doc_name": doc.doc_name,
|
||||
"section_title": doc.section_title,
|
||||
"clause_number": doc.clause_number,
|
||||
"page_number": doc.page_number,
|
||||
"score": doc.score
|
||||
})
|
||||
|
||||
context_text = "\n\n".join(context_parts)
|
||||
return context_text, sources, truncated
|
||||
|
||||
def _format_single_doc(
|
||||
self,
|
||||
doc: RetrievedDocument,
|
||||
index: int
|
||||
) -> str:
|
||||
"""格式化单个文档"""
|
||||
parts = []
|
||||
|
||||
# 索引编号
|
||||
parts.append(f"[{index}]")
|
||||
|
||||
# 元数据(可选)
|
||||
if self.include_metadata:
|
||||
meta_parts = []
|
||||
|
||||
if doc.doc_name:
|
||||
meta_parts.append(f"文档: {doc.doc_name}")
|
||||
|
||||
if doc.section_title:
|
||||
meta_parts.append(f"章节: {doc.section_title}")
|
||||
|
||||
if doc.clause_number:
|
||||
clause_text = self.citation_format.format(clause=doc.clause_number)
|
||||
meta_parts.append(clause_text)
|
||||
|
||||
if meta_parts:
|
||||
parts.append(" | ".join(meta_parts))
|
||||
|
||||
# 内容
|
||||
parts.append(doc.content)
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _default_system_prompt(self) -> str:
|
||||
"""默认系统提示词"""
|
||||
return """你是合规专家助手,基于提供的法规条款回答问题。
|
||||
|
||||
回答要求:
|
||||
1. 直接回答问题,必须引用具体条款编号(如【条款5.2.1】)
|
||||
2. 如引用的条款不完整,说明需要进一步查阅原文
|
||||
3. 给出明确的合规建议和操作指导
|
||||
4. 如果检索内容不足以回答问题,如实说明
|
||||
|
||||
回答格式:
|
||||
- 先给出直接结论
|
||||
- 然后引用支撑条款
|
||||
- 最后给出合规建议"""
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
"""估算文本token数"""
|
||||
# 中文字符约1.5 token,英文约0.25 token
|
||||
chinese_chars = sum(1 for c in text if '一' <= c <= '鿿')
|
||||
other_chars = len(text) - chinese_chars
|
||||
return int(chinese_chars * 1.5 + other_chars * 0.25)
|
||||
|
||||
def build_messages(
|
||||
self,
|
||||
context: RAGContext
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
构建LLM消息格式
|
||||
|
||||
Args:
|
||||
context: RAG上下文对象
|
||||
|
||||
Returns:
|
||||
List[Dict]: [{"role": "system/user/assistant", "content": "..."}]
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": context.system_prompt},
|
||||
{"role": "user", "content": f"参考以下法规条款回答问题。\n\n{context.context_text}\n\n问题:{context.user_query}"}
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def build_rag_context(
|
||||
query: str,
|
||||
documents: List[RetrievedDocument],
|
||||
**kwargs
|
||||
) -> RAGContext:
|
||||
"""便捷函数:构建RAG上下文"""
|
||||
builder = ContextBuilder()
|
||||
return builder.build(query, documents, **kwargs)
|
||||
296
backend/app/services/rag/prompt_templates.py
Normal file
296
backend/app/services/rag/prompt_templates.py
Normal file
@@ -0,0 +1,296 @@
|
||||
# src/services/rag/prompt_templates.py
|
||||
"""RAG Prompt模板 - 合规问答专用Prompt"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptTemplate:
|
||||
"""Prompt模板"""
|
||||
name: str
|
||||
system_prompt: str
|
||||
user_template: str
|
||||
description: str
|
||||
|
||||
|
||||
class PromptTemplates:
|
||||
"""
|
||||
合规问答Prompt模板库
|
||||
|
||||
包含多种场景的Prompt模板:
|
||||
- 合规问答(标准)
|
||||
- 条款解读(详细解释)
|
||||
- 合规检查(判断合规状态)
|
||||
- 差异对比(新旧法规对比)
|
||||
- 报告生成(合规报告)
|
||||
"""
|
||||
|
||||
# 合规问答标准模板
|
||||
COMPLIANCE_QA = PromptTemplate(
|
||||
name="compliance_qa",
|
||||
system_prompt="""你是合规专家助手,专门解答法规合规问题。
|
||||
|
||||
角色定位:
|
||||
- 深谙国家法规标准(GB标准、行业标准)
|
||||
- 熟悉车辆安全、数据安全、EHS等领域合规要求
|
||||
- 提供专业、准确、可操作的合规建议
|
||||
|
||||
回答规范:
|
||||
1. 必须引用具体条款编号(如【条款5.2.1】)
|
||||
2. 优先引用高相关性条款(score ≥ 0.5)
|
||||
3. 如条款内容不完整,明确提示需要查阅原文
|
||||
4. 给出明确的合规结论和建议
|
||||
5. 如检索内容不足以回答,如实说明
|
||||
|
||||
回答格式:
|
||||
【结论】直接给出合规判断或答案
|
||||
|
||||
【条款依据】
|
||||
- 【条款X.X.X】简要内容摘要(相关性: 高/中)
|
||||
- ...
|
||||
|
||||
【合规建议】
|
||||
1. 具体操作建议
|
||||
2. 需要注意的风险点
|
||||
3. 后续行动建议""",
|
||||
user_template="""请根据以下法规条款回答问题。
|
||||
|
||||
【法规条款】
|
||||
{context}
|
||||
|
||||
【用户问题】
|
||||
{query}""",
|
||||
description="标准合规问答模板"
|
||||
)
|
||||
|
||||
# 条款解读模板(详细解释)
|
||||
CLAUSE_INTERPRETATION = PromptTemplate(
|
||||
name="clause_interpretation",
|
||||
system_prompt="""你是法规解读专家,负责详细解释法规条款的含义和应用。
|
||||
|
||||
解读要求:
|
||||
1. 逐句解释条款原文的含义
|
||||
2. 说明条款的目的和背景
|
||||
3. 举例说明条款的实际应用场景
|
||||
4. 解释常见的误解和注意事项
|
||||
|
||||
解读格式:
|
||||
【条款原文】完整引用条款
|
||||
|
||||
【逐句解读】
|
||||
- "原文句1":解读含义
|
||||
- "原文句2":解读含义
|
||||
...
|
||||
|
||||
【应用场景】
|
||||
具体举例说明该条款在实际工作中如何应用
|
||||
|
||||
【注意事项】
|
||||
常见误解、执行难点、合规风险点""",
|
||||
user_template="""请解读以下法规条款:
|
||||
|
||||
条款编号:{clause_number}
|
||||
条款内容:{content}
|
||||
|
||||
用户关注点:{query}""",
|
||||
description="条款详细解读模板"
|
||||
)
|
||||
|
||||
# 合规检查模板(判断合规状态)
|
||||
COMPLIANCE_CHECK = PromptTemplate(
|
||||
name="compliance_check",
|
||||
system_prompt="""你是合规检查专家,负责评估企业行为或产品的合规状态。
|
||||
|
||||
检查流程:
|
||||
1. 理解企业行为/产品描述
|
||||
2. 识别相关的法规条款
|
||||
3. 逐条对照检查合规状态
|
||||
4. 给出综合合规结论和整改建议
|
||||
|
||||
合规状态分类:
|
||||
- ✅ 符合:完全满足法规要求
|
||||
- ⚠️ 需评估:需要进一步核实或补充材料
|
||||
- ❌ 不符合:明确违反法规要求
|
||||
- ❓ 无适用条款:检索内容不足以判断
|
||||
|
||||
检查格式:
|
||||
【合规检查报告】
|
||||
|
||||
一、检查对象
|
||||
{描述企业行为/产品}
|
||||
|
||||
二、条款对照检查
|
||||
| 条款编号 | 要求摘要 | 检查状态 | 说明 |
|
||||
|--------|---------|---------|------|
|
||||
| 【条款X.X.X】 | ... | ✅/⚠️/❌/❓ | ... |
|
||||
|
||||
三、综合结论
|
||||
合规等级:A/B/C/D(完全合规/基本合规/部分合规/不合规)
|
||||
|
||||
四、整改建议(如需要)
|
||||
1. ...
|
||||
2. ...""",
|
||||
user_template="""请对以下企业行为进行合规检查:
|
||||
|
||||
【行为/产品描述】
|
||||
{query}
|
||||
|
||||
【相关法规条款】
|
||||
{context}""",
|
||||
description="合规检查评估模板"
|
||||
)
|
||||
|
||||
# 差异对比模板(新旧法规对比)
|
||||
COMPARISON = PromptTemplate(
|
||||
name="comparison",
|
||||
system_prompt="""你是法规变更分析专家,负责对比新旧法规版本的差异。
|
||||
|
||||
对比任务:
|
||||
1. 识别新旧版本的条款差异
|
||||
2. 分类差异类型(新增/修改/删除)
|
||||
3. 分析差异的影响范围
|
||||
4. 给出企业应对建议
|
||||
|
||||
差异分类:
|
||||
- 🆕 新增条款:原版本不存在
|
||||
- 🔄 修改条款:内容有实质性变更
|
||||
- ❌ 删除条款:原条款被移除
|
||||
- ⚖️ 调整条款:仅格式/编号调整,实质内容不变
|
||||
|
||||
对比格式:
|
||||
【法规变更对比分析】
|
||||
|
||||
一、变更概述
|
||||
- 旧版本:{version_old}
|
||||
- 新版本:{version_new}
|
||||
- 变更条款数:{count}
|
||||
|
||||
二、差异明细
|
||||
| 类型 | 条款编号 | 旧版本内容 | 新版本内容 | 变化要点 |
|
||||
|-----|---------|-----------|-----------|---------|
|
||||
| 🆕 | X.X.X | - | ... | 新增要求... |
|
||||
|
||||
三、影响分析
|
||||
- 高影响:{条款列表}
|
||||
- 中影响:{条款列表}
|
||||
- 低影响:{条款列表}
|
||||
|
||||
四、应对建议
|
||||
1. 立即整改项
|
||||
2. 逐步调整项
|
||||
3. 持续关注项""",
|
||||
user_template="""请对比分析以下法规差异:
|
||||
|
||||
【用户问题】
|
||||
{query}
|
||||
|
||||
【旧版本条款】
|
||||
{context_old}
|
||||
|
||||
【新版本条款】
|
||||
{context_new}""",
|
||||
description="法规版本对比模板"
|
||||
)
|
||||
|
||||
# 报告生成模板
|
||||
REPORT_GENERATION = PromptTemplate(
|
||||
name="report_generation",
|
||||
system_prompt="""你是合规报告撰写专家,负责生成结构化的合规分析报告。
|
||||
|
||||
报告要求:
|
||||
1. 结构清晰、逻辑严谨
|
||||
2. 数据准确、引用规范
|
||||
3. 结论明确、建议可操作
|
||||
4. 语言专业、表达简洁
|
||||
|
||||
报告结构:
|
||||
1. 概述(背景、范围)
|
||||
2. 分析内容(主体分析)
|
||||
3. 发现问题(合规差距)
|
||||
4. 整改建议(具体措施)
|
||||
5. 附录(引用条款原文)""",
|
||||
user_template="""请生成以下合规报告:
|
||||
|
||||
【报告主题】
|
||||
{query}
|
||||
|
||||
【分析依据】
|
||||
{context}
|
||||
|
||||
【报告要求】
|
||||
{requirements}""",
|
||||
description="合规报告生成模板"
|
||||
)
|
||||
|
||||
# 文档摘要生成模板
|
||||
DOCUMENT_SUMMARY = PromptTemplate(
|
||||
name="document_summary",
|
||||
system_prompt="""你是法规文档摘要专家,负责生成法规文档的核心要点摘要。
|
||||
|
||||
摘要要求:
|
||||
1. 精炼核心内容,不超过1024字
|
||||
2. 突出关键合规要求和条款编号
|
||||
3. 说明适用范围和生效条件
|
||||
4. 列出重要定义和术语解释
|
||||
|
||||
摘要格式:
|
||||
【法规名称】{doc_name}
|
||||
|
||||
【适用范围】{适用范围描述}
|
||||
|
||||
【核心条款摘要】
|
||||
- 【条款X.X.X】{关键要求}(重要度:高)
|
||||
- ...
|
||||
|
||||
【关键术语】
|
||||
- 术语1:定义解释
|
||||
- ...
|
||||
|
||||
【合规要点】
|
||||
1. 必须满足的核心要求
|
||||
2. 需要特别注意的条款""",
|
||||
user_template="""请生成以下法规文档的摘要:
|
||||
|
||||
【文档名称】
|
||||
{doc_name}
|
||||
|
||||
【文档内容】
|
||||
{content}
|
||||
|
||||
请生成不超过1024字的摘要。""",
|
||||
description="文档摘要生成模板"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_template(cls, name: str) -> Optional[PromptTemplate]:
|
||||
"""获取指定模板"""
|
||||
templates = {
|
||||
"compliance_qa": cls.COMPLIANCE_QA,
|
||||
"clause_interpretation": cls.CLAUSE_INTERPRETATION,
|
||||
"compliance_check": cls.COMPLIANCE_CHECK,
|
||||
"comparison": cls.COMPARISON,
|
||||
"report": cls.REPORT_GENERATION,
|
||||
"document_summary": cls.DOCUMENT_SUMMARY
|
||||
}
|
||||
return templates.get(name)
|
||||
|
||||
@classmethod
|
||||
def list_templates(cls) -> Dict[str, str]:
|
||||
"""列出所有模板"""
|
||||
return {
|
||||
"compliance_qa": cls.COMPLIANCE_QA.description,
|
||||
"clause_interpretation": cls.CLAUSE_INTERPRETATION.description,
|
||||
"compliance_check": cls.COMPLIANCE_CHECK.description,
|
||||
"comparison": cls.COMPARISON.description,
|
||||
"report": cls.REPORT_GENERATION.description,
|
||||
"document_summary": cls.DOCUMENT_SUMMARY.description
|
||||
}
|
||||
|
||||
|
||||
def get_prompt_template(name: str) -> PromptTemplate:
|
||||
"""便捷函数:获取Prompt模板"""
|
||||
template = PromptTemplates.get_template(name)
|
||||
if not template:
|
||||
raise ValueError(f"不存在的模板: {name}")
|
||||
return template
|
||||
193
backend/app/services/rag/retriever.py
Normal file
193
backend/app/services/rag/retriever.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# src/services/rag/retriever.py
|
||||
"""RAG检索服务 - 封装Milvus检索"""
|
||||
|
||||
from typing import List, Dict, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
|
||||
from app.services.embedding.bge_m3_embedder import BGEM3Embedder
|
||||
from app.services.storage.milvus_client import MilvusClient, SearchResult
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievedDocument:
|
||||
"""检索到的文档"""
|
||||
content: str
|
||||
doc_id: str # 文档ID,用于下载
|
||||
doc_name: str
|
||||
section_title: str
|
||||
clause_number: str
|
||||
page_number: int
|
||||
score: float
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class Retriever:
|
||||
"""
|
||||
RAG检索器
|
||||
|
||||
功能:
|
||||
- 向量检索(Dense + Sparse混合)
|
||||
- 重排序(可选)
|
||||
- 过滤和筛选
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int = None,
|
||||
rerank: bool = False,
|
||||
min_score: float = 0.3
|
||||
):
|
||||
"""
|
||||
初始化检索器
|
||||
|
||||
Args:
|
||||
top_k: 检索召回数量
|
||||
rerank: 是否启用重排序
|
||||
min_score: 最低相关性分数阈值
|
||||
"""
|
||||
self.top_k = top_k or settings.rag_top_k
|
||||
self.rerank = rerank
|
||||
self.min_score = min_score
|
||||
|
||||
# 嵌入模型(延迟加载)
|
||||
self.embedder: Optional[BGEM3Embedder] = None
|
||||
|
||||
# Milvus客户端(延迟连接)
|
||||
self.milvus: Optional[MilvusClient] = None
|
||||
|
||||
logger.info(f"检索器初始化: top_k={self.top_k}, rerank={self.rerank}")
|
||||
|
||||
def _init_embedder(self):
|
||||
"""延迟初始化嵌入模型"""
|
||||
if self.embedder is None:
|
||||
logger.info("加载嵌入模型...")
|
||||
self.embedder = BGEM3Embedder(model_name=settings.embedding_model)
|
||||
|
||||
def _init_milvus(self):
|
||||
"""延迟初始化Milvus"""
|
||||
if self.milvus is None:
|
||||
logger.info("连接Milvus...")
|
||||
self.milvus = MilvusClient()
|
||||
self.milvus.connect()
|
||||
self.milvus.create_collection(recreate=False)
|
||||
self.milvus.load_collection()
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
filters: Optional[str] = None,
|
||||
top_k: Optional[int] = None
|
||||
) -> List[RetrievedDocument]:
|
||||
"""
|
||||
检索相关文档
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
filters: 过滤条件(如 "regulation_type=='车辆安全'")
|
||||
top_k: 返回数量(可选,覆盖默认值)
|
||||
|
||||
Returns:
|
||||
List[RetrievedDocument]: 检索结果列表
|
||||
"""
|
||||
logger.info(f"执行检索: {query}")
|
||||
|
||||
# 初始化组件
|
||||
self._init_embedder()
|
||||
self._init_milvus()
|
||||
|
||||
# 生成查询向量
|
||||
query_embedding = self.embedder.embed_single(query)
|
||||
|
||||
# 执行混合检索
|
||||
results = self.milvus.hybrid_search(
|
||||
query_dense=query_embedding['dense'].tolist(),
|
||||
query_sparse=query_embedding['sparse'],
|
||||
top_k=top_k or self.top_k,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
# 转换为RetrievedDocument格式
|
||||
documents = []
|
||||
for r in results:
|
||||
if r.score >= self.min_score:
|
||||
doc = RetrievedDocument(
|
||||
content=r.content,
|
||||
doc_id=r.metadata.get("doc_id", ""),
|
||||
doc_name=r.metadata.get("doc_name", ""),
|
||||
section_title=r.metadata.get("section_title", ""),
|
||||
clause_number=r.metadata.get("clause_number", ""),
|
||||
page_number=r.metadata.get("page_number", 0),
|
||||
score=r.score,
|
||||
metadata=r.metadata
|
||||
)
|
||||
documents.append(doc)
|
||||
|
||||
logger.success(f"检索完成,返回{len(documents)}条结果(阈值过滤后)")
|
||||
return documents
|
||||
|
||||
def retrieve_with_scores(
|
||||
self,
|
||||
query: str,
|
||||
filters: Optional[str] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
检索并返回完整结果(包含分数)
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
filters: 过滤条件
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含分数的检索结果
|
||||
"""
|
||||
documents = self.retrieve(query, filters)
|
||||
return [
|
||||
{
|
||||
"content": doc.content,
|
||||
"doc_id": doc.doc_id,
|
||||
"doc_name": doc.doc_name,
|
||||
"section_title": doc.section_title,
|
||||
"clause_number": doc.clause_number,
|
||||
"page_number": doc.page_number,
|
||||
"score": doc.score
|
||||
}
|
||||
for doc in documents
|
||||
]
|
||||
|
||||
def search_by_doc_name(
|
||||
self,
|
||||
query: str,
|
||||
doc_name: str
|
||||
) -> List[RetrievedDocument]:
|
||||
"""按文档名称过滤检索"""
|
||||
filters = f'doc_name=="{doc_name}"'
|
||||
return self.retrieve(query, filters)
|
||||
|
||||
def search_by_regulation_type(
|
||||
self,
|
||||
query: str,
|
||||
regulation_type: str
|
||||
) -> List[RetrievedDocument]:
|
||||
"""按法规类型过滤检索"""
|
||||
filters = f'regulation_type=="{regulation_type}"'
|
||||
return self.retrieve(query, filters)
|
||||
|
||||
def close(self):
|
||||
"""关闭连接"""
|
||||
if self.milvus:
|
||||
self.milvus.disconnect()
|
||||
logger.info("检索器已关闭")
|
||||
|
||||
|
||||
def retrieve_regulations(
|
||||
query: str,
|
||||
top_k: int = 10,
|
||||
filters: Optional[str] = None
|
||||
) -> List[RetrievedDocument]:
|
||||
"""便捷函数:检索法规"""
|
||||
retriever = Retriever(top_k=top_k)
|
||||
results = retriever.retrieve(query, filters)
|
||||
retriever.close()
|
||||
return results
|
||||
Reference in New Issue
Block a user