This commit is contained in:
2026-06-05 09:00:36 +08:00
parent 746513cc54
commit 06e0967128
13 changed files with 4560 additions and 239 deletions

View File

@@ -5,17 +5,19 @@ from __future__ import annotations
import asyncio
import json
from pathlib import Path
from typing import AsyncGenerator
from typing import AsyncGenerator, Optional
from fastapi import APIRouter, File, UploadFile
from fastapi import APIRouter, File, Form, UploadFile
from fastapi.responses import StreamingResponse
from loguru import logger
from app.schemas.compliance import (
AnalyzeResponse,
ComplianceChatRequest,
)
from app.services.mock_data import generate_task_id, get_mock_compliance_result
from app.shared.bootstrap import get_agent_conversation_service
from app.shared.bootstrap import get_agent_conversation_service, get_retrieval_service
from app.config.settings import settings
router = APIRouter(prefix="/compliance", tags=["合规分析"])
@@ -62,6 +64,128 @@ async def get_result(task_id: str):
return task["result"]
def _sse(data: dict) -> str:
return f"event: message\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
@router.post("/analyze-stream")
async def analyze_stream(
text: Optional[str] = Form(None),
doc_id: Optional[str] = Form(None),
file: Optional[UploadFile] = File(None),
domains: Optional[str] = Form(None),
title: Optional[str] = Form(None),
):
"""Stream compliance analysis as SSE events.
Stages: clause_split → retrieval (per clause) → gap_check → conclusion
Events: stage | source | finding | done | error
"""
from app.application.compliance.pipeline import (
check_clause_compliance,
extract_text_from_doc_id,
extract_text_from_file,
retrieve_for_clause,
split_into_clauses,
synthesize_conclusion,
)
from app.services.llm.llm_factory import get_llm_client
from app.shared.bootstrap import get_retrieval_service
# Read file content eagerly (before async generator)
file_content: bytes | None = None
file_name: str | None = None
if file is not None:
file_content = await file.read()
file_name = file.filename
async def generate() -> AsyncGenerator[str, None]:
try:
client = get_llm_client(provider=settings.llm_provider, model=settings.llm_model)
retrieval_service = get_retrieval_service()
# ── Stage 1: extract text ─────────────────────────────────────
yield _sse({"type": "stage", "stage": "extracting", "label": "Extracting text…"})
await asyncio.sleep(0)
if text:
para_text = text.strip()
elif doc_id:
try:
para_text = await asyncio.to_thread(extract_text_from_doc_id, doc_id)
except Exception as exc:
yield _sse({"type": "error", "text": f"Document not found: {exc}"})
return
elif file_content is not None:
para_text = await asyncio.to_thread(
extract_text_from_file, file_content, file_name or "upload"
)
else:
yield _sse({"type": "error", "text": "No input provided"})
return
if not para_text.strip():
yield _sse({"type": "error", "text": "Could not extract text from the provided input"})
return
# ── Stage 2: split into clauses ───────────────────────────────
yield _sse({"type": "stage", "stage": "splitting", "label": "Splitting into clauses…"})
await asyncio.sleep(0)
clauses: list[str] = await asyncio.to_thread(split_into_clauses, para_text, client)
# ── Stage 3: retrieve + gap check per clause ──────────────────
findings: list[dict] = []
for i, clause in enumerate(clauses):
yield _sse({
"type": "stage",
"stage": "analyzing",
"label": f"Analyzing clause {i + 1}/{len(clauses)}",
})
await asyncio.sleep(0)
chunks = await asyncio.to_thread(
retrieve_for_clause, clause, retrieval_service, 5, domains or None
)
# Emit source events
for chunk in chunks[:3]:
yield _sse({
"type": "source",
"standard": getattr(chunk, "doc_title", "") or getattr(chunk, "doc_name", ""),
"clause": getattr(chunk, "section_title", "") or "",
"score": round(float(getattr(chunk, "score", 0)), 3),
"status": "retrieved",
"full_content": (getattr(chunk, "text", "") or "")[:300],
})
await asyncio.sleep(0)
finding = await asyncio.to_thread(check_clause_compliance, clause, chunks, client)
if finding:
findings.append(finding)
yield _sse({"type": "finding", **finding})
await asyncio.sleep(0)
# ── Stage 4: synthesize conclusion ────────────────────────────
yield _sse({"type": "stage", "stage": "concluding", "label": "Generating conclusion…"})
await asyncio.sleep(0)
conclusion_data = await asyncio.to_thread(
synthesize_conclusion, para_text, findings, client
)
yield _sse({"type": "done", **conclusion_data})
except Exception as exc:
logger.exception("analyze-stream pipeline error")
yield _sse({"type": "error", "text": str(exc)})
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
@router.post("/chat/{segment_id}")
async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
"""Stream compliance Q&A grounded in real vector retrieval."""

