# backend/app/infrastructure/compliance/repository.py """PostgreSQL-backed compliance analysis repository. Follows the same psycopg2 pattern as PostgresDocumentRepository: ThreadedConnectionPool + RealDictCursor for reads, _ensure_schema on init. """ from __future__ import annotations import json from contextlib import contextmanager from datetime import datetime from typing import Optional import psycopg2 import psycopg2.extras import psycopg2.pool from loguru import logger from app.domain.compliance.ports import ( AnalysisRecord, ComplianceRepository, FindingRecord, ) class PostgresComplianceRepository(ComplianceRepository): """Stores compliance analyses, findings, and finding chat messages in PostgreSQL.""" def __init__( self, host: str, port: int, user: str, password: str, dbname: str, minconn: int = 1, maxconn: int = 5, ) -> None: self._pool = psycopg2.pool.ThreadedConnectionPool( minconn=minconn, maxconn=maxconn, host=host, port=port, user=user, password=password, dbname=dbname, ) self._ensure_schema() @contextmanager def _conn(self): conn = self._pool.getconn() try: yield conn finally: self._pool.putconn(conn) def _ensure_schema(self) -> None: """Create tables if they do not exist.""" with self._conn() as conn: with conn.cursor() as cur: cur.execute(""" CREATE TABLE IF NOT EXISTS compliance_analyses ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), created_at TIMESTAMPTZ NOT NULL DEFAULT now(), created_by VARCHAR(255), doc_name VARCHAR(500), standard_name VARCHAR(500), risk_score INTEGER, conclusion TEXT, actions JSONB, para_text TEXT, highlight_terms JSONB ); """) cur.execute(""" CREATE TABLE IF NOT EXISTS compliance_findings ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), analysis_id UUID NOT NULL REFERENCES compliance_analyses(id) ON DELETE CASCADE, seq INTEGER NOT NULL, title VARCHAR(500), description TEXT, status VARCHAR(50), clause_ref VARCHAR(200) ); """) cur.execute(""" CREATE TABLE IF NOT EXISTS finding_chat_messages ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), analysis_id UUID NOT NULL REFERENCES compliance_analyses(id) ON DELETE CASCADE, finding_id UUID NOT NULL REFERENCES compliance_findings(id) ON DELETE CASCADE, role VARCHAR(20) NOT NULL, content TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT now() ); """) conn.commit() def save_analysis(self, record: AnalysisRecord) -> str: """Insert analysis + findings; return the new analysis UUID.""" with self._conn() as conn: with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: cur.execute( """ INSERT INTO compliance_analyses (created_by, doc_name, standard_name, risk_score, conclusion, actions, para_text, highlight_terms) VALUES (%(created_by)s, %(doc_name)s, %(standard_name)s, %(risk_score)s, %(conclusion)s, %(actions)s, %(para_text)s, %(highlight_terms)s) RETURNING id """, { "created_by": record.created_by, "doc_name": record.doc_name, "standard_name": record.standard_name, "risk_score": record.risk_score, "conclusion": record.conclusion, "actions": json.dumps(record.actions, ensure_ascii=False), "para_text": record.para_text, "highlight_terms": json.dumps(record.highlight_terms, ensure_ascii=False), }, ) row = cur.fetchone() analysis_id = str(row["id"]) if record.findings: with conn.cursor() as cur: for f in record.findings: cur.execute( """ INSERT INTO compliance_findings (analysis_id, seq, title, description, status, clause_ref) VALUES (%(analysis_id)s, %(seq)s, %(title)s, %(desc)s, %(status)s, %(clause_ref)s) """, { "analysis_id": analysis_id, "seq": f.seq, "title": f.title, "desc": f.description, "status": f.status, "clause_ref": f.clause_ref, }, ) conn.commit() return analysis_id def list_analyses(self, limit: int = 50, offset: int = 0) -> list[AnalysisRecord]: """Return analyses without nested findings, ordered newest first.""" with self._conn() as conn: with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: cur.execute( """ SELECT id, created_at, created_by, doc_name, standard_name, risk_score, conclusion, actions, para_text, highlight_terms FROM compliance_analyses ORDER BY created_at DESC LIMIT %(limit)s OFFSET %(offset)s """, {"limit": limit, "offset": offset}, ) rows = cur.fetchall() return [self._row_to_record(dict(r)) for r in rows] def get_analysis(self, analysis_id: str) -> Optional[AnalysisRecord]: """Return analysis with nested findings list.""" with self._conn() as conn: with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: cur.execute( "SELECT * FROM compliance_analyses WHERE id = %(id)s", {"id": analysis_id}, ) row = cur.fetchone() if not row: return None record = self._row_to_record(dict(row)) cur.execute( """ SELECT id, analysis_id, seq, title, description, status, clause_ref FROM compliance_findings WHERE analysis_id = %(id)s ORDER BY seq """, {"id": analysis_id}, ) findings = [ FindingRecord( id=str(r["id"]), analysis_id=str(r["analysis_id"]), seq=r["seq"], title=r["title"] or "", description=r["description"] or "", status=r["status"] or "ok", clause_ref=r["clause_ref"], ) for r in cur.fetchall() ] record.findings = findings return record def delete_analysis(self, analysis_id: str) -> None: """Delete analysis; findings and chat messages cascade automatically.""" with self._conn() as conn: with conn.cursor() as cur: cur.execute( "DELETE FROM compliance_analyses WHERE id = %(id)s", {"id": analysis_id}, ) conn.commit() def save_message(self, analysis_id: str, finding_id: str, role: str, content: str) -> str: """Persist a chat message; return its UUID.""" with self._conn() as conn: with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: cur.execute( """ INSERT INTO finding_chat_messages (analysis_id, finding_id, role, content) VALUES (%(analysis_id)s, %(finding_id)s, %(role)s, %(content)s) RETURNING id """, { "analysis_id": analysis_id, "finding_id": finding_id, "role": role, "content": content, }, ) row = cur.fetchone() conn.commit() return str(row["id"]) def get_messages(self, finding_id: str) -> list[dict]: """Return messages for a finding, oldest first.""" with self._conn() as conn: with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: cur.execute( """ SELECT id, role, content, created_at FROM finding_chat_messages WHERE finding_id = %(finding_id)s ORDER BY created_at ASC """, {"finding_id": finding_id}, ) rows = cur.fetchall() return [ { "id": str(r["id"]), "role": r["role"], "content": r["content"], "created_at": r["created_at"].isoformat() if r["created_at"] else "", } for r in rows ] def _row_to_record(self, row: dict) -> AnalysisRecord: """Convert a RealDictCursor row to an AnalysisRecord (no findings).""" actions = row.get("actions") or [] if isinstance(actions, str): actions = json.loads(actions) highlight_terms = row.get("highlight_terms") or [] if isinstance(highlight_terms, str): highlight_terms = json.loads(highlight_terms) return AnalysisRecord( id=str(row["id"]), created_at=row["created_at"] if isinstance(row["created_at"], datetime) else datetime.utcnow(), created_by=row.get("created_by"), doc_name=row.get("doc_name") or "", standard_name=row.get("standard_name") or "", risk_score=int(row.get("risk_score") or 0), conclusion=row.get("conclusion") or "", actions=actions, para_text=row.get("para_text") or "", highlight_terms=highlight_terms, findings=[], )