add
This commit is contained in:
@@ -5,17 +5,19 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from fastapi import APIRouter, File, UploadFile
|
||||
from fastapi import APIRouter, File, Form, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
from app.schemas.compliance import (
|
||||
AnalyzeResponse,
|
||||
ComplianceChatRequest,
|
||||
)
|
||||
from app.services.mock_data import generate_task_id, get_mock_compliance_result
|
||||
from app.shared.bootstrap import get_agent_conversation_service
|
||||
from app.shared.bootstrap import get_agent_conversation_service, get_retrieval_service
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
router = APIRouter(prefix="/compliance", tags=["合规分析"])
|
||||
@@ -62,6 +64,128 @@ async def get_result(task_id: str):
|
||||
return task["result"]
|
||||
|
||||
|
||||
def _sse(data: dict) -> str:
|
||||
return f"event: message\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
@router.post("/analyze-stream")
|
||||
async def analyze_stream(
|
||||
text: Optional[str] = Form(None),
|
||||
doc_id: Optional[str] = Form(None),
|
||||
file: Optional[UploadFile] = File(None),
|
||||
domains: Optional[str] = Form(None),
|
||||
title: Optional[str] = Form(None),
|
||||
):
|
||||
"""Stream compliance analysis as SSE events.
|
||||
|
||||
Stages: clause_split → retrieval (per clause) → gap_check → conclusion
|
||||
Events: stage | source | finding | done | error
|
||||
"""
|
||||
from app.application.compliance.pipeline import (
|
||||
check_clause_compliance,
|
||||
extract_text_from_doc_id,
|
||||
extract_text_from_file,
|
||||
retrieve_for_clause,
|
||||
split_into_clauses,
|
||||
synthesize_conclusion,
|
||||
)
|
||||
from app.services.llm.llm_factory import get_llm_client
|
||||
from app.shared.bootstrap import get_retrieval_service
|
||||
|
||||
# Read file content eagerly (before async generator)
|
||||
file_content: bytes | None = None
|
||||
file_name: str | None = None
|
||||
if file is not None:
|
||||
file_content = await file.read()
|
||||
file_name = file.filename
|
||||
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
client = get_llm_client(provider=settings.llm_provider, model=settings.llm_model)
|
||||
retrieval_service = get_retrieval_service()
|
||||
|
||||
# ── Stage 1: extract text ─────────────────────────────────────
|
||||
yield _sse({"type": "stage", "stage": "extracting", "label": "Extracting text…"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
if text:
|
||||
para_text = text.strip()
|
||||
elif doc_id:
|
||||
try:
|
||||
para_text = await asyncio.to_thread(extract_text_from_doc_id, doc_id)
|
||||
except Exception as exc:
|
||||
yield _sse({"type": "error", "text": f"Document not found: {exc}"})
|
||||
return
|
||||
elif file_content is not None:
|
||||
para_text = await asyncio.to_thread(
|
||||
extract_text_from_file, file_content, file_name or "upload"
|
||||
)
|
||||
else:
|
||||
yield _sse({"type": "error", "text": "No input provided"})
|
||||
return
|
||||
|
||||
if not para_text.strip():
|
||||
yield _sse({"type": "error", "text": "Could not extract text from the provided input"})
|
||||
return
|
||||
|
||||
# ── Stage 2: split into clauses ───────────────────────────────
|
||||
yield _sse({"type": "stage", "stage": "splitting", "label": "Splitting into clauses…"})
|
||||
await asyncio.sleep(0)
|
||||
clauses: list[str] = await asyncio.to_thread(split_into_clauses, para_text, client)
|
||||
|
||||
# ── Stage 3: retrieve + gap check per clause ──────────────────
|
||||
findings: list[dict] = []
|
||||
|
||||
for i, clause in enumerate(clauses):
|
||||
yield _sse({
|
||||
"type": "stage",
|
||||
"stage": "analyzing",
|
||||
"label": f"Analyzing clause {i + 1}/{len(clauses)}…",
|
||||
})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
chunks = await asyncio.to_thread(
|
||||
retrieve_for_clause, clause, retrieval_service, 5, domains or None
|
||||
)
|
||||
|
||||
# Emit source events
|
||||
for chunk in chunks[:3]:
|
||||
yield _sse({
|
||||
"type": "source",
|
||||
"standard": getattr(chunk, "doc_title", "") or getattr(chunk, "doc_name", ""),
|
||||
"clause": getattr(chunk, "section_title", "") or "",
|
||||
"score": round(float(getattr(chunk, "score", 0)), 3),
|
||||
"status": "retrieved",
|
||||
"full_content": (getattr(chunk, "text", "") or "")[:300],
|
||||
})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
finding = await asyncio.to_thread(check_clause_compliance, clause, chunks, client)
|
||||
if finding:
|
||||
findings.append(finding)
|
||||
yield _sse({"type": "finding", **finding})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# ── Stage 4: synthesize conclusion ────────────────────────────
|
||||
yield _sse({"type": "stage", "stage": "concluding", "label": "Generating conclusion…"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
conclusion_data = await asyncio.to_thread(
|
||||
synthesize_conclusion, para_text, findings, client
|
||||
)
|
||||
yield _sse({"type": "done", **conclusion_data})
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception("analyze-stream pipeline error")
|
||||
yield _sse({"type": "error", "text": str(exc)})
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat/{segment_id}")
|
||||
async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
|
||||
"""Stream compliance Q&A grounded in real vector retrieval."""
|
||||
|
||||
1
backend/app/application/compliance/__init__.py
Normal file
1
backend/app/application/compliance/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Compliance application layer."""
|
||||
215
backend/app/application/compliance/pipeline.py
Normal file
215
backend/app/application/compliance/pipeline.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""Compliance analysis pipeline helpers.
|
||||
|
||||
All functions are synchronous — call them via asyncio.to_thread() in async SSE generators.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.application.knowledge import KnowledgeRetrievalService
|
||||
from app.domain.retrieval import RetrievedChunk
|
||||
from app.services.llm.base_client import BaseLLMClient
|
||||
|
||||
|
||||
def _extract_json(text: str):
|
||||
"""Extract JSON from LLM response, tolerating markdown wrappers."""
|
||||
stripped = text.strip()
|
||||
match = re.search(r"```(?:json)?\s*([\s\S]*?)```", stripped)
|
||||
if match:
|
||||
stripped = match.group(1).strip()
|
||||
try:
|
||||
return json.loads(stripped)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
for pattern in (r"(\[[\s\S]*\])", r"(\{[\s\S]*\})"):
|
||||
m = re.search(pattern, stripped)
|
||||
if m:
|
||||
try:
|
||||
return json.loads(m.group(1))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
raise ValueError(f"No valid JSON found in LLM response: {text[:300]}")
|
||||
|
||||
|
||||
def extract_text_from_doc_id(doc_id: str) -> str:
|
||||
from app.shared.bootstrap import get_document_query_service, get_retrieval_service
|
||||
doc = get_document_query_service().get(doc_id)
|
||||
if not doc:
|
||||
raise ValueError(f"Document '{doc_id}' not found")
|
||||
service = get_retrieval_service()
|
||||
chunks = service.retrieve(query=doc.doc_name, top_k=30)
|
||||
doc_chunks = [c for c in chunks if c.doc_id == doc_id]
|
||||
if not doc_chunks:
|
||||
doc_chunks = chunks[:15]
|
||||
return "\n\n".join(c.text for c in doc_chunks[:15])
|
||||
|
||||
|
||||
def extract_text_from_file(content: bytes, filename: str) -> str:
|
||||
from app.shared.bootstrap import get_document_command_service
|
||||
suffix = os.path.splitext(filename or "doc.pdf")[1] or ".pdf"
|
||||
tmp_path = ""
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
service = get_document_command_service()
|
||||
parsed = service.parser.parse(file_path=tmp_path, doc_id="tmp_analysis", doc_name=filename)
|
||||
if parsed.raw_text:
|
||||
return parsed.raw_text[:4000]
|
||||
return "\n".join(
|
||||
b.get("text", "") for b in parsed.semantic_blocks[:30] if b.get("text")
|
||||
)[:4000]
|
||||
except Exception as exc:
|
||||
logger.warning("File text extraction failed: {}", exc)
|
||||
return ""
|
||||
finally:
|
||||
if tmp_path:
|
||||
try: os.unlink(tmp_path)
|
||||
except OSError: pass
|
||||
|
||||
|
||||
def split_into_clauses(text: str, client: "BaseLLMClient") -> list[str]:
|
||||
prompt = (
|
||||
"You are a compliance analysis expert. Split the following text into 3-8 "
|
||||
"semantically complete compliance clauses. Each clause should be an independent "
|
||||
"compliance requirement or technical statement.\n"
|
||||
"Return as JSON array of strings, e.g.:\n"
|
||||
'["Clause one...", "Clause two..."]\n'
|
||||
"Return ONLY the JSON array.\n\n"
|
||||
f"Text:\n{text[:2000]}"
|
||||
)
|
||||
response = client.chat([{"role": "user", "content": prompt}], max_tokens=1000)
|
||||
if response.is_success:
|
||||
try:
|
||||
result = _extract_json(response.content)
|
||||
if isinstance(result, list):
|
||||
clauses = [str(c).strip() for c in result if str(c).strip()]
|
||||
if clauses:
|
||||
return clauses[:8]
|
||||
except (ValueError, TypeError):
|
||||
logger.warning("Clause split JSON parse failed, using fallback")
|
||||
sentences = re.split(r"[.?!;\n]+", text)
|
||||
return [s.strip() for s in sentences if len(s.strip()) > 20][:6]
|
||||
|
||||
|
||||
def retrieve_for_clause(
|
||||
clause: str,
|
||||
retrieval_service: "KnowledgeRetrievalService",
|
||||
top_k: int = 5,
|
||||
domains: str | None = None,
|
||||
) -> list["RetrievedChunk"]:
|
||||
return retrieval_service.retrieve(query=clause, top_k=top_k, filters=domains)
|
||||
|
||||
|
||||
def check_clause_compliance(
|
||||
clause: str,
|
||||
chunks: list["RetrievedChunk"],
|
||||
client: "BaseLLMClient",
|
||||
) -> dict | None:
|
||||
if not chunks:
|
||||
return None
|
||||
reg_context = "\n".join(
|
||||
f"[{i+1}] {c.doc_title} {c.section_title or ''}: {c.text[:300]}"
|
||||
for i, c in enumerate(chunks[:5])
|
||||
)
|
||||
prompt = (
|
||||
"You are a compliance expert. Judge whether the following business clause "
|
||||
"complies with the retrieved regulations.\n\n"
|
||||
f"Business clause:\n{clause}\n\n"
|
||||
f"Retrieved regulations:\n{reg_context}\n\n"
|
||||
"Return JSON:\n"
|
||||
"{\n"
|
||||
' "status": "ok" | "warn" | "risk",\n'
|
||||
' "title": "Short finding title (max 30 chars)",\n'
|
||||
' "desc": "Description (50-120 chars)",\n'
|
||||
' "clause_ref": "Regulation clause reference e.g. Art.9.1 or Sec.3.1"\n'
|
||||
"}\n"
|
||||
"status: ok=compliant, warn=gap exists, risk=critical/missing\n"
|
||||
"Return ONLY the JSON object."
|
||||
)
|
||||
response = client.chat([{"role": "user", "content": prompt}], max_tokens=500)
|
||||
if not response.is_success:
|
||||
return None
|
||||
try:
|
||||
result = _extract_json(response.content)
|
||||
if isinstance(result, dict) and "status" in result:
|
||||
return {
|
||||
"title": str(result.get("title", "Compliance finding")),
|
||||
"desc": str(result.get("desc", "")),
|
||||
"status": result.get("status", "info"),
|
||||
"clause_ref": result.get("clause_ref"),
|
||||
}
|
||||
except (ValueError, TypeError) as exc:
|
||||
logger.warning("Gap check JSON parse failed: {}", exc)
|
||||
return None
|
||||
|
||||
|
||||
def synthesize_conclusion(
|
||||
para_text: str,
|
||||
findings: list[dict],
|
||||
client: "BaseLLMClient",
|
||||
) -> dict:
|
||||
if not findings:
|
||||
return {
|
||||
"conclusion": "No significant compliance gaps found. Continue monitoring regulation updates.",
|
||||
"actions": [{"label": "Next action", "value": "Monitor regulation updates"}],
|
||||
"risk_score": 10,
|
||||
"highlight_terms": [],
|
||||
"para_text": para_text[:800],
|
||||
}
|
||||
findings_text = "\n".join(
|
||||
f"- [{f['status'].upper()}] {f['title']}: {f['desc']}"
|
||||
for f in findings
|
||||
)
|
||||
prompt = (
|
||||
"You are a compliance analysis expert. Generate a summary report "
|
||||
"based on the following compliance findings.\n\n"
|
||||
f"Original text (first 600 chars):\n{para_text[:600]}\n\n"
|
||||
f"Findings:\n{findings_text}\n\n"
|
||||
"Return JSON:\n"
|
||||
"{\n"
|
||||
' "conclusion": "Overall compliance conclusion (100-200 chars)",\n'
|
||||
' "actions": [\n'
|
||||
' {"label": "Action label", "value": "Description"},\n'
|
||||
' {"label": "Priority", "value": "High/Medium/Low", "risk": true}\n'
|
||||
' ],\n'
|
||||
' "risk_score": 0-100 (integer, higher=riskier),\n'
|
||||
' "highlight_terms": ["Key terms to highlight, max 10 terms"],\n'
|
||||
' "para_text": "Original text or summary (max 600 chars)"\n'
|
||||
"}\n"
|
||||
"Return ONLY the JSON object."
|
||||
)
|
||||
response = client.chat([{"role": "user", "content": prompt}], max_tokens=1200)
|
||||
fallback = {
|
||||
"conclusion": "Compliance analysis complete. Review findings and create remediation plan.",
|
||||
"actions": [
|
||||
{"label": "Next action", "value": "Review critical findings"},
|
||||
{"label": "Escalation", "value": "Legal review required", "risk": True},
|
||||
],
|
||||
"risk_score": 60,
|
||||
"highlight_terms": [],
|
||||
"para_text": para_text[:800],
|
||||
}
|
||||
if not response.is_success:
|
||||
return fallback
|
||||
try:
|
||||
result = _extract_json(response.content)
|
||||
if isinstance(result, dict):
|
||||
return {
|
||||
"conclusion": str(result.get("conclusion", fallback["conclusion"])),
|
||||
"actions": result.get("actions", fallback["actions"]),
|
||||
"risk_score": int(result.get("risk_score", 60)),
|
||||
"highlight_terms": result.get("highlight_terms", []),
|
||||
"para_text": str(result.get("para_text", para_text[:800])),
|
||||
}
|
||||
except (ValueError, TypeError) as exc:
|
||||
logger.warning("Conclusion synthesis JSON parse failed: {}", exc)
|
||||
return fallback
|
||||
@@ -80,4 +80,30 @@ class ComplianceChatRequest(BaseModel):
|
||||
class AnalyzeResponse(BaseModel):
|
||||
"""Define the Analyze Response API model."""
|
||||
task_id: str
|
||||
status: str = "processing"
|
||||
status: str = "processing"
|
||||
|
||||
|
||||
class AnalyzeStreamSource(BaseModel):
|
||||
"""SSE source event payload for analyze-stream."""
|
||||
standard: str
|
||||
clause: str
|
||||
score: float
|
||||
status: str
|
||||
full_content: str
|
||||
|
||||
|
||||
class AnalyzeStreamFinding(BaseModel):
|
||||
"""SSE finding event payload for analyze-stream."""
|
||||
title: str
|
||||
desc: str
|
||||
status: str
|
||||
clause_ref: Optional[str] = None
|
||||
|
||||
|
||||
class AnalyzeStreamDone(BaseModel):
|
||||
"""SSE done event payload for analyze-stream."""
|
||||
conclusion: str
|
||||
actions: list[dict]
|
||||
risk_score: int
|
||||
highlight_terms: list[str]
|
||||
para_text: str
|
||||
Reference in New Issue
Block a user