update for 1. 优化 2.中英切换
This commit is contained in:
@@ -85,10 +85,9 @@ async def analyze_stream(
|
||||
Events: stage | source | finding | done | error
|
||||
"""
|
||||
from app.application.compliance.pipeline import (
|
||||
check_clause_compliance,
|
||||
extract_text_from_doc_id,
|
||||
extract_text_from_file,
|
||||
retrieve_for_clause,
|
||||
run_clauses_parallel,
|
||||
split_into_clauses,
|
||||
synthesize_conclusion,
|
||||
)
|
||||
@@ -136,22 +135,28 @@ async def analyze_stream(
|
||||
await asyncio.sleep(0)
|
||||
clauses: list[str] = await asyncio.to_thread(split_into_clauses, para_text, client)
|
||||
|
||||
# ── Stage 3: retrieve + gap check per clause ──────────────────
|
||||
# ── Stage 3: retrieve + gap check (parallel across all clauses) ────────────
|
||||
findings: list[dict] = []
|
||||
|
||||
for i, clause in enumerate(clauses):
|
||||
yield _sse({
|
||||
"type": "stage",
|
||||
"stage": "analyzing",
|
||||
"label": f"Analyzing clause {i + 1}/{len(clauses)}…",
|
||||
})
|
||||
await asyncio.sleep(0)
|
||||
yield _sse({
|
||||
"type": "stage",
|
||||
"stage": "analyzing",
|
||||
"label": f"Analyzing {len(clauses)} clauses in parallel…",
|
||||
})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
chunks = await asyncio.to_thread(
|
||||
retrieve_for_clause, clause, retrieval_service, 5, domains or None
|
||||
)
|
||||
clause_results = await run_clauses_parallel(
|
||||
clauses, retrieval_service, client,
|
||||
top_k=5,
|
||||
domains=domains or None,
|
||||
)
|
||||
|
||||
# Emit source events
|
||||
for res in clause_results:
|
||||
i = res["index"]
|
||||
chunks = res["chunks"]
|
||||
finding = res["finding"]
|
||||
|
||||
# Emit source events for this clause
|
||||
for chunk in chunks[:3]:
|
||||
yield _sse({
|
||||
"type": "source",
|
||||
@@ -161,12 +166,11 @@ async def analyze_stream(
|
||||
"status": "retrieved",
|
||||
"full_content": (getattr(chunk, "text", "") or "")[:300],
|
||||
})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
finding = await asyncio.to_thread(check_clause_compliance, clause, chunks, client)
|
||||
if finding:
|
||||
findings.append(finding)
|
||||
yield _sse({"type": "finding", **finding})
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# ── Stage 4: synthesize conclusion ────────────────────────────
|
||||
@@ -178,6 +182,45 @@ async def analyze_stream(
|
||||
)
|
||||
yield _sse({"type": "done", **conclusion_data})
|
||||
|
||||
# Auto-save analysis to database
|
||||
try:
|
||||
from app.shared.bootstrap import get_compliance_repository
|
||||
from app.domain.compliance.ports import AnalysisRecord, FindingRecord
|
||||
from datetime import datetime
|
||||
|
||||
repo = get_compliance_repository()
|
||||
finding_records = [
|
||||
FindingRecord(
|
||||
id="",
|
||||
analysis_id="",
|
||||
seq=i,
|
||||
title=f.get("title", ""),
|
||||
description=f.get("desc", ""),
|
||||
status=f.get("status", "ok"),
|
||||
clause_ref=f.get("clause_ref"),
|
||||
)
|
||||
for i, f in enumerate(findings)
|
||||
]
|
||||
record = AnalysisRecord(
|
||||
id="",
|
||||
created_at=datetime.utcnow(),
|
||||
created_by=current_user.username if hasattr(current_user, "username") else None,
|
||||
doc_name=file_name or (title or "Pasted text"),
|
||||
standard_name=title or "",
|
||||
risk_score=conclusion_data.get("risk_score", 0),
|
||||
conclusion=conclusion_data.get("conclusion", ""),
|
||||
actions=conclusion_data.get("actions", []),
|
||||
para_text=conclusion_data.get("para_text", ""),
|
||||
highlight_terms=conclusion_data.get("highlight_terms", []),
|
||||
findings=finding_records,
|
||||
)
|
||||
analysis_id = await asyncio.to_thread(repo.save_analysis, record)
|
||||
yield _sse({"type": "saved", "analysis_id": analysis_id})
|
||||
except NotImplementedError:
|
||||
pass # No postgres backend configured — skip saving
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to auto-save compliance analysis: {}", exc)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception("analyze-stream pipeline error")
|
||||
yield _sse({"type": "error", "text": str(exc)})
|
||||
@@ -225,3 +268,226 @@ async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def list_history(
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
current_user: UserClaims = Depends(get_current_user),
|
||||
):
|
||||
"""Return paginated list of saved compliance analyses (newest first)."""
|
||||
from app.shared.bootstrap import get_compliance_repository
|
||||
try:
|
||||
repo = get_compliance_repository()
|
||||
records = await asyncio.to_thread(repo.list_analyses, limit, offset)
|
||||
return [
|
||||
{
|
||||
"id": r.id,
|
||||
"created_at": r.created_at.isoformat(),
|
||||
"created_by": r.created_by,
|
||||
"doc_name": r.doc_name,
|
||||
"standard_name": r.standard_name,
|
||||
"risk_score": r.risk_score,
|
||||
"finding_count": len(r.findings),
|
||||
}
|
||||
for r in records
|
||||
]
|
||||
except NotImplementedError:
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/history/{analysis_id}")
|
||||
async def get_history_item(
|
||||
analysis_id: str,
|
||||
current_user: UserClaims = Depends(get_current_user),
|
||||
):
|
||||
"""Return full analysis record including findings."""
|
||||
from app.shared.bootstrap import get_compliance_repository
|
||||
from fastapi import HTTPException
|
||||
repo = get_compliance_repository()
|
||||
record = await asyncio.to_thread(repo.get_analysis, analysis_id)
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Analysis not found")
|
||||
return {
|
||||
"id": record.id,
|
||||
"created_at": record.created_at.isoformat(),
|
||||
"created_by": record.created_by,
|
||||
"doc_name": record.doc_name,
|
||||
"standard_name": record.standard_name,
|
||||
"risk_score": record.risk_score,
|
||||
"conclusion": record.conclusion,
|
||||
"actions": record.actions,
|
||||
"para_text": record.para_text,
|
||||
"highlight_terms": record.highlight_terms,
|
||||
"findings": [
|
||||
{
|
||||
"id": f.id,
|
||||
"seq": f.seq,
|
||||
"title": f.title,
|
||||
"description": f.description,
|
||||
"status": f.status,
|
||||
"clause_ref": f.clause_ref,
|
||||
}
|
||||
for f in record.findings
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/history/{analysis_id}", status_code=204)
|
||||
async def delete_history_item(
|
||||
analysis_id: str,
|
||||
current_user: UserClaims = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a saved analysis (cascade removes findings and chat messages)."""
|
||||
from app.shared.bootstrap import get_compliance_repository
|
||||
repo = get_compliance_repository()
|
||||
await asyncio.to_thread(repo.delete_analysis, analysis_id)
|
||||
|
||||
|
||||
@router.get("/history/{analysis_id}/download")
|
||||
async def download_history_docx(
|
||||
analysis_id: str,
|
||||
current_user: UserClaims = Depends(get_current_user),
|
||||
):
|
||||
"""Return a DOCX compliance report for the given analysis."""
|
||||
from app.shared.bootstrap import get_compliance_repository
|
||||
from app.infrastructure.compliance.docx_export import generate_docx
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import Response
|
||||
|
||||
repo = get_compliance_repository()
|
||||
record = await asyncio.to_thread(repo.get_analysis, analysis_id)
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Analysis not found")
|
||||
|
||||
docx_bytes = await asyncio.to_thread(generate_docx, record)
|
||||
safe_name = (record.doc_name or "report").replace(" ", "_")[:50]
|
||||
filename = f"compliance_{safe_name}_{record.created_at.strftime('%Y%m%d')}.docx"
|
||||
return Response(
|
||||
content=docx_bytes,
|
||||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/analyses/{analysis_id}/findings/{finding_id}/chat")
|
||||
async def get_finding_chat_history(
|
||||
analysis_id: str,
|
||||
finding_id: str,
|
||||
current_user: UserClaims = Depends(get_current_user),
|
||||
):
|
||||
"""Return persisted chat messages for a finding thread, oldest first."""
|
||||
from app.shared.bootstrap import get_compliance_repository
|
||||
try:
|
||||
repo = get_compliance_repository()
|
||||
messages = await asyncio.to_thread(repo.get_messages, finding_id)
|
||||
return messages
|
||||
except NotImplementedError:
|
||||
return []
|
||||
|
||||
|
||||
@router.post("/analyses/{analysis_id}/findings/{finding_id}/suggestions")
|
||||
async def get_finding_suggestions(
|
||||
analysis_id: str,
|
||||
finding_id: str,
|
||||
current_user: UserClaims = Depends(get_current_user),
|
||||
):
|
||||
"""Generate 3 LLM-powered follow-up question suggestions for a finding."""
|
||||
from app.application.compliance.pipeline import generate_suggestions
|
||||
from app.shared.bootstrap import get_compliance_repository
|
||||
from app.services.llm.llm_factory import get_llm_client
|
||||
from fastapi import HTTPException
|
||||
|
||||
repo = get_compliance_repository()
|
||||
analysis = await asyncio.to_thread(repo.get_analysis, analysis_id)
|
||||
if not analysis:
|
||||
raise HTTPException(status_code=404, detail="Analysis not found")
|
||||
|
||||
finding = next((f for f in analysis.findings if f.id == finding_id), None)
|
||||
if not finding:
|
||||
raise HTTPException(status_code=404, detail="Finding not found")
|
||||
|
||||
client = get_llm_client(provider=settings.llm_provider, model=settings.llm_model)
|
||||
questions = await asyncio.to_thread(generate_suggestions, finding, analysis, client)
|
||||
return {"questions": questions}
|
||||
|
||||
|
||||
@router.post("/analyses/{analysis_id}/findings/{finding_id}/chat")
|
||||
async def finding_chat(
|
||||
analysis_id: str,
|
||||
finding_id: str,
|
||||
request: ComplianceChatRequest,
|
||||
current_user: UserClaims = Depends(get_current_user),
|
||||
):
|
||||
"""Stream a grounded chat response for a specific finding.
|
||||
|
||||
Loads the finding and analysis from DB to build grounded context.
|
||||
Persists both user message and assistant response to finding_chat_messages.
|
||||
"""
|
||||
from app.application.compliance.pipeline import build_finding_context
|
||||
from app.shared.bootstrap import get_compliance_repository
|
||||
from fastapi import HTTPException
|
||||
|
||||
repo = get_compliance_repository()
|
||||
analysis = await asyncio.to_thread(repo.get_analysis, analysis_id)
|
||||
if not analysis:
|
||||
raise HTTPException(status_code=404, detail="Analysis not found")
|
||||
finding = next((f for f in analysis.findings if f.id == finding_id), None)
|
||||
if not finding:
|
||||
raise HTTPException(status_code=404, detail="Finding not found")
|
||||
|
||||
# Persist user message
|
||||
await asyncio.to_thread(
|
||||
repo.save_message, analysis_id, finding_id, "user", request.query
|
||||
)
|
||||
|
||||
# Build message history (last 10 messages = 5 turns)
|
||||
history = await asyncio.to_thread(repo.get_messages, finding_id)
|
||||
history_messages = [
|
||||
{"role": m["role"], "content": m["content"]}
|
||||
for m in history[-10:]
|
||||
]
|
||||
|
||||
# Build grounded system context
|
||||
system_context = build_finding_context(finding, analysis)
|
||||
full_query = f"[Compliance Finding Context]\n{system_context}\n\nUser question: {request.query}"
|
||||
|
||||
assistant_buffer: list[str] = []
|
||||
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
_, event_stream = get_agent_conversation_service().stream_chat(
|
||||
query=full_query,
|
||||
top_k=5,
|
||||
prompt_template="compliance_qa",
|
||||
)
|
||||
for event in event_stream:
|
||||
event_type = event.get("event", "")
|
||||
if event_type == "content":
|
||||
text = event.get("data", "")
|
||||
if text:
|
||||
assistant_buffer.append(text)
|
||||
yield _sse({"type": "chunk", "text": text})
|
||||
elif event_type == "done":
|
||||
yield _sse({"type": "done"})
|
||||
await asyncio.sleep(0)
|
||||
except Exception as exc:
|
||||
logger.exception("finding_chat stream error")
|
||||
yield _sse({"type": "error", "text": str(exc)})
|
||||
finally:
|
||||
# Persist assistant response after stream completes
|
||||
full_response = "".join(assistant_buffer)
|
||||
if full_response:
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
repo.save_message, analysis_id, finding_id, "assistant", full_response
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to persist assistant message: {}", exc)
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user