"""PostgreSQL-backed regulatory event store.""" from __future__ import annotations import json from contextlib import contextmanager from datetime import UTC, date, datetime, timedelta from typing import Any import psycopg2 import psycopg2.extras from psycopg2.pool import ThreadedConnectionPool from app.config.settings import settings from app.infrastructure.perception.base_event_store import BaseEventStore _CREATE_TABLE = """ CREATE TABLE IF NOT EXISTS regulation_events ( id TEXT PRIMARY KEY, source TEXT NOT NULL, source_label TEXT, standard_code TEXT NOT NULL, title TEXT NOT NULL, summary TEXT, full_text_url TEXT, status TEXT, impact_level TEXT, published_at DATE, effective_at DATE, category TEXT, tags TEXT[], obligations JSONB, deadlines JSONB, scope TEXT, penalties TEXT, content_hash TEXT, previous_hash TEXT, change_summary TEXT, changed_sections JSONB, affected_docs JSONB, crawled_at TIMESTAMPTZ DEFAULT now(), processed_at TIMESTAMPTZ, raw_storage_key TEXT ); CREATE INDEX IF NOT EXISTS reg_events_source_date ON regulation_events (source, published_at DESC); CREATE INDEX IF NOT EXISTS reg_events_impact_date ON regulation_events (impact_level, published_at DESC); """ _ALL_COLUMNS = ( "id", "source", "source_label", "standard_code", "title", "summary", "full_text_url", "status", "impact_level", "published_at", "effective_at", "category", "tags", "obligations", "deadlines", "scope", "penalties", "content_hash", "previous_hash", "change_summary", "changed_sections", "affected_docs", "crawled_at", "processed_at", "raw_storage_key", ) def _row_to_dict(row: dict[str, Any]) -> dict: """Convert a psycopg2 RealDictRow to a plain dict with serialized JSON fields.""" d = dict(row) for field in ("obligations", "deadlines", "changed_sections", "affected_docs"): val = d.get(field) if isinstance(val, str): d[field] = json.loads(val) for date_field in ("published_at", "effective_at"): val = d.get(date_field) if isinstance(val, datetime): d[date_field] = val.date().isoformat() elif isinstance(val, date): d[date_field] = val.isoformat() for ts_field in ("crawled_at", "processed_at"): val = d.get(ts_field) if isinstance(val, datetime): d[ts_field] = val.isoformat() return d class PostgresEventStore(BaseEventStore): """Regulatory event store backed by PostgreSQL.""" def __init__(self) -> None: self._pool = ThreadedConnectionPool( minconn=1, maxconn=5, host=settings.postgres_host, port=settings.postgres_port, user=settings.postgres_user, password=settings.postgres_password, dbname=settings.postgres_db, ) self._ensure_schema() def _ensure_schema(self) -> None: with self._conn() as conn: try: with conn.cursor() as cur: cur.execute(_CREATE_TABLE) conn.commit() except Exception: conn.rollback() raise @contextmanager def _conn(self): conn = None try: conn = self._pool.getconn() yield conn finally: if conn is not None: self._pool.putconn(conn) def all(self) -> list[dict]: with self._conn() as conn: with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: cur.execute( "SELECT * FROM regulation_events ORDER BY published_at DESC NULLS LAST" ) return [_row_to_dict(r) for r in cur.fetchall()] def get(self, event_id: str) -> dict | None: with self._conn() as conn: with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: cur.execute( "SELECT * FROM regulation_events WHERE id = %s", (event_id,) ) row = cur.fetchone() return _row_to_dict(row) if row else None def filter( self, *, source: str | None = None, impact_level: str | None = None, limit: int = 50, ) -> list[dict]: conditions: list[str] = [] params: list[Any] = [] if source: conditions.append("source = %s") params.append(source) if impact_level: conditions.append("impact_level = %s") params.append(impact_level) where = ("WHERE " + " AND ".join(conditions)) if conditions else "" params.append(limit) sql = f""" SELECT * FROM regulation_events {where} ORDER BY published_at DESC NULLS LAST LIMIT %s """ with self._conn() as conn: with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: cur.execute(sql, params) return [_row_to_dict(r) for r in cur.fetchall()] def stats(self) -> dict: cutoff = (date.today() - timedelta(days=90)).isoformat() with self._conn() as conn: with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: cur.execute("SELECT COUNT(*) AS count FROM regulation_events") total = (cur.fetchone() or {}).get("count", 0) cur.execute( "SELECT COUNT(*) AS count FROM regulation_events WHERE impact_level = 'high'" ) high = (cur.fetchone() or {}).get("count", 0) cur.execute( "SELECT COUNT(*) AS count FROM regulation_events WHERE impact_level = 'medium'" ) medium = (cur.fetchone() or {}).get("count", 0) cur.execute( "SELECT COUNT(*) AS count FROM regulation_events WHERE published_at >= %s", (cutoff,), ) recent = (cur.fetchone() or {}).get("count", 0) return { "total": int(total), "high_impact": int(high), "medium_impact": int(medium), "recent_90d": int(recent), } def upsert(self, event: dict) -> None: """Insert or update a regulation event.""" cols = [c for c in _ALL_COLUMNS if c in event] placeholders = ", ".join(f"%({c})s" for c in cols) updates = ", ".join(f"{c} = EXCLUDED.{c}" for c in cols if c != "id") sql = f""" INSERT INTO regulation_events ({', '.join(cols)}) VALUES ({placeholders}) ON CONFLICT (id) DO UPDATE SET {updates} """ row: dict[str, Any] = {} for c in cols: val = event.get(c) if c in ("obligations", "deadlines", "changed_sections", "affected_docs") and val is not None: row[c] = json.dumps(val, ensure_ascii=False) elif c == "tags" and isinstance(val, list): row[c] = val else: row[c] = val with self._conn() as conn: try: with conn.cursor() as cur: cur.execute(sql, row) conn.commit() except Exception: conn.rollback() raise def get_by_standard_code(self, standard_code: str) -> dict | None: with self._conn() as conn: with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: cur.execute( """SELECT * FROM regulation_events WHERE standard_code = %s ORDER BY published_at DESC NULLS LAST LIMIT 1""", (standard_code,), ) row = cur.fetchone() return _row_to_dict(row) if row else None