fix somethings

This commit is contained in:
2026-06-08 11:16:28 +08:00
parent 9fea9c6a53
commit e7963b267e
34 changed files with 5195 additions and 246 deletions

View File

@@ -0,0 +1,39 @@
"""Abstract base class for regulatory event stores."""
from __future__ import annotations
from abc import ABC, abstractmethod
class BaseEventStore(ABC):
"""Port interface for regulatory event persistence."""
@abstractmethod
def all(self) -> list[dict]:
"""Return all events, most-recent first."""
@abstractmethod
def get(self, event_id: str) -> dict | None:
"""Return a single event by ID, or None."""
@abstractmethod
def filter(
self,
*,
source: str | None = None,
impact_level: str | None = None,
limit: int = 50,
) -> list[dict]:
"""Return filtered events sorted by published_at descending."""
@abstractmethod
def stats(self) -> dict:
"""Return {total, high_impact, medium_impact, low_impact, recent_90d}."""
@abstractmethod
def upsert(self, event: dict) -> None:
"""Insert or update an event record."""
@abstractmethod
def get_by_standard_code(self, standard_code: str) -> dict | None:
"""Return the most-recent event with matching standard_code, or None."""

View File

@@ -0,0 +1,43 @@
"""Shared utility functions for crawlers."""
from __future__ import annotations
import re
from datetime import date
def parse_date(text: str) -> str:
"""Return YYYY-MM-DD from common Chinese date formats, or today's date."""
text = text.strip()
if not text:
return date.today().isoformat()
m = re.search(r"(\d{4})[/-](\d{1,2})[/-](\d{1,2})", text)
if m:
try:
return date(int(m.group(1)), int(m.group(2)), int(m.group(3))).isoformat()
except ValueError:
pass
m2 = re.search(r"(\d{4})年(\d{1,2})月(\d{1,2})日?", text)
if m2:
try:
return date(int(m2.group(1)), int(m2.group(2)), int(m2.group(3))).isoformat()
except ValueError:
pass
return date.today().isoformat()
def extract_tags(standard_code: str, title: str) -> list[str]:
"""Derive simple keyword tags from standard code and title."""
tags: list[str] = []
code_upper = standard_code.upper()
if "GB" in code_upper:
tags.append("国家标准")
if "/T" in code_upper:
tags.append("推荐性")
else:
tags.append("强制性")
keywords = ["电动", "安全", "自动驾驶", "充电", "智能网联", "碰撞", "排放", "网络安全"]
for kw in keywords:
if kw in title:
tags.append(kw)
return tags[:5]

View File

@@ -0,0 +1,32 @@
"""Shared contracts for regulatory source crawlers."""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
@dataclass
class RawEvent:
"""Raw regulatory event returned by a crawler before enrichment."""
source: str
source_label: str
standard_code: str
title: str
summary: str
full_text_url: str
status: str # 'enacted' | 'draft' | 'consultation'
published_at: str # YYYY-MM-DD string
effective_at: str | None
category: str
tags: list[str] = field(default_factory=list)
raw_text: str = "" # full crawled text for hashing + LLM
class BaseCrawler(ABC):
"""Abstract regulatory source crawler."""
@abstractmethod
def fetch(self, limit: int = 50) -> list[RawEvent]:
"""Fetch up to `limit` recent events from the data source."""

View File

