494 lines
19 KiB
Python
494 lines
19 KiB
Python
"""Define API routes for compliance."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
from pathlib import Path
|
|
from typing import AsyncGenerator, Optional
|
|
|
|
from fastapi import APIRouter, Depends, File, Form, UploadFile
|
|
from fastapi.responses import StreamingResponse
|
|
from loguru import logger
|
|
|
|
from app.api.dependencies.auth import get_current_user
|
|
from app.domain.auth.models import UserClaims
|
|
from app.schemas.compliance import (
|
|
AnalyzeResponse,
|
|
ComplianceChatRequest,
|
|
)
|
|
from app.services.mock_data import generate_task_id, get_mock_compliance_result
|
|
from app.shared.bootstrap import get_agent_conversation_service, get_retrieval_service
|
|
from app.config.settings import settings
|
|
|
|
|
|
router = APIRouter(prefix="/compliance", tags=["合规分析"])
|
|
|
|
tasks_store: dict[str, dict] = {}
|
|
|
|
RAW_DATA_DIR = Path(__file__).resolve().parents[3] / "data" / "raw"
|
|
|
|
|
|
@router.post("/analyze", response_model=AnalyzeResponse)
|
|
async def analyze_document(file: UploadFile = File(...)):
|
|
"""Handle analyze document."""
|
|
task_id = generate_task_id()
|
|
|
|
RAW_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
|
file_path = RAW_DATA_DIR / f"compliance_{task_id}_{file.filename}"
|
|
|
|
content = await file.read()
|
|
with file_path.open("wb") as f:
|
|
f.write(content)
|
|
|
|
tasks_store[task_id] = {
|
|
"task_id": task_id,
|
|
"file_path": str(file_path),
|
|
"status": "processing",
|
|
"result": None,
|
|
}
|
|
|
|
tasks_store[task_id]["status"] = "completed"
|
|
tasks_store[task_id]["result"] = get_mock_compliance_result(task_id)
|
|
|
|
return AnalyzeResponse(task_id=task_id)
|
|
|
|
|
|
@router.get("/result/{task_id}")
|
|
async def get_result(task_id: str):
|
|
"""Return result."""
|
|
if task_id not in tasks_store:
|
|
return get_mock_compliance_result(task_id)
|
|
|
|
task = tasks_store[task_id]
|
|
if task["status"] == "processing":
|
|
return {"status": "processing", "message": "分析进行中"}
|
|
return task["result"]
|
|
|
|
|
|
def _sse(data: dict) -> str:
|
|
return f"event: message\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
@router.post("/analyze-stream")
|
|
async def analyze_stream(
|
|
text: Optional[str] = Form(None),
|
|
doc_id: Optional[str] = Form(None),
|
|
file: Optional[UploadFile] = File(None),
|
|
domains: Optional[str] = Form(None),
|
|
title: Optional[str] = Form(None),
|
|
current_user: UserClaims = Depends(get_current_user),
|
|
):
|
|
"""Stream compliance analysis as SSE events.
|
|
|
|
Stages: clause_split → retrieval (per clause) → gap_check → conclusion
|
|
Events: stage | source | finding | done | error
|
|
"""
|
|
from app.application.compliance.pipeline import (
|
|
extract_text_from_doc_id,
|
|
extract_text_from_file,
|
|
run_clauses_parallel,
|
|
split_into_clauses,
|
|
synthesize_conclusion,
|
|
)
|
|
from app.services.llm.llm_factory import get_llm_client
|
|
from app.shared.bootstrap import get_retrieval_service
|
|
|
|
# Read file content eagerly (before async generator)
|
|
file_content: bytes | None = None
|
|
file_name: str | None = None
|
|
if file is not None:
|
|
file_content = await file.read()
|
|
file_name = file.filename
|
|
|
|
async def generate() -> AsyncGenerator[str, None]:
|
|
try:
|
|
client = get_llm_client(provider=settings.llm_provider, model=settings.llm_model)
|
|
retrieval_service = get_retrieval_service()
|
|
|
|
# ── Stage 1: extract text ─────────────────────────────────────
|
|
yield _sse({"type": "stage", "stage": "extracting", "label": "Extracting text…"})
|
|
await asyncio.sleep(0)
|
|
|
|
if text:
|
|
para_text = text.strip()
|
|
elif doc_id:
|
|
try:
|
|
para_text = await asyncio.to_thread(extract_text_from_doc_id, doc_id)
|
|
except Exception as exc:
|
|
yield _sse({"type": "error", "text": f"Document not found: {exc}"})
|
|
return
|
|
elif file_content is not None:
|
|
para_text = await asyncio.to_thread(
|
|
extract_text_from_file, file_content, file_name or "upload"
|
|
)
|
|
else:
|
|
yield _sse({"type": "error", "text": "No input provided"})
|
|
return
|
|
|
|
if not para_text.strip():
|
|
yield _sse({"type": "error", "text": "Could not extract text from the provided input"})
|
|
return
|
|
|
|
# ── Stage 2: split into clauses ───────────────────────────────
|
|
yield _sse({"type": "stage", "stage": "splitting", "label": "Splitting into clauses…"})
|
|
await asyncio.sleep(0)
|
|
clauses: list[str] = await asyncio.to_thread(split_into_clauses, para_text, client)
|
|
|
|
# ── Stage 3: retrieve + gap check (parallel across all clauses) ────────────
|
|
findings: list[dict] = []
|
|
|
|
yield _sse({
|
|
"type": "stage",
|
|
"stage": "analyzing",
|
|
"label": f"Analyzing {len(clauses)} clauses in parallel…",
|
|
})
|
|
await asyncio.sleep(0)
|
|
|
|
clause_results = await run_clauses_parallel(
|
|
clauses, retrieval_service, client,
|
|
top_k=5,
|
|
domains=domains or None,
|
|
)
|
|
|
|
for res in clause_results:
|
|
i = res["index"]
|
|
chunks = res["chunks"]
|
|
finding = res["finding"]
|
|
|
|
# Emit source events for this clause
|
|
for chunk in chunks[:3]:
|
|
yield _sse({
|
|
"type": "source",
|
|
"standard": getattr(chunk, "doc_title", "") or getattr(chunk, "doc_name", ""),
|
|
"clause": getattr(chunk, "section_title", "") or "",
|
|
"score": round(float(getattr(chunk, "score", 0)), 3),
|
|
"status": "retrieved",
|
|
"full_content": (getattr(chunk, "text", "") or "")[:300],
|
|
})
|
|
|
|
if finding:
|
|
findings.append(finding)
|
|
yield _sse({"type": "finding", **finding})
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
# ── Stage 4: synthesize conclusion ────────────────────────────
|
|
yield _sse({"type": "stage", "stage": "concluding", "label": "Generating conclusion…"})
|
|
await asyncio.sleep(0)
|
|
|
|
conclusion_data = await asyncio.to_thread(
|
|
synthesize_conclusion, para_text, findings, client
|
|
)
|
|
yield _sse({"type": "done", **conclusion_data})
|
|
|
|
# Auto-save analysis to database
|
|
try:
|
|
from app.shared.bootstrap import get_compliance_repository
|
|
from app.domain.compliance.ports import AnalysisRecord, FindingRecord
|
|
from datetime import datetime
|
|
|
|
repo = get_compliance_repository()
|
|
finding_records = [
|
|
FindingRecord(
|
|
id="",
|
|
analysis_id="",
|
|
seq=i,
|
|
title=f.get("title", ""),
|
|
description=f.get("desc", ""),
|
|
status=f.get("status", "ok"),
|
|
clause_ref=f.get("clause_ref"),
|
|
)
|
|
for i, f in enumerate(findings)
|
|
]
|
|
record = AnalysisRecord(
|
|
id="",
|
|
created_at=datetime.utcnow(),
|
|
created_by=current_user.username if hasattr(current_user, "username") else None,
|
|
doc_name=file_name or (title or "Pasted text"),
|
|
standard_name=title or "",
|
|
risk_score=conclusion_data.get("risk_score", 0),
|
|
conclusion=conclusion_data.get("conclusion", ""),
|
|
actions=conclusion_data.get("actions", []),
|
|
para_text=conclusion_data.get("para_text", ""),
|
|
highlight_terms=conclusion_data.get("highlight_terms", []),
|
|
findings=finding_records,
|
|
)
|
|
analysis_id = await asyncio.to_thread(repo.save_analysis, record)
|
|
yield _sse({"type": "saved", "analysis_id": analysis_id})
|
|
except NotImplementedError:
|
|
pass # No postgres backend configured — skip saving
|
|
except Exception as exc:
|
|
logger.warning("Failed to auto-save compliance analysis: {}", exc)
|
|
|
|
except Exception as exc:
|
|
logger.exception("analyze-stream pipeline error")
|
|
yield _sse({"type": "error", "text": str(exc)})
|
|
|
|
return StreamingResponse(
|
|
generate(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
|
)
|
|
|
|
|
|
@router.post("/chat/{segment_id}")
|
|
async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
|
|
"""Stream compliance Q&A grounded in real vector retrieval."""
|
|
query = request.query
|
|
if request.segment_context:
|
|
query = f"[段落分析上下文]\n{request.segment_context}\n\n用户问题:{request.query}"
|
|
|
|
_, event_stream = get_agent_conversation_service().stream_chat(
|
|
query=query,
|
|
top_k=5,
|
|
prompt_template="compliance_qa",
|
|
)
|
|
|
|
async def generate() -> AsyncGenerator[str, None]:
|
|
"""Translate agent SSE events to compliance chunk/done format."""
|
|
for event in event_stream:
|
|
event_type = event.get("event", "")
|
|
if event_type == "content":
|
|
text = event.get("data", "")
|
|
if text:
|
|
yield (
|
|
"event: message\n"
|
|
f"data: {json.dumps({'type': 'chunk', 'text': text}, ensure_ascii=False)}\n\n"
|
|
)
|
|
elif event_type == "done":
|
|
yield (
|
|
"event: message\n"
|
|
f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
|
)
|
|
await asyncio.sleep(0)
|
|
|
|
return StreamingResponse(
|
|
generate(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
|
)
|
|
|
|
|
|
@router.get("/history")
|
|
async def list_history(
|
|
limit: int = 20,
|
|
offset: int = 0,
|
|
current_user: UserClaims = Depends(get_current_user),
|
|
):
|
|
"""Return paginated list of saved compliance analyses (newest first)."""
|
|
from app.shared.bootstrap import get_compliance_repository
|
|
try:
|
|
repo = get_compliance_repository()
|
|
records = await asyncio.to_thread(repo.list_analyses, limit, offset)
|
|
return [
|
|
{
|
|
"id": r.id,
|
|
"created_at": r.created_at.isoformat(),
|
|
"created_by": r.created_by,
|
|
"doc_name": r.doc_name,
|
|
"standard_name": r.standard_name,
|
|
"risk_score": r.risk_score,
|
|
"finding_count": len(r.findings),
|
|
}
|
|
for r in records
|
|
]
|
|
except NotImplementedError:
|
|
return []
|
|
|
|
|
|
@router.get("/history/{analysis_id}")
|
|
async def get_history_item(
|
|
analysis_id: str,
|
|
current_user: UserClaims = Depends(get_current_user),
|
|
):
|
|
"""Return full analysis record including findings."""
|
|
from app.shared.bootstrap import get_compliance_repository
|
|
from fastapi import HTTPException
|
|
repo = get_compliance_repository()
|
|
record = await asyncio.to_thread(repo.get_analysis, analysis_id)
|
|
if not record:
|
|
raise HTTPException(status_code=404, detail="Analysis not found")
|
|
return {
|
|
"id": record.id,
|
|
"created_at": record.created_at.isoformat(),
|
|
"created_by": record.created_by,
|
|
"doc_name": record.doc_name,
|
|
"standard_name": record.standard_name,
|
|
"risk_score": record.risk_score,
|
|
"conclusion": record.conclusion,
|
|
"actions": record.actions,
|
|
"para_text": record.para_text,
|
|
"highlight_terms": record.highlight_terms,
|
|
"findings": [
|
|
{
|
|
"id": f.id,
|
|
"seq": f.seq,
|
|
"title": f.title,
|
|
"description": f.description,
|
|
"status": f.status,
|
|
"clause_ref": f.clause_ref,
|
|
}
|
|
for f in record.findings
|
|
],
|
|
}
|
|
|
|
|
|
@router.delete("/history/{analysis_id}", status_code=204)
|
|
async def delete_history_item(
|
|
analysis_id: str,
|
|
current_user: UserClaims = Depends(get_current_user),
|
|
):
|
|
"""Delete a saved analysis (cascade removes findings and chat messages)."""
|
|
from app.shared.bootstrap import get_compliance_repository
|
|
repo = get_compliance_repository()
|
|
await asyncio.to_thread(repo.delete_analysis, analysis_id)
|
|
|
|
|
|
@router.get("/history/{analysis_id}/download")
|
|
async def download_history_docx(
|
|
analysis_id: str,
|
|
current_user: UserClaims = Depends(get_current_user),
|
|
):
|
|
"""Return a DOCX compliance report for the given analysis."""
|
|
from app.shared.bootstrap import get_compliance_repository
|
|
from app.infrastructure.compliance.docx_export import generate_docx
|
|
from fastapi import HTTPException
|
|
from fastapi.responses import Response
|
|
|
|
repo = get_compliance_repository()
|
|
record = await asyncio.to_thread(repo.get_analysis, analysis_id)
|
|
if not record:
|
|
raise HTTPException(status_code=404, detail="Analysis not found")
|
|
|
|
docx_bytes = await asyncio.to_thread(generate_docx, record)
|
|
safe_name = (record.doc_name or "report").replace(" ", "_")[:50]
|
|
filename = f"compliance_{safe_name}_{record.created_at.strftime('%Y%m%d')}.docx"
|
|
return Response(
|
|
content=docx_bytes,
|
|
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
|
)
|
|
|
|
|
|
@router.get("/analyses/{analysis_id}/findings/{finding_id}/chat")
|
|
async def get_finding_chat_history(
|
|
analysis_id: str,
|
|
finding_id: str,
|
|
current_user: UserClaims = Depends(get_current_user),
|
|
):
|
|
"""Return persisted chat messages for a finding thread, oldest first."""
|
|
from app.shared.bootstrap import get_compliance_repository
|
|
try:
|
|
repo = get_compliance_repository()
|
|
messages = await asyncio.to_thread(repo.get_messages, finding_id)
|
|
return messages
|
|
except NotImplementedError:
|
|
return []
|
|
|
|
|
|
@router.post("/analyses/{analysis_id}/findings/{finding_id}/suggestions")
|
|
async def get_finding_suggestions(
|
|
analysis_id: str,
|
|
finding_id: str,
|
|
current_user: UserClaims = Depends(get_current_user),
|
|
):
|
|
"""Generate 3 LLM-powered follow-up question suggestions for a finding."""
|
|
from app.application.compliance.pipeline import generate_suggestions
|
|
from app.shared.bootstrap import get_compliance_repository
|
|
from app.services.llm.llm_factory import get_llm_client
|
|
from fastapi import HTTPException
|
|
|
|
repo = get_compliance_repository()
|
|
analysis = await asyncio.to_thread(repo.get_analysis, analysis_id)
|
|
if not analysis:
|
|
raise HTTPException(status_code=404, detail="Analysis not found")
|
|
|
|
finding = next((f for f in analysis.findings if f.id == finding_id), None)
|
|
if not finding:
|
|
raise HTTPException(status_code=404, detail="Finding not found")
|
|
|
|
client = get_llm_client(provider=settings.llm_provider, model=settings.llm_model)
|
|
questions = await asyncio.to_thread(generate_suggestions, finding, analysis, client)
|
|
return {"questions": questions}
|
|
|
|
|
|
@router.post("/analyses/{analysis_id}/findings/{finding_id}/chat")
|
|
async def finding_chat(
|
|
analysis_id: str,
|
|
finding_id: str,
|
|
request: ComplianceChatRequest,
|
|
current_user: UserClaims = Depends(get_current_user),
|
|
):
|
|
"""Stream a grounded chat response for a specific finding.
|
|
|
|
Loads the finding and analysis from DB to build grounded context.
|
|
Persists both user message and assistant response to finding_chat_messages.
|
|
"""
|
|
from app.application.compliance.pipeline import build_finding_context
|
|
from app.shared.bootstrap import get_compliance_repository
|
|
from fastapi import HTTPException
|
|
|
|
repo = get_compliance_repository()
|
|
analysis = await asyncio.to_thread(repo.get_analysis, analysis_id)
|
|
if not analysis:
|
|
raise HTTPException(status_code=404, detail="Analysis not found")
|
|
finding = next((f for f in analysis.findings if f.id == finding_id), None)
|
|
if not finding:
|
|
raise HTTPException(status_code=404, detail="Finding not found")
|
|
|
|
# Persist user message
|
|
await asyncio.to_thread(
|
|
repo.save_message, analysis_id, finding_id, "user", request.query
|
|
)
|
|
|
|
# Build message history (last 10 messages = 5 turns)
|
|
history = await asyncio.to_thread(repo.get_messages, finding_id)
|
|
history_messages = [
|
|
{"role": m["role"], "content": m["content"]}
|
|
for m in history[-10:]
|
|
]
|
|
|
|
# Build grounded system context
|
|
system_context = build_finding_context(finding, analysis)
|
|
full_query = f"[Compliance Finding Context]\n{system_context}\n\nUser question: {request.query}"
|
|
|
|
assistant_buffer: list[str] = []
|
|
|
|
async def generate() -> AsyncGenerator[str, None]:
|
|
try:
|
|
_, event_stream = get_agent_conversation_service().stream_chat(
|
|
query=full_query,
|
|
top_k=5,
|
|
prompt_template="compliance_qa",
|
|
)
|
|
for event in event_stream:
|
|
event_type = event.get("event", "")
|
|
if event_type == "content":
|
|
text = event.get("data", "")
|
|
if text:
|
|
assistant_buffer.append(text)
|
|
yield _sse({"type": "chunk", "text": text})
|
|
elif event_type == "done":
|
|
yield _sse({"type": "done"})
|
|
await asyncio.sleep(0)
|
|
except Exception as exc:
|
|
logger.exception("finding_chat stream error")
|
|
yield _sse({"type": "error", "text": str(exc)})
|
|
finally:
|
|
# Persist assistant response after stream completes
|
|
full_response = "".join(assistant_buffer)
|
|
if full_response:
|
|
try:
|
|
await asyncio.to_thread(
|
|
repo.save_message, analysis_id, finding_id, "assistant", full_response
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("Failed to persist assistant message: {}", exc)
|
|
|
|
return StreamingResponse(
|
|
generate(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
|
)
|