View File

@@ -0,0 +1 @@
"""Compliance application layer."""

View File

@@ -0,0 +1,215 @@
"""Compliance analysis pipeline helpers.
All functions are synchronous — call them via asyncio.to_thread() in async SSE generators.
"""
from __future__ import annotations
import json
import os
import re
import tempfile
from typing import TYPE_CHECKING
from loguru import logger
if TYPE_CHECKING:
from app.application.knowledge import KnowledgeRetrievalService
from app.domain.retrieval import RetrievedChunk
from app.services.llm.base_client import BaseLLMClient
def _extract_json(text: str):
"""Extract JSON from LLM response, tolerating markdown wrappers."""
stripped = text.strip()
match = re.search(r"```(?:json)?\s*([\s\S]*?)```", stripped)
if match:
stripped = match.group(1).strip()
try:
return json.loads(stripped)
except json.JSONDecodeError:
pass
for pattern in (r"(\[[\s\S]*\])", r"(\{[\s\S]*\})"):
m = re.search(pattern, stripped)
if m:
try:
return json.loads(m.group(1))
except json.JSONDecodeError:
continue
raise ValueError(f"No valid JSON found in LLM response: {text[:300]}")
def extract_text_from_doc_id(doc_id: str) -> str:
from app.shared.bootstrap import get_document_query_service, get_retrieval_service
doc = get_document_query_service().get(doc_id)
if not doc:
raise ValueError(f"Document '{doc_id}' not found")
service = get_retrieval_service()
chunks = service.retrieve(query=doc.doc_name, top_k=30)
doc_chunks = [c for c in chunks if c.doc_id == doc_id]
if not doc_chunks:
doc_chunks = chunks[:15]
return "\n\n".join(c.text for c in doc_chunks[:15])
def extract_text_from_file(content: bytes, filename: str) -> str:
from app.shared.bootstrap import get_document_command_service
suffix = os.path.splitext(filename or "doc.pdf")[1] or ".pdf"
tmp_path = ""
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(content)
tmp_path = tmp.name
service = get_document_command_service()
parsed = service.parser.parse(file_path=tmp_path, doc_id="tmp_analysis", doc_name=filename)
if parsed.raw_text:
return parsed.raw_text[:4000]
return "\n".join(
b.get("text", "") for b in parsed.semantic_blocks[:30] if b.get("text")
)[:4000]
except Exception as exc:
logger.warning("File text extraction failed: {}", exc)
return ""
finally:
if tmp_path:
try: os.unlink(tmp_path)
except OSError: pass
def split_into_clauses(text: str, client: "BaseLLMClient") -> list[str]:
prompt = (
"You are a compliance analysis expert. Split the following text into 3-8 "
"semantically complete compliance clauses. Each clause should be an independent "
"compliance requirement or technical statement.\n"
"Return as JSON array of strings, e.g.:\n"
'["Clause one...", "Clause two..."]\n'
"Return ONLY the JSON array.\n\n"
f"Text:\n{text[:2000]}"
)
response = client.chat([{"role": "user", "content": prompt}], max_tokens=1000)
if response.is_success:
try:
result = _extract_json(response.content)
if isinstance(result, list):
clauses = [str(c).strip() for c in result if str(c).strip()]
if clauses:
return clauses[:8]
except (ValueError, TypeError):
logger.warning("Clause split JSON parse failed, using fallback")
sentences = re.split(r"[.?!;\n]+", text)
return [s.strip() for s in sentences if len(s.strip()) > 20][:6]
def retrieve_for_clause(
clause: str,
retrieval_service: "KnowledgeRetrievalService",
top_k: int = 5,
domains: str | None = None,
) -> list["RetrievedChunk"]:
return retrieval_service.retrieve(query=clause, top_k=top_k, filters=domains)
def check_clause_compliance(
clause: str,
chunks: list["RetrievedChunk"],
client: "BaseLLMClient",
) -> dict | None:
if not chunks:
return None
reg_context = "\n".join(
f"[{i+1}] {c.doc_title} {c.section_title or ''}: {c.text[:300]}"
for i, c in enumerate(chunks[:5])
)
prompt = (
"You are a compliance expert. Judge whether the following business clause "
"complies with the retrieved regulations.\n\n"
f"Business clause:\n{clause}\n\n"
f"Retrieved regulations:\n{reg_context}\n\n"
"Return JSON:\n"
"{\n"
' "status": "ok" | "warn" | "risk",\n'
' "title": "Short finding title (max 30 chars)",\n'
' "desc": "Description (50-120 chars)",\n'
' "clause_ref": "Regulation clause reference e.g. Art.9.1 or Sec.3.1"\n'
"}\n"
"status: ok=compliant, warn=gap exists, risk=critical/missing\n"
"Return ONLY the JSON object."
)
response = client.chat([{"role": "user", "content": prompt}], max_tokens=500)
if not response.is_success:
return None
try:
result = _extract_json(response.content)
if isinstance(result, dict) and "status" in result:
return {
"title": str(result.get("title", "Compliance finding")),
"desc": str(result.get("desc", "")),
"status": result.get("status", "info"),
"clause_ref": result.get("clause_ref"),
}
except (ValueError, TypeError) as exc:
logger.warning("Gap check JSON parse failed: {}", exc)
return None
def synthesize_conclusion(
para_text: str,
findings: list[dict],
client: "BaseLLMClient",
) -> dict:
if not findings:
return {
"conclusion": "No significant compliance gaps found. Continue monitoring regulation updates.",
"actions": [{"label": "Next action", "value": "Monitor regulation updates"}],
"risk_score": 10,
"highlight_terms": [],
"para_text": para_text[:800],
}
findings_text = "\n".join(
f"- [{f['status'].upper()}] {f['title']}: {f['desc']}"
for f in findings
)
prompt = (
"You are a compliance analysis expert. Generate a summary report "
"based on the following compliance findings.\n\n"
f"Original text (first 600 chars):\n{para_text[:600]}\n\n"
f"Findings:\n{findings_text}\n\n"
"Return JSON:\n"
"{\n"
' "conclusion": "Overall compliance conclusion (100-200 chars)",\n'
' "actions": [\n'
' {"label": "Action label", "value": "Description"},\n'
' {"label": "Priority", "value": "High/Medium/Low", "risk": true}\n'
' ],\n'
' "risk_score": 0-100 (integer, higher=riskier),\n'
' "highlight_terms": ["Key terms to highlight, max 10 terms"],\n'
' "para_text": "Original text or summary (max 600 chars)"\n'
"}\n"
"Return ONLY the JSON object."
)
response = client.chat([{"role": "user", "content": prompt}], max_tokens=1200)
fallback = {
"conclusion": "Compliance analysis complete. Review findings and create remediation plan.",
"actions": [
{"label": "Next action", "value": "Review critical findings"},
{"label": "Escalation", "value": "Legal review required", "risk": True},
],
"risk_score": 60,
"highlight_terms": [],
"para_text": para_text[:800],
}
if not response.is_success:
return fallback
try:
result = _extract_json(response.content)
if isinstance(result, dict):
return {
"conclusion": str(result.get("conclusion", fallback["conclusion"])),
"actions": result.get("actions", fallback["actions"]),
"risk_score": int(result.get("risk_score", 60)),
"highlight_terms": result.get("highlight_terms", []),
"para_text": str(result.get("para_text", para_text[:800])),
}
except (ValueError, TypeError) as exc:
logger.warning("Conclusion synthesis JSON parse failed: {}", exc)
return fallback

View File

@@ -80,4 +80,30 @@ class ComplianceChatRequest(BaseModel):
class AnalyzeResponse(BaseModel):
"""Define the Analyze Response API model."""
task_id: str
status: str = "processing"
status: str = "processing"
class AnalyzeStreamSource(BaseModel):
"""SSE source event payload for analyze-stream."""
standard: str
clause: str
score: float
status: str
full_content: str
class AnalyzeStreamFinding(BaseModel):
"""SSE finding event payload for analyze-stream."""
title: str
desc: str
status: str
clause_ref: Optional[str] = None
class AnalyzeStreamDone(BaseModel):
"""SSE done event payload for analyze-stream."""
conclusion: str
actions: list[dict]
risk_score: int
highlight_terms: list[str]
para_text: str