@@ -0,0 +1,83 @@
"""Crawler for CATARC automotive standard catalogue."""
from __future__ import annotations
from urllib.parse import urljoin
import httpx
from bs4 import BeautifulSoup
from loguru import logger
from app.infrastructure.perception.crawlers.base import BaseCrawler, RawEvent
from ._utils import extract_tags, parse_date
_BASE_URL = "https://www.catarc.org.cn/bzzxd/qcbz/index.html"
_HOST = "https://www.catarc.org.cn"
_STATUS_MAP = {
"现行": "enacted",
"即将实施": "enacted",
"废止": "enacted",
"征求意见": "consultation",
"报批": "draft",
}
class CatarcCrawler(BaseCrawler):
"""Scrape the CATARC automotive standard list page."""
def fetch(self, limit: int = 50) -> list[RawEvent]:
events: list[RawEvent] = []
page = 1
max_pages = max(10, limit)
while len(events) < limit and page <= max_pages:
url = f"{_BASE_URL}?page={page}"
try:
resp = httpx.get(url, timeout=30, follow_redirects=True)
resp.raise_for_status()
except Exception as exc:
logger.warning("CATARC fetch failed page={} err={}", page, exc)
break
soup = BeautifulSoup(resp.text, "lxml")
rows = soup.select("table tr")
if not rows:
break
batch: list[RawEvent] = []
for row in rows:
cells = row.find_all("td")
if len(cells) < 3:
continue
link = cells[0].find("a")
standard_code = link.get_text(strip=True) if link else cells[0].get_text(strip=True)
title = cells[1].get_text(strip=True) if len(cells) > 1 else standard_code
date_text = cells[2].get_text(strip=True) if len(cells) > 2 else ""
published_at = parse_date(date_text)
status_text = cells[3].get_text(strip=True) if len(cells) > 3 else ""
status = _STATUS_MAP.get(status_text, "enacted")
detail_url = urljoin(_HOST, link["href"]) if link and link.get("href") else url
raw_text = f"{standard_code} {title}"
batch.append(RawEvent(
source="CATARC",
source_label="全国汽车标准化技术委员会",
standard_code=standard_code,
title=title,
summary=title,
full_text_url=detail_url,
status=status,
published_at=published_at,
effective_at=None,
category="汽车标准",
tags=extract_tags(standard_code, title),
raw_text=raw_text,
))
if not batch:
break
events.extend(batch)
page += 1
return events[:limit]

View File

@@ -0,0 +1,117 @@
"""Crawler for EUR-Lex RSS feeds covering EU AI Act and automotive regulations."""
from __future__ import annotations
import re
from email.utils import parsedate_to_datetime
import httpx
from bs4 import BeautifulSoup
from loguru import logger
from app.infrastructure.perception.crawlers.base import BaseCrawler, RawEvent
from ._utils import parse_date
_EURLEX_RSS_URLS = [
"https://eur-lex.europa.eu/rss-feed/OJ-L.rss",
]
_AUTOMOTIVE_KEYWORDS = [
"vehicle", "automotive", "motor", "tyre", "emission", "ADAS", "autonomous",
"AI Act", "artificial intelligence", "cybersecurity", "software update",
"R155", "R156", "汽车", "车辆",
]
_AUTOMOTIVE_KEYWORDS_LOWER = [kw.lower() for kw in _AUTOMOTIVE_KEYWORDS]
def _is_automotive_relevant(title: str, description: str) -> bool:
combined = (title + " " + description).lower()
return any(kw in combined for kw in _AUTOMOTIVE_KEYWORDS_LOWER)
def _extract_celex(url: str) -> str:
m = re.search(r"CELEX[:/]([0-9A-Z]+)", url)
return m.group(1) if m else ""
def _parse_rss_date(rfc2822: str) -> str:
try:
dt = parsedate_to_datetime(rfc2822)
return dt.date().isoformat()
except Exception:
return parse_date(rfc2822)
class EurlexCrawler(BaseCrawler):
"""Fetch automotive-relevant EU regulations from EUR-Lex RSS feeds."""
def fetch(self, limit: int = 50) -> list[RawEvent]:
events: list[RawEvent] = []
for rss_url in _EURLEX_RSS_URLS:
if len(events) >= limit:
break
try:
resp = httpx.get(rss_url, timeout=30, follow_redirects=True)
resp.raise_for_status()
except Exception as exc:
logger.warning("EUR-Lex RSS fetch failed url={} err={}", rss_url, exc)
continue
soup = BeautifulSoup(resp.content, "lxml-xml")
for item in soup.find_all("item"):
if len(events) >= limit:
break
title_tag = item.find("title")
title = title_tag.get_text(strip=True) if title_tag else ""
desc_tag = item.find("description")
description = desc_tag.get_text(strip=True) if desc_tag else ""
link_tag = item.find("link")
link = link_tag.get_text(strip=True) if link_tag else ""
pub_date_tag = item.find("pubDate")
pub_date = pub_date_tag.get_text(strip=True) if pub_date_tag else ""
if not _is_automotive_relevant(title, description):
continue
celex = _extract_celex(link)
standard_code = celex if celex else title[:60]
published_at = _parse_rss_date(pub_date) if pub_date else ""
events.append(RawEvent(
source="EUR-Lex",
source_label="欧盟官方公报",
standard_code=standard_code,
title=title,
summary=description[:500],
full_text_url=link,
status="enacted",
published_at=published_at,
effective_at=None,
category="EU法规",
tags=_extract_eurlex_tags(title, description),
raw_text=f"{title}\n{description}",
))
return events[:limit]
def _extract_eurlex_tags(title: str, description: str) -> list[str]:
combined = title + " " + description
tag_map = {
"AI Act": "EU AI Act",
"artificial intelligence": "EU AI Act",
"R155": "UN R155",
"R156": "UN R156",
"cybersecurity": "网络安全",
"emission": "排放",
"autonomous": "自动驾驶",
"ADAS": "ADAS",
}
combined_lower = combined.lower()
tags = []
for kw, tag in tag_map.items():
if kw.lower() in combined_lower:
tags.append(tag)
return tags[:5]

