Files
AIRegulation-DocAnalysis/backend/app/api/routes/compliance.py

494 lines
19 KiB
Python
Raw Normal View History

"""Define API routes for compliance."""
from __future__ import annotations
2026-05-14 15:07:34 +08:00
import asyncio
import json
from pathlib import Path
2026-06-05 09:00:36 +08:00
from typing import AsyncGenerator, Optional
from fastapi import APIRouter, Depends, File, Form, UploadFile
from fastapi.responses import StreamingResponse
2026-06-05 09:00:36 +08:00
from loguru import logger
from app.api.dependencies.auth import get_current_user
from app.domain.auth.models import UserClaims
2026-05-14 15:07:34 +08:00
from app.schemas.compliance import (
AnalyzeResponse,
ComplianceChatRequest,
)
from app.services.mock_data import generate_task_id, get_mock_compliance_result
2026-06-05 09:00:36 +08:00
from app.shared.bootstrap import get_agent_conversation_service, get_retrieval_service
from app.config.settings import settings
2026-05-14 15:07:34 +08:00
router = APIRouter(prefix="/compliance", tags=["合规分析"])
tasks_store: dict[str, dict] = {}
RAW_DATA_DIR = Path(__file__).resolve().parents[3] / "data" / "raw"
2026-05-14 15:07:34 +08:00
@router.post("/analyze", response_model=AnalyzeResponse)
async def analyze_document(file: UploadFile = File(...)):
"""Handle analyze document."""
2026-05-14 15:07:34 +08:00
task_id = generate_task_id()
RAW_DATA_DIR.mkdir(parents=True, exist_ok=True)
file_path = RAW_DATA_DIR / f"compliance_{task_id}_{file.filename}"
2026-05-14 15:07:34 +08:00
content = await file.read()
with file_path.open("wb") as f:
2026-05-14 15:07:34 +08:00
f.write(content)
tasks_store[task_id] = {
"task_id": task_id,
"file_path": str(file_path),
2026-05-14 15:07:34 +08:00
"status": "processing",
"result": None,
}
tasks_store[task_id]["status"] = "completed"
tasks_store[task_id]["result"] = get_mock_compliance_result(task_id)
return AnalyzeResponse(task_id=task_id)
@router.get("/result/{task_id}")
async def get_result(task_id: str):
"""Return result."""
2026-05-14 15:07:34 +08:00
if task_id not in tasks_store:
return get_mock_compliance_result(task_id)
task = tasks_store[task_id]
if task["status"] == "processing":
return {"status": "processing", "message": "分析进行中"}
return task["result"]
2026-06-05 09:00:36 +08:00
def _sse(data: dict) -> str:
return f"event: message\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
@router.post("/analyze-stream")
async def analyze_stream(
text: Optional[str] = Form(None),
doc_id: Optional[str] = Form(None),
file: Optional[UploadFile] = File(None),
domains: Optional[str] = Form(None),
title: Optional[str] = Form(None),
current_user: UserClaims = Depends(get_current_user),
2026-06-05 09:00:36 +08:00
):
"""Stream compliance analysis as SSE events.
Stages: clause_split retrieval (per clause) gap_check conclusion
Events: stage | source | finding | done | error
"""
from app.application.compliance.pipeline import (
extract_text_from_doc_id,
extract_text_from_file,
2026-06-10 11:10:36 +08:00
run_clauses_parallel,
2026-06-05 09:00:36 +08:00
split_into_clauses,
synthesize_conclusion,
)
from app.services.llm.llm_factory import get_llm_client
from app.shared.bootstrap import get_retrieval_service
# Read file content eagerly (before async generator)
file_content: bytes | None = None
file_name: str | None = None
if file is not None:
file_content = await file.read()
file_name = file.filename
async def generate() -> AsyncGenerator[str, None]:
try:
client = get_llm_client(provider=settings.llm_provider, model=settings.llm_model)
retrieval_service = get_retrieval_service()
# ── Stage 1: extract text ─────────────────────────────────────
yield _sse({"type": "stage", "stage": "extracting", "label": "Extracting text…"})
await asyncio.sleep(0)
if text:
para_text = text.strip()
elif doc_id:
try:
para_text = await asyncio.to_thread(extract_text_from_doc_id, doc_id)
except Exception as exc:
yield _sse({"type": "error", "text": f"Document not found: {exc}"})
return
elif file_content is not None:
para_text = await asyncio.to_thread(
extract_text_from_file, file_content, file_name or "upload"
)
else:
yield _sse({"type": "error", "text": "No input provided"})
return
if not para_text.strip():
yield _sse({"type": "error", "text": "Could not extract text from the provided input"})
return
# ── Stage 2: split into clauses ───────────────────────────────
yield _sse({"type": "stage", "stage": "splitting", "label": "Splitting into clauses…"})
await asyncio.sleep(0)
clauses: list[str] = await asyncio.to_thread(split_into_clauses, para_text, client)
2026-06-10 11:10:36 +08:00
# ── Stage 3: retrieve + gap check (parallel across all clauses) ────────────
2026-06-05 09:00:36 +08:00
findings: list[dict] = []
2026-06-10 11:10:36 +08:00
yield _sse({
"type": "stage",
"stage": "analyzing",
"label": f"Analyzing {len(clauses)} clauses in parallel…",
})
await asyncio.sleep(0)
2026-06-05 09:00:36 +08:00
2026-06-10 11:10:36 +08:00
clause_results = await run_clauses_parallel(
clauses, retrieval_service, client,
top_k=5,
domains=domains or None,
)
2026-06-05 09:00:36 +08:00
2026-06-10 11:10:36 +08:00
for res in clause_results:
i = res["index"]
chunks = res["chunks"]
finding = res["finding"]
# Emit source events for this clause
2026-06-05 09:00:36 +08:00
for chunk in chunks[:3]:
yield _sse({
"type": "source",
"standard": getattr(chunk, "doc_title", "") or getattr(chunk, "doc_name", ""),
"clause": getattr(chunk, "section_title", "") or "",
"score": round(float(getattr(chunk, "score", 0)), 3),
"status": "retrieved",
"full_content": (getattr(chunk, "text", "") or "")[:300],
})
if finding:
findings.append(finding)
yield _sse({"type": "finding", **finding})
2026-06-10 11:10:36 +08:00
2026-06-05 09:00:36 +08:00
await asyncio.sleep(0)
# ── Stage 4: synthesize conclusion ────────────────────────────
yield _sse({"type": "stage", "stage": "concluding", "label": "Generating conclusion…"})
await asyncio.sleep(0)
conclusion_data = await asyncio.to_thread(
synthesize_conclusion, para_text, findings, client
)
yield _sse({"type": "done", **conclusion_data})
2026-06-10 11:10:36 +08:00
# Auto-save analysis to database
try:
from app.shared.bootstrap import get_compliance_repository
from app.domain.compliance.ports import AnalysisRecord, FindingRecord
from datetime import datetime
repo = get_compliance_repository()
finding_records = [
FindingRecord(
id="",
analysis_id="",
seq=i,
title=f.get("title", ""),
description=f.get("desc", ""),
status=f.get("status", "ok"),
clause_ref=f.get("clause_ref"),
)
for i, f in enumerate(findings)
]
record = AnalysisRecord(
id="",
created_at=datetime.utcnow(),
created_by=current_user.username if hasattr(current_user, "username") else None,
doc_name=file_name or (title or "Pasted text"),
standard_name=title or "",
risk_score=conclusion_data.get("risk_score", 0),
conclusion=conclusion_data.get("conclusion", ""),
actions=conclusion_data.get("actions", []),
para_text=conclusion_data.get("para_text", ""),
highlight_terms=conclusion_data.get("highlight_terms", []),
findings=finding_records,
)
analysis_id = await asyncio.to_thread(repo.save_analysis, record)
yield _sse({"type": "saved", "analysis_id": analysis_id})
except NotImplementedError:
pass # No postgres backend configured — skip saving
except Exception as exc:
logger.warning("Failed to auto-save compliance analysis: {}", exc)
2026-06-05 09:00:36 +08:00
except Exception as exc:
logger.exception("analyze-stream pipeline error")
yield _sse({"type": "error", "text": str(exc)})
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
2026-05-14 15:07:34 +08:00
@router.post("/chat/{segment_id}")
async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
"""Stream compliance Q&A grounded in real vector retrieval."""
query = request.query
if request.segment_context:
query = f"[段落分析上下文]\n{request.segment_context}\n\n用户问题:{request.query}"
_, event_stream = get_agent_conversation_service().stream_chat(
query=query,
top_k=5,
prompt_template="compliance_qa",
)
2026-05-14 15:07:34 +08:00
async def generate() -> AsyncGenerator[str, None]:
"""Translate agent SSE events to compliance chunk/done format."""
for event in event_stream:
event_type = event.get("event", "")
if event_type == "content":
text = event.get("data", "")
if text:
yield (
"event: message\n"
f"data: {json.dumps({'type': 'chunk', 'text': text}, ensure_ascii=False)}\n\n"
)
elif event_type == "done":
yield (
"event: message\n"
f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
)
await asyncio.sleep(0)
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
2026-06-10 11:10:36 +08:00
@router.get("/history")
async def list_history(
limit: int = 20,
offset: int = 0,
current_user: UserClaims = Depends(get_current_user),
):
"""Return paginated list of saved compliance analyses (newest first)."""
from app.shared.bootstrap import get_compliance_repository
try:
repo = get_compliance_repository()
records = await asyncio.to_thread(repo.list_analyses, limit, offset)
return [
{
"id": r.id,
"created_at": r.created_at.isoformat(),
"created_by": r.created_by,
"doc_name": r.doc_name,
"standard_name": r.standard_name,
"risk_score": r.risk_score,
"finding_count": len(r.findings),
}
for r in records
]
except NotImplementedError:
return []
@router.get("/history/{analysis_id}")
async def get_history_item(
analysis_id: str,
current_user: UserClaims = Depends(get_current_user),
):
"""Return full analysis record including findings."""
from app.shared.bootstrap import get_compliance_repository
from fastapi import HTTPException
repo = get_compliance_repository()
record = await asyncio.to_thread(repo.get_analysis, analysis_id)
if not record:
raise HTTPException(status_code=404, detail="Analysis not found")
return {
"id": record.id,
"created_at": record.created_at.isoformat(),
"created_by": record.created_by,
"doc_name": record.doc_name,
"standard_name": record.standard_name,
"risk_score": record.risk_score,
"conclusion": record.conclusion,
"actions": record.actions,
"para_text": record.para_text,
"highlight_terms": record.highlight_terms,
"findings": [
{
"id": f.id,
"seq": f.seq,
"title": f.title,
"description": f.description,
"status": f.status,
"clause_ref": f.clause_ref,
}
for f in record.findings
],
}
@router.delete("/history/{analysis_id}", status_code=204)
async def delete_history_item(
analysis_id: str,
current_user: UserClaims = Depends(get_current_user),
):
"""Delete a saved analysis (cascade removes findings and chat messages)."""
from app.shared.bootstrap import get_compliance_repository
repo = get_compliance_repository()
await asyncio.to_thread(repo.delete_analysis, analysis_id)
@router.get("/history/{analysis_id}/download")
async def download_history_docx(
analysis_id: str,
current_user: UserClaims = Depends(get_current_user),
):
"""Return a DOCX compliance report for the given analysis."""
from app.shared.bootstrap import get_compliance_repository
from app.infrastructure.compliance.docx_export import generate_docx
from fastapi import HTTPException
from fastapi.responses import Response
repo = get_compliance_repository()
record = await asyncio.to_thread(repo.get_analysis, analysis_id)
if not record:
raise HTTPException(status_code=404, detail="Analysis not found")
docx_bytes = await asyncio.to_thread(generate_docx, record)
safe_name = (record.doc_name or "report").replace(" ", "_")[:50]
filename = f"compliance_{safe_name}_{record.created_at.strftime('%Y%m%d')}.docx"
return Response(
content=docx_bytes,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
@router.get("/analyses/{analysis_id}/findings/{finding_id}/chat")
async def get_finding_chat_history(
analysis_id: str,
finding_id: str,
current_user: UserClaims = Depends(get_current_user),
):
"""Return persisted chat messages for a finding thread, oldest first."""
from app.shared.bootstrap import get_compliance_repository
try:
repo = get_compliance_repository()
messages = await asyncio.to_thread(repo.get_messages, finding_id)
return messages
except NotImplementedError:
return []
@router.post("/analyses/{analysis_id}/findings/{finding_id}/suggestions")
async def get_finding_suggestions(
analysis_id: str,
finding_id: str,
current_user: UserClaims = Depends(get_current_user),
):
"""Generate 3 LLM-powered follow-up question suggestions for a finding."""
from app.application.compliance.pipeline import generate_suggestions
from app.shared.bootstrap import get_compliance_repository
from app.services.llm.llm_factory import get_llm_client
from fastapi import HTTPException
repo = get_compliance_repository()
analysis = await asyncio.to_thread(repo.get_analysis, analysis_id)
if not analysis:
raise HTTPException(status_code=404, detail="Analysis not found")
finding = next((f for f in analysis.findings if f.id == finding_id), None)
if not finding:
raise HTTPException(status_code=404, detail="Finding not found")
client = get_llm_client(provider=settings.llm_provider, model=settings.llm_model)
questions = await asyncio.to_thread(generate_suggestions, finding, analysis, client)
return {"questions": questions}
@router.post("/analyses/{analysis_id}/findings/{finding_id}/chat")
async def finding_chat(
analysis_id: str,
finding_id: str,
request: ComplianceChatRequest,
current_user: UserClaims = Depends(get_current_user),
):
"""Stream a grounded chat response for a specific finding.
Loads the finding and analysis from DB to build grounded context.
Persists both user message and assistant response to finding_chat_messages.
"""
from app.application.compliance.pipeline import build_finding_context
from app.shared.bootstrap import get_compliance_repository
from fastapi import HTTPException
repo = get_compliance_repository()
analysis = await asyncio.to_thread(repo.get_analysis, analysis_id)
if not analysis:
raise HTTPException(status_code=404, detail="Analysis not found")
finding = next((f for f in analysis.findings if f.id == finding_id), None)
if not finding:
raise HTTPException(status_code=404, detail="Finding not found")
# Persist user message
await asyncio.to_thread(
repo.save_message, analysis_id, finding_id, "user", request.query
)
# Build message history (last 10 messages = 5 turns)
history = await asyncio.to_thread(repo.get_messages, finding_id)
history_messages = [
{"role": m["role"], "content": m["content"]}
for m in history[-10:]
]
# Build grounded system context
system_context = build_finding_context(finding, analysis)
full_query = f"[Compliance Finding Context]\n{system_context}\n\nUser question: {request.query}"
assistant_buffer: list[str] = []
async def generate() -> AsyncGenerator[str, None]:
try:
_, event_stream = get_agent_conversation_service().stream_chat(
query=full_query,
top_k=5,
prompt_template="compliance_qa",
)
for event in event_stream:
event_type = event.get("event", "")
if event_type == "content":
text = event.get("data", "")
if text:
assistant_buffer.append(text)
yield _sse({"type": "chunk", "text": text})
elif event_type == "done":
yield _sse({"type": "done"})
await asyncio.sleep(0)
except Exception as exc:
logger.exception("finding_chat stream error")
yield _sse({"type": "error", "text": str(exc)})
finally:
# Persist assistant response after stream completes
full_response = "".join(assistant_buffer)
if full_response:
try:
await asyncio.to_thread(
repo.save_message, analysis_id, finding_id, "assistant", full_response
)
except Exception as exc:
logger.warning("Failed to persist assistant message: {}", exc)
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)