Files
AIRegulation-DocAnalysis/backend/app/services/llm/document_summarizer.py

196 lines
6.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Provide service-layer logic for document summarizer."""
from typing import Dict, Optional
from dataclasses import dataclass
from loguru import logger
from app.services.llm.base_client import BaseLLMClient
from app.services.llm.llm_factory import get_llm_client
from app.services.rag.prompt_templates import get_prompt_template
from app.config.settings import settings
# Keep provider-specific behavior explicit so debugging stays straightforward.
@dataclass
class DocumentSummary:
"""Represent the Document Summary type."""
doc_name: str
summary: str
applicable_scope: str
key_clauses: list
key_terms: list
compliance_points: list
model: str
latency_ms: int
error: Optional[str] = None
@property
def is_success(self) -> bool:
"""Return whether success for the Document Summary instance."""
return self.error is None
class DocumentSummarizer:
"""Represent the Document Summarizer type."""
def __init__(
self,
provider: str = None,
model: str = None,
max_tokens: int = None
):
"""Initialize the Document Summarizer instance."""
self.provider = provider or settings.llm_provider
self.model = model or settings.llm_model
self.max_tokens = max_tokens or settings.rag_summary_max_tokens
# Keep provider-specific behavior explicit so debugging stays straightforward.
self.llm: Optional[BaseLLMClient] = None
logger.info(f"摘要生成器初始化: provider={self.provider}, model={self.model}")
def _init_llm(self):
"""Handle init llm for this module for the Document Summarizer instance."""
if self.llm is None:
self.llm = get_llm_client(
provider=self.provider,
model=self.model
)
def summarize(
self,
doc_name: str,
content: str,
regulation_type: str = "",
max_tokens: Optional[int] = None
) -> DocumentSummary:
"""Handle summarize for the Document Summarizer instance."""
import time
start_time = time.time()
logger.info(f"生成文档摘要: {doc_name}")
try:
self._init_llm()
# Keep provider-specific behavior explicit so debugging stays straightforward.
template = get_prompt_template("document_summary")
# Keep provider-specific behavior explicit so debugging stays straightforward.
user_content = template.user_template.format(
doc_name=doc_name,
content=content[:8000] # Keep provider-specific behavior explicit so debugging stays straightforward.
)
# Keep provider-specific behavior explicit so debugging stays straightforward.
response = self.llm.chat(
messages=[
{"role": "system", "content": template.system_prompt},
{"role": "user", "content": user_content}
],
max_tokens=max_tokens or self.max_tokens,
temperature=0.3 # Keep provider-specific behavior explicit so debugging stays straightforward.
)
latency_ms = int((time.time() - start_time) * 1000)
if not response.is_success:
return DocumentSummary(
doc_name=doc_name,
summary="",
applicable_scope="",
key_clauses=[],
key_terms=[],
compliance_points=[],
model=self.model,
latency_ms=latency_ms,
error=response.error
)
# Keep provider-specific behavior explicit so debugging stays straightforward.
summary_data = self._parse_summary(response.content)
logger.success(f"摘要生成完成: {doc_name}, {latency_ms}ms")
return DocumentSummary(
doc_name=doc_name,
summary=summary_data.get("summary", response.content),
applicable_scope=summary_data.get("applicable_scope", ""),
key_clauses=summary_data.get("key_clauses", []),
key_terms=summary_data.get("key_terms", []),
compliance_points=summary_data.get("compliance_points", []),
model=response.model,
latency_ms=latency_ms
)
except Exception as e:
logger.error(f"摘要生成失败: {e}")
return DocumentSummary(
doc_name=doc_name,
summary="",
applicable_scope="",
key_clauses=[],
key_terms=[],
compliance_points=[],
model=self.model,
latency_ms=0,
error=str(e)
)
def _parse_summary(self, content: str) -> Dict:
"""Handle parse summary for this module for the Document Summarizer instance."""
result = {
"summary": content,
"applicable_scope": "",
"key_clauses": [],
"key_terms": [],
"compliance_points": []
}
# Keep provider-specific behavior explicit so debugging stays straightforward.
lines = content.split("\n")
for line in lines:
line = line.strip()
# Keep provider-specific behavior explicit so debugging stays straightforward.
if "适用范围" in line or "适用对象" in line:
result["applicable_scope"] = line.split("")[-1].strip() if "" in line else line.split(":")[-1].strip()
# Keep provider-specific behavior explicit so debugging stays straightforward.
if line.startswith("- 【条款") or line.startswith("【条款"):
result["key_clauses"].append(line)
# Keep provider-specific behavior explicit so debugging stays straightforward.
if "关键术语" in line or "术语定义" in line:
# Keep provider-specific behavior explicit so debugging stays straightforward.
pass
# Keep provider-specific behavior explicit so debugging stays straightforward.
if "合规要点" in line or "必须满足" in line:
pass
return result
def batch_summarize(
self,
documents: list
) -> list:
"""Handle batch summarize for the Document Summarizer instance."""
results = []
for doc in documents:
result = self.summarize(doc["doc_name"], doc["content"])
results.append(result)
return results
def summarize_document(
doc_name: str,
content: str,
**kwargs
) -> DocumentSummary:
"""Handle summarize document."""
summarizer = DocumentSummarizer(**kwargs)
return summarizer.summarize(doc_name, content)