View File

@@ -0,0 +1,92 @@
"""Crawlers for the 国标委 (SAMR) standard information platform."""
from __future__ import annotations
import httpx
from loguru import logger
from app.infrastructure.perception.crawlers.base import BaseCrawler, RawEvent
from ._utils import extract_tags, parse_date
_BASE_URL = "https://openstd.samr.gov.cn/bzgk/std/std_list_type"
_HEADERS = {"User-Agent": "Mozilla/5.0 (compatible; RegulatoryBot/1.0)"}
def _fetch_page(std_type: int, page: int, page_size: int) -> list[dict]:
params = {
"p.p1": std_type,
"p.p2": "",
"p.p90": "circulation_date",
"p.p91": "desc",
"p.p6": page,
"p.p7": page_size,
}
try:
resp = httpx.get(_BASE_URL, params=params, headers=_HEADERS, timeout=30)
resp.raise_for_status()
data = resp.json()
return data.get("rows", []) or []
except Exception as exc:
logger.warning("国标委 fetch failed type={} page={} err={}", std_type, page, exc)
return []
def _row_to_raw_event(row: dict, source_label: str) -> RawEvent:
standard_code = row.get("std_code", "")
title = row.get("std_name", standard_code)
published_at = parse_date(row.get("release_date", ""))
effective_at_raw = row.get("implement_date", "")
effective_at = parse_date(effective_at_raw) if effective_at_raw else None
status_text = row.get("std_status", "")
if "征求意见" in status_text:
status = "consultation"
elif "报批" in status_text or "草案" in status_text:
status = "draft"
else:
status = "enacted"
return RawEvent(
source="国标委",
source_label=source_label,
standard_code=standard_code,
title=title,
summary=title,
full_text_url=f"https://openstd.samr.gov.cn/bzgk/std/detail?id={row.get('id', '')}",
status=status,
published_at=published_at,
effective_at=effective_at,
category=row.get("std_type", "国家标准"),
tags=extract_tags(standard_code, title),
raw_text=f"{standard_code} {title}",
)
class GuobiaoMandatoryCrawler(BaseCrawler):
"""Fetch mandatory national standards (强制性) related to vehicles."""
def fetch(self, limit: int = 50) -> list[RawEvent]:
events: list[RawEvent] = []
page = 1
max_pages = max(10, limit)
while len(events) < limit and page <= max_pages:
rows = _fetch_page(std_type=1, page=page, page_size=20)
if not rows:
break
events.extend(_row_to_raw_event(r, "国标委·强制性") for r in rows)
page += 1
return events[:limit]
class GuobiaoRecommendedCrawler(BaseCrawler):
"""Fetch recommended national standards (推荐性) related to vehicles."""
def fetch(self, limit: int = 50) -> list[RawEvent]:
events: list[RawEvent] = []
page = 1
max_pages = max(10, limit)
while len(events) < limit and page <= max_pages:
rows = _fetch_page(std_type=2, page=page, page_size=20)
if not rows:
break
events.extend(_row_to_raw_event(r, "国标委·推荐性") for r in rows)
page += 1
return events[:limit]

View File

@@ -0,0 +1,241 @@
"""LLM-driven pipeline for regulatory event enrichment."""
from __future__ import annotations
import json
import math
from typing import Any
from loguru import logger
from app.config.settings import settings
from app.infrastructure.embedding.openai_compatible_embedding_provider import (
OpenAICompatibleEmbeddingProvider,
)
from app.services.llm.llm_factory import get_llm_client
_EXTRACT_SYSTEM = (
"You are a regulatory compliance expert specialising in automotive standards "
"(GB, UN-ECE, ISO, EU). Extract structured information from regulation text. "
"Return valid JSON only — no markdown fences, no extra keys."
)
_ASSESS_SYSTEM = (
"You are an automotive compliance analyst. Given a regulation and related document excerpts, "
"identify which documents are affected and what actions are required. "
"Return a JSON array only."
)
_DIFF_SYSTEM = (
"You are a regulatory change analyst. Given an old and new version of a regulation paragraph, "
"classify the type of change and summarise it. "
"Return JSON only: {\"change_type\": \"tightened|relaxed|added|removed\", \"summary\": \"...\"}"
)
_SIMILARITY_THRESHOLD = 0.85
def _cosine(a: list[float], b: list[float]) -> float:
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
def _llm_json(client: Any, messages: list[dict]) -> Any:
"""Call LLM and parse JSON response; return None on failure."""
try:
resp = client.chat(messages)
text = (resp.content or "").strip()
if text.startswith("```"):
text = text.split("```")[1]
if text.startswith("json"):
text = text[4:]
return json.loads(text)
except Exception as exc:
logger.warning("LLM JSON parse failed: {}", exc)
return None
class LlmPipeline:
"""Three-step enrichment pipeline for crawled regulatory events."""
def __init__(self) -> None:
self._client = get_llm_client(
provider=settings.llm_provider,
model=settings.llm_model,
)
self._embedder = OpenAICompatibleEmbeddingProvider()
# ------------------------------------------------------------------
# Step 1: Structure extraction
# ------------------------------------------------------------------
def extract_structure(self, event: dict) -> dict:
"""Extract obligations, deadlines, scope, penalties, impact_level from event text."""
prompt = f"""Extract structured compliance information from this regulation:
Standard: {event.get('standard_code', '')}
Title: {event.get('title', '')}
Source: {event.get('source_label', '')}
Summary: {event.get('summary', '')}
Tags: {', '.join(event.get('tags') or [])}
Return JSON with exactly these keys:
{{
"obligations": [{{"text": "...", "deontic": "must|shall|may|prohibited", "subject": "...", "object": "...", "condition": ""}}],
"deadlines": [{{"date": "YYYY-MM-DD or null", "description": "..."}}],
"scope": "one sentence describing who/what this applies to",
"penalties": "one sentence on consequences of non-compliance, or null",
"impact_level": "high|medium|low"
}}"""
messages = [
{"role": "system", "content": _EXTRACT_SYSTEM},
{"role": "user", "content": prompt},
]
result = _llm_json(self._client, messages)
if not isinstance(result, dict):
return {
"obligations": [],
"deadlines": [],
"scope": "",
"penalties": "",
"impact_level": "medium",
}
return result
# ------------------------------------------------------------------
# Step 2: Impact assessment
# ------------------------------------------------------------------
def assess_impact(self, event: dict, retrieval_service: Any) -> list[dict]:
"""Use RAG to find affected documents and generate recommendations."""
obligations = event.get("obligations") or []
obligation_texts = " ".join(o.get("text", "") for o in obligations[:3])
query = f"{event.get('standard_code', '')} {event.get('title', '')} {obligation_texts}"
try:
chunks = retrieval_service.retrieve(query=query, top_k=5)
except Exception as exc:
logger.warning("RAG retrieval failed: {}", exc)
return []
if not chunks:
return []
seen: set[str] = set()
doc_excerpts: list[dict] = []
for chunk in chunks:
if chunk.doc_id not in seen:
seen.add(chunk.doc_id)
doc_excerpts.append({
"doc_id": chunk.doc_id,
"doc_name": chunk.doc_title,
"score": round(float(chunk.score if chunk.score is not None else 0), 4),
"snippet": (chunk.text or "")[:300],
"clause": getattr(chunk, "section_title", "") or "",
})
context = "\n".join(
f"[{d['doc_name']} {d['clause']}] score={d['score']}: {d['snippet']}"
for d in doc_excerpts
)
prompt = f"""Regulation: {event.get('standard_code')}{event.get('title')}
Obligations: {obligation_texts or event.get('summary', '')}
Affected documents found in knowledge base:
{context}
For each document, assess impact and recommend action. Return JSON array:
[{{"doc_id":"...","doc_name":"...","score":0.0,"key_clauses":"...","recommendation":"one sentence action"}}]"""
messages = [
{"role": "system", "content": _ASSESS_SYSTEM},
{"role": "user", "content": prompt},
]
result = _llm_json(self._client, messages)
if isinstance(result, list):
score_map = {d["doc_id"]: d["score"] for d in doc_excerpts}
for item in result:
if isinstance(item, dict) and item.get("doc_id") in score_map:
item["score"] = score_map[item["doc_id"]]
return result
return doc_excerpts
# ------------------------------------------------------------------
# Step 3: Semantic diff
# ------------------------------------------------------------------
def compute_diff(self, old_text: str, new_text: str) -> dict:
"""Compare old and new regulation text; return changed sections and summary."""
old_paras = [p.strip() for p in old_text.split("\n") if p.strip()]
new_paras = [p.strip() for p in new_text.split("\n") if p.strip()]
if not old_paras or not new_paras:
return {"changed_sections": [], "change_summary": "No comparable text."}
all_paras = old_paras + new_paras
try:
all_embeddings = self._embedder.embed_texts(all_paras)
except Exception as exc:
logger.warning("Embedding for diff failed: {}", exc)
return {"changed_sections": [], "change_summary": "Diff unavailable (embedding error)."}
old_embeddings = all_embeddings[: len(old_paras)]
new_embeddings = all_embeddings[len(old_paras):]
changed_sections: list[dict] = []
max_len = max(len(old_paras), len(new_paras))
for i in range(max_len):
if i >= len(old_paras):
# New paragraph added
changed_sections.append({
"old_text": "",
"new_text": new_paras[i][:300],
"similarity": 0.0,
"change_type": "added",
"summary": "New paragraph added.",
})
continue
if i >= len(new_paras):
# Old paragraph removed
changed_sections.append({
"old_text": old_paras[i][:300],
"new_text": "",
"similarity": 0.0,
"change_type": "removed",
"summary": "Paragraph removed.",
})
continue
# Both exist — compare via embeddings
sim = _cosine(old_embeddings[i], new_embeddings[i])
if sim < _SIMILARITY_THRESHOLD:
messages = [
{"role": "system", "content": _DIFF_SYSTEM},
{"role": "user", "content": f"OLD: {old_paras[i][:500]}\nNEW: {new_paras[i][:500]}"},
]
classification = _llm_json(self._client, messages) or {}
changed_sections.append({
"old_text": old_paras[i][:300],
"new_text": new_paras[i][:300],
"similarity": round(sim, 3),
"change_type": classification.get("change_type", "modified"),
"summary": classification.get("summary", ""),
})
if not changed_sections:
change_summary = "No substantive changes detected between versions."
else:
types = [s["change_type"] for s in changed_sections]
change_summary = (
f"{len(changed_sections)} paragraph(s) changed: "
+ ", ".join(f"{t}" for t in set(types))
+ ". "
+ (changed_sections[0].get("summary", "") if changed_sections else "")
)
return {"changed_sections": changed_sections, "change_summary": change_summary}

View File

@@ -4,6 +4,8 @@ from __future__ import annotations
from typing import Any
from app.infrastructure.perception.base_event_store import BaseEventStore
MOCK_EVENTS: list[dict[str, Any]] = [
# ------------------------------------------------------------------ HIGH
{
@@ -379,18 +381,18 @@ MOCK_EVENTS: list[dict[str, Any]] = [
},
]
# Index for fast lookup
_EVENT_INDEX: dict[str, dict] = {e["id"]: e for e in MOCK_EVENTS}
class MockEventStore:
class MockEventStore(BaseEventStore):
"""In-memory mock store for regulatory events."""
def __init__(self) -> None:
self._events: list[dict] = [dict(e) for e in MOCK_EVENTS]
self._index: dict[str, dict] = {e["id"]: e for e in self._events}
def all(self) -> list[dict]:
return list(MOCK_EVENTS)
return list(self._events)
def get(self, event_id: str) -> dict | None:
return _EVENT_INDEX.get(event_id)
return self._index.get(event_id)
def filter(
self,
@@ -399,23 +401,39 @@ class MockEventStore:
impact_level: str | None = None,
limit: int = 50,
) -> list[dict]:
events = list(MOCK_EVENTS)
events = list(self._events)
if source:
events = [e for e in events if e["source"] == source]
if impact_level:
events = [e for e in events if e["impact_level"] == impact_level]
events.sort(key=lambda e: e["published_at"], reverse=True)
events.sort(key=lambda e: e.get("published_at") or "", reverse=True)
return events[:limit]
def stats(self) -> dict:
from datetime import date, timedelta
events = MOCK_EVENTS
events = self._events
cutoff = (date.today() - timedelta(days=90)).isoformat()
return {
"total": len(events),
"high_impact": sum(1 for e in events if e["impact_level"] == "high"),
"medium_impact": sum(1 for e in events if e["impact_level"] == "medium"),
"low_impact": sum(1 for e in events if e["impact_level"] == "low"),
"recent_90d": sum(1 for e in events if e["published_at"] >= cutoff),
"recent_90d": sum(1 for e in events if (e.get("published_at") or "") >= cutoff),
}
def upsert(self, event: dict) -> None:
"""Insert or update event in the in-memory list (used in tests)."""
existing = self._index.get(event["id"])
if existing:
existing.update(event)
else:
self._events.append(event)
self._index[event["id"]] = event
def get_by_standard_code(self, standard_code: str) -> dict | None:
"""Return most-recent event with matching standard_code."""
matches = [e for e in self._events if e.get("standard_code") == standard_code]
if not matches:
return None
return max(matches, key=lambda e: e.get("published_at", ""))

View File

@@ -0,0 +1,225 @@
"""PostgreSQL-backed regulatory event store."""
from __future__ import annotations
import json
from contextlib import contextmanager
from datetime import UTC, date, datetime, timedelta
from typing import Any
import psycopg2
import psycopg2.extras
from psycopg2.pool import ThreadedConnectionPool
from app.config.settings import settings
from app.infrastructure.perception.base_event_store import BaseEventStore
_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS regulation_events (
id TEXT PRIMARY KEY,
source TEXT NOT NULL,
source_label TEXT,
standard_code TEXT NOT NULL,
title TEXT NOT NULL,
summary TEXT,
full_text_url TEXT,
status TEXT,
impact_level TEXT,
published_at DATE,
effective_at DATE,
category TEXT,
tags TEXT[],
obligations JSONB,
deadlines JSONB,
scope TEXT,
penalties TEXT,
content_hash TEXT,
previous_hash TEXT,
change_summary TEXT,
changed_sections JSONB,
affected_docs JSONB,
crawled_at TIMESTAMPTZ DEFAULT now(),
processed_at TIMESTAMPTZ,
raw_storage_key TEXT
);
CREATE INDEX IF NOT EXISTS reg_events_source_date
ON regulation_events (source, published_at DESC);
CREATE INDEX IF NOT EXISTS reg_events_impact_date
ON regulation_events (impact_level, published_at DESC);
"""
_ALL_COLUMNS = (
"id", "source", "source_label", "standard_code", "title", "summary",
"full_text_url", "status", "impact_level", "published_at", "effective_at",
"category", "tags", "obligations", "deadlines", "scope", "penalties",
"content_hash", "previous_hash", "change_summary", "changed_sections",
"affected_docs", "crawled_at", "processed_at", "raw_storage_key",
)
def _row_to_dict(row: dict[str, Any]) -> dict:
"""Convert a psycopg2 RealDictRow to a plain dict with serialized JSON fields."""
d = dict(row)
for field in ("obligations", "deadlines", "changed_sections", "affected_docs"):
val = d.get(field)
if isinstance(val, str):
d[field] = json.loads(val)
for date_field in ("published_at", "effective_at"):
val = d.get(date_field)
if isinstance(val, datetime):
d[date_field] = val.date().isoformat()
elif isinstance(val, date):
d[date_field] = val.isoformat()
for ts_field in ("crawled_at", "processed_at"):
val = d.get(ts_field)
if isinstance(val, datetime):
d[ts_field] = val.isoformat()
return d
class PostgresEventStore(BaseEventStore):
"""Regulatory event store backed by PostgreSQL."""
def __init__(self) -> None:
self._pool = ThreadedConnectionPool(
minconn=1,
maxconn=5,
host=settings.postgres_host,
port=settings.postgres_port,
user=settings.postgres_user,
password=settings.postgres_password,
dbname=settings.postgres_db,
)
self._ensure_schema()
def _ensure_schema(self) -> None:
with self._conn() as conn:
try:
with conn.cursor() as cur:
cur.execute(_CREATE_TABLE)
conn.commit()
except Exception:
conn.rollback()
raise
@contextmanager
def _conn(self):
conn = None
try:
conn = self._pool.getconn()
yield conn
finally:
if conn is not None:
self._pool.putconn(conn)
def all(self) -> list[dict]:
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
"SELECT * FROM regulation_events ORDER BY published_at DESC NULLS LAST"
)
return [_row_to_dict(r) for r in cur.fetchall()]
def get(self, event_id: str) -> dict | None:
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
"SELECT * FROM regulation_events WHERE id = %s", (event_id,)
)
row = cur.fetchone()
return _row_to_dict(row) if row else None
def filter(
self,
*,
source: str | None = None,
impact_level: str | None = None,
limit: int = 50,
) -> list[dict]:
conditions: list[str] = []
params: list[Any] = []
if source:
conditions.append("source = %s")
params.append(source)
if impact_level:
conditions.append("impact_level = %s")
params.append(impact_level)
where = ("WHERE " + " AND ".join(conditions)) if conditions else ""
params.append(limit)
sql = f"""
SELECT * FROM regulation_events
{where}
ORDER BY published_at DESC NULLS LAST
LIMIT %s
"""
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql, params)
return [_row_to_dict(r) for r in cur.fetchall()]
def stats(self) -> dict:
cutoff = (date.today() - timedelta(days=90)).isoformat()
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute("SELECT COUNT(*) AS count FROM regulation_events")
total = (cur.fetchone() or {}).get("count", 0)
cur.execute(
"SELECT COUNT(*) AS count FROM regulation_events WHERE impact_level = 'high'"
)
high = (cur.fetchone() or {}).get("count", 0)
cur.execute(
"SELECT COUNT(*) AS count FROM regulation_events WHERE impact_level = 'medium'"
)
medium = (cur.fetchone() or {}).get("count", 0)
cur.execute(
"SELECT COUNT(*) AS count FROM regulation_events WHERE published_at >= %s",
(cutoff,),
)
recent = (cur.fetchone() or {}).get("count", 0)
return {
"total": int(total),
"high_impact": int(high),
"medium_impact": int(medium),
"recent_90d": int(recent),
}
def upsert(self, event: dict) -> None:
"""Insert or update a regulation event."""
cols = [c for c in _ALL_COLUMNS if c in event]
placeholders = ", ".join(f"%({c})s" for c in cols)
updates = ", ".join(f"{c} = EXCLUDED.{c}" for c in cols if c != "id")
sql = f"""
INSERT INTO regulation_events ({', '.join(cols)})
VALUES ({placeholders})
ON CONFLICT (id) DO UPDATE SET {updates}
"""
row: dict[str, Any] = {}
for c in cols:
val = event.get(c)
if c in ("obligations", "deadlines", "changed_sections", "affected_docs") and val is not None:
row[c] = json.dumps(val, ensure_ascii=False)
elif c == "tags" and isinstance(val, list):
row[c] = val
else:
row[c] = val
with self._conn() as conn:
try:
with conn.cursor() as cur:
cur.execute(sql, row)
conn.commit()
except Exception:
conn.rollback()
raise
def get_by_standard_code(self, standard_code: str) -> dict | None:
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
"""SELECT * FROM regulation_events
WHERE standard_code = %s
ORDER BY published_at DESC NULLS LAST
LIMIT 1""",
(standard_code,),
)
row = cur.fetchone()
return _row_to_dict(row) if row else None