230 lines
6.5 KiB
Python
230 lines
6.5 KiB
Python
|
|
"""RAG上下文构建服务 - 构建LLM输入上下文"""
|
|||
|
|
|
|||
|
|
from typing import List, Dict, Optional
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from loguru import logger
|
|||
|
|
|
|||
|
|
from .retriever import RetrievedDocument
|
|||
|
|
from app.config.settings import settings
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class RAGContext:
|
|||
|
|
"""RAG构建的上下文"""
|
|||
|
|
system_prompt: str
|
|||
|
|
context_text: str
|
|||
|
|
user_query: str
|
|||
|
|
total_tokens: int
|
|||
|
|
sources: List[Dict]
|
|||
|
|
truncated: bool = False
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ContextBuilder:
|
|||
|
|
"""
|
|||
|
|
RAG上下文构建器
|
|||
|
|
|
|||
|
|
功能:
|
|||
|
|
- 格式化检索结果为上下文文本
|
|||
|
|
- 控制上下文长度(token限制)
|
|||
|
|
- 构建完整的LLM输入格式
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
max_context_tokens: int = None,
|
|||
|
|
include_metadata: bool = True,
|
|||
|
|
citation_format: str = "【条款{clause}】"
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
初始化上下文构建器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
max_context_tokens: 最大上下文token数
|
|||
|
|
include_metadata: 是否包含元数据(文档名、条款号等)
|
|||
|
|
citation_format: 引用格式模板
|
|||
|
|
"""
|
|||
|
|
self.max_context_tokens = max_context_tokens or settings.rag_max_context_tokens
|
|||
|
|
self.include_metadata = include_metadata
|
|||
|
|
self.citation_format = citation_format
|
|||
|
|
|
|||
|
|
logger.info(f"上下文构建器初始化: max_tokens={self.max_context_tokens}")
|
|||
|
|
|
|||
|
|
def build(
|
|||
|
|
self,
|
|||
|
|
query: str,
|
|||
|
|
documents: List[RetrievedDocument],
|
|||
|
|
system_prompt: Optional[str] = None,
|
|||
|
|
max_tokens: Optional[int] = None
|
|||
|
|
) -> RAGContext:
|
|||
|
|
"""
|
|||
|
|
构建RAG上下文
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
query: 用户查询
|
|||
|
|
documents: 检索到的文档列表
|
|||
|
|
system_prompt: 系统提示词(可选)
|
|||
|
|
max_tokens: 最大token数(可选,覆盖默认值)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
RAGContext: 构建的上下文对象
|
|||
|
|
"""
|
|||
|
|
max_tokens = max_tokens or self.max_context_tokens
|
|||
|
|
|
|||
|
|
# 格式化文档内容
|
|||
|
|
context_text, sources, truncated = self._format_documents(
|
|||
|
|
documents,
|
|||
|
|
max_tokens
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 构建系统提示词
|
|||
|
|
system_prompt = system_prompt or self._default_system_prompt()
|
|||
|
|
|
|||
|
|
# 估算总token数
|
|||
|
|
total_tokens = self._estimate_tokens(system_prompt + context_text + query)
|
|||
|
|
|
|||
|
|
logger.info(f"上下文构建完成: {len(documents)}条文档, {total_tokens}tokens, truncated={truncated}")
|
|||
|
|
|
|||
|
|
return RAGContext(
|
|||
|
|
system_prompt=system_prompt,
|
|||
|
|
context_text=context_text,
|
|||
|
|
user_query=query,
|
|||
|
|
total_tokens=total_tokens,
|
|||
|
|
sources=sources,
|
|||
|
|
truncated=truncated
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _format_documents(
|
|||
|
|
self,
|
|||
|
|
documents: List[RetrievedDocument],
|
|||
|
|
max_tokens: int
|
|||
|
|
) -> tuple:
|
|||
|
|
"""
|
|||
|
|
格式化文档内容
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
documents: 文档列表
|
|||
|
|
max_tokens: 最大token数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
(context_text, sources, truncated)
|
|||
|
|
"""
|
|||
|
|
context_parts = []
|
|||
|
|
sources = []
|
|||
|
|
current_tokens = 0
|
|||
|
|
truncated = False
|
|||
|
|
|
|||
|
|
for i, doc in enumerate(documents):
|
|||
|
|
# 格式化单个文档
|
|||
|
|
formatted = self._format_single_doc(doc, i + 1)
|
|||
|
|
|
|||
|
|
# 估算token数
|
|||
|
|
doc_tokens = self._estimate_tokens(formatted)
|
|||
|
|
|
|||
|
|
# 检查是否超出限制
|
|||
|
|
if current_tokens + doc_tokens > max_tokens:
|
|||
|
|
truncated = True
|
|||
|
|
logger.warning(f"上下文截断: 已达到{max_tokens}tokens限制")
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
context_parts.append(formatted)
|
|||
|
|
current_tokens += doc_tokens
|
|||
|
|
|
|||
|
|
# 记录来源
|
|||
|
|
sources.append({
|
|||
|
|
"index": i + 1,
|
|||
|
|
"doc_id": doc.doc_id,
|
|||
|
|
"doc_name": doc.doc_name,
|
|||
|
|
"section_title": doc.section_title,
|
|||
|
|
"clause_number": doc.clause_number,
|
|||
|
|
"page_number": doc.page_number,
|
|||
|
|
"score": doc.score
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
context_text = "\n\n".join(context_parts)
|
|||
|
|
return context_text, sources, truncated
|
|||
|
|
|
|||
|
|
def _format_single_doc(
|
|||
|
|
self,
|
|||
|
|
doc: RetrievedDocument,
|
|||
|
|
index: int
|
|||
|
|
) -> str:
|
|||
|
|
"""格式化单个文档"""
|
|||
|
|
parts = []
|
|||
|
|
|
|||
|
|
# 索引编号
|
|||
|
|
parts.append(f"[{index}]")
|
|||
|
|
|
|||
|
|
# 元数据(可选)
|
|||
|
|
if self.include_metadata:
|
|||
|
|
meta_parts = []
|
|||
|
|
|
|||
|
|
if doc.doc_name:
|
|||
|
|
meta_parts.append(f"文档: {doc.doc_name}")
|
|||
|
|
|
|||
|
|
if doc.section_title:
|
|||
|
|
meta_parts.append(f"章节: {doc.section_title}")
|
|||
|
|
|
|||
|
|
if doc.clause_number:
|
|||
|
|
clause_text = self.citation_format.format(clause=doc.clause_number)
|
|||
|
|
meta_parts.append(clause_text)
|
|||
|
|
|
|||
|
|
if meta_parts:
|
|||
|
|
parts.append(" | ".join(meta_parts))
|
|||
|
|
|
|||
|
|
# 内容
|
|||
|
|
parts.append(doc.content)
|
|||
|
|
|
|||
|
|
return "\n".join(parts)
|
|||
|
|
|
|||
|
|
def _default_system_prompt(self) -> str:
|
|||
|
|
"""默认系统提示词"""
|
|||
|
|
return """你是合规专家助手,基于提供的法规条款回答问题。
|
|||
|
|
|
|||
|
|
回答要求:
|
|||
|
|
1. 直接回答问题,必须引用具体条款编号(如【条款5.2.1】)
|
|||
|
|
2. 如引用的条款不完整,说明需要进一步查阅原文
|
|||
|
|
3. 给出明确的合规建议和操作指导
|
|||
|
|
4. 如果检索内容不足以回答问题,如实说明
|
|||
|
|
|
|||
|
|
回答格式:
|
|||
|
|
- 先给出直接结论
|
|||
|
|
- 然后引用支撑条款
|
|||
|
|
- 最后给出合规建议"""
|
|||
|
|
|
|||
|
|
def _estimate_tokens(self, text: str) -> int:
|
|||
|
|
"""估算文本token数"""
|
|||
|
|
# 中文字符约1.5 token,英文约0.25 token
|
|||
|
|
chinese_chars = sum(1 for c in text if '一' <= c <= '鿿')
|
|||
|
|
other_chars = len(text) - chinese_chars
|
|||
|
|
return int(chinese_chars * 1.5 + other_chars * 0.25)
|
|||
|
|
|
|||
|
|
def build_messages(
|
|||
|
|
self,
|
|||
|
|
context: RAGContext
|
|||
|
|
) -> List[Dict[str, str]]:
|
|||
|
|
"""
|
|||
|
|
构建LLM消息格式
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
context: RAG上下文对象
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
List[Dict]: [{"role": "system/user/assistant", "content": "..."}]
|
|||
|
|
"""
|
|||
|
|
messages = [
|
|||
|
|
{"role": "system", "content": context.system_prompt},
|
|||
|
|
{"role": "user", "content": f"参考以下法规条款回答问题。\n\n{context.context_text}\n\n问题:{context.user_query}"}
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
return messages
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_rag_context(
|
|||
|
|
query: str,
|
|||
|
|
documents: List[RetrievedDocument],
|
|||
|
|
**kwargs
|
|||
|
|
) -> RAGContext:
|
|||
|
|
"""便捷函数:构建RAG上下文"""
|
|||
|
|
builder = ContextBuilder()
|
|||
|
|
return builder.build(query, documents, **kwargs)
|