Files
AIRegulation-DocAnalysis/backend/app/services/rag/context_builder.py

230 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""RAG上下文构建服务 - 构建LLM输入上下文"""
from typing import List, Dict, Optional
from dataclasses import dataclass
from loguru import logger
from .retriever import RetrievedDocument
from app.config.settings import settings
@dataclass
class RAGContext:
"""RAG构建的上下文"""
system_prompt: str
context_text: str
user_query: str
total_tokens: int
sources: List[Dict]
truncated: bool = False
class ContextBuilder:
"""
RAG上下文构建器
功能:
- 格式化检索结果为上下文文本
- 控制上下文长度token限制
- 构建完整的LLM输入格式
"""
def __init__(
self,
max_context_tokens: int = None,
include_metadata: bool = True,
citation_format: str = "【条款{clause}"
):
"""
初始化上下文构建器
Args:
max_context_tokens: 最大上下文token数
include_metadata: 是否包含元数据(文档名、条款号等)
citation_format: 引用格式模板
"""
self.max_context_tokens = max_context_tokens or settings.rag_max_context_tokens
self.include_metadata = include_metadata
self.citation_format = citation_format
logger.info(f"上下文构建器初始化: max_tokens={self.max_context_tokens}")
def build(
self,
query: str,
documents: List[RetrievedDocument],
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None
) -> RAGContext:
"""
构建RAG上下文
Args:
query: 用户查询
documents: 检索到的文档列表
system_prompt: 系统提示词(可选)
max_tokens: 最大token数可选覆盖默认值
Returns:
RAGContext: 构建的上下文对象
"""
max_tokens = max_tokens or self.max_context_tokens
# 格式化文档内容
context_text, sources, truncated = self._format_documents(
documents,
max_tokens
)
# 构建系统提示词
system_prompt = system_prompt or self._default_system_prompt()
# 估算总token数
total_tokens = self._estimate_tokens(system_prompt + context_text + query)
logger.info(f"上下文构建完成: {len(documents)}条文档, {total_tokens}tokens, truncated={truncated}")
return RAGContext(
system_prompt=system_prompt,
context_text=context_text,
user_query=query,
total_tokens=total_tokens,
sources=sources,
truncated=truncated
)
def _format_documents(
self,
documents: List[RetrievedDocument],
max_tokens: int
) -> tuple:
"""
格式化文档内容
Args:
documents: 文档列表
max_tokens: 最大token数
Returns:
(context_text, sources, truncated)
"""
context_parts = []
sources = []
current_tokens = 0
truncated = False
for i, doc in enumerate(documents):
# 格式化单个文档
formatted = self._format_single_doc(doc, i + 1)
# 估算token数
doc_tokens = self._estimate_tokens(formatted)
# 检查是否超出限制
if current_tokens + doc_tokens > max_tokens:
truncated = True
logger.warning(f"上下文截断: 已达到{max_tokens}tokens限制")
break
context_parts.append(formatted)
current_tokens += doc_tokens
# 记录来源
sources.append({
"index": i + 1,
"doc_id": doc.doc_id,
"doc_name": doc.doc_name,
"section_title": doc.section_title,
"clause_number": doc.clause_number,
"page_number": doc.page_number,
"score": doc.score
})
context_text = "\n\n".join(context_parts)
return context_text, sources, truncated
def _format_single_doc(
self,
doc: RetrievedDocument,
index: int
) -> str:
"""格式化单个文档"""
parts = []
# 索引编号
parts.append(f"[{index}]")
# 元数据(可选)
if self.include_metadata:
meta_parts = []
if doc.doc_name:
meta_parts.append(f"文档: {doc.doc_name}")
if doc.section_title:
meta_parts.append(f"章节: {doc.section_title}")
if doc.clause_number:
clause_text = self.citation_format.format(clause=doc.clause_number)
meta_parts.append(clause_text)
if meta_parts:
parts.append(" | ".join(meta_parts))
# 内容
parts.append(doc.content)
return "\n".join(parts)
def _default_system_prompt(self) -> str:
"""默认系统提示词"""
return """你是合规专家助手,基于提供的法规条款回答问题。
回答要求:
1. 直接回答问题必须引用具体条款编号如【条款5.2.1】)
2. 如引用的条款不完整,说明需要进一步查阅原文
3. 给出明确的合规建议和操作指导
4. 如果检索内容不足以回答问题,如实说明
回答格式:
- 先给出直接结论
- 然后引用支撑条款
- 最后给出合规建议"""
def _estimate_tokens(self, text: str) -> int:
"""估算文本token数"""
# 中文字符约1.5 token英文约0.25 token
chinese_chars = sum(1 for c in text if '' <= c <= '鿿')
other_chars = len(text) - chinese_chars
return int(chinese_chars * 1.5 + other_chars * 0.25)
def build_messages(
self,
context: RAGContext
) -> List[Dict[str, str]]:
"""
构建LLM消息格式
Args:
context: RAG上下文对象
Returns:
List[Dict]: [{"role": "system/user/assistant", "content": "..."}]
"""
messages = [
{"role": "system", "content": context.system_prompt},
{"role": "user", "content": f"参考以下法规条款回答问题。\n\n{context.context_text}\n\n问题:{context.user_query}"}
]
return messages
def build_rag_context(
query: str,
documents: List[RetrievedDocument],
**kwargs
) -> RAGContext:
"""便捷函数构建RAG上下文"""
builder = ContextBuilder()
return builder.build(query, documents, **kwargs)