update for 1. 优化 2.中英切换

This commit is contained in:
2026-06-10 11:10:36 +08:00
parent e7963b267e
commit 9212747e1b
42 changed files with 7866 additions and 278 deletions

View File

@@ -85,10 +85,9 @@ async def analyze_stream(
Events: stage | source | finding | done | error
"""
from app.application.compliance.pipeline import (
check_clause_compliance,
extract_text_from_doc_id,
extract_text_from_file,
retrieve_for_clause,
run_clauses_parallel,
split_into_clauses,
synthesize_conclusion,
)
@@ -136,22 +135,28 @@ async def analyze_stream(
await asyncio.sleep(0)
clauses: list[str] = await asyncio.to_thread(split_into_clauses, para_text, client)
# ── Stage 3: retrieve + gap check per clause ──────────────────
# ── Stage 3: retrieve + gap check (parallel across all clauses) ────────────
findings: list[dict] = []
for i, clause in enumerate(clauses):
yield _sse({
"type": "stage",
"stage": "analyzing",
"label": f"Analyzing clause {i + 1}/{len(clauses)}",
})
await asyncio.sleep(0)
yield _sse({
"type": "stage",
"stage": "analyzing",
"label": f"Analyzing {len(clauses)} clauses in parallel…",
})
await asyncio.sleep(0)
chunks = await asyncio.to_thread(
retrieve_for_clause, clause, retrieval_service, 5, domains or None
)
clause_results = await run_clauses_parallel(
clauses, retrieval_service, client,
top_k=5,
domains=domains or None,
)
# Emit source events
for res in clause_results:
i = res["index"]
chunks = res["chunks"]
finding = res["finding"]
# Emit source events for this clause
for chunk in chunks[:3]:
yield _sse({
"type": "source",
@@ -161,12 +166,11 @@ async def analyze_stream(
"status": "retrieved",
"full_content": (getattr(chunk, "text", "") or "")[:300],
})
await asyncio.sleep(0)
finding = await asyncio.to_thread(check_clause_compliance, clause, chunks, client)
if finding:
findings.append(finding)
yield _sse({"type": "finding", **finding})
await asyncio.sleep(0)
# ── Stage 4: synthesize conclusion ────────────────────────────
@@ -178,6 +182,45 @@ async def analyze_stream(
)
yield _sse({"type": "done", **conclusion_data})
# Auto-save analysis to database
try:
from app.shared.bootstrap import get_compliance_repository
from app.domain.compliance.ports import AnalysisRecord, FindingRecord
from datetime import datetime
repo = get_compliance_repository()
finding_records = [
FindingRecord(
id="",
analysis_id="",
seq=i,
title=f.get("title", ""),
description=f.get("desc", ""),
status=f.get("status", "ok"),
clause_ref=f.get("clause_ref"),
)
for i, f in enumerate(findings)
]
record = AnalysisRecord(
id="",
created_at=datetime.utcnow(),
created_by=current_user.username if hasattr(current_user, "username") else None,
doc_name=file_name or (title or "Pasted text"),
standard_name=title or "",
risk_score=conclusion_data.get("risk_score", 0),
conclusion=conclusion_data.get("conclusion", ""),
actions=conclusion_data.get("actions", []),
para_text=conclusion_data.get("para_text", ""),
highlight_terms=conclusion_data.get("highlight_terms", []),
findings=finding_records,
)
analysis_id = await asyncio.to_thread(repo.save_analysis, record)
yield _sse({"type": "saved", "analysis_id": analysis_id})
except NotImplementedError:
pass # No postgres backend configured — skip saving
except Exception as exc:
logger.warning("Failed to auto-save compliance analysis: {}", exc)
except Exception as exc:
logger.exception("analyze-stream pipeline error")
yield _sse({"type": "error", "text": str(exc)})
@@ -225,3 +268,226 @@ async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
@router.get("/history")
async def list_history(
limit: int = 20,
offset: int = 0,
current_user: UserClaims = Depends(get_current_user),
):
"""Return paginated list of saved compliance analyses (newest first)."""
from app.shared.bootstrap import get_compliance_repository
try:
repo = get_compliance_repository()
records = await asyncio.to_thread(repo.list_analyses, limit, offset)
return [
{
"id": r.id,
"created_at": r.created_at.isoformat(),
"created_by": r.created_by,
"doc_name": r.doc_name,
"standard_name": r.standard_name,
"risk_score": r.risk_score,
"finding_count": len(r.findings),
}
for r in records
]
except NotImplementedError:
return []
@router.get("/history/{analysis_id}")
async def get_history_item(
analysis_id: str,
current_user: UserClaims = Depends(get_current_user),
):
"""Return full analysis record including findings."""
from app.shared.bootstrap import get_compliance_repository
from fastapi import HTTPException
repo = get_compliance_repository()
record = await asyncio.to_thread(repo.get_analysis, analysis_id)
if not record:
raise HTTPException(status_code=404, detail="Analysis not found")
return {
"id": record.id,
"created_at": record.created_at.isoformat(),
"created_by": record.created_by,
"doc_name": record.doc_name,
"standard_name": record.standard_name,
"risk_score": record.risk_score,
"conclusion": record.conclusion,
"actions": record.actions,
"para_text": record.para_text,
"highlight_terms": record.highlight_terms,
"findings": [
{
"id": f.id,
"seq": f.seq,
"title": f.title,
"description": f.description,
"status": f.status,
"clause_ref": f.clause_ref,
}
for f in record.findings
],
}
@router.delete("/history/{analysis_id}", status_code=204)
async def delete_history_item(
analysis_id: str,
current_user: UserClaims = Depends(get_current_user),
):
"""Delete a saved analysis (cascade removes findings and chat messages)."""
from app.shared.bootstrap import get_compliance_repository
repo = get_compliance_repository()
await asyncio.to_thread(repo.delete_analysis, analysis_id)
@router.get("/history/{analysis_id}/download")
async def download_history_docx(
analysis_id: str,
current_user: UserClaims = Depends(get_current_user),
):
"""Return a DOCX compliance report for the given analysis."""
from app.shared.bootstrap import get_compliance_repository
from app.infrastructure.compliance.docx_export import generate_docx
from fastapi import HTTPException
from fastapi.responses import Response
repo = get_compliance_repository()
record = await asyncio.to_thread(repo.get_analysis, analysis_id)
if not record:
raise HTTPException(status_code=404, detail="Analysis not found")
docx_bytes = await asyncio.to_thread(generate_docx, record)
safe_name = (record.doc_name or "report").replace(" ", "_")[:50]
filename = f"compliance_{safe_name}_{record.created_at.strftime('%Y%m%d')}.docx"
return Response(
content=docx_bytes,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
@router.get("/analyses/{analysis_id}/findings/{finding_id}/chat")
async def get_finding_chat_history(
analysis_id: str,
finding_id: str,
current_user: UserClaims = Depends(get_current_user),
):
"""Return persisted chat messages for a finding thread, oldest first."""
from app.shared.bootstrap import get_compliance_repository
try:
repo = get_compliance_repository()
messages = await asyncio.to_thread(repo.get_messages, finding_id)
return messages
except NotImplementedError:
return []
@router.post("/analyses/{analysis_id}/findings/{finding_id}/suggestions")
async def get_finding_suggestions(
analysis_id: str,
finding_id: str,
current_user: UserClaims = Depends(get_current_user),
):
"""Generate 3 LLM-powered follow-up question suggestions for a finding."""
from app.application.compliance.pipeline import generate_suggestions
from app.shared.bootstrap import get_compliance_repository
from app.services.llm.llm_factory import get_llm_client
from fastapi import HTTPException
repo = get_compliance_repository()
analysis = await asyncio.to_thread(repo.get_analysis, analysis_id)
if not analysis:
raise HTTPException(status_code=404, detail="Analysis not found")
finding = next((f for f in analysis.findings if f.id == finding_id), None)
if not finding:
raise HTTPException(status_code=404, detail="Finding not found")
client = get_llm_client(provider=settings.llm_provider, model=settings.llm_model)
questions = await asyncio.to_thread(generate_suggestions, finding, analysis, client)
return {"questions": questions}
@router.post("/analyses/{analysis_id}/findings/{finding_id}/chat")
async def finding_chat(
analysis_id: str,
finding_id: str,
request: ComplianceChatRequest,
current_user: UserClaims = Depends(get_current_user),
):
"""Stream a grounded chat response for a specific finding.
Loads the finding and analysis from DB to build grounded context.
Persists both user message and assistant response to finding_chat_messages.
"""
from app.application.compliance.pipeline import build_finding_context
from app.shared.bootstrap import get_compliance_repository
from fastapi import HTTPException
repo = get_compliance_repository()
analysis = await asyncio.to_thread(repo.get_analysis, analysis_id)
if not analysis:
raise HTTPException(status_code=404, detail="Analysis not found")
finding = next((f for f in analysis.findings if f.id == finding_id), None)
if not finding:
raise HTTPException(status_code=404, detail="Finding not found")
# Persist user message
await asyncio.to_thread(
repo.save_message, analysis_id, finding_id, "user", request.query
)
# Build message history (last 10 messages = 5 turns)
history = await asyncio.to_thread(repo.get_messages, finding_id)
history_messages = [
{"role": m["role"], "content": m["content"]}
for m in history[-10:]
]
# Build grounded system context
system_context = build_finding_context(finding, analysis)
full_query = f"[Compliance Finding Context]\n{system_context}\n\nUser question: {request.query}"
assistant_buffer: list[str] = []
async def generate() -> AsyncGenerator[str, None]:
try:
_, event_stream = get_agent_conversation_service().stream_chat(
query=full_query,
top_k=5,
prompt_template="compliance_qa",
)
for event in event_stream:
event_type = event.get("event", "")
if event_type == "content":
text = event.get("data", "")
if text:
assistant_buffer.append(text)
yield _sse({"type": "chunk", "text": text})
elif event_type == "done":
yield _sse({"type": "done"})
await asyncio.sleep(0)
except Exception as exc:
logger.exception("finding_chat stream error")
yield _sse({"type": "error", "text": str(exc)})
finally:
# Persist assistant response after stream completes
full_response = "".join(assistant_buffer)
if full_response:
try:
await asyncio.to_thread(
repo.save_message, analysis_id, finding_id, "assistant", full_response
)
except Exception as exc:
logger.warning("Failed to persist assistant message: {}", exc)
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)

View File

@@ -5,6 +5,7 @@ All functions are synchronous — call them via asyncio.to_thread() in async SSE
from __future__ import annotations
import asyncio
import json
import os
import re
@@ -12,10 +13,20 @@ import tempfile
from typing import TYPE_CHECKING
from loguru import logger
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
# Shared retry policy for LLM calls: 3 attempts, exponential back-off 14 s.
_llm_retry = retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=4),
retry=retry_if_exception_type((ValueError, TimeoutError, ConnectionError)),
reraise=True,
)
if TYPE_CHECKING:
from app.application.knowledge import KnowledgeRetrievalService
from app.domain.retrieval import RetrievedChunk
from app.domain.compliance.ports import AnalysisRecord, FindingRecord
from app.services.llm.base_client import BaseLLMClient
@@ -109,17 +120,67 @@ def retrieve_for_clause(
return retrieval_service.retrieve(query=clause, top_k=top_k, filters=domains)
def process_single_clause(
clause: str,
index: int,
retrieval_service: "KnowledgeRetrievalService",
client: "BaseLLMClient",
top_k: int = 5,
domains: str | None = None,
) -> dict:
"""Process one clause: retrieve relevant regulations then check compliance.
Returns a dict with keys: index, chunks, finding (may be None on LLM failure).
Designed to run inside asyncio.to_thread() for parallel execution.
"""
chunks = retrieve_for_clause(clause, retrieval_service, top_k, domains)
finding = check_clause_compliance(clause, chunks, client)
return {"index": index, "chunks": chunks, "finding": finding}
async def run_clauses_parallel(
clauses: list[str],
retrieval_service: "KnowledgeRetrievalService",
client: "BaseLLMClient",
top_k: int = 5,
domains: str | None = None,
) -> list[dict]:
"""Run all clauses through retrieve+gap-check in parallel.
Results are returned in the original clause order even though processing
is concurrent. Exceptions in individual clauses are caught and returned as
dicts with finding=None so the stream continues for remaining clauses.
Both retrieval_service and client must be thread-safe — they are shared
across all asyncio.to_thread() calls without locking.
"""
tasks = [
asyncio.to_thread(
process_single_clause,
clause, i, retrieval_service, client, top_k, domains,
)
for i, clause in enumerate(clauses)
]
raw = await asyncio.gather(*tasks, return_exceptions=True)
results = []
for i, r in enumerate(raw):
if isinstance(r, Exception):
logger.warning("Clause {} processing failed: {}", i, r)
results.append({"index": i, "chunks": [], "finding": None})
else:
results.append(r)
return results
def check_clause_compliance(
clause: str,
chunks: list["RetrievedChunk"],
client: "BaseLLMClient",
) -> dict | None:
if not chunks:
return None
reg_context = "\n".join(
f"[{i+1}] {c.doc_title} {c.section_title or ''}: {c.text[:300]}"
for i, c in enumerate(chunks[:5])
)
) if chunks else "(no regulatory context retrieved)"
prompt = (
"You are a compliance expert. Judge whether the following business clause "
"complies with the retrieved regulations.\n\n"
@@ -135,9 +196,19 @@ def check_clause_compliance(
"status: ok=compliant, warn=gap exists, risk=critical/missing\n"
"Return ONLY the JSON object."
)
response = client.chat([{"role": "user", "content": prompt}], max_tokens=500)
if not response.is_success:
def _do_check():
resp = client.chat([{"role": "user", "content": prompt}], max_tokens=500)
if not resp.is_success:
raise ValueError("LLM returned non-success for gap check")
return resp
try:
response = _llm_retry(_do_check)()
except Exception as exc:
logger.warning("check_clause_compliance LLM call failed after retries: {}", exc)
return None
try:
result = _extract_json(response.content)
if isinstance(result, dict) and "status" in result:
@@ -182,12 +253,11 @@ def synthesize_conclusion(
' {"label": "Priority", "value": "High/Medium/Low", "risk": true}\n'
' ],\n'
' "risk_score": 0-100 (integer, higher=riskier),\n'
' "highlight_terms": ["Key terms to highlight, max 10 terms"],\n'
' "highlight_terms": ["term1", "term2"], // up to 10 key technical/legal terms actually present in the text\n'
' "para_text": "Original text or summary (max 600 chars)"\n'
"}\n"
"Return ONLY the JSON object."
)
response = client.chat([{"role": "user", "content": prompt}], max_tokens=1200)
fallback = {
"conclusion": "Compliance analysis complete. Review findings and create remediation plan.",
"actions": [
@@ -198,8 +268,19 @@ def synthesize_conclusion(
"highlight_terms": [],
"para_text": para_text[:800],
}
if not response.is_success:
def _do_synthesize():
resp = client.chat([{"role": "user", "content": prompt}], max_tokens=1200)
if not resp.is_success:
raise ValueError("LLM returned non-success for synthesis")
return resp
try:
response = _llm_retry(_do_synthesize)()
except Exception as exc:
logger.warning("synthesize_conclusion LLM call failed after retries: {}", exc)
return fallback
try:
result = _extract_json(response.content)
if isinstance(result, dict):
@@ -212,4 +293,78 @@ def synthesize_conclusion(
}
except (ValueError, TypeError) as exc:
logger.warning("Conclusion synthesis JSON parse failed: {}", exc)
return fallback
return fallback
_SUGGESTION_FOCUS = {
"risk": "Focus on remediation steps, required certifications, and timeline to resolve.",
"warn": "Focus on identifying the specific compliance gap and how to close it.",
"ok": "Focus on maintaining compliance evidence and monitoring future changes.",
}
_SUGGESTION_FALLBACK = {
"risk": [
"What specific certifications or documents are required to remediate this finding?",
"What is the typical remediation timeline for this type of non-compliance?",
"Which regulation clause defines the exact requirement?",
],
"warn": [
"What is the exact gap between the current state and the requirement?",
"What evidence would demonstrate partial compliance?",
"Which regulation clause applies to this warning?",
],
"ok": [
"What documentation should be maintained to evidence this compliance?",
"How should this area be monitored as regulations evolve?",
"Are there related clauses that may affect this compliant area?",
],
}
def build_finding_context(finding: "FindingRecord", analysis: "AnalysisRecord") -> str:
"""Build a grounded system context string for a finding chat thread.
Combines finding details with analysis metadata so the LLM has full
context without relying on the frontend to pass segment_context.
"""
return (
f"Document: {analysis.doc_name}\n"
f"Standard: {analysis.standard_name}\n"
f"Finding [{finding.seq + 1}]: {finding.title}\n"
f"Status: {finding.status}\n"
f"Clause reference: {finding.clause_ref or 'N/A'}\n"
f"Description: {finding.description}\n"
f"Overall conclusion: {analysis.conclusion}\n"
)
def generate_suggestions(
finding: "FindingRecord",
analysis: "AnalysisRecord",
client: "BaseLLMClient",
) -> list[str]:
"""Generate 3 context-aware follow-up questions for a finding chat thread.
Returns exactly 3 question strings. Falls back to static templates on error.
"""
fallback = _SUGGESTION_FALLBACK.get(finding.status, _SUGGESTION_FALLBACK["warn"])
context = build_finding_context(finding, analysis)
focus = _SUGGESTION_FOCUS.get(finding.status, _SUGGESTION_FOCUS["warn"])
prompt = (
f"{context}\n\n"
f"Task: {focus}\n"
"Generate exactly 3 concise follow-up questions a compliance analyst would ask.\n"
'Return JSON: {"questions": ["question 1", "question 2", "question 3"]}\n'
"Return ONLY the JSON object."
)
response = client.chat([{"role": "user", "content": prompt}], max_tokens=300)
if not response.is_success:
return fallback
try:
result = _extract_json(response.content)
questions = result.get("questions", [])
if isinstance(questions, list) and len(questions) >= 3:
return [str(q) for q in questions[:3]]
except (ValueError, TypeError) as exc:
logger.warning("generate_suggestions JSON parse failed: {}", exc)
return fallback

View File

@@ -0,0 +1,66 @@
"""Domain ports for compliance history persistence."""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional
@dataclass
class FindingRecord:
"""Single finding row linked to an analysis."""
id: str
analysis_id: str
seq: int
title: str
description: str
status: str # "ok" | "warn" | "risk"
clause_ref: Optional[str] = None
@dataclass
class AnalysisRecord:
"""Full compliance analysis record with nested findings."""
id: str # UUID string; empty string means not yet persisted
created_at: datetime
created_by: Optional[str]
doc_name: str
standard_name: str
risk_score: int
conclusion: str
actions: list # list[dict] — serialised action items
para_text: str
highlight_terms: list # list[str]
findings: list[FindingRecord] = field(default_factory=list)
class ComplianceRepository(ABC):
"""Port for persisting and retrieving compliance analysis records."""
@abstractmethod
def save_analysis(self, record: AnalysisRecord) -> str:
"""Persist a new analysis record and return the assigned UUID string."""
@abstractmethod
def list_analyses(self, limit: int = 50, offset: int = 0) -> list[AnalysisRecord]:
"""Return analyses ordered by created_at DESC, without nested findings."""
@abstractmethod
def get_analysis(self, analysis_id: str) -> Optional[AnalysisRecord]:
"""Return a single analysis with all nested findings, or None."""
@abstractmethod
def delete_analysis(self, analysis_id: str) -> None:
"""Delete an analysis and all related findings and chat messages (cascade)."""
@abstractmethod
def save_message(self, analysis_id: str, finding_id: str, role: str, content: str) -> str:
"""Persist a chat message and return its UUID string."""
@abstractmethod
def get_messages(self, finding_id: str) -> list[dict]:
"""Return chat messages for a finding ordered by created_at ASC.
Each dict has keys: id, role, content, created_at (ISO string).
"""

View File

@@ -0,0 +1,101 @@
"""DOCX report generator for compliance analysis results.
Uses python-docx (already in requirements.txt). Returns raw bytes so the
caller can stream the response without writing to disk.
"""
from __future__ import annotations
from datetime import datetime, timezone
from io import BytesIO
from docx import Document
from docx.shared import Pt, RGBColor
from docx.enum.text import WD_ALIGN_PARAGRAPH
from app.domain.compliance.ports import AnalysisRecord
_STATUS_LABEL = {"ok": "Compliant", "warn": "Warning", "risk": "Non-Compliant"}
_STATUS_COLOR = {
"ok": RGBColor(0x22, 0x8B, 0x22),
"warn": RGBColor(0xFF, 0x8C, 0x00),
"risk": RGBColor(0xDC, 0x14, 0x3C),
}
def generate_docx(record: AnalysisRecord) -> bytes:
"""Generate a compliance report DOCX and return its raw bytes.
Structure:
- Cover: document name, standard, date, risk score
- Executive summary (conclusion)
- Findings table
- Recommended actions
- Footer note
"""
doc = Document()
# ── Cover ──────────────────────────────────────────────────────────────────
title_para = doc.add_heading("Compliance Analysis Report", level=0)
title_para.alignment = WD_ALIGN_PARAGRAPH.CENTER
doc.add_paragraph("")
meta_table = doc.add_table(rows=4, cols=2)
meta_table.style = "Table Grid"
labels = ["Document", "Standard", "Date", "Risk Score"]
values = [
record.doc_name,
record.standard_name,
record.created_at.strftime("%Y-%m-%d %H:%M UTC") if record.created_at else "",
f"{record.risk_score} / 100",
]
for i, (label, value) in enumerate(zip(labels, values)):
meta_table.cell(i, 0).text = label
meta_table.cell(i, 1).text = value
# ── Executive Summary ──────────────────────────────────────────────────────
doc.add_heading("Executive Summary", level=1)
doc.add_paragraph(record.conclusion)
# ── Findings ───────────────────────────────────────────────────────────────
doc.add_heading("Findings", level=1)
if record.findings:
table = doc.add_table(rows=1, cols=4)
table.style = "Table Grid"
hdr = table.rows[0].cells
for i, h in enumerate(["#", "Status", "Title", "Description / Clause"]):
hdr[i].text = h
for run in hdr[i].paragraphs[0].runs:
run.bold = True
for f in record.findings:
row = table.add_row().cells
row[0].text = str(f.seq + 1)
row[1].text = _STATUS_LABEL.get(f.status, f.status)
row[2].text = f.title
desc = f.description
if f.clause_ref:
desc += f"\n[{f.clause_ref}]"
row[3].text = desc
else:
doc.add_paragraph("No findings recorded.")
# ── Recommended Actions ────────────────────────────────────────────────────
doc.add_heading("Recommended Actions", level=1)
for i, action in enumerate(record.actions, start=1):
label = action.get("label", "Action")
value = action.get("value", "")
doc.add_paragraph(f"{i}. {label}: {value}", style="List Number")
# ── Footer note ────────────────────────────────────────────────────────────
doc.add_paragraph("")
footer = doc.add_paragraph(
f"Generated by AI Regulation Analysis System — {datetime.now(timezone.utc).strftime('%Y-%m-%d')}"
)
footer.alignment = WD_ALIGN_PARAGRAPH.CENTER
for run in footer.runs:
run.font.size = Pt(8)
run.font.color.rgb = RGBColor(0x88, 0x88, 0x88)
buf = BytesIO()
doc.save(buf)
return buf.getvalue()

View File

@@ -0,0 +1,280 @@
# backend/app/infrastructure/compliance/repository.py
"""PostgreSQL-backed compliance analysis repository.
Follows the same psycopg2 pattern as PostgresDocumentRepository:
ThreadedConnectionPool + RealDictCursor for reads, _ensure_schema on init.
"""
from __future__ import annotations
import json
from contextlib import contextmanager
from datetime import datetime
from typing import Optional
import psycopg2
import psycopg2.extras
import psycopg2.pool
from loguru import logger
from app.domain.compliance.ports import (
AnalysisRecord,
ComplianceRepository,
FindingRecord,
)
class PostgresComplianceRepository(ComplianceRepository):
"""Stores compliance analyses, findings, and finding chat messages in PostgreSQL."""
def __init__(
self,
host: str,
port: int,
user: str,
password: str,
dbname: str,
minconn: int = 1,
maxconn: int = 5,
) -> None:
self._pool = psycopg2.pool.ThreadedConnectionPool(
minconn=minconn,
maxconn=maxconn,
host=host,
port=port,
user=user,
password=password,
dbname=dbname,
)
self._ensure_schema()
@contextmanager
def _conn(self):
conn = self._pool.getconn()
try:
yield conn
finally:
self._pool.putconn(conn)
def _ensure_schema(self) -> None:
"""Create tables if they do not exist."""
with self._conn() as conn:
with conn.cursor() as cur:
cur.execute("""
CREATE TABLE IF NOT EXISTS compliance_analyses (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
created_by VARCHAR(255),
doc_name VARCHAR(500),
standard_name VARCHAR(500),
risk_score INTEGER,
conclusion TEXT,
actions JSONB,
para_text TEXT,
highlight_terms JSONB
);
""")
cur.execute("""
CREATE TABLE IF NOT EXISTS compliance_findings (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
analysis_id UUID NOT NULL REFERENCES compliance_analyses(id) ON DELETE CASCADE,
seq INTEGER NOT NULL,
title VARCHAR(500),
description TEXT,
status VARCHAR(50),
clause_ref VARCHAR(200)
);
""")
cur.execute("""
CREATE TABLE IF NOT EXISTS finding_chat_messages (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
analysis_id UUID NOT NULL REFERENCES compliance_analyses(id) ON DELETE CASCADE,
finding_id UUID NOT NULL REFERENCES compliance_findings(id) ON DELETE CASCADE,
role VARCHAR(20) NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
conn.commit()
def save_analysis(self, record: AnalysisRecord) -> str:
"""Insert analysis + findings; return the new analysis UUID."""
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
"""
INSERT INTO compliance_analyses
(created_by, doc_name, standard_name, risk_score,
conclusion, actions, para_text, highlight_terms)
VALUES
(%(created_by)s, %(doc_name)s, %(standard_name)s, %(risk_score)s,
%(conclusion)s, %(actions)s, %(para_text)s, %(highlight_terms)s)
RETURNING id
""",
{
"created_by": record.created_by,
"doc_name": record.doc_name,
"standard_name": record.standard_name,
"risk_score": record.risk_score,
"conclusion": record.conclusion,
"actions": json.dumps(record.actions, ensure_ascii=False),
"para_text": record.para_text,
"highlight_terms": json.dumps(record.highlight_terms, ensure_ascii=False),
},
)
row = cur.fetchone()
analysis_id = str(row["id"])
if record.findings:
with conn.cursor() as cur:
for f in record.findings:
cur.execute(
"""
INSERT INTO compliance_findings
(analysis_id, seq, title, description, status, clause_ref)
VALUES
(%(analysis_id)s, %(seq)s, %(title)s, %(desc)s, %(status)s, %(clause_ref)s)
""",
{
"analysis_id": analysis_id,
"seq": f.seq,
"title": f.title,
"desc": f.description,
"status": f.status,
"clause_ref": f.clause_ref,
},
)
conn.commit()
return analysis_id
def list_analyses(self, limit: int = 50, offset: int = 0) -> list[AnalysisRecord]:
"""Return analyses without nested findings, ordered newest first."""
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
"""
SELECT id, created_at, created_by, doc_name, standard_name,
risk_score, conclusion, actions, para_text, highlight_terms
FROM compliance_analyses
ORDER BY created_at DESC
LIMIT %(limit)s OFFSET %(offset)s
""",
{"limit": limit, "offset": offset},
)
rows = cur.fetchall()
return [self._row_to_record(dict(r)) for r in rows]
def get_analysis(self, analysis_id: str) -> Optional[AnalysisRecord]:
"""Return analysis with nested findings list."""
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
"SELECT * FROM compliance_analyses WHERE id = %(id)s",
{"id": analysis_id},
)
row = cur.fetchone()
if not row:
return None
record = self._row_to_record(dict(row))
cur.execute(
"""
SELECT id, analysis_id, seq, title, description, status, clause_ref
FROM compliance_findings
WHERE analysis_id = %(id)s
ORDER BY seq
""",
{"id": analysis_id},
)
findings = [
FindingRecord(
id=str(r["id"]),
analysis_id=str(r["analysis_id"]),
seq=r["seq"],
title=r["title"] or "",
description=r["description"] or "",
status=r["status"] or "ok",
clause_ref=r["clause_ref"],
)
for r in cur.fetchall()
]
record.findings = findings
return record
def delete_analysis(self, analysis_id: str) -> None:
"""Delete analysis; findings and chat messages cascade automatically."""
with self._conn() as conn:
with conn.cursor() as cur:
cur.execute(
"DELETE FROM compliance_analyses WHERE id = %(id)s",
{"id": analysis_id},
)
conn.commit()
def save_message(self, analysis_id: str, finding_id: str, role: str, content: str) -> str:
"""Persist a chat message; return its UUID."""
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
"""
INSERT INTO finding_chat_messages
(analysis_id, finding_id, role, content)
VALUES
(%(analysis_id)s, %(finding_id)s, %(role)s, %(content)s)
RETURNING id
""",
{
"analysis_id": analysis_id,
"finding_id": finding_id,
"role": role,
"content": content,
},
)
row = cur.fetchone()
conn.commit()
return str(row["id"])
def get_messages(self, finding_id: str) -> list[dict]:
"""Return messages for a finding, oldest first."""
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
"""
SELECT id, role, content, created_at
FROM finding_chat_messages
WHERE finding_id = %(finding_id)s
ORDER BY created_at ASC
""",
{"finding_id": finding_id},
)
rows = cur.fetchall()
return [
{
"id": str(r["id"]),
"role": r["role"],
"content": r["content"],
"created_at": r["created_at"].isoformat() if r["created_at"] else "",
}
for r in rows
]
def _row_to_record(self, row: dict) -> AnalysisRecord:
"""Convert a RealDictCursor row to an AnalysisRecord (no findings)."""
actions = row.get("actions") or []
if isinstance(actions, str):
actions = json.loads(actions)
highlight_terms = row.get("highlight_terms") or []
if isinstance(highlight_terms, str):
highlight_terms = json.loads(highlight_terms)
return AnalysisRecord(
id=str(row["id"]),
created_at=row["created_at"] if isinstance(row["created_at"], datetime) else datetime.utcnow(),
created_by=row.get("created_by"),
doc_name=row.get("doc_name") or "",
standard_name=row.get("standard_name") or "",
risk_score=int(row.get("risk_score") or 0),
conclusion=row.get("conclusion") or "",
actions=actions,
para_text=row.get("para_text") or "",
highlight_terms=highlight_terms,
findings=[],
)

View File

@@ -0,0 +1,21 @@
"""No-op reranker stub.
Returns the original candidate list sliced to top_k.
Replace with CrossEncoderReranker when a local cross-encoder model is available.
"""
from __future__ import annotations
from app.domain.retrieval.models import RetrievedChunk
from app.domain.retrieval.ports import Reranker
class PassThroughReranker(Reranker):
"""Pass-through reranker that preserves original retrieval order.
Acts as a placeholder for future cross-encoder reranking (e.g. ms-marco-MiniLM).
Wire via bootstrap.get_compliance_reranker() when ready to swap.
"""
def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int) -> list[RetrievedChunk]:
"""Return the first top_k chunks without reordering."""
return chunks[:top_k]

View File

@@ -40,6 +40,8 @@ from app.infrastructure.vectorstore.cross_encoder_reranker import OpenAICompatib
from app.infrastructure.vectorstore.dense_retriever import DenseRetriever
from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex
from app.services.llm.llm_factory import LLMFactory
from app.domain.compliance.ports import ComplianceRepository
from app.infrastructure.compliance.repository import PostgresComplianceRepository
# Keep shared wiring centralized so dependency construction remains consistent.
@@ -311,6 +313,28 @@ def get_event_store() -> BaseEventStore:
return MockEventStore()
@lru_cache
def get_compliance_repository() -> ComplianceRepository:
"""Return the compliance analysis repository.
Requires document_repository_backend=postgres and valid postgres_* settings.
Raises NotImplementedError for any other backend value.
"""
if settings.document_repository_backend != "postgres":
raise NotImplementedError(
f"ComplianceRepository requires document_repository_backend=postgres, "
f"got '{settings.document_repository_backend}'. "
"Set DOCUMENT_REPOSITORY_BACKEND=postgres in your .env file."
)
return PostgresComplianceRepository(
host=settings.postgres_host,
port=settings.postgres_port,
user=settings.postgres_user,
password=settings.postgres_password,
dbname=settings.postgres_db,
)
@lru_cache
def get_perception_service() -> PerceptionService:
return PerceptionService(