add
This commit is contained in:
@@ -5,17 +5,19 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from fastapi import APIRouter, File, UploadFile
|
||||
from fastapi import APIRouter, File, Form, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
from app.schemas.compliance import (
|
||||
AnalyzeResponse,
|
||||
ComplianceChatRequest,
|
||||
)
|
||||
from app.services.mock_data import generate_task_id, get_mock_compliance_result
|
||||
from app.shared.bootstrap import get_agent_conversation_service
|
||||
from app.shared.bootstrap import get_agent_conversation_service, get_retrieval_service
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
router = APIRouter(prefix="/compliance", tags=["合规分析"])
|
||||
@@ -62,6 +64,128 @@ async def get_result(task_id: str):
|
||||
return task["result"]
|
||||
|
||||
|
||||
def _sse(data: dict) -> str:
|
||||
return f"event: message\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
@router.post("/analyze-stream")
|
||||
async def analyze_stream(
|
||||
text: Optional[str] = Form(None),
|
||||
doc_id: Optional[str] = Form(None),
|
||||
file: Optional[UploadFile] = File(None),
|
||||
domains: Optional[str] = Form(None),
|
||||
title: Optional[str] = Form(None),
|
||||
):
|
||||
"""Stream compliance analysis as SSE events.
|
||||
|
||||
Stages: clause_split → retrieval (per clause) → gap_check → conclusion
|
||||
Events: stage | source | finding | done | error
|
||||
"""
|
||||
from app.application.compliance.pipeline import (
|
||||
check_clause_compliance,
|
||||
extract_text_from_doc_id,
|
||||
extract_text_from_file,
|
||||
retrieve_for_clause,
|
||||
split_into_clauses,
|
||||
synthesize_conclusion,
|
||||
)
|
||||
from app.services.llm.llm_factory import get_llm_client
|
||||
from app.shared.bootstrap import get_retrieval_service
|
||||
|
||||
# Read file content eagerly (before async generator)
|
||||
file_content: bytes | None = None
|
||||
file_name: str | None = None
|
||||
if file is not None:
|
||||
file_content = await file.read()
|
||||
file_name = file.filename
|
||||
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
client = get_llm_client(provider=settings.llm_provider, model=settings.llm_model)
|
||||
retrieval_service = get_retrieval_service()
|
||||
|
||||
# ── Stage 1: extract text ─────────────────────────────────────
|
||||
yield _sse({"type": "stage", "stage": "extracting", "label": "Extracting text…"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
if text:
|
||||
para_text = text.strip()
|
||||
elif doc_id:
|
||||
try:
|
||||
para_text = await asyncio.to_thread(extract_text_from_doc_id, doc_id)
|
||||
except Exception as exc:
|
||||
yield _sse({"type": "error", "text": f"Document not found: {exc}"})
|
||||
return
|
||||
elif file_content is not None:
|
||||
para_text = await asyncio.to_thread(
|
||||
extract_text_from_file, file_content, file_name or "upload"
|
||||
)
|
||||
else:
|
||||
yield _sse({"type": "error", "text": "No input provided"})
|
||||
return
|
||||
|
||||
if not para_text.strip():
|
||||
yield _sse({"type": "error", "text": "Could not extract text from the provided input"})
|
||||
return
|
||||
|
||||
# ── Stage 2: split into clauses ───────────────────────────────
|
||||
yield _sse({"type": "stage", "stage": "splitting", "label": "Splitting into clauses…"})
|
||||
await asyncio.sleep(0)
|
||||
clauses: list[str] = await asyncio.to_thread(split_into_clauses, para_text, client)
|
||||
|
||||
# ── Stage 3: retrieve + gap check per clause ──────────────────
|
||||
findings: list[dict] = []
|
||||
|
||||
for i, clause in enumerate(clauses):
|
||||
yield _sse({
|
||||
"type": "stage",
|
||||
"stage": "analyzing",
|
||||
"label": f"Analyzing clause {i + 1}/{len(clauses)}…",
|
||||
})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
chunks = await asyncio.to_thread(
|
||||
retrieve_for_clause, clause, retrieval_service, 5, domains or None
|
||||
)
|
||||
|
||||
# Emit source events
|
||||
for chunk in chunks[:3]:
|
||||
yield _sse({
|
||||
"type": "source",
|
||||
"standard": getattr(chunk, "doc_title", "") or getattr(chunk, "doc_name", ""),
|
||||
"clause": getattr(chunk, "section_title", "") or "",
|
||||
"score": round(float(getattr(chunk, "score", 0)), 3),
|
||||
"status": "retrieved",
|
||||
"full_content": (getattr(chunk, "text", "") or "")[:300],
|
||||
})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
finding = await asyncio.to_thread(check_clause_compliance, clause, chunks, client)
|
||||
if finding:
|
||||
findings.append(finding)
|
||||
yield _sse({"type": "finding", **finding})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# ── Stage 4: synthesize conclusion ────────────────────────────
|
||||
yield _sse({"type": "stage", "stage": "concluding", "label": "Generating conclusion…"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
conclusion_data = await asyncio.to_thread(
|
||||
synthesize_conclusion, para_text, findings, client
|
||||
)
|
||||
yield _sse({"type": "done", **conclusion_data})
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception("analyze-stream pipeline error")
|
||||
yield _sse({"type": "error", "text": str(exc)})
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat/{segment_id}")
|
||||
async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
|
||||
"""Stream compliance Q&A grounded in real vector retrieval."""
|
||||
|
||||
Reference in New Issue
Block a user