update for 1. 优化 2.中英切换
This commit is contained in:
@@ -85,10 +85,9 @@ async def analyze_stream(
|
||||
Events: stage | source | finding | done | error
|
||||
"""
|
||||
from app.application.compliance.pipeline import (
|
||||
check_clause_compliance,
|
||||
extract_text_from_doc_id,
|
||||
extract_text_from_file,
|
||||
retrieve_for_clause,
|
||||
run_clauses_parallel,
|
||||
split_into_clauses,
|
||||
synthesize_conclusion,
|
||||
)
|
||||
@@ -136,22 +135,28 @@ async def analyze_stream(
|
||||
await asyncio.sleep(0)
|
||||
clauses: list[str] = await asyncio.to_thread(split_into_clauses, para_text, client)
|
||||
|
||||
# ── Stage 3: retrieve + gap check per clause ──────────────────
|
||||
# ── Stage 3: retrieve + gap check (parallel across all clauses) ────────────
|
||||
findings: list[dict] = []
|
||||
|
||||
for i, clause in enumerate(clauses):
|
||||
yield _sse({
|
||||
"type": "stage",
|
||||
"stage": "analyzing",
|
||||
"label": f"Analyzing clause {i + 1}/{len(clauses)}…",
|
||||
})
|
||||
await asyncio.sleep(0)
|
||||
yield _sse({
|
||||
"type": "stage",
|
||||
"stage": "analyzing",
|
||||
"label": f"Analyzing {len(clauses)} clauses in parallel…",
|
||||
})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
chunks = await asyncio.to_thread(
|
||||
retrieve_for_clause, clause, retrieval_service, 5, domains or None
|
||||
)
|
||||
clause_results = await run_clauses_parallel(
|
||||
clauses, retrieval_service, client,
|
||||
top_k=5,
|
||||
domains=domains or None,
|
||||
)
|
||||
|
||||
# Emit source events
|
||||
for res in clause_results:
|
||||
i = res["index"]
|
||||
chunks = res["chunks"]
|
||||
finding = res["finding"]
|
||||
|
||||
# Emit source events for this clause
|
||||
for chunk in chunks[:3]:
|
||||
yield _sse({
|
||||
"type": "source",
|
||||
@@ -161,12 +166,11 @@ async def analyze_stream(
|
||||
"status": "retrieved",
|
||||
"full_content": (getattr(chunk, "text", "") or "")[:300],
|
||||
})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
finding = await asyncio.to_thread(check_clause_compliance, clause, chunks, client)
|
||||
if finding:
|
||||
findings.append(finding)
|
||||
yield _sse({"type": "finding", **finding})
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# ── Stage 4: synthesize conclusion ────────────────────────────
|
||||
@@ -178,6 +182,45 @@ async def analyze_stream(
|
||||
)
|
||||
yield _sse({"type": "done", **conclusion_data})
|
||||
|
||||
# Auto-save analysis to database
|
||||
try:
|
||||
from app.shared.bootstrap import get_compliance_repository
|
||||
from app.domain.compliance.ports import AnalysisRecord, FindingRecord
|
||||
from datetime import datetime
|
||||
|
||||
repo = get_compliance_repository()
|
||||
finding_records = [
|
||||
FindingRecord(
|
||||
id="",
|
||||
analysis_id="",
|
||||
seq=i,
|
||||
title=f.get("title", ""),
|
||||
description=f.get("desc", ""),
|
||||
status=f.get("status", "ok"),
|
||||
clause_ref=f.get("clause_ref"),
|
||||
)
|
||||
for i, f in enumerate(findings)
|
||||
]
|
||||
record = AnalysisRecord(
|
||||
id="",
|
||||
created_at=datetime.utcnow(),
|
||||
created_by=current_user.username if hasattr(current_user, "username") else None,
|
||||
doc_name=file_name or (title or "Pasted text"),
|
||||
standard_name=title or "",
|
||||
risk_score=conclusion_data.get("risk_score", 0),
|
||||
conclusion=conclusion_data.get("conclusion", ""),
|
||||
actions=conclusion_data.get("actions", []),
|
||||
para_text=conclusion_data.get("para_text", ""),
|
||||
highlight_terms=conclusion_data.get("highlight_terms", []),
|
||||
findings=finding_records,
|
||||
)
|
||||
analysis_id = await asyncio.to_thread(repo.save_analysis, record)
|
||||
yield _sse({"type": "saved", "analysis_id": analysis_id})
|
||||
except NotImplementedError:
|
||||
pass # No postgres backend configured — skip saving
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to auto-save compliance analysis: {}", exc)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception("analyze-stream pipeline error")
|
||||
yield _sse({"type": "error", "text": str(exc)})
|
||||
@@ -225,3 +268,226 @@ async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def list_history(
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
current_user: UserClaims = Depends(get_current_user),
|
||||
):
|
||||
"""Return paginated list of saved compliance analyses (newest first)."""
|
||||
from app.shared.bootstrap import get_compliance_repository
|
||||
try:
|
||||
repo = get_compliance_repository()
|
||||
records = await asyncio.to_thread(repo.list_analyses, limit, offset)
|
||||
return [
|
||||
{
|
||||
"id": r.id,
|
||||
"created_at": r.created_at.isoformat(),
|
||||
"created_by": r.created_by,
|
||||
"doc_name": r.doc_name,
|
||||
"standard_name": r.standard_name,
|
||||
"risk_score": r.risk_score,
|
||||
"finding_count": len(r.findings),
|
||||
}
|
||||
for r in records
|
||||
]
|
||||
except NotImplementedError:
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/history/{analysis_id}")
|
||||
async def get_history_item(
|
||||
analysis_id: str,
|
||||
current_user: UserClaims = Depends(get_current_user),
|
||||
):
|
||||
"""Return full analysis record including findings."""
|
||||
from app.shared.bootstrap import get_compliance_repository
|
||||
from fastapi import HTTPException
|
||||
repo = get_compliance_repository()
|
||||
record = await asyncio.to_thread(repo.get_analysis, analysis_id)
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Analysis not found")
|
||||
return {
|
||||
"id": record.id,
|
||||
"created_at": record.created_at.isoformat(),
|
||||
"created_by": record.created_by,
|
||||
"doc_name": record.doc_name,
|
||||
"standard_name": record.standard_name,
|
||||
"risk_score": record.risk_score,
|
||||
"conclusion": record.conclusion,
|
||||
"actions": record.actions,
|
||||
"para_text": record.para_text,
|
||||
"highlight_terms": record.highlight_terms,
|
||||
"findings": [
|
||||
{
|
||||
"id": f.id,
|
||||
"seq": f.seq,
|
||||
"title": f.title,
|
||||
"description": f.description,
|
||||
"status": f.status,
|
||||
"clause_ref": f.clause_ref,
|
||||
}
|
||||
for f in record.findings
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/history/{analysis_id}", status_code=204)
|
||||
async def delete_history_item(
|
||||
analysis_id: str,
|
||||
current_user: UserClaims = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a saved analysis (cascade removes findings and chat messages)."""
|
||||
from app.shared.bootstrap import get_compliance_repository
|
||||
repo = get_compliance_repository()
|
||||
await asyncio.to_thread(repo.delete_analysis, analysis_id)
|
||||
|
||||
|
||||
@router.get("/history/{analysis_id}/download")
|
||||
async def download_history_docx(
|
||||
analysis_id: str,
|
||||
current_user: UserClaims = Depends(get_current_user),
|
||||
):
|
||||
"""Return a DOCX compliance report for the given analysis."""
|
||||
from app.shared.bootstrap import get_compliance_repository
|
||||
from app.infrastructure.compliance.docx_export import generate_docx
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import Response
|
||||
|
||||
repo = get_compliance_repository()
|
||||
record = await asyncio.to_thread(repo.get_analysis, analysis_id)
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Analysis not found")
|
||||
|
||||
docx_bytes = await asyncio.to_thread(generate_docx, record)
|
||||
safe_name = (record.doc_name or "report").replace(" ", "_")[:50]
|
||||
filename = f"compliance_{safe_name}_{record.created_at.strftime('%Y%m%d')}.docx"
|
||||
return Response(
|
||||
content=docx_bytes,
|
||||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/analyses/{analysis_id}/findings/{finding_id}/chat")
|
||||
async def get_finding_chat_history(
|
||||
analysis_id: str,
|
||||
finding_id: str,
|
||||
current_user: UserClaims = Depends(get_current_user),
|
||||
):
|
||||
"""Return persisted chat messages for a finding thread, oldest first."""
|
||||
from app.shared.bootstrap import get_compliance_repository
|
||||
try:
|
||||
repo = get_compliance_repository()
|
||||
messages = await asyncio.to_thread(repo.get_messages, finding_id)
|
||||
return messages
|
||||
except NotImplementedError:
|
||||
return []
|
||||
|
||||
|
||||
@router.post("/analyses/{analysis_id}/findings/{finding_id}/suggestions")
|
||||
async def get_finding_suggestions(
|
||||
analysis_id: str,
|
||||
finding_id: str,
|
||||
current_user: UserClaims = Depends(get_current_user),
|
||||
):
|
||||
"""Generate 3 LLM-powered follow-up question suggestions for a finding."""
|
||||
from app.application.compliance.pipeline import generate_suggestions
|
||||
from app.shared.bootstrap import get_compliance_repository
|
||||
from app.services.llm.llm_factory import get_llm_client
|
||||
from fastapi import HTTPException
|
||||
|
||||
repo = get_compliance_repository()
|
||||
analysis = await asyncio.to_thread(repo.get_analysis, analysis_id)
|
||||
if not analysis:
|
||||
raise HTTPException(status_code=404, detail="Analysis not found")
|
||||
|
||||
finding = next((f for f in analysis.findings if f.id == finding_id), None)
|
||||
if not finding:
|
||||
raise HTTPException(status_code=404, detail="Finding not found")
|
||||
|
||||
client = get_llm_client(provider=settings.llm_provider, model=settings.llm_model)
|
||||
questions = await asyncio.to_thread(generate_suggestions, finding, analysis, client)
|
||||
return {"questions": questions}
|
||||
|
||||
|
||||
@router.post("/analyses/{analysis_id}/findings/{finding_id}/chat")
|
||||
async def finding_chat(
|
||||
analysis_id: str,
|
||||
finding_id: str,
|
||||
request: ComplianceChatRequest,
|
||||
current_user: UserClaims = Depends(get_current_user),
|
||||
):
|
||||
"""Stream a grounded chat response for a specific finding.
|
||||
|
||||
Loads the finding and analysis from DB to build grounded context.
|
||||
Persists both user message and assistant response to finding_chat_messages.
|
||||
"""
|
||||
from app.application.compliance.pipeline import build_finding_context
|
||||
from app.shared.bootstrap import get_compliance_repository
|
||||
from fastapi import HTTPException
|
||||
|
||||
repo = get_compliance_repository()
|
||||
analysis = await asyncio.to_thread(repo.get_analysis, analysis_id)
|
||||
if not analysis:
|
||||
raise HTTPException(status_code=404, detail="Analysis not found")
|
||||
finding = next((f for f in analysis.findings if f.id == finding_id), None)
|
||||
if not finding:
|
||||
raise HTTPException(status_code=404, detail="Finding not found")
|
||||
|
||||
# Persist user message
|
||||
await asyncio.to_thread(
|
||||
repo.save_message, analysis_id, finding_id, "user", request.query
|
||||
)
|
||||
|
||||
# Build message history (last 10 messages = 5 turns)
|
||||
history = await asyncio.to_thread(repo.get_messages, finding_id)
|
||||
history_messages = [
|
||||
{"role": m["role"], "content": m["content"]}
|
||||
for m in history[-10:]
|
||||
]
|
||||
|
||||
# Build grounded system context
|
||||
system_context = build_finding_context(finding, analysis)
|
||||
full_query = f"[Compliance Finding Context]\n{system_context}\n\nUser question: {request.query}"
|
||||
|
||||
assistant_buffer: list[str] = []
|
||||
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
_, event_stream = get_agent_conversation_service().stream_chat(
|
||||
query=full_query,
|
||||
top_k=5,
|
||||
prompt_template="compliance_qa",
|
||||
)
|
||||
for event in event_stream:
|
||||
event_type = event.get("event", "")
|
||||
if event_type == "content":
|
||||
text = event.get("data", "")
|
||||
if text:
|
||||
assistant_buffer.append(text)
|
||||
yield _sse({"type": "chunk", "text": text})
|
||||
elif event_type == "done":
|
||||
yield _sse({"type": "done"})
|
||||
await asyncio.sleep(0)
|
||||
except Exception as exc:
|
||||
logger.exception("finding_chat stream error")
|
||||
yield _sse({"type": "error", "text": str(exc)})
|
||||
finally:
|
||||
# Persist assistant response after stream completes
|
||||
full_response = "".join(assistant_buffer)
|
||||
if full_response:
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
repo.save_message, analysis_id, finding_id, "assistant", full_response
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to persist assistant message: {}", exc)
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ All functions are synchronous — call them via asyncio.to_thread() in async SSE
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@@ -12,10 +13,20 @@ import tempfile
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||
|
||||
# Shared retry policy for LLM calls: 3 attempts, exponential back-off 1–4 s.
|
||||
_llm_retry = retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=4),
|
||||
retry=retry_if_exception_type((ValueError, TimeoutError, ConnectionError)),
|
||||
reraise=True,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.application.knowledge import KnowledgeRetrievalService
|
||||
from app.domain.retrieval import RetrievedChunk
|
||||
from app.domain.compliance.ports import AnalysisRecord, FindingRecord
|
||||
from app.services.llm.base_client import BaseLLMClient
|
||||
|
||||
|
||||
@@ -109,17 +120,67 @@ def retrieve_for_clause(
|
||||
return retrieval_service.retrieve(query=clause, top_k=top_k, filters=domains)
|
||||
|
||||
|
||||
def process_single_clause(
|
||||
clause: str,
|
||||
index: int,
|
||||
retrieval_service: "KnowledgeRetrievalService",
|
||||
client: "BaseLLMClient",
|
||||
top_k: int = 5,
|
||||
domains: str | None = None,
|
||||
) -> dict:
|
||||
"""Process one clause: retrieve relevant regulations then check compliance.
|
||||
|
||||
Returns a dict with keys: index, chunks, finding (may be None on LLM failure).
|
||||
Designed to run inside asyncio.to_thread() for parallel execution.
|
||||
"""
|
||||
chunks = retrieve_for_clause(clause, retrieval_service, top_k, domains)
|
||||
finding = check_clause_compliance(clause, chunks, client)
|
||||
return {"index": index, "chunks": chunks, "finding": finding}
|
||||
|
||||
|
||||
async def run_clauses_parallel(
|
||||
clauses: list[str],
|
||||
retrieval_service: "KnowledgeRetrievalService",
|
||||
client: "BaseLLMClient",
|
||||
top_k: int = 5,
|
||||
domains: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""Run all clauses through retrieve+gap-check in parallel.
|
||||
|
||||
Results are returned in the original clause order even though processing
|
||||
is concurrent. Exceptions in individual clauses are caught and returned as
|
||||
dicts with finding=None so the stream continues for remaining clauses.
|
||||
|
||||
Both retrieval_service and client must be thread-safe — they are shared
|
||||
across all asyncio.to_thread() calls without locking.
|
||||
"""
|
||||
tasks = [
|
||||
asyncio.to_thread(
|
||||
process_single_clause,
|
||||
clause, i, retrieval_service, client, top_k, domains,
|
||||
)
|
||||
for i, clause in enumerate(clauses)
|
||||
]
|
||||
raw = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
results = []
|
||||
for i, r in enumerate(raw):
|
||||
if isinstance(r, Exception):
|
||||
logger.warning("Clause {} processing failed: {}", i, r)
|
||||
results.append({"index": i, "chunks": [], "finding": None})
|
||||
else:
|
||||
results.append(r)
|
||||
return results
|
||||
|
||||
|
||||
def check_clause_compliance(
|
||||
clause: str,
|
||||
chunks: list["RetrievedChunk"],
|
||||
client: "BaseLLMClient",
|
||||
) -> dict | None:
|
||||
if not chunks:
|
||||
return None
|
||||
reg_context = "\n".join(
|
||||
f"[{i+1}] {c.doc_title} {c.section_title or ''}: {c.text[:300]}"
|
||||
for i, c in enumerate(chunks[:5])
|
||||
)
|
||||
) if chunks else "(no regulatory context retrieved)"
|
||||
prompt = (
|
||||
"You are a compliance expert. Judge whether the following business clause "
|
||||
"complies with the retrieved regulations.\n\n"
|
||||
@@ -135,9 +196,19 @@ def check_clause_compliance(
|
||||
"status: ok=compliant, warn=gap exists, risk=critical/missing\n"
|
||||
"Return ONLY the JSON object."
|
||||
)
|
||||
response = client.chat([{"role": "user", "content": prompt}], max_tokens=500)
|
||||
if not response.is_success:
|
||||
|
||||
def _do_check():
|
||||
resp = client.chat([{"role": "user", "content": prompt}], max_tokens=500)
|
||||
if not resp.is_success:
|
||||
raise ValueError("LLM returned non-success for gap check")
|
||||
return resp
|
||||
|
||||
try:
|
||||
response = _llm_retry(_do_check)()
|
||||
except Exception as exc:
|
||||
logger.warning("check_clause_compliance LLM call failed after retries: {}", exc)
|
||||
return None
|
||||
|
||||
try:
|
||||
result = _extract_json(response.content)
|
||||
if isinstance(result, dict) and "status" in result:
|
||||
@@ -182,12 +253,11 @@ def synthesize_conclusion(
|
||||
' {"label": "Priority", "value": "High/Medium/Low", "risk": true}\n'
|
||||
' ],\n'
|
||||
' "risk_score": 0-100 (integer, higher=riskier),\n'
|
||||
' "highlight_terms": ["Key terms to highlight, max 10 terms"],\n'
|
||||
' "highlight_terms": ["term1", "term2"], // up to 10 key technical/legal terms actually present in the text\n'
|
||||
' "para_text": "Original text or summary (max 600 chars)"\n'
|
||||
"}\n"
|
||||
"Return ONLY the JSON object."
|
||||
)
|
||||
response = client.chat([{"role": "user", "content": prompt}], max_tokens=1200)
|
||||
fallback = {
|
||||
"conclusion": "Compliance analysis complete. Review findings and create remediation plan.",
|
||||
"actions": [
|
||||
@@ -198,8 +268,19 @@ def synthesize_conclusion(
|
||||
"highlight_terms": [],
|
||||
"para_text": para_text[:800],
|
||||
}
|
||||
if not response.is_success:
|
||||
|
||||
def _do_synthesize():
|
||||
resp = client.chat([{"role": "user", "content": prompt}], max_tokens=1200)
|
||||
if not resp.is_success:
|
||||
raise ValueError("LLM returned non-success for synthesis")
|
||||
return resp
|
||||
|
||||
try:
|
||||
response = _llm_retry(_do_synthesize)()
|
||||
except Exception as exc:
|
||||
logger.warning("synthesize_conclusion LLM call failed after retries: {}", exc)
|
||||
return fallback
|
||||
|
||||
try:
|
||||
result = _extract_json(response.content)
|
||||
if isinstance(result, dict):
|
||||
@@ -212,4 +293,78 @@ def synthesize_conclusion(
|
||||
}
|
||||
except (ValueError, TypeError) as exc:
|
||||
logger.warning("Conclusion synthesis JSON parse failed: {}", exc)
|
||||
return fallback
|
||||
return fallback
|
||||
|
||||
|
||||
_SUGGESTION_FOCUS = {
|
||||
"risk": "Focus on remediation steps, required certifications, and timeline to resolve.",
|
||||
"warn": "Focus on identifying the specific compliance gap and how to close it.",
|
||||
"ok": "Focus on maintaining compliance evidence and monitoring future changes.",
|
||||
}
|
||||
|
||||
_SUGGESTION_FALLBACK = {
|
||||
"risk": [
|
||||
"What specific certifications or documents are required to remediate this finding?",
|
||||
"What is the typical remediation timeline for this type of non-compliance?",
|
||||
"Which regulation clause defines the exact requirement?",
|
||||
],
|
||||
"warn": [
|
||||
"What is the exact gap between the current state and the requirement?",
|
||||
"What evidence would demonstrate partial compliance?",
|
||||
"Which regulation clause applies to this warning?",
|
||||
],
|
||||
"ok": [
|
||||
"What documentation should be maintained to evidence this compliance?",
|
||||
"How should this area be monitored as regulations evolve?",
|
||||
"Are there related clauses that may affect this compliant area?",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def build_finding_context(finding: "FindingRecord", analysis: "AnalysisRecord") -> str:
|
||||
"""Build a grounded system context string for a finding chat thread.
|
||||
|
||||
Combines finding details with analysis metadata so the LLM has full
|
||||
context without relying on the frontend to pass segment_context.
|
||||
"""
|
||||
return (
|
||||
f"Document: {analysis.doc_name}\n"
|
||||
f"Standard: {analysis.standard_name}\n"
|
||||
f"Finding [{finding.seq + 1}]: {finding.title}\n"
|
||||
f"Status: {finding.status}\n"
|
||||
f"Clause reference: {finding.clause_ref or 'N/A'}\n"
|
||||
f"Description: {finding.description}\n"
|
||||
f"Overall conclusion: {analysis.conclusion}\n"
|
||||
)
|
||||
|
||||
|
||||
def generate_suggestions(
|
||||
finding: "FindingRecord",
|
||||
analysis: "AnalysisRecord",
|
||||
client: "BaseLLMClient",
|
||||
) -> list[str]:
|
||||
"""Generate 3 context-aware follow-up questions for a finding chat thread.
|
||||
|
||||
Returns exactly 3 question strings. Falls back to static templates on error.
|
||||
"""
|
||||
fallback = _SUGGESTION_FALLBACK.get(finding.status, _SUGGESTION_FALLBACK["warn"])
|
||||
context = build_finding_context(finding, analysis)
|
||||
focus = _SUGGESTION_FOCUS.get(finding.status, _SUGGESTION_FOCUS["warn"])
|
||||
prompt = (
|
||||
f"{context}\n\n"
|
||||
f"Task: {focus}\n"
|
||||
"Generate exactly 3 concise follow-up questions a compliance analyst would ask.\n"
|
||||
'Return JSON: {"questions": ["question 1", "question 2", "question 3"]}\n'
|
||||
"Return ONLY the JSON object."
|
||||
)
|
||||
response = client.chat([{"role": "user", "content": prompt}], max_tokens=300)
|
||||
if not response.is_success:
|
||||
return fallback
|
||||
try:
|
||||
result = _extract_json(response.content)
|
||||
questions = result.get("questions", [])
|
||||
if isinstance(questions, list) and len(questions) >= 3:
|
||||
return [str(q) for q in questions[:3]]
|
||||
except (ValueError, TypeError) as exc:
|
||||
logger.warning("generate_suggestions JSON parse failed: {}", exc)
|
||||
return fallback
|
||||
|
||||
0
backend/app/domain/compliance/__init__.py
Normal file
0
backend/app/domain/compliance/__init__.py
Normal file
66
backend/app/domain/compliance/ports.py
Normal file
66
backend/app/domain/compliance/ports.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Domain ports for compliance history persistence."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class FindingRecord:
|
||||
"""Single finding row linked to an analysis."""
|
||||
id: str
|
||||
analysis_id: str
|
||||
seq: int
|
||||
title: str
|
||||
description: str
|
||||
status: str # "ok" | "warn" | "risk"
|
||||
clause_ref: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnalysisRecord:
|
||||
"""Full compliance analysis record with nested findings."""
|
||||
id: str # UUID string; empty string means not yet persisted
|
||||
created_at: datetime
|
||||
created_by: Optional[str]
|
||||
doc_name: str
|
||||
standard_name: str
|
||||
risk_score: int
|
||||
conclusion: str
|
||||
actions: list # list[dict] — serialised action items
|
||||
para_text: str
|
||||
highlight_terms: list # list[str]
|
||||
findings: list[FindingRecord] = field(default_factory=list)
|
||||
|
||||
|
||||
class ComplianceRepository(ABC):
|
||||
"""Port for persisting and retrieving compliance analysis records."""
|
||||
|
||||
@abstractmethod
|
||||
def save_analysis(self, record: AnalysisRecord) -> str:
|
||||
"""Persist a new analysis record and return the assigned UUID string."""
|
||||
|
||||
@abstractmethod
|
||||
def list_analyses(self, limit: int = 50, offset: int = 0) -> list[AnalysisRecord]:
|
||||
"""Return analyses ordered by created_at DESC, without nested findings."""
|
||||
|
||||
@abstractmethod
|
||||
def get_analysis(self, analysis_id: str) -> Optional[AnalysisRecord]:
|
||||
"""Return a single analysis with all nested findings, or None."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_analysis(self, analysis_id: str) -> None:
|
||||
"""Delete an analysis and all related findings and chat messages (cascade)."""
|
||||
|
||||
@abstractmethod
|
||||
def save_message(self, analysis_id: str, finding_id: str, role: str, content: str) -> str:
|
||||
"""Persist a chat message and return its UUID string."""
|
||||
|
||||
@abstractmethod
|
||||
def get_messages(self, finding_id: str) -> list[dict]:
|
||||
"""Return chat messages for a finding ordered by created_at ASC.
|
||||
|
||||
Each dict has keys: id, role, content, created_at (ISO string).
|
||||
"""
|
||||
0
backend/app/infrastructure/compliance/__init__.py
Normal file
0
backend/app/infrastructure/compliance/__init__.py
Normal file
101
backend/app/infrastructure/compliance/docx_export.py
Normal file
101
backend/app/infrastructure/compliance/docx_export.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""DOCX report generator for compliance analysis results.
|
||||
|
||||
Uses python-docx (already in requirements.txt). Returns raw bytes so the
|
||||
caller can stream the response without writing to disk.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from io import BytesIO
|
||||
|
||||
from docx import Document
|
||||
from docx.shared import Pt, RGBColor
|
||||
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
||||
|
||||
from app.domain.compliance.ports import AnalysisRecord
|
||||
|
||||
_STATUS_LABEL = {"ok": "Compliant", "warn": "Warning", "risk": "Non-Compliant"}
|
||||
_STATUS_COLOR = {
|
||||
"ok": RGBColor(0x22, 0x8B, 0x22),
|
||||
"warn": RGBColor(0xFF, 0x8C, 0x00),
|
||||
"risk": RGBColor(0xDC, 0x14, 0x3C),
|
||||
}
|
||||
|
||||
|
||||
def generate_docx(record: AnalysisRecord) -> bytes:
|
||||
"""Generate a compliance report DOCX and return its raw bytes.
|
||||
|
||||
Structure:
|
||||
- Cover: document name, standard, date, risk score
|
||||
- Executive summary (conclusion)
|
||||
- Findings table
|
||||
- Recommended actions
|
||||
- Footer note
|
||||
"""
|
||||
doc = Document()
|
||||
|
||||
# ── Cover ──────────────────────────────────────────────────────────────────
|
||||
title_para = doc.add_heading("Compliance Analysis Report", level=0)
|
||||
title_para.alignment = WD_ALIGN_PARAGRAPH.CENTER
|
||||
|
||||
doc.add_paragraph("")
|
||||
meta_table = doc.add_table(rows=4, cols=2)
|
||||
meta_table.style = "Table Grid"
|
||||
labels = ["Document", "Standard", "Date", "Risk Score"]
|
||||
values = [
|
||||
record.doc_name,
|
||||
record.standard_name,
|
||||
record.created_at.strftime("%Y-%m-%d %H:%M UTC") if record.created_at else "",
|
||||
f"{record.risk_score} / 100",
|
||||
]
|
||||
for i, (label, value) in enumerate(zip(labels, values)):
|
||||
meta_table.cell(i, 0).text = label
|
||||
meta_table.cell(i, 1).text = value
|
||||
|
||||
# ── Executive Summary ──────────────────────────────────────────────────────
|
||||
doc.add_heading("Executive Summary", level=1)
|
||||
doc.add_paragraph(record.conclusion)
|
||||
|
||||
# ── Findings ───────────────────────────────────────────────────────────────
|
||||
doc.add_heading("Findings", level=1)
|
||||
if record.findings:
|
||||
table = doc.add_table(rows=1, cols=4)
|
||||
table.style = "Table Grid"
|
||||
hdr = table.rows[0].cells
|
||||
for i, h in enumerate(["#", "Status", "Title", "Description / Clause"]):
|
||||
hdr[i].text = h
|
||||
for run in hdr[i].paragraphs[0].runs:
|
||||
run.bold = True
|
||||
|
||||
for f in record.findings:
|
||||
row = table.add_row().cells
|
||||
row[0].text = str(f.seq + 1)
|
||||
row[1].text = _STATUS_LABEL.get(f.status, f.status)
|
||||
row[2].text = f.title
|
||||
desc = f.description
|
||||
if f.clause_ref:
|
||||
desc += f"\n[{f.clause_ref}]"
|
||||
row[3].text = desc
|
||||
else:
|
||||
doc.add_paragraph("No findings recorded.")
|
||||
|
||||
# ── Recommended Actions ────────────────────────────────────────────────────
|
||||
doc.add_heading("Recommended Actions", level=1)
|
||||
for i, action in enumerate(record.actions, start=1):
|
||||
label = action.get("label", "Action")
|
||||
value = action.get("value", "")
|
||||
doc.add_paragraph(f"{i}. {label}: {value}", style="List Number")
|
||||
|
||||
# ── Footer note ────────────────────────────────────────────────────────────
|
||||
doc.add_paragraph("")
|
||||
footer = doc.add_paragraph(
|
||||
f"Generated by AI Regulation Analysis System — {datetime.now(timezone.utc).strftime('%Y-%m-%d')}"
|
||||
)
|
||||
footer.alignment = WD_ALIGN_PARAGRAPH.CENTER
|
||||
for run in footer.runs:
|
||||
run.font.size = Pt(8)
|
||||
run.font.color.rgb = RGBColor(0x88, 0x88, 0x88)
|
||||
|
||||
buf = BytesIO()
|
||||
doc.save(buf)
|
||||
return buf.getvalue()
|
||||
280
backend/app/infrastructure/compliance/repository.py
Normal file
280
backend/app/infrastructure/compliance/repository.py
Normal file
@@ -0,0 +1,280 @@
|
||||
# backend/app/infrastructure/compliance/repository.py
|
||||
"""PostgreSQL-backed compliance analysis repository.
|
||||
|
||||
Follows the same psycopg2 pattern as PostgresDocumentRepository:
|
||||
ThreadedConnectionPool + RealDictCursor for reads, _ensure_schema on init.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from loguru import logger
|
||||
|
||||
from app.domain.compliance.ports import (
|
||||
AnalysisRecord,
|
||||
ComplianceRepository,
|
||||
FindingRecord,
|
||||
)
|
||||
|
||||
|
||||
class PostgresComplianceRepository(ComplianceRepository):
|
||||
"""Stores compliance analyses, findings, and finding chat messages in PostgreSQL."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
user: str,
|
||||
password: str,
|
||||
dbname: str,
|
||||
minconn: int = 1,
|
||||
maxconn: int = 5,
|
||||
) -> None:
|
||||
self._pool = psycopg2.pool.ThreadedConnectionPool(
|
||||
minconn=minconn,
|
||||
maxconn=maxconn,
|
||||
host=host,
|
||||
port=port,
|
||||
user=user,
|
||||
password=password,
|
||||
dbname=dbname,
|
||||
)
|
||||
self._ensure_schema()
|
||||
|
||||
@contextmanager
|
||||
def _conn(self):
|
||||
conn = self._pool.getconn()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
self._pool.putconn(conn)
|
||||
|
||||
def _ensure_schema(self) -> None:
|
||||
"""Create tables if they do not exist."""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS compliance_analyses (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
created_by VARCHAR(255),
|
||||
doc_name VARCHAR(500),
|
||||
standard_name VARCHAR(500),
|
||||
risk_score INTEGER,
|
||||
conclusion TEXT,
|
||||
actions JSONB,
|
||||
para_text TEXT,
|
||||
highlight_terms JSONB
|
||||
);
|
||||
""")
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS compliance_findings (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
analysis_id UUID NOT NULL REFERENCES compliance_analyses(id) ON DELETE CASCADE,
|
||||
seq INTEGER NOT NULL,
|
||||
title VARCHAR(500),
|
||||
description TEXT,
|
||||
status VARCHAR(50),
|
||||
clause_ref VARCHAR(200)
|
||||
);
|
||||
""")
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS finding_chat_messages (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
analysis_id UUID NOT NULL REFERENCES compliance_analyses(id) ON DELETE CASCADE,
|
||||
finding_id UUID NOT NULL REFERENCES compliance_findings(id) ON DELETE CASCADE,
|
||||
role VARCHAR(20) NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
def save_analysis(self, record: AnalysisRecord) -> str:
|
||||
"""Insert analysis + findings; return the new analysis UUID."""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO compliance_analyses
|
||||
(created_by, doc_name, standard_name, risk_score,
|
||||
conclusion, actions, para_text, highlight_terms)
|
||||
VALUES
|
||||
(%(created_by)s, %(doc_name)s, %(standard_name)s, %(risk_score)s,
|
||||
%(conclusion)s, %(actions)s, %(para_text)s, %(highlight_terms)s)
|
||||
RETURNING id
|
||||
""",
|
||||
{
|
||||
"created_by": record.created_by,
|
||||
"doc_name": record.doc_name,
|
||||
"standard_name": record.standard_name,
|
||||
"risk_score": record.risk_score,
|
||||
"conclusion": record.conclusion,
|
||||
"actions": json.dumps(record.actions, ensure_ascii=False),
|
||||
"para_text": record.para_text,
|
||||
"highlight_terms": json.dumps(record.highlight_terms, ensure_ascii=False),
|
||||
},
|
||||
)
|
||||
row = cur.fetchone()
|
||||
analysis_id = str(row["id"])
|
||||
|
||||
if record.findings:
|
||||
with conn.cursor() as cur:
|
||||
for f in record.findings:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO compliance_findings
|
||||
(analysis_id, seq, title, description, status, clause_ref)
|
||||
VALUES
|
||||
(%(analysis_id)s, %(seq)s, %(title)s, %(desc)s, %(status)s, %(clause_ref)s)
|
||||
""",
|
||||
{
|
||||
"analysis_id": analysis_id,
|
||||
"seq": f.seq,
|
||||
"title": f.title,
|
||||
"desc": f.description,
|
||||
"status": f.status,
|
||||
"clause_ref": f.clause_ref,
|
||||
},
|
||||
)
|
||||
conn.commit()
|
||||
return analysis_id
|
||||
|
||||
def list_analyses(self, limit: int = 50, offset: int = 0) -> list[AnalysisRecord]:
|
||||
"""Return analyses without nested findings, ordered newest first."""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, created_at, created_by, doc_name, standard_name,
|
||||
risk_score, conclusion, actions, para_text, highlight_terms
|
||||
FROM compliance_analyses
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %(limit)s OFFSET %(offset)s
|
||||
""",
|
||||
{"limit": limit, "offset": offset},
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [self._row_to_record(dict(r)) for r in rows]
|
||||
|
||||
def get_analysis(self, analysis_id: str) -> Optional[AnalysisRecord]:
|
||||
"""Return analysis with nested findings list."""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(
|
||||
"SELECT * FROM compliance_analyses WHERE id = %(id)s",
|
||||
{"id": analysis_id},
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
record = self._row_to_record(dict(row))
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, analysis_id, seq, title, description, status, clause_ref
|
||||
FROM compliance_findings
|
||||
WHERE analysis_id = %(id)s
|
||||
ORDER BY seq
|
||||
""",
|
||||
{"id": analysis_id},
|
||||
)
|
||||
findings = [
|
||||
FindingRecord(
|
||||
id=str(r["id"]),
|
||||
analysis_id=str(r["analysis_id"]),
|
||||
seq=r["seq"],
|
||||
title=r["title"] or "",
|
||||
description=r["description"] or "",
|
||||
status=r["status"] or "ok",
|
||||
clause_ref=r["clause_ref"],
|
||||
)
|
||||
for r in cur.fetchall()
|
||||
]
|
||||
record.findings = findings
|
||||
return record
|
||||
|
||||
def delete_analysis(self, analysis_id: str) -> None:
|
||||
"""Delete analysis; findings and chat messages cascade automatically."""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"DELETE FROM compliance_analyses WHERE id = %(id)s",
|
||||
{"id": analysis_id},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def save_message(self, analysis_id: str, finding_id: str, role: str, content: str) -> str:
|
||||
"""Persist a chat message; return its UUID."""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO finding_chat_messages
|
||||
(analysis_id, finding_id, role, content)
|
||||
VALUES
|
||||
(%(analysis_id)s, %(finding_id)s, %(role)s, %(content)s)
|
||||
RETURNING id
|
||||
""",
|
||||
{
|
||||
"analysis_id": analysis_id,
|
||||
"finding_id": finding_id,
|
||||
"role": role,
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
row = cur.fetchone()
|
||||
conn.commit()
|
||||
return str(row["id"])
|
||||
|
||||
def get_messages(self, finding_id: str) -> list[dict]:
|
||||
"""Return messages for a finding, oldest first."""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, role, content, created_at
|
||||
FROM finding_chat_messages
|
||||
WHERE finding_id = %(finding_id)s
|
||||
ORDER BY created_at ASC
|
||||
""",
|
||||
{"finding_id": finding_id},
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [
|
||||
{
|
||||
"id": str(r["id"]),
|
||||
"role": r["role"],
|
||||
"content": r["content"],
|
||||
"created_at": r["created_at"].isoformat() if r["created_at"] else "",
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
def _row_to_record(self, row: dict) -> AnalysisRecord:
|
||||
"""Convert a RealDictCursor row to an AnalysisRecord (no findings)."""
|
||||
actions = row.get("actions") or []
|
||||
if isinstance(actions, str):
|
||||
actions = json.loads(actions)
|
||||
highlight_terms = row.get("highlight_terms") or []
|
||||
if isinstance(highlight_terms, str):
|
||||
highlight_terms = json.loads(highlight_terms)
|
||||
return AnalysisRecord(
|
||||
id=str(row["id"]),
|
||||
created_at=row["created_at"] if isinstance(row["created_at"], datetime) else datetime.utcnow(),
|
||||
created_by=row.get("created_by"),
|
||||
doc_name=row.get("doc_name") or "",
|
||||
standard_name=row.get("standard_name") or "",
|
||||
risk_score=int(row.get("risk_score") or 0),
|
||||
conclusion=row.get("conclusion") or "",
|
||||
actions=actions,
|
||||
para_text=row.get("para_text") or "",
|
||||
highlight_terms=highlight_terms,
|
||||
findings=[],
|
||||
)
|
||||
@@ -0,0 +1,21 @@
|
||||
"""No-op reranker stub.
|
||||
|
||||
Returns the original candidate list sliced to top_k.
|
||||
Replace with CrossEncoderReranker when a local cross-encoder model is available.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from app.domain.retrieval.models import RetrievedChunk
|
||||
from app.domain.retrieval.ports import Reranker
|
||||
|
||||
|
||||
class PassThroughReranker(Reranker):
|
||||
"""Pass-through reranker that preserves original retrieval order.
|
||||
|
||||
Acts as a placeholder for future cross-encoder reranking (e.g. ms-marco-MiniLM).
|
||||
Wire via bootstrap.get_compliance_reranker() when ready to swap.
|
||||
"""
|
||||
|
||||
def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int) -> list[RetrievedChunk]:
|
||||
"""Return the first top_k chunks without reordering."""
|
||||
return chunks[:top_k]
|
||||
@@ -40,6 +40,8 @@ from app.infrastructure.vectorstore.cross_encoder_reranker import OpenAICompatib
|
||||
from app.infrastructure.vectorstore.dense_retriever import DenseRetriever
|
||||
from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex
|
||||
from app.services.llm.llm_factory import LLMFactory
|
||||
from app.domain.compliance.ports import ComplianceRepository
|
||||
from app.infrastructure.compliance.repository import PostgresComplianceRepository
|
||||
# Keep shared wiring centralized so dependency construction remains consistent.
|
||||
|
||||
|
||||
@@ -311,6 +313,28 @@ def get_event_store() -> BaseEventStore:
|
||||
return MockEventStore()
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_compliance_repository() -> ComplianceRepository:
|
||||
"""Return the compliance analysis repository.
|
||||
|
||||
Requires document_repository_backend=postgres and valid postgres_* settings.
|
||||
Raises NotImplementedError for any other backend value.
|
||||
"""
|
||||
if settings.document_repository_backend != "postgres":
|
||||
raise NotImplementedError(
|
||||
f"ComplianceRepository requires document_repository_backend=postgres, "
|
||||
f"got '{settings.document_repository_backend}'. "
|
||||
"Set DOCUMENT_REPOSITORY_BACKEND=postgres in your .env file."
|
||||
)
|
||||
return PostgresComplianceRepository(
|
||||
host=settings.postgres_host,
|
||||
port=settings.postgres_port,
|
||||
user=settings.postgres_user,
|
||||
password=settings.postgres_password,
|
||||
dbname=settings.postgres_db,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_perception_service() -> PerceptionService:
|
||||
return PerceptionService(
|
||||
|
||||
0
backend/tests/compliance/__init__.py
Normal file
0
backend/tests/compliance/__init__.py
Normal file
140
backend/tests/compliance/test_pipeline.py
Normal file
140
backend/tests/compliance/test_pipeline.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from datetime import datetime
|
||||
|
||||
from app.infrastructure.vectorstore.pass_through_reranker import PassThroughReranker
|
||||
from app.domain.retrieval.models import RetrievedChunk
|
||||
from app.domain.compliance.ports import AnalysisRecord, FindingRecord
|
||||
|
||||
|
||||
# ── helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def _make_chunk(score: float) -> RetrievedChunk:
|
||||
return RetrievedChunk(
|
||||
chunk_id="c1",
|
||||
doc_id="d1",
|
||||
doc_title="Test Doc",
|
||||
section_title="S1",
|
||||
text="some text",
|
||||
score=score,
|
||||
page_start=1,
|
||||
)
|
||||
|
||||
|
||||
def _make_mock_client(content: str = '{"status":"ok","title":"T","desc":"D","clause_ref":"A1"}'):
|
||||
client = MagicMock()
|
||||
response = MagicMock()
|
||||
response.is_success = True
|
||||
response.content = content
|
||||
client.chat.return_value = response
|
||||
return client
|
||||
|
||||
|
||||
def _make_mock_retrieval():
|
||||
svc = MagicMock()
|
||||
svc.retrieve.return_value = []
|
||||
return svc
|
||||
|
||||
|
||||
# ── existing tests ────────────────────────────────────────────────────────────
|
||||
|
||||
def test_pass_through_returns_top_k():
|
||||
reranker = PassThroughReranker()
|
||||
chunks = [_make_chunk(0.9), _make_chunk(0.8), _make_chunk(0.7)]
|
||||
result = reranker.rerank(query="test", chunks=chunks, top_k=2)
|
||||
assert len(result) == 2
|
||||
assert result[0].score == 0.9
|
||||
|
||||
|
||||
def test_pass_through_returns_all_when_top_k_exceeds():
|
||||
reranker = PassThroughReranker()
|
||||
chunks = [_make_chunk(0.5)]
|
||||
result = reranker.rerank(query="test", chunks=chunks, top_k=10)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
# ── new tests ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_process_single_clause_returns_finding():
|
||||
from app.application.compliance.pipeline import process_single_clause
|
||||
client = _make_mock_client()
|
||||
svc = _make_mock_retrieval()
|
||||
result = process_single_clause("test clause", 0, svc, client)
|
||||
assert result["finding"] is not None
|
||||
assert result["index"] == 0
|
||||
assert result["chunks"] == []
|
||||
|
||||
|
||||
def test_run_clauses_parallel_runs_all():
|
||||
from app.application.compliance.pipeline import run_clauses_parallel
|
||||
client = _make_mock_client()
|
||||
svc = _make_mock_retrieval()
|
||||
clauses = ["clause one", "clause two", "clause three"]
|
||||
results = asyncio.run(run_clauses_parallel(clauses, svc, client))
|
||||
assert len(results) == 3
|
||||
assert all(r["index"] == i for i, r in enumerate(results))
|
||||
|
||||
|
||||
def test_run_clauses_parallel_handles_clause_failure():
|
||||
from app.application.compliance.pipeline import run_clauses_parallel
|
||||
svc = _make_mock_retrieval()
|
||||
bad_client = MagicMock()
|
||||
bad_client.chat.side_effect = RuntimeError("LLM exploded")
|
||||
results = asyncio.run(run_clauses_parallel(
|
||||
["clause one", "clause two"], svc, bad_client
|
||||
))
|
||||
assert len(results) == 2
|
||||
assert all(r["finding"] is None for r in results)
|
||||
assert all(r["chunks"] == [] for r in results)
|
||||
|
||||
|
||||
# ── helpers for new tests ─────────────────────────────────────────────────────
|
||||
|
||||
def _sample_analysis() -> AnalysisRecord:
|
||||
return AnalysisRecord(
|
||||
id="a1", created_at=datetime(2026, 6, 8), created_by="u",
|
||||
doc_name="doc.pdf", standard_name="EU AI Act",
|
||||
risk_score=72, conclusion="Gaps found.", actions=[], para_text="para",
|
||||
highlight_terms=[], findings=[],
|
||||
)
|
||||
|
||||
|
||||
def _sample_finding(status: str = "risk") -> FindingRecord:
|
||||
return FindingRecord(
|
||||
id="f1", analysis_id="a1", seq=0,
|
||||
title="Missing CSMS", description="No CSMS certification.",
|
||||
status=status, clause_ref="Art.9.1",
|
||||
)
|
||||
|
||||
|
||||
# ── new tests ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_build_finding_context_contains_required_fields():
|
||||
from app.application.compliance.pipeline import build_finding_context
|
||||
ctx = build_finding_context(_sample_finding(), _sample_analysis())
|
||||
assert "doc.pdf" in ctx
|
||||
assert "EU AI Act" in ctx
|
||||
assert "Missing CSMS" in ctx
|
||||
assert "Art.9.1" in ctx
|
||||
|
||||
|
||||
def test_generate_suggestions_returns_three_questions():
|
||||
from app.application.compliance.pipeline import generate_suggestions
|
||||
client = _make_mock_client(
|
||||
'{"questions": ["Q1?", "Q2?", "Q3?"]}'
|
||||
)
|
||||
questions = generate_suggestions(_sample_finding("risk"), _sample_analysis(), client)
|
||||
assert len(questions) == 3
|
||||
assert all(isinstance(q, str) for q in questions)
|
||||
|
||||
|
||||
def test_generate_suggestions_falls_back_on_error():
|
||||
from app.application.compliance.pipeline import generate_suggestions
|
||||
bad_client = MagicMock()
|
||||
bad_resp = MagicMock()
|
||||
bad_resp.is_success = False
|
||||
bad_client.chat.return_value = bad_resp
|
||||
questions = generate_suggestions(_sample_finding(), _sample_analysis(), bad_client)
|
||||
assert len(questions) == 3 # fallback always returns 3
|
||||
98
backend/tests/compliance/test_repository.py
Normal file
98
backend/tests/compliance/test_repository.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from datetime import datetime
|
||||
from app.domain.compliance.ports import (
|
||||
AnalysisRecord,
|
||||
FindingRecord,
|
||||
ComplianceRepository,
|
||||
)
|
||||
|
||||
|
||||
def _mock_pool():
|
||||
"""Return a mock psycopg2 ThreadedConnectionPool."""
|
||||
conn = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor.__enter__ = MagicMock(return_value=cursor)
|
||||
cursor.__exit__ = MagicMock(return_value=False)
|
||||
conn.cursor.return_value = cursor
|
||||
pool = MagicMock()
|
||||
pool.getconn.return_value = conn
|
||||
return pool, conn, cursor
|
||||
|
||||
|
||||
@patch("app.infrastructure.compliance.repository.psycopg2.pool.ThreadedConnectionPool")
|
||||
def test_save_analysis_returns_uuid(mock_pool_cls):
|
||||
from app.infrastructure.compliance.repository import PostgresComplianceRepository
|
||||
pool, conn, cursor = _mock_pool()
|
||||
mock_pool_cls.return_value = pool
|
||||
cursor.fetchone.return_value = {"id": "abc-123"}
|
||||
|
||||
repo = PostgresComplianceRepository(
|
||||
host="localhost", port=5432, user="u", password="p", dbname="db"
|
||||
)
|
||||
record = AnalysisRecord(
|
||||
id="", created_at=datetime.utcnow(), created_by="user1",
|
||||
doc_name="doc.pdf", standard_name="EU AI Act",
|
||||
risk_score=50, conclusion="OK", actions=[], para_text="p",
|
||||
highlight_terms=[], findings=[],
|
||||
)
|
||||
result = repo.save_analysis(record)
|
||||
assert result == "abc-123"
|
||||
|
||||
|
||||
def test_analysis_record_construction():
|
||||
record = AnalysisRecord(
|
||||
id="",
|
||||
created_at=datetime.utcnow(),
|
||||
created_by="user1",
|
||||
doc_name="test.pdf",
|
||||
standard_name="EU AI Act",
|
||||
risk_score=72,
|
||||
conclusion="Several gaps found.",
|
||||
actions=[{"label": "Fix", "value": "Update docs"}],
|
||||
para_text="The system shall...",
|
||||
highlight_terms=["CSMS", "ISO 21434"],
|
||||
findings=[
|
||||
FindingRecord(
|
||||
id="",
|
||||
analysis_id="",
|
||||
seq=0,
|
||||
title="Missing CSMS",
|
||||
description="No CSMS certification found.",
|
||||
status="risk",
|
||||
clause_ref="Art.9.1",
|
||||
)
|
||||
],
|
||||
)
|
||||
assert record.doc_name == "test.pdf"
|
||||
assert len(record.findings) == 1
|
||||
assert record.findings[0].status == "risk"
|
||||
|
||||
|
||||
def test_compliance_repository_is_abstract():
|
||||
import inspect
|
||||
assert inspect.isabstract(ComplianceRepository)
|
||||
|
||||
|
||||
def test_generate_docx_returns_bytes():
|
||||
from app.infrastructure.compliance.docx_export import generate_docx
|
||||
record = AnalysisRecord(
|
||||
id="test-id", created_at=datetime(2026, 6, 8), created_by="user1",
|
||||
doc_name="test.pdf", standard_name="EU AI Act",
|
||||
risk_score=72, conclusion="Several gaps found.",
|
||||
actions=[{"label": "Fix", "value": "Update CSMS docs"}],
|
||||
para_text="The system shall implement CSMS.",
|
||||
highlight_terms=["CSMS"],
|
||||
findings=[
|
||||
FindingRecord(
|
||||
id="f1", analysis_id="test-id", seq=0,
|
||||
title="Missing CSMS", description="No CSMS cert.",
|
||||
status="risk", clause_ref="Art.9.1",
|
||||
)
|
||||
],
|
||||
)
|
||||
data = generate_docx(record)
|
||||
assert isinstance(data, bytes)
|
||||
assert len(data) > 1000 # DOCX is at minimum a ZIP with ~1 KB overhead
|
||||
# Verify it's a valid ZIP (DOCX = ZIP container)
|
||||
import zipfile, io
|
||||
assert zipfile.is_zipfile(io.BytesIO(data))
|
||||
Reference in New Issue
Block a user