2026-05-18 16:32:42 +08:00
|
|
|
"""Define API routes for compliance."""
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2026-05-14 15:07:34 +08:00
|
|
|
import asyncio
|
2026-05-18 16:32:42 +08:00
|
|
|
import json
|
|
|
|
|
from pathlib import Path
|
2026-06-05 09:00:36 +08:00
|
|
|
from typing import AsyncGenerator, Optional
|
2026-05-18 16:32:42 +08:00
|
|
|
|
2026-06-05 09:00:36 +08:00
|
|
|
from fastapi import APIRouter, File, Form, UploadFile
|
2026-05-18 16:32:42 +08:00
|
|
|
from fastapi.responses import StreamingResponse
|
2026-06-05 09:00:36 +08:00
|
|
|
from loguru import logger
|
2026-05-18 16:32:42 +08:00
|
|
|
|
2026-05-14 15:07:34 +08:00
|
|
|
from app.schemas.compliance import (
|
|
|
|
|
AnalyzeResponse,
|
|
|
|
|
ComplianceChatRequest,
|
|
|
|
|
)
|
2026-05-20 23:34:08 +08:00
|
|
|
from app.services.mock_data import generate_task_id, get_mock_compliance_result
|
2026-06-05 09:00:36 +08:00
|
|
|
from app.shared.bootstrap import get_agent_conversation_service, get_retrieval_service
|
|
|
|
|
from app.config.settings import settings
|
2026-05-18 16:32:42 +08:00
|
|
|
|
2026-05-14 15:07:34 +08:00
|
|
|
|
|
|
|
|
router = APIRouter(prefix="/compliance", tags=["合规分析"])
|
|
|
|
|
|
|
|
|
|
tasks_store: dict[str, dict] = {}
|
|
|
|
|
|
2026-05-18 16:32:42 +08:00
|
|
|
RAW_DATA_DIR = Path(__file__).resolve().parents[3] / "data" / "raw"
|
|
|
|
|
|
2026-05-14 15:07:34 +08:00
|
|
|
|
|
|
|
|
@router.post("/analyze", response_model=AnalyzeResponse)
|
|
|
|
|
async def analyze_document(file: UploadFile = File(...)):
|
2026-05-18 16:32:42 +08:00
|
|
|
"""Handle analyze document."""
|
2026-05-14 15:07:34 +08:00
|
|
|
task_id = generate_task_id()
|
|
|
|
|
|
2026-05-18 16:32:42 +08:00
|
|
|
RAW_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
file_path = RAW_DATA_DIR / f"compliance_{task_id}_{file.filename}"
|
2026-05-14 15:07:34 +08:00
|
|
|
|
|
|
|
|
content = await file.read()
|
2026-05-18 16:32:42 +08:00
|
|
|
with file_path.open("wb") as f:
|
2026-05-14 15:07:34 +08:00
|
|
|
f.write(content)
|
|
|
|
|
|
|
|
|
|
tasks_store[task_id] = {
|
|
|
|
|
"task_id": task_id,
|
2026-05-18 16:32:42 +08:00
|
|
|
"file_path": str(file_path),
|
2026-05-14 15:07:34 +08:00
|
|
|
"status": "processing",
|
|
|
|
|
"result": None,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tasks_store[task_id]["status"] = "completed"
|
|
|
|
|
tasks_store[task_id]["result"] = get_mock_compliance_result(task_id)
|
|
|
|
|
|
|
|
|
|
return AnalyzeResponse(task_id=task_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/result/{task_id}")
|
|
|
|
|
async def get_result(task_id: str):
|
2026-05-18 16:32:42 +08:00
|
|
|
"""Return result."""
|
2026-05-14 15:07:34 +08:00
|
|
|
if task_id not in tasks_store:
|
|
|
|
|
return get_mock_compliance_result(task_id)
|
|
|
|
|
|
|
|
|
|
task = tasks_store[task_id]
|
|
|
|
|
if task["status"] == "processing":
|
|
|
|
|
return {"status": "processing", "message": "分析进行中"}
|
|
|
|
|
return task["result"]
|
|
|
|
|
|
|
|
|
|
|
2026-06-05 09:00:36 +08:00
|
|
|
def _sse(data: dict) -> str:
|
|
|
|
|
return f"event: message\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/analyze-stream")
|
|
|
|
|
async def analyze_stream(
|
|
|
|
|
text: Optional[str] = Form(None),
|
|
|
|
|
doc_id: Optional[str] = Form(None),
|
|
|
|
|
file: Optional[UploadFile] = File(None),
|
|
|
|
|
domains: Optional[str] = Form(None),
|
|
|
|
|
title: Optional[str] = Form(None),
|
|
|
|
|
):
|
|
|
|
|
"""Stream compliance analysis as SSE events.
|
|
|
|
|
|
|
|
|
|
Stages: clause_split → retrieval (per clause) → gap_check → conclusion
|
|
|
|
|
Events: stage | source | finding | done | error
|
|
|
|
|
"""
|
|
|
|
|
from app.application.compliance.pipeline import (
|
|
|
|
|
check_clause_compliance,
|
|
|
|
|
extract_text_from_doc_id,
|
|
|
|
|
extract_text_from_file,
|
|
|
|
|
retrieve_for_clause,
|
|
|
|
|
split_into_clauses,
|
|
|
|
|
synthesize_conclusion,
|
|
|
|
|
)
|
|
|
|
|
from app.services.llm.llm_factory import get_llm_client
|
|
|
|
|
from app.shared.bootstrap import get_retrieval_service
|
|
|
|
|
|
|
|
|
|
# Read file content eagerly (before async generator)
|
|
|
|
|
file_content: bytes | None = None
|
|
|
|
|
file_name: str | None = None
|
|
|
|
|
if file is not None:
|
|
|
|
|
file_content = await file.read()
|
|
|
|
|
file_name = file.filename
|
|
|
|
|
|
|
|
|
|
async def generate() -> AsyncGenerator[str, None]:
|
|
|
|
|
try:
|
|
|
|
|
client = get_llm_client(provider=settings.llm_provider, model=settings.llm_model)
|
|
|
|
|
retrieval_service = get_retrieval_service()
|
|
|
|
|
|
|
|
|
|
# ── Stage 1: extract text ─────────────────────────────────────
|
|
|
|
|
yield _sse({"type": "stage", "stage": "extracting", "label": "Extracting text…"})
|
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
|
|
|
|
|
|
if text:
|
|
|
|
|
para_text = text.strip()
|
|
|
|
|
elif doc_id:
|
|
|
|
|
try:
|
|
|
|
|
para_text = await asyncio.to_thread(extract_text_from_doc_id, doc_id)
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
yield _sse({"type": "error", "text": f"Document not found: {exc}"})
|
|
|
|
|
return
|
|
|
|
|
elif file_content is not None:
|
|
|
|
|
para_text = await asyncio.to_thread(
|
|
|
|
|
extract_text_from_file, file_content, file_name or "upload"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
yield _sse({"type": "error", "text": "No input provided"})
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if not para_text.strip():
|
|
|
|
|
yield _sse({"type": "error", "text": "Could not extract text from the provided input"})
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# ── Stage 2: split into clauses ───────────────────────────────
|
|
|
|
|
yield _sse({"type": "stage", "stage": "splitting", "label": "Splitting into clauses…"})
|
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
|
clauses: list[str] = await asyncio.to_thread(split_into_clauses, para_text, client)
|
|
|
|
|
|
|
|
|
|
# ── Stage 3: retrieve + gap check per clause ──────────────────
|
|
|
|
|
findings: list[dict] = []
|
|
|
|
|
|
|
|
|
|
for i, clause in enumerate(clauses):
|
|
|
|
|
yield _sse({
|
|
|
|
|
"type": "stage",
|
|
|
|
|
"stage": "analyzing",
|
|
|
|
|
"label": f"Analyzing clause {i + 1}/{len(clauses)}…",
|
|
|
|
|
})
|
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
|
|
|
|
|
|
chunks = await asyncio.to_thread(
|
|
|
|
|
retrieve_for_clause, clause, retrieval_service, 5, domains or None
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Emit source events
|
|
|
|
|
for chunk in chunks[:3]:
|
|
|
|
|
yield _sse({
|
|
|
|
|
"type": "source",
|
|
|
|
|
"standard": getattr(chunk, "doc_title", "") or getattr(chunk, "doc_name", ""),
|
|
|
|
|
"clause": getattr(chunk, "section_title", "") or "",
|
|
|
|
|
"score": round(float(getattr(chunk, "score", 0)), 3),
|
|
|
|
|
"status": "retrieved",
|
|
|
|
|
"full_content": (getattr(chunk, "text", "") or "")[:300],
|
|
|
|
|
})
|
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
|
|
|
|
|
|
finding = await asyncio.to_thread(check_clause_compliance, clause, chunks, client)
|
|
|
|
|
if finding:
|
|
|
|
|
findings.append(finding)
|
|
|
|
|
yield _sse({"type": "finding", **finding})
|
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
|
|
|
|
|
|
# ── Stage 4: synthesize conclusion ────────────────────────────
|
|
|
|
|
yield _sse({"type": "stage", "stage": "concluding", "label": "Generating conclusion…"})
|
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
|
|
|
|
|
|
conclusion_data = await asyncio.to_thread(
|
|
|
|
|
synthesize_conclusion, para_text, findings, client
|
|
|
|
|
)
|
|
|
|
|
yield _sse({"type": "done", **conclusion_data})
|
|
|
|
|
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
logger.exception("analyze-stream pipeline error")
|
|
|
|
|
yield _sse({"type": "error", "text": str(exc)})
|
|
|
|
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
generate(),
|
|
|
|
|
media_type="text/event-stream",
|
|
|
|
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2026-05-14 15:07:34 +08:00
|
|
|
@router.post("/chat/{segment_id}")
|
|
|
|
|
async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
|
2026-05-20 23:34:08 +08:00
|
|
|
"""Stream compliance Q&A grounded in real vector retrieval."""
|
|
|
|
|
query = request.query
|
|
|
|
|
if request.segment_context:
|
|
|
|
|
query = f"[段落分析上下文]\n{request.segment_context}\n\n用户问题:{request.query}"
|
|
|
|
|
|
|
|
|
|
_, event_stream = get_agent_conversation_service().stream_chat(
|
|
|
|
|
query=query,
|
|
|
|
|
top_k=5,
|
|
|
|
|
prompt_template="compliance_qa",
|
|
|
|
|
)
|
2026-05-14 15:07:34 +08:00
|
|
|
|
2026-05-18 16:32:42 +08:00
|
|
|
async def generate() -> AsyncGenerator[str, None]:
|
2026-05-20 23:34:08 +08:00
|
|
|
"""Translate agent SSE events to compliance chunk/done format."""
|
|
|
|
|
for event in event_stream:
|
|
|
|
|
event_type = event.get("event", "")
|
|
|
|
|
if event_type == "content":
|
|
|
|
|
text = event.get("data", "")
|
|
|
|
|
if text:
|
|
|
|
|
yield (
|
|
|
|
|
"event: message\n"
|
|
|
|
|
f"data: {json.dumps({'type': 'chunk', 'text': text}, ensure_ascii=False)}\n\n"
|
|
|
|
|
)
|
|
|
|
|
elif event_type == "done":
|
|
|
|
|
yield (
|
|
|
|
|
"event: message\n"
|
|
|
|
|
f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
|
|
|
|
)
|
|
|
|
|
await asyncio.sleep(0)
|
2026-05-18 16:32:42 +08:00
|
|
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
generate(),
|
|
|
|
|
media_type="text/event-stream",
|
|
|
|
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
|
|
|
|
)
|