Files
AIRegulation-DocAnalysis/backend/app/infrastructure/perception/postgres_event_store.py

226 lines
8.0 KiB
Python
Raw Normal View History

2026-06-08 11:16:28 +08:00
"""PostgreSQL-backed regulatory event store."""
from __future__ import annotations
import json
from contextlib import contextmanager
from datetime import UTC, date, datetime, timedelta
from typing import Any
import psycopg2
import psycopg2.extras
from psycopg2.pool import ThreadedConnectionPool
from app.config.settings import settings
from app.infrastructure.perception.base_event_store import BaseEventStore
_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS regulation_events (
id TEXT PRIMARY KEY,
source TEXT NOT NULL,
source_label TEXT,
standard_code TEXT NOT NULL,
title TEXT NOT NULL,
summary TEXT,
full_text_url TEXT,
status TEXT,
impact_level TEXT,
published_at DATE,
effective_at DATE,
category TEXT,
tags TEXT[],
obligations JSONB,
deadlines JSONB,
scope TEXT,
penalties TEXT,
content_hash TEXT,
previous_hash TEXT,
change_summary TEXT,
changed_sections JSONB,
affected_docs JSONB,
crawled_at TIMESTAMPTZ DEFAULT now(),
processed_at TIMESTAMPTZ,
raw_storage_key TEXT
);
CREATE INDEX IF NOT EXISTS reg_events_source_date
ON regulation_events (source, published_at DESC);
CREATE INDEX IF NOT EXISTS reg_events_impact_date
ON regulation_events (impact_level, published_at DESC);
"""
_ALL_COLUMNS = (
"id", "source", "source_label", "standard_code", "title", "summary",
"full_text_url", "status", "impact_level", "published_at", "effective_at",
"category", "tags", "obligations", "deadlines", "scope", "penalties",
"content_hash", "previous_hash", "change_summary", "changed_sections",
"affected_docs", "crawled_at", "processed_at", "raw_storage_key",
)
def _row_to_dict(row: dict[str, Any]) -> dict:
"""Convert a psycopg2 RealDictRow to a plain dict with serialized JSON fields."""
d = dict(row)
for field in ("obligations", "deadlines", "changed_sections", "affected_docs"):
val = d.get(field)
if isinstance(val, str):
d[field] = json.loads(val)
for date_field in ("published_at", "effective_at"):
val = d.get(date_field)
if isinstance(val, datetime):
d[date_field] = val.date().isoformat()
elif isinstance(val, date):
d[date_field] = val.isoformat()
for ts_field in ("crawled_at", "processed_at"):
val = d.get(ts_field)
if isinstance(val, datetime):
d[ts_field] = val.isoformat()
return d
class PostgresEventStore(BaseEventStore):
"""Regulatory event store backed by PostgreSQL."""
def __init__(self) -> None:
self._pool = ThreadedConnectionPool(
minconn=1,
maxconn=5,
host=settings.postgres_host,
port=settings.postgres_port,
user=settings.postgres_user,
password=settings.postgres_password,
dbname=settings.postgres_db,
)
self._ensure_schema()
def _ensure_schema(self) -> None:
with self._conn() as conn:
try:
with conn.cursor() as cur:
cur.execute(_CREATE_TABLE)
conn.commit()
except Exception:
conn.rollback()
raise
@contextmanager
def _conn(self):
conn = None
try:
conn = self._pool.getconn()
yield conn
finally:
if conn is not None:
self._pool.putconn(conn)
def all(self) -> list[dict]:
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
"SELECT * FROM regulation_events ORDER BY published_at DESC NULLS LAST"
)
return [_row_to_dict(r) for r in cur.fetchall()]
def get(self, event_id: str) -> dict | None:
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
"SELECT * FROM regulation_events WHERE id = %s", (event_id,)
)
row = cur.fetchone()
return _row_to_dict(row) if row else None
def filter(
self,
*,
source: str | None = None,
impact_level: str | None = None,
limit: int = 50,
) -> list[dict]:
conditions: list[str] = []
params: list[Any] = []
if source:
conditions.append("source = %s")
params.append(source)
if impact_level:
conditions.append("impact_level = %s")
params.append(impact_level)
where = ("WHERE " + " AND ".join(conditions)) if conditions else ""
params.append(limit)
sql = f"""
SELECT * FROM regulation_events
{where}
ORDER BY published_at DESC NULLS LAST
LIMIT %s
"""
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql, params)
return [_row_to_dict(r) for r in cur.fetchall()]
def stats(self) -> dict:
cutoff = (date.today() - timedelta(days=90)).isoformat()
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute("SELECT COUNT(*) AS count FROM regulation_events")
total = (cur.fetchone() or {}).get("count", 0)
cur.execute(
"SELECT COUNT(*) AS count FROM regulation_events WHERE impact_level = 'high'"
)
high = (cur.fetchone() or {}).get("count", 0)
cur.execute(
"SELECT COUNT(*) AS count FROM regulation_events WHERE impact_level = 'medium'"
)
medium = (cur.fetchone() or {}).get("count", 0)
cur.execute(
"SELECT COUNT(*) AS count FROM regulation_events WHERE published_at >= %s",
(cutoff,),
)
recent = (cur.fetchone() or {}).get("count", 0)
return {
"total": int(total),
"high_impact": int(high),
"medium_impact": int(medium),
"recent_90d": int(recent),
}
def upsert(self, event: dict) -> None:
"""Insert or update a regulation event."""
cols = [c for c in _ALL_COLUMNS if c in event]
placeholders = ", ".join(f"%({c})s" for c in cols)
updates = ", ".join(f"{c} = EXCLUDED.{c}" for c in cols if c != "id")
sql = f"""
INSERT INTO regulation_events ({', '.join(cols)})
VALUES ({placeholders})
ON CONFLICT (id) DO UPDATE SET {updates}
"""
row: dict[str, Any] = {}
for c in cols:
val = event.get(c)
if c in ("obligations", "deadlines", "changed_sections", "affected_docs") and val is not None:
row[c] = json.dumps(val, ensure_ascii=False)
elif c == "tags" and isinstance(val, list):
row[c] = val
else:
row[c] = val
with self._conn() as conn:
try:
with conn.cursor() as cur:
cur.execute(sql, row)
conn.commit()
except Exception:
conn.rollback()
raise
def get_by_standard_code(self, standard_code: str) -> dict | None:
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
"""SELECT * FROM regulation_events
WHERE standard_code = %s
ORDER BY published_at DESC NULLS LAST
LIMIT 1""",
(standard_code,),
)
row = cur.fetchone()
return _row_to_dict(row) if row else None