Files

494 lines
19 KiB
Python

"""Define API routes for compliance."""
from __future__ import annotations
import asyncio
import json
from pathlib import Path
from typing import AsyncGenerator, Optional
from fastapi import APIRouter, Depends, File, Form, UploadFile
from fastapi.responses import StreamingResponse
from loguru import logger
from app.api.dependencies.auth import get_current_user
from app.domain.auth.models import UserClaims
from app.schemas.compliance import (
AnalyzeResponse,
ComplianceChatRequest,
)
from app.services.mock_data import generate_task_id, get_mock_compliance_result
from app.shared.bootstrap import get_agent_conversation_service, get_retrieval_service
from app.config.settings import settings
router = APIRouter(prefix="/compliance", tags=["合规分析"])
tasks_store: dict[str, dict] = {}
RAW_DATA_DIR = Path(__file__).resolve().parents[3] / "data" / "raw"
@router.post("/analyze", response_model=AnalyzeResponse)
async def analyze_document(file: UploadFile = File(...)):
"""Handle analyze document."""
task_id = generate_task_id()
RAW_DATA_DIR.mkdir(parents=True, exist_ok=True)
file_path = RAW_DATA_DIR / f"compliance_{task_id}_{file.filename}"
content = await file.read()
with file_path.open("wb") as f:
f.write(content)
tasks_store[task_id] = {
"task_id": task_id,
"file_path": str(file_path),
"status": "processing",
"result": None,
}
tasks_store[task_id]["status"] = "completed"
tasks_store[task_id]["result"] = get_mock_compliance_result(task_id)
return AnalyzeResponse(task_id=task_id)
@router.get("/result/{task_id}")
async def get_result(task_id: str):
"""Return result."""
if task_id not in tasks_store:
return get_mock_compliance_result(task_id)
task = tasks_store[task_id]
if task["status"] == "processing":
return {"status": "processing", "message": "分析进行中"}
return task["result"]
def _sse(data: dict) -> str:
return f"event: message\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
@router.post("/analyze-stream")
async def analyze_stream(
text: Optional[str] = Form(None),
doc_id: Optional[str] = Form(None),
file: Optional[UploadFile] = File(None),
domains: Optional[str] = Form(None),
title: Optional[str] = Form(None),
current_user: UserClaims = Depends(get_current_user),
):
"""Stream compliance analysis as SSE events.
Stages: clause_split → retrieval (per clause) → gap_check → conclusion
Events: stage | source | finding | done | error
"""
from app.application.compliance.pipeline import (
extract_text_from_doc_id,
extract_text_from_file,
run_clauses_parallel,
split_into_clauses,
synthesize_conclusion,
)
from app.services.llm.llm_factory import get_llm_client
from app.shared.bootstrap import get_retrieval_service
# Read file content eagerly (before async generator)
file_content: bytes | None = None
file_name: str | None = None
if file is not None:
file_content = await file.read()
file_name = file.filename
async def generate() -> AsyncGenerator[str, None]:
try:
client = get_llm_client(provider=settings.llm_provider, model=settings.llm_model)
retrieval_service = get_retrieval_service()
# ── Stage 1: extract text ─────────────────────────────────────
yield _sse({"type": "stage", "stage": "extracting", "label": "Extracting text…"})
await asyncio.sleep(0)
if text:
para_text = text.strip()
elif doc_id:
try:
para_text = await asyncio.to_thread(extract_text_from_doc_id, doc_id)
except Exception as exc:
yield _sse({"type": "error", "text": f"Document not found: {exc}"})
return
elif file_content is not None:
para_text = await asyncio.to_thread(
extract_text_from_file, file_content, file_name or "upload"
)
else:
yield _sse({"type": "error", "text": "No input provided"})
return
if not para_text.strip():
yield _sse({"type": "error", "text": "Could not extract text from the provided input"})
return
# ── Stage 2: split into clauses ───────────────────────────────
yield _sse({"type": "stage", "stage": "splitting", "label": "Splitting into clauses…"})
await asyncio.sleep(0)
clauses: list[str] = await asyncio.to_thread(split_into_clauses, para_text, client)
# ── Stage 3: retrieve + gap check (parallel across all clauses) ────────────
findings: list[dict] = []
yield _sse({
"type": "stage",
"stage": "analyzing",
"label": f"Analyzing {len(clauses)} clauses in parallel…",
})
await asyncio.sleep(0)
clause_results = await run_clauses_parallel(
clauses, retrieval_service, client,
top_k=5,
domains=domains or None,
)
for res in clause_results:
i = res["index"]
chunks = res["chunks"]
finding = res["finding"]
# Emit source events for this clause
for chunk in chunks[:3]:
yield _sse({
"type": "source",
"standard": getattr(chunk, "doc_title", "") or getattr(chunk, "doc_name", ""),
"clause": getattr(chunk, "section_title", "") or "",
"score": round(float(getattr(chunk, "score", 0)), 3),
"status": "retrieved",
"full_content": (getattr(chunk, "text", "") or "")[:300],
})
if finding:
findings.append(finding)
yield _sse({"type": "finding", **finding})
await asyncio.sleep(0)
# ── Stage 4: synthesize conclusion ────────────────────────────
yield _sse({"type": "stage", "stage": "concluding", "label": "Generating conclusion…"})
await asyncio.sleep(0)
conclusion_data = await asyncio.to_thread(
synthesize_conclusion, para_text, findings, client
)
yield _sse({"type": "done", **conclusion_data})
# Auto-save analysis to database
try:
from app.shared.bootstrap import get_compliance_repository
from app.domain.compliance.ports import AnalysisRecord, FindingRecord
from datetime import datetime
repo = get_compliance_repository()
finding_records = [
FindingRecord(
id="",
analysis_id="",
seq=i,
title=f.get("title", ""),
description=f.get("desc", ""),
status=f.get("status", "ok"),
clause_ref=f.get("clause_ref"),
)
for i, f in enumerate(findings)
]
record = AnalysisRecord(
id="",
created_at=datetime.utcnow(),
created_by=current_user.username if hasattr(current_user, "username") else None,
doc_name=file_name or (title or "Pasted text"),
standard_name=title or "",
risk_score=conclusion_data.get("risk_score", 0),
conclusion=conclusion_data.get("conclusion", ""),
actions=conclusion_data.get("actions", []),
para_text=conclusion_data.get("para_text", ""),
highlight_terms=conclusion_data.get("highlight_terms", []),
findings=finding_records,
)
analysis_id = await asyncio.to_thread(repo.save_analysis, record)
yield _sse({"type": "saved", "analysis_id": analysis_id})
except NotImplementedError:
pass # No postgres backend configured — skip saving
except Exception as exc:
logger.warning("Failed to auto-save compliance analysis: {}", exc)
except Exception as exc:
logger.exception("analyze-stream pipeline error")
yield _sse({"type": "error", "text": str(exc)})
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
@router.post("/chat/{segment_id}")
async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
"""Stream compliance Q&A grounded in real vector retrieval."""
query = request.query
if request.segment_context:
query = f"[段落分析上下文]\n{request.segment_context}\n\n用户问题:{request.query}"
_, event_stream = get_agent_conversation_service().stream_chat(
query=query,
top_k=5,
prompt_template="compliance_qa",
)
async def generate() -> AsyncGenerator[str, None]:
"""Translate agent SSE events to compliance chunk/done format."""
for event in event_stream:
event_type = event.get("event", "")
if event_type == "content":
text = event.get("data", "")
if text:
yield (
"event: message\n"
f"data: {json.dumps({'type': 'chunk', 'text': text}, ensure_ascii=False)}\n\n"
)
elif event_type == "done":
yield (
"event: message\n"
f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
)
await asyncio.sleep(0)
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
@router.get("/history")
async def list_history(
limit: int = 20,
offset: int = 0,
current_user: UserClaims = Depends(get_current_user),
):
"""Return paginated list of saved compliance analyses (newest first)."""
from app.shared.bootstrap import get_compliance_repository
try:
repo = get_compliance_repository()
records = await asyncio.to_thread(repo.list_analyses, limit, offset)
return [
{
"id": r.id,
"created_at": r.created_at.isoformat(),
"created_by": r.created_by,
"doc_name": r.doc_name,
"standard_name": r.standard_name,
"risk_score": r.risk_score,
"finding_count": len(r.findings),
}
for r in records
]
except NotImplementedError:
return []
@router.get("/history/{analysis_id}")
async def get_history_item(
analysis_id: str,
current_user: UserClaims = Depends(get_current_user),
):
"""Return full analysis record including findings."""
from app.shared.bootstrap import get_compliance_repository
from fastapi import HTTPException
repo = get_compliance_repository()
record = await asyncio.to_thread(repo.get_analysis, analysis_id)
if not record:
raise HTTPException(status_code=404, detail="Analysis not found")
return {
"id": record.id,
"created_at": record.created_at.isoformat(),
"created_by": record.created_by,
"doc_name": record.doc_name,
"standard_name": record.standard_name,
"risk_score": record.risk_score,
"conclusion": record.conclusion,
"actions": record.actions,
"para_text": record.para_text,
"highlight_terms": record.highlight_terms,
"findings": [
{
"id": f.id,
"seq": f.seq,
"title": f.title,
"description": f.description,
"status": f.status,
"clause_ref": f.clause_ref,
}
for f in record.findings
],
}
@router.delete("/history/{analysis_id}", status_code=204)
async def delete_history_item(
analysis_id: str,
current_user: UserClaims = Depends(get_current_user),
):
"""Delete a saved analysis (cascade removes findings and chat messages)."""
from app.shared.bootstrap import get_compliance_repository
repo = get_compliance_repository()
await asyncio.to_thread(repo.delete_analysis, analysis_id)
@router.get("/history/{analysis_id}/download")
async def download_history_docx(
analysis_id: str,
current_user: UserClaims = Depends(get_current_user),
):
"""Return a DOCX compliance report for the given analysis."""
from app.shared.bootstrap import get_compliance_repository
from app.infrastructure.compliance.docx_export import generate_docx
from fastapi import HTTPException
from fastapi.responses import Response
repo = get_compliance_repository()
record = await asyncio.to_thread(repo.get_analysis, analysis_id)
if not record:
raise HTTPException(status_code=404, detail="Analysis not found")
docx_bytes = await asyncio.to_thread(generate_docx, record)
safe_name = (record.doc_name or "report").replace(" ", "_")[:50]
filename = f"compliance_{safe_name}_{record.created_at.strftime('%Y%m%d')}.docx"
return Response(
content=docx_bytes,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
@router.get("/analyses/{analysis_id}/findings/{finding_id}/chat")
async def get_finding_chat_history(
analysis_id: str,
finding_id: str,
current_user: UserClaims = Depends(get_current_user),
):
"""Return persisted chat messages for a finding thread, oldest first."""
from app.shared.bootstrap import get_compliance_repository
try:
repo = get_compliance_repository()
messages = await asyncio.to_thread(repo.get_messages, finding_id)
return messages
except NotImplementedError:
return []
@router.post("/analyses/{analysis_id}/findings/{finding_id}/suggestions")
async def get_finding_suggestions(
analysis_id: str,
finding_id: str,
current_user: UserClaims = Depends(get_current_user),
):
"""Generate 3 LLM-powered follow-up question suggestions for a finding."""
from app.application.compliance.pipeline import generate_suggestions
from app.shared.bootstrap import get_compliance_repository
from app.services.llm.llm_factory import get_llm_client
from fastapi import HTTPException
repo = get_compliance_repository()
analysis = await asyncio.to_thread(repo.get_analysis, analysis_id)
if not analysis:
raise HTTPException(status_code=404, detail="Analysis not found")
finding = next((f for f in analysis.findings if f.id == finding_id), None)
if not finding:
raise HTTPException(status_code=404, detail="Finding not found")
client = get_llm_client(provider=settings.llm_provider, model=settings.llm_model)
questions = await asyncio.to_thread(generate_suggestions, finding, analysis, client)
return {"questions": questions}
@router.post("/analyses/{analysis_id}/findings/{finding_id}/chat")
async def finding_chat(
analysis_id: str,
finding_id: str,
request: ComplianceChatRequest,
current_user: UserClaims = Depends(get_current_user),
):
"""Stream a grounded chat response for a specific finding.
Loads the finding and analysis from DB to build grounded context.
Persists both user message and assistant response to finding_chat_messages.
"""
from app.application.compliance.pipeline import build_finding_context
from app.shared.bootstrap import get_compliance_repository
from fastapi import HTTPException
repo = get_compliance_repository()
analysis = await asyncio.to_thread(repo.get_analysis, analysis_id)
if not analysis:
raise HTTPException(status_code=404, detail="Analysis not found")
finding = next((f for f in analysis.findings if f.id == finding_id), None)
if not finding:
raise HTTPException(status_code=404, detail="Finding not found")
# Persist user message
await asyncio.to_thread(
repo.save_message, analysis_id, finding_id, "user", request.query
)
# Build message history (last 10 messages = 5 turns)
history = await asyncio.to_thread(repo.get_messages, finding_id)
history_messages = [
{"role": m["role"], "content": m["content"]}
for m in history[-10:]
]
# Build grounded system context
system_context = build_finding_context(finding, analysis)
full_query = f"[Compliance Finding Context]\n{system_context}\n\nUser question: {request.query}"
assistant_buffer: list[str] = []
async def generate() -> AsyncGenerator[str, None]:
try:
_, event_stream = get_agent_conversation_service().stream_chat(
query=full_query,
top_k=5,
prompt_template="compliance_qa",
)
for event in event_stream:
event_type = event.get("event", "")
if event_type == "content":
text = event.get("data", "")
if text:
assistant_buffer.append(text)
yield _sse({"type": "chunk", "text": text})
elif event_type == "done":
yield _sse({"type": "done"})
await asyncio.sleep(0)
except Exception as exc:
logger.exception("finding_chat stream error")
yield _sse({"type": "error", "text": str(exc)})
finally:
# Persist assistant response after stream completes
full_response = "".join(assistant_buffer)
if full_response:
try:
await asyncio.to_thread(
repo.save_message, analysis_id, finding_id, "assistant", full_response
)
except Exception as exc:
logger.warning("Failed to persist assistant message: {}", exc)
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)