Files

190 lines
6.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Provide service-layer logic for context builder."""
from typing import List, Dict, Optional
from dataclasses import dataclass
from loguru import logger
from .retriever import RetrievedDocument
from app.config.settings import settings
# Keep service responsibilities explicit so downstream behavior stays predictable.
@dataclass
class RAGContext:
"""Represent the R A G Context type."""
system_prompt: str
context_text: str
user_query: str
total_tokens: int
sources: List[Dict]
truncated: bool = False
class ContextBuilder:
"""Provide the Context Builder builder."""
def __init__(
self,
max_context_tokens: int = None,
include_metadata: bool = True,
citation_format: str = "【条款{clause}"
):
"""Initialize the Context Builder instance."""
self.max_context_tokens = max_context_tokens or settings.rag_max_context_tokens
self.include_metadata = include_metadata
self.citation_format = citation_format
logger.info(f"上下文构建器初始化: max_tokens={self.max_context_tokens}")
def build(
self,
query: str,
documents: List[RetrievedDocument],
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None
) -> RAGContext:
"""Handle build for the Context Builder instance."""
max_tokens = max_tokens or self.max_context_tokens
# Keep service responsibilities explicit so downstream behavior stays predictable.
context_text, sources, truncated = self._format_documents(
documents,
max_tokens
)
# Keep service responsibilities explicit so downstream behavior stays predictable.
system_prompt = system_prompt or self._default_system_prompt()
# Keep service responsibilities explicit so downstream behavior stays predictable.
total_tokens = self._estimate_tokens(system_prompt + context_text + query)
logger.info(f"上下文构建完成: {len(documents)}条文档, {total_tokens}tokens, truncated={truncated}")
return RAGContext(
system_prompt=system_prompt,
context_text=context_text,
user_query=query,
total_tokens=total_tokens,
sources=sources,
truncated=truncated
)
def _format_documents(
self,
documents: List[RetrievedDocument],
max_tokens: int
) -> tuple:
"""Handle format documents for this module for the Context Builder instance."""
context_parts = []
sources = []
current_tokens = 0
truncated = False
for i, doc in enumerate(documents):
# Keep service responsibilities explicit so downstream behavior stays predictable.
formatted = self._format_single_doc(doc, i + 1)
# Keep service responsibilities explicit so downstream behavior stays predictable.
doc_tokens = self._estimate_tokens(formatted)
# Keep service responsibilities explicit so downstream behavior stays predictable.
if current_tokens + doc_tokens > max_tokens:
truncated = True
logger.warning(f"上下文截断: 已达到{max_tokens}tokens限制")
break
context_parts.append(formatted)
current_tokens += doc_tokens
# Keep service responsibilities explicit so downstream behavior stays predictable.
sources.append({
"index": i + 1,
"doc_id": doc.doc_id,
"doc_name": doc.doc_name,
"section_title": doc.section_title,
"clause_number": doc.clause_number,
"page_number": doc.page_number,
"score": doc.score
})
context_text = "\n\n".join(context_parts)
return context_text, sources, truncated
def _format_single_doc(
self,
doc: RetrievedDocument,
index: int
) -> str:
"""Handle format single doc for this module for the Context Builder instance."""
parts = []
# Keep service responsibilities explicit so downstream behavior stays predictable.
parts.append(f"[{index}]")
# Keep service responsibilities explicit so downstream behavior stays predictable.
if self.include_metadata:
meta_parts = []
if doc.doc_name:
meta_parts.append(f"文档: {doc.doc_name}")
if doc.section_title:
meta_parts.append(f"章节: {doc.section_title}")
if doc.clause_number:
clause_text = self.citation_format.format(clause=doc.clause_number)
meta_parts.append(clause_text)
if meta_parts:
parts.append(" | ".join(meta_parts))
# Keep service responsibilities explicit so downstream behavior stays predictable.
parts.append(doc.content)
return "\n".join(parts)
def _default_system_prompt(self) -> str:
"""Handle default system prompt for this module for the Context Builder instance."""
return """你是合规专家助手,基于提供的法规条款回答问题。
回答要求:
1. 直接回答问题必须引用具体条款编号如【条款5.2.1】)
2. 如引用的条款不完整,说明需要进一步查阅原文
3. 给出明确的合规建议和操作指导
4. 如果检索内容不足以回答问题,如实说明
回答格式:
- 先给出直接结论
- 然后引用支撑条款
- 最后给出合规建议"""
def _estimate_tokens(self, text: str) -> int:
"""Handle estimate tokens for this module for the Context Builder instance."""
# Keep service responsibilities explicit so downstream behavior stays predictable.
chinese_chars = sum(1 for c in text if '' <= c <= '鿿')
other_chars = len(text) - chinese_chars
return int(chinese_chars * 1.5 + other_chars * 0.25)
def build_messages(
self,
context: RAGContext
) -> List[Dict[str, str]]:
"""Build messages for the Context Builder instance."""
messages = [
{"role": "system", "content": context.system_prompt},
{"role": "user", "content": f"参考以下法规条款回答问题。\n\n{context.context_text}\n\n问题:{context.user_query}"}
]
return messages
def build_rag_context(
query: str,
documents: List[RetrievedDocument],
**kwargs
) -> RAGContext:
"""Build rag context."""
builder = ContextBuilder()
return builder.build(query, documents, **kwargs)