"""Define API routes for compliance.""" from __future__ import annotations import asyncio import json from pathlib import Path from typing import AsyncGenerator, Optional from fastapi import APIRouter, Depends, File, Form, UploadFile from fastapi.responses import StreamingResponse from loguru import logger from app.api.dependencies.auth import get_current_user from app.domain.auth.models import UserClaims from app.schemas.compliance import ( AnalyzeResponse, ComplianceChatRequest, ) from app.services.mock_data import generate_task_id, get_mock_compliance_result from app.shared.bootstrap import get_agent_conversation_service, get_retrieval_service from app.config.settings import settings router = APIRouter(prefix="/compliance", tags=["合规分析"]) tasks_store: dict[str, dict] = {} RAW_DATA_DIR = Path(__file__).resolve().parents[3] / "data" / "raw" @router.post("/analyze", response_model=AnalyzeResponse) async def analyze_document(file: UploadFile = File(...)): """Handle analyze document.""" task_id = generate_task_id() RAW_DATA_DIR.mkdir(parents=True, exist_ok=True) file_path = RAW_DATA_DIR / f"compliance_{task_id}_{file.filename}" content = await file.read() with file_path.open("wb") as f: f.write(content) tasks_store[task_id] = { "task_id": task_id, "file_path": str(file_path), "status": "processing", "result": None, } tasks_store[task_id]["status"] = "completed" tasks_store[task_id]["result"] = get_mock_compliance_result(task_id) return AnalyzeResponse(task_id=task_id) @router.get("/result/{task_id}") async def get_result(task_id: str): """Return result.""" if task_id not in tasks_store: return get_mock_compliance_result(task_id) task = tasks_store[task_id] if task["status"] == "processing": return {"status": "processing", "message": "分析进行中"} return task["result"] def _sse(data: dict) -> str: return f"event: message\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" @router.post("/analyze-stream") async def analyze_stream( text: Optional[str] = Form(None), doc_id: Optional[str] = Form(None), file: Optional[UploadFile] = File(None), domains: Optional[str] = Form(None), title: Optional[str] = Form(None), current_user: UserClaims = Depends(get_current_user), ): """Stream compliance analysis as SSE events. Stages: clause_split → retrieval (per clause) → gap_check → conclusion Events: stage | source | finding | done | error """ from app.application.compliance.pipeline import ( check_clause_compliance, extract_text_from_doc_id, extract_text_from_file, retrieve_for_clause, split_into_clauses, synthesize_conclusion, ) from app.services.llm.llm_factory import get_llm_client from app.shared.bootstrap import get_retrieval_service # Read file content eagerly (before async generator) file_content: bytes | None = None file_name: str | None = None if file is not None: file_content = await file.read() file_name = file.filename async def generate() -> AsyncGenerator[str, None]: try: client = get_llm_client(provider=settings.llm_provider, model=settings.llm_model) retrieval_service = get_retrieval_service() # ── Stage 1: extract text ───────────────────────────────────── yield _sse({"type": "stage", "stage": "extracting", "label": "Extracting text…"}) await asyncio.sleep(0) if text: para_text = text.strip() elif doc_id: try: para_text = await asyncio.to_thread(extract_text_from_doc_id, doc_id) except Exception as exc: yield _sse({"type": "error", "text": f"Document not found: {exc}"}) return elif file_content is not None: para_text = await asyncio.to_thread( extract_text_from_file, file_content, file_name or "upload" ) else: yield _sse({"type": "error", "text": "No input provided"}) return if not para_text.strip(): yield _sse({"type": "error", "text": "Could not extract text from the provided input"}) return # ── Stage 2: split into clauses ─────────────────────────────── yield _sse({"type": "stage", "stage": "splitting", "label": "Splitting into clauses…"}) await asyncio.sleep(0) clauses: list[str] = await asyncio.to_thread(split_into_clauses, para_text, client) # ── Stage 3: retrieve + gap check per clause ────────────────── findings: list[dict] = [] for i, clause in enumerate(clauses): yield _sse({ "type": "stage", "stage": "analyzing", "label": f"Analyzing clause {i + 1}/{len(clauses)}…", }) await asyncio.sleep(0) chunks = await asyncio.to_thread( retrieve_for_clause, clause, retrieval_service, 5, domains or None ) # Emit source events for chunk in chunks[:3]: yield _sse({ "type": "source", "standard": getattr(chunk, "doc_title", "") or getattr(chunk, "doc_name", ""), "clause": getattr(chunk, "section_title", "") or "", "score": round(float(getattr(chunk, "score", 0)), 3), "status": "retrieved", "full_content": (getattr(chunk, "text", "") or "")[:300], }) await asyncio.sleep(0) finding = await asyncio.to_thread(check_clause_compliance, clause, chunks, client) if finding: findings.append(finding) yield _sse({"type": "finding", **finding}) await asyncio.sleep(0) # ── Stage 4: synthesize conclusion ──────────────────────────── yield _sse({"type": "stage", "stage": "concluding", "label": "Generating conclusion…"}) await asyncio.sleep(0) conclusion_data = await asyncio.to_thread( synthesize_conclusion, para_text, findings, client ) yield _sse({"type": "done", **conclusion_data}) except Exception as exc: logger.exception("analyze-stream pipeline error") yield _sse({"type": "error", "text": str(exc)}) return StreamingResponse( generate(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, ) @router.post("/chat/{segment_id}") async def compliance_chat(segment_id: int, request: ComplianceChatRequest): """Stream compliance Q&A grounded in real vector retrieval.""" query = request.query if request.segment_context: query = f"[段落分析上下文]\n{request.segment_context}\n\n用户问题:{request.query}" _, event_stream = get_agent_conversation_service().stream_chat( query=query, top_k=5, prompt_template="compliance_qa", ) async def generate() -> AsyncGenerator[str, None]: """Translate agent SSE events to compliance chunk/done format.""" for event in event_stream: event_type = event.get("event", "") if event_type == "content": text = event.get("data", "") if text: yield ( "event: message\n" f"data: {json.dumps({'type': 'chunk', 'text': text}, ensure_ascii=False)}\n\n" ) elif event_type == "done": yield ( "event: message\n" f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" ) await asyncio.sleep(0) return StreamingResponse( generate(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, )