fix somethings
This commit is contained in:
225
backend/app/infrastructure/perception/postgres_event_store.py
Normal file
225
backend/app/infrastructure/perception/postgres_event_store.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""PostgreSQL-backed regulatory event store."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from datetime import UTC, date, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
from psycopg2.pool import ThreadedConnectionPool
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.infrastructure.perception.base_event_store import BaseEventStore
|
||||
|
||||
_CREATE_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS regulation_events (
|
||||
id TEXT PRIMARY KEY,
|
||||
source TEXT NOT NULL,
|
||||
source_label TEXT,
|
||||
standard_code TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
summary TEXT,
|
||||
full_text_url TEXT,
|
||||
status TEXT,
|
||||
impact_level TEXT,
|
||||
published_at DATE,
|
||||
effective_at DATE,
|
||||
category TEXT,
|
||||
tags TEXT[],
|
||||
obligations JSONB,
|
||||
deadlines JSONB,
|
||||
scope TEXT,
|
||||
penalties TEXT,
|
||||
content_hash TEXT,
|
||||
previous_hash TEXT,
|
||||
change_summary TEXT,
|
||||
changed_sections JSONB,
|
||||
affected_docs JSONB,
|
||||
crawled_at TIMESTAMPTZ DEFAULT now(),
|
||||
processed_at TIMESTAMPTZ,
|
||||
raw_storage_key TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS reg_events_source_date
|
||||
ON regulation_events (source, published_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS reg_events_impact_date
|
||||
ON regulation_events (impact_level, published_at DESC);
|
||||
"""
|
||||
|
||||
_ALL_COLUMNS = (
|
||||
"id", "source", "source_label", "standard_code", "title", "summary",
|
||||
"full_text_url", "status", "impact_level", "published_at", "effective_at",
|
||||
"category", "tags", "obligations", "deadlines", "scope", "penalties",
|
||||
"content_hash", "previous_hash", "change_summary", "changed_sections",
|
||||
"affected_docs", "crawled_at", "processed_at", "raw_storage_key",
|
||||
)
|
||||
|
||||
|
||||
def _row_to_dict(row: dict[str, Any]) -> dict:
|
||||
"""Convert a psycopg2 RealDictRow to a plain dict with serialized JSON fields."""
|
||||
d = dict(row)
|
||||
for field in ("obligations", "deadlines", "changed_sections", "affected_docs"):
|
||||
val = d.get(field)
|
||||
if isinstance(val, str):
|
||||
d[field] = json.loads(val)
|
||||
for date_field in ("published_at", "effective_at"):
|
||||
val = d.get(date_field)
|
||||
if isinstance(val, datetime):
|
||||
d[date_field] = val.date().isoformat()
|
||||
elif isinstance(val, date):
|
||||
d[date_field] = val.isoformat()
|
||||
for ts_field in ("crawled_at", "processed_at"):
|
||||
val = d.get(ts_field)
|
||||
if isinstance(val, datetime):
|
||||
d[ts_field] = val.isoformat()
|
||||
return d
|
||||
|
||||
|
||||
class PostgresEventStore(BaseEventStore):
|
||||
"""Regulatory event store backed by PostgreSQL."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pool = ThreadedConnectionPool(
|
||||
minconn=1,
|
||||
maxconn=5,
|
||||
host=settings.postgres_host,
|
||||
port=settings.postgres_port,
|
||||
user=settings.postgres_user,
|
||||
password=settings.postgres_password,
|
||||
dbname=settings.postgres_db,
|
||||
)
|
||||
self._ensure_schema()
|
||||
|
||||
def _ensure_schema(self) -> None:
|
||||
with self._conn() as conn:
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(_CREATE_TABLE)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
@contextmanager
|
||||
def _conn(self):
|
||||
conn = None
|
||||
try:
|
||||
conn = self._pool.getconn()
|
||||
yield conn
|
||||
finally:
|
||||
if conn is not None:
|
||||
self._pool.putconn(conn)
|
||||
|
||||
def all(self) -> list[dict]:
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(
|
||||
"SELECT * FROM regulation_events ORDER BY published_at DESC NULLS LAST"
|
||||
)
|
||||
return [_row_to_dict(r) for r in cur.fetchall()]
|
||||
|
||||
def get(self, event_id: str) -> dict | None:
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(
|
||||
"SELECT * FROM regulation_events WHERE id = %s", (event_id,)
|
||||
)
|
||||
row = cur.fetchone()
|
||||
return _row_to_dict(row) if row else None
|
||||
|
||||
def filter(
|
||||
self,
|
||||
*,
|
||||
source: str | None = None,
|
||||
impact_level: str | None = None,
|
||||
limit: int = 50,
|
||||
) -> list[dict]:
|
||||
conditions: list[str] = []
|
||||
params: list[Any] = []
|
||||
if source:
|
||||
conditions.append("source = %s")
|
||||
params.append(source)
|
||||
if impact_level:
|
||||
conditions.append("impact_level = %s")
|
||||
params.append(impact_level)
|
||||
where = ("WHERE " + " AND ".join(conditions)) if conditions else ""
|
||||
params.append(limit)
|
||||
sql = f"""
|
||||
SELECT * FROM regulation_events
|
||||
{where}
|
||||
ORDER BY published_at DESC NULLS LAST
|
||||
LIMIT %s
|
||||
"""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(sql, params)
|
||||
return [_row_to_dict(r) for r in cur.fetchall()]
|
||||
|
||||
def stats(self) -> dict:
|
||||
cutoff = (date.today() - timedelta(days=90)).isoformat()
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute("SELECT COUNT(*) AS count FROM regulation_events")
|
||||
total = (cur.fetchone() or {}).get("count", 0)
|
||||
cur.execute(
|
||||
"SELECT COUNT(*) AS count FROM regulation_events WHERE impact_level = 'high'"
|
||||
)
|
||||
high = (cur.fetchone() or {}).get("count", 0)
|
||||
cur.execute(
|
||||
"SELECT COUNT(*) AS count FROM regulation_events WHERE impact_level = 'medium'"
|
||||
)
|
||||
medium = (cur.fetchone() or {}).get("count", 0)
|
||||
cur.execute(
|
||||
"SELECT COUNT(*) AS count FROM regulation_events WHERE published_at >= %s",
|
||||
(cutoff,),
|
||||
)
|
||||
recent = (cur.fetchone() or {}).get("count", 0)
|
||||
return {
|
||||
"total": int(total),
|
||||
"high_impact": int(high),
|
||||
"medium_impact": int(medium),
|
||||
"recent_90d": int(recent),
|
||||
}
|
||||
|
||||
def upsert(self, event: dict) -> None:
|
||||
"""Insert or update a regulation event."""
|
||||
cols = [c for c in _ALL_COLUMNS if c in event]
|
||||
placeholders = ", ".join(f"%({c})s" for c in cols)
|
||||
updates = ", ".join(f"{c} = EXCLUDED.{c}" for c in cols if c != "id")
|
||||
sql = f"""
|
||||
INSERT INTO regulation_events ({', '.join(cols)})
|
||||
VALUES ({placeholders})
|
||||
ON CONFLICT (id) DO UPDATE SET {updates}
|
||||
"""
|
||||
row: dict[str, Any] = {}
|
||||
for c in cols:
|
||||
val = event.get(c)
|
||||
if c in ("obligations", "deadlines", "changed_sections", "affected_docs") and val is not None:
|
||||
row[c] = json.dumps(val, ensure_ascii=False)
|
||||
elif c == "tags" and isinstance(val, list):
|
||||
row[c] = val
|
||||
else:
|
||||
row[c] = val
|
||||
with self._conn() as conn:
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, row)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
def get_by_standard_code(self, standard_code: str) -> dict | None:
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(
|
||||
"""SELECT * FROM regulation_events
|
||||
WHERE standard_code = %s
|
||||
ORDER BY published_at DESC NULLS LAST
|
||||
LIMIT 1""",
|
||||
(standard_code,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
return _row_to_dict(row) if row else None
|
||||
Reference in New Issue
Block a user