"""Define API routes for compliance.""" from __future__ import annotations import asyncio import json from pathlib import Path from typing import AsyncGenerator, Optional from fastapi import APIRouter, Depends, File, Form, UploadFile from fastapi.responses import StreamingResponse from loguru import logger from app.api.dependencies.auth import get_current_user from app.domain.auth.models import UserClaims from app.schemas.compliance import ( AnalyzeResponse, ComplianceChatRequest, ) from app.services.mock_data import generate_task_id, get_mock_compliance_result from app.shared.bootstrap import get_agent_conversation_service, get_retrieval_service from app.config.settings import settings router = APIRouter(prefix="/compliance", tags=["合规分析"]) tasks_store: dict[str, dict] = {} RAW_DATA_DIR = Path(__file__).resolve().parents[3] / "data" / "raw" @router.post("/analyze", response_model=AnalyzeResponse) async def analyze_document(file: UploadFile = File(...)): """Handle analyze document.""" task_id = generate_task_id() RAW_DATA_DIR.mkdir(parents=True, exist_ok=True) file_path = RAW_DATA_DIR / f"compliance_{task_id}_{file.filename}" content = await file.read() with file_path.open("wb") as f: f.write(content) tasks_store[task_id] = { "task_id": task_id, "file_path": str(file_path), "status": "processing", "result": None, } tasks_store[task_id]["status"] = "completed" tasks_store[task_id]["result"] = get_mock_compliance_result(task_id) return AnalyzeResponse(task_id=task_id) @router.get("/result/{task_id}") async def get_result(task_id: str): """Return result.""" if task_id not in tasks_store: return get_mock_compliance_result(task_id) task = tasks_store[task_id] if task["status"] == "processing": return {"status": "processing", "message": "分析进行中"} return task["result"] def _sse(data: dict) -> str: return f"event: message\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" @router.post("/analyze-stream") async def analyze_stream( text: Optional[str] = Form(None), doc_id: Optional[str] = Form(None), file: Optional[UploadFile] = File(None), domains: Optional[str] = Form(None), title: Optional[str] = Form(None), current_user: UserClaims = Depends(get_current_user), ): """Stream compliance analysis as SSE events. Stages: clause_split → retrieval (per clause) → gap_check → conclusion Events: stage | source | finding | done | error """ from app.application.compliance.pipeline import ( extract_text_from_doc_id, extract_text_from_file, run_clauses_parallel, split_into_clauses, synthesize_conclusion, ) from app.services.llm.llm_factory import get_llm_client from app.shared.bootstrap import get_retrieval_service # Read file content eagerly (before async generator) file_content: bytes | None = None file_name: str | None = None if file is not None: file_content = await file.read() file_name = file.filename async def generate() -> AsyncGenerator[str, None]: try: client = get_llm_client(provider=settings.llm_provider, model=settings.llm_model) retrieval_service = get_retrieval_service() # ── Stage 1: extract text ───────────────────────────────────── yield _sse({"type": "stage", "stage": "extracting", "label": "Extracting text…"}) await asyncio.sleep(0) if text: para_text = text.strip() elif doc_id: try: para_text = await asyncio.to_thread(extract_text_from_doc_id, doc_id) except Exception as exc: yield _sse({"type": "error", "text": f"Document not found: {exc}"}) return elif file_content is not None: para_text = await asyncio.to_thread( extract_text_from_file, file_content, file_name or "upload" ) else: yield _sse({"type": "error", "text": "No input provided"}) return if not para_text.strip(): yield _sse({"type": "error", "text": "Could not extract text from the provided input"}) return # ── Stage 2: split into clauses ─────────────────────────────── yield _sse({"type": "stage", "stage": "splitting", "label": "Splitting into clauses…"}) await asyncio.sleep(0) clauses: list[str] = await asyncio.to_thread(split_into_clauses, para_text, client) # ── Stage 3: retrieve + gap check (parallel across all clauses) ──────────── findings: list[dict] = [] yield _sse({ "type": "stage", "stage": "analyzing", "label": f"Analyzing {len(clauses)} clauses in parallel…", }) await asyncio.sleep(0) clause_results = await run_clauses_parallel( clauses, retrieval_service, client, top_k=5, domains=domains or None, ) for res in clause_results: i = res["index"] chunks = res["chunks"] finding = res["finding"] # Emit source events for this clause for chunk in chunks[:3]: yield _sse({ "type": "source", "standard": getattr(chunk, "doc_title", "") or getattr(chunk, "doc_name", ""), "clause": getattr(chunk, "section_title", "") or "", "score": round(float(getattr(chunk, "score", 0)), 3), "status": "retrieved", "full_content": (getattr(chunk, "text", "") or "")[:300], }) if finding: findings.append(finding) yield _sse({"type": "finding", **finding}) await asyncio.sleep(0) # ── Stage 4: synthesize conclusion ──────────────────────────── yield _sse({"type": "stage", "stage": "concluding", "label": "Generating conclusion…"}) await asyncio.sleep(0) conclusion_data = await asyncio.to_thread( synthesize_conclusion, para_text, findings, client ) yield _sse({"type": "done", **conclusion_data}) # Auto-save analysis to database try: from app.shared.bootstrap import get_compliance_repository from app.domain.compliance.ports import AnalysisRecord, FindingRecord from datetime import datetime repo = get_compliance_repository() finding_records = [ FindingRecord( id="", analysis_id="", seq=i, title=f.get("title", ""), description=f.get("desc", ""), status=f.get("status", "ok"), clause_ref=f.get("clause_ref"), ) for i, f in enumerate(findings) ] record = AnalysisRecord( id="", created_at=datetime.utcnow(), created_by=current_user.username if hasattr(current_user, "username") else None, doc_name=file_name or (title or "Pasted text"), standard_name=title or "", risk_score=conclusion_data.get("risk_score", 0), conclusion=conclusion_data.get("conclusion", ""), actions=conclusion_data.get("actions", []), para_text=conclusion_data.get("para_text", ""), highlight_terms=conclusion_data.get("highlight_terms", []), findings=finding_records, ) analysis_id = await asyncio.to_thread(repo.save_analysis, record) yield _sse({"type": "saved", "analysis_id": analysis_id}) except NotImplementedError: pass # No postgres backend configured — skip saving except Exception as exc: logger.warning("Failed to auto-save compliance analysis: {}", exc) except Exception as exc: logger.exception("analyze-stream pipeline error") yield _sse({"type": "error", "text": str(exc)}) return StreamingResponse( generate(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, ) @router.post("/chat/{segment_id}") async def compliance_chat(segment_id: int, request: ComplianceChatRequest): """Stream compliance Q&A grounded in real vector retrieval.""" query = request.query if request.segment_context: query = f"[段落分析上下文]\n{request.segment_context}\n\n用户问题:{request.query}" _, event_stream = get_agent_conversation_service().stream_chat( query=query, top_k=5, prompt_template="compliance_qa", ) async def generate() -> AsyncGenerator[str, None]: """Translate agent SSE events to compliance chunk/done format.""" for event in event_stream: event_type = event.get("event", "") if event_type == "content": text = event.get("data", "") if text: yield ( "event: message\n" f"data: {json.dumps({'type': 'chunk', 'text': text}, ensure_ascii=False)}\n\n" ) elif event_type == "done": yield ( "event: message\n" f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" ) await asyncio.sleep(0) return StreamingResponse( generate(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, ) @router.get("/history") async def list_history( limit: int = 20, offset: int = 0, current_user: UserClaims = Depends(get_current_user), ): """Return paginated list of saved compliance analyses (newest first).""" from app.shared.bootstrap import get_compliance_repository try: repo = get_compliance_repository() records = await asyncio.to_thread(repo.list_analyses, limit, offset) return [ { "id": r.id, "created_at": r.created_at.isoformat(), "created_by": r.created_by, "doc_name": r.doc_name, "standard_name": r.standard_name, "risk_score": r.risk_score, "finding_count": len(r.findings), } for r in records ] except NotImplementedError: return [] @router.get("/history/{analysis_id}") async def get_history_item( analysis_id: str, current_user: UserClaims = Depends(get_current_user), ): """Return full analysis record including findings.""" from app.shared.bootstrap import get_compliance_repository from fastapi import HTTPException repo = get_compliance_repository() record = await asyncio.to_thread(repo.get_analysis, analysis_id) if not record: raise HTTPException(status_code=404, detail="Analysis not found") return { "id": record.id, "created_at": record.created_at.isoformat(), "created_by": record.created_by, "doc_name": record.doc_name, "standard_name": record.standard_name, "risk_score": record.risk_score, "conclusion": record.conclusion, "actions": record.actions, "para_text": record.para_text, "highlight_terms": record.highlight_terms, "findings": [ { "id": f.id, "seq": f.seq, "title": f.title, "description": f.description, "status": f.status, "clause_ref": f.clause_ref, } for f in record.findings ], } @router.delete("/history/{analysis_id}", status_code=204) async def delete_history_item( analysis_id: str, current_user: UserClaims = Depends(get_current_user), ): """Delete a saved analysis (cascade removes findings and chat messages).""" from app.shared.bootstrap import get_compliance_repository repo = get_compliance_repository() await asyncio.to_thread(repo.delete_analysis, analysis_id) @router.get("/history/{analysis_id}/download") async def download_history_docx( analysis_id: str, current_user: UserClaims = Depends(get_current_user), ): """Return a DOCX compliance report for the given analysis.""" from app.shared.bootstrap import get_compliance_repository from app.infrastructure.compliance.docx_export import generate_docx from fastapi import HTTPException from fastapi.responses import Response repo = get_compliance_repository() record = await asyncio.to_thread(repo.get_analysis, analysis_id) if not record: raise HTTPException(status_code=404, detail="Analysis not found") docx_bytes = await asyncio.to_thread(generate_docx, record) safe_name = (record.doc_name or "report").replace(" ", "_")[:50] filename = f"compliance_{safe_name}_{record.created_at.strftime('%Y%m%d')}.docx" return Response( content=docx_bytes, media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", headers={"Content-Disposition": f'attachment; filename="{filename}"'}, ) @router.get("/analyses/{analysis_id}/findings/{finding_id}/chat") async def get_finding_chat_history( analysis_id: str, finding_id: str, current_user: UserClaims = Depends(get_current_user), ): """Return persisted chat messages for a finding thread, oldest first.""" from app.shared.bootstrap import get_compliance_repository try: repo = get_compliance_repository() messages = await asyncio.to_thread(repo.get_messages, finding_id) return messages except NotImplementedError: return [] @router.post("/analyses/{analysis_id}/findings/{finding_id}/suggestions") async def get_finding_suggestions( analysis_id: str, finding_id: str, current_user: UserClaims = Depends(get_current_user), ): """Generate 3 LLM-powered follow-up question suggestions for a finding.""" from app.application.compliance.pipeline import generate_suggestions from app.shared.bootstrap import get_compliance_repository from app.services.llm.llm_factory import get_llm_client from fastapi import HTTPException repo = get_compliance_repository() analysis = await asyncio.to_thread(repo.get_analysis, analysis_id) if not analysis: raise HTTPException(status_code=404, detail="Analysis not found") finding = next((f for f in analysis.findings if f.id == finding_id), None) if not finding: raise HTTPException(status_code=404, detail="Finding not found") client = get_llm_client(provider=settings.llm_provider, model=settings.llm_model) questions = await asyncio.to_thread(generate_suggestions, finding, analysis, client) return {"questions": questions} @router.post("/analyses/{analysis_id}/findings/{finding_id}/chat") async def finding_chat( analysis_id: str, finding_id: str, request: ComplianceChatRequest, current_user: UserClaims = Depends(get_current_user), ): """Stream a grounded chat response for a specific finding. Loads the finding and analysis from DB to build grounded context. Persists both user message and assistant response to finding_chat_messages. """ from app.application.compliance.pipeline import build_finding_context from app.shared.bootstrap import get_compliance_repository from fastapi import HTTPException repo = get_compliance_repository() analysis = await asyncio.to_thread(repo.get_analysis, analysis_id) if not analysis: raise HTTPException(status_code=404, detail="Analysis not found") finding = next((f for f in analysis.findings if f.id == finding_id), None) if not finding: raise HTTPException(status_code=404, detail="Finding not found") # Persist user message await asyncio.to_thread( repo.save_message, analysis_id, finding_id, "user", request.query ) # Build message history (last 10 messages = 5 turns) history = await asyncio.to_thread(repo.get_messages, finding_id) history_messages = [ {"role": m["role"], "content": m["content"]} for m in history[-10:] ] # Build grounded system context system_context = build_finding_context(finding, analysis) full_query = f"[Compliance Finding Context]\n{system_context}\n\nUser question: {request.query}" assistant_buffer: list[str] = [] async def generate() -> AsyncGenerator[str, None]: try: _, event_stream = get_agent_conversation_service().stream_chat( query=full_query, top_k=5, prompt_template="compliance_qa", ) for event in event_stream: event_type = event.get("event", "") if event_type == "content": text = event.get("data", "") if text: assistant_buffer.append(text) yield _sse({"type": "chunk", "text": text}) elif event_type == "done": yield _sse({"type": "done"}) await asyncio.sleep(0) except Exception as exc: logger.exception("finding_chat stream error") yield _sse({"type": "error", "text": str(exc)}) finally: # Persist assistant response after stream completes full_response = "".join(assistant_buffer) if full_response: try: await asyncio.to_thread( repo.save_message, analysis_id, finding_id, "assistant", full_response ) except Exception as exc: logger.warning("Failed to persist assistant message: {}", exc) return StreamingResponse( generate(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, )