Files
AIRegulation-DocAnalysis/backend/app/services/rag/context_builder.py

190 lines
6.5 KiB
Python
Raw Permalink Normal View History

"""Provide service-layer logic for context builder."""
2026-05-14 15:07:34 +08:00
from typing import List, Dict, Optional
from dataclasses import dataclass
from loguru import logger
from .retriever import RetrievedDocument
from app.config.settings import settings
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
@dataclass
class RAGContext:
"""Represent the R A G Context type."""
2026-05-14 15:07:34 +08:00
system_prompt: str
context_text: str
user_query: str
total_tokens: int
sources: List[Dict]
truncated: bool = False
class ContextBuilder:
"""Provide the Context Builder builder."""
2026-05-14 15:07:34 +08:00
def __init__(
self,
max_context_tokens: int = None,
include_metadata: bool = True,
citation_format: str = "【条款{clause}"
):
"""Initialize the Context Builder instance."""
2026-05-14 15:07:34 +08:00
self.max_context_tokens = max_context_tokens or settings.rag_max_context_tokens
self.include_metadata = include_metadata
self.citation_format = citation_format
logger.info(f"上下文构建器初始化: max_tokens={self.max_context_tokens}")
def build(
self,
query: str,
documents: List[RetrievedDocument],
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None
) -> RAGContext:
"""Handle build for the Context Builder instance."""
2026-05-14 15:07:34 +08:00
max_tokens = max_tokens or self.max_context_tokens
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
context_text, sources, truncated = self._format_documents(
documents,
max_tokens
)
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
system_prompt = system_prompt or self._default_system_prompt()
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
total_tokens = self._estimate_tokens(system_prompt + context_text + query)
logger.info(f"上下文构建完成: {len(documents)}条文档, {total_tokens}tokens, truncated={truncated}")
return RAGContext(
system_prompt=system_prompt,
context_text=context_text,
user_query=query,
total_tokens=total_tokens,
sources=sources,
truncated=truncated
)
def _format_documents(
self,
documents: List[RetrievedDocument],
max_tokens: int
) -> tuple:
"""Handle format documents for this module for the Context Builder instance."""
2026-05-14 15:07:34 +08:00
context_parts = []
sources = []
current_tokens = 0
truncated = False
for i, doc in enumerate(documents):
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
formatted = self._format_single_doc(doc, i + 1)
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
doc_tokens = self._estimate_tokens(formatted)
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
if current_tokens + doc_tokens > max_tokens:
truncated = True
logger.warning(f"上下文截断: 已达到{max_tokens}tokens限制")
break
context_parts.append(formatted)
current_tokens += doc_tokens
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
sources.append({
"index": i + 1,
"doc_id": doc.doc_id,
"doc_name": doc.doc_name,
"section_title": doc.section_title,
"clause_number": doc.clause_number,
"page_number": doc.page_number,
"score": doc.score
})
context_text = "\n\n".join(context_parts)
return context_text, sources, truncated
def _format_single_doc(
self,
doc: RetrievedDocument,
index: int
) -> str:
"""Handle format single doc for this module for the Context Builder instance."""
2026-05-14 15:07:34 +08:00
parts = []
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
parts.append(f"[{index}]")
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
if self.include_metadata:
meta_parts = []
if doc.doc_name:
meta_parts.append(f"文档: {doc.doc_name}")
if doc.section_title:
meta_parts.append(f"章节: {doc.section_title}")
if doc.clause_number:
clause_text = self.citation_format.format(clause=doc.clause_number)
meta_parts.append(clause_text)
if meta_parts:
parts.append(" | ".join(meta_parts))
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
parts.append(doc.content)
return "\n".join(parts)
def _default_system_prompt(self) -> str:
"""Handle default system prompt for this module for the Context Builder instance."""
2026-05-14 15:07:34 +08:00
return """你是合规专家助手,基于提供的法规条款回答问题。
回答要求
1. 直接回答问题必须引用具体条款编号条款5.2.1
2. 如引用的条款不完整说明需要进一步查阅原文
3. 给出明确的合规建议和操作指导
4. 如果检索内容不足以回答问题如实说明
回答格式
- 先给出直接结论
- 然后引用支撑条款
- 最后给出合规建议"""
def _estimate_tokens(self, text: str) -> int:
"""Handle estimate tokens for this module for the Context Builder instance."""
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
chinese_chars = sum(1 for c in text if '' <= c <= '鿿')
other_chars = len(text) - chinese_chars
return int(chinese_chars * 1.5 + other_chars * 0.25)
def build_messages(
self,
context: RAGContext
) -> List[Dict[str, str]]:
"""Build messages for the Context Builder instance."""
2026-05-14 15:07:34 +08:00
messages = [
{"role": "system", "content": context.system_prompt},
{"role": "user", "content": f"参考以下法规条款回答问题。\n\n{context.context_text}\n\n问题:{context.user_query}"}
]
return messages
def build_rag_context(
query: str,
documents: List[RetrievedDocument],
**kwargs
) -> RAGContext:
"""Build rag context."""
2026-05-14 15:07:34 +08:00
builder = ContextBuilder()
return builder.build(query, documents, **kwargs)