"""Implement infrastructure support for postgres document repository.""" from __future__ import annotations import json from contextlib import contextmanager from datetime import UTC, datetime from typing import Any import psycopg2 import psycopg2.extras from psycopg2.pool import ThreadedConnectionPool from app.config.settings import settings from app.domain.documents import Document, DocumentRepository, DocumentStatus _CREATE_TABLE = """ CREATE TABLE IF NOT EXISTS documents ( doc_id VARCHAR(128) PRIMARY KEY, doc_name VARCHAR(512) NOT NULL DEFAULT '', file_name VARCHAR(512) NOT NULL DEFAULT '', object_name VARCHAR(1024) NOT NULL DEFAULT '', content_type VARCHAR(128) NOT NULL DEFAULT '', size_bytes BIGINT NOT NULL DEFAULT 0, status VARCHAR(32) NOT NULL DEFAULT 'pending', regulation_type VARCHAR(128) NOT NULL DEFAULT '', version VARCHAR(64) NOT NULL DEFAULT '', summary TEXT NOT NULL DEFAULT '', summary_latency_ms INTEGER NOT NULL DEFAULT 0, chunk_count INTEGER NOT NULL DEFAULT 0, parser_name VARCHAR(128) NOT NULL DEFAULT '', index_name VARCHAR(128) NOT NULL DEFAULT '', error_message TEXT NOT NULL DEFAULT '', metadata JSONB NOT NULL DEFAULT '{}', created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); """ _COLUMNS = ( "doc_id", "doc_name", "file_name", "object_name", "content_type", "size_bytes", "status", "regulation_type", "version", "summary", "summary_latency_ms", "chunk_count", "parser_name", "index_name", "error_message", "metadata", "created_at", "updated_at", ) class PostgresDocumentRepository(DocumentRepository): """DocumentRepository implementation backed by PostgreSQL.""" def __init__(self) -> None: self._pool = ThreadedConnectionPool( minconn=1, maxconn=5, host=settings.postgres_host, port=settings.postgres_port, user=settings.postgres_user, password=settings.postgres_password, dbname=settings.postgres_db, ) self._ensure_schema() def _ensure_schema(self) -> None: with self._conn() as conn: with conn.cursor() as cur: cur.execute(_CREATE_TABLE) conn.commit() @contextmanager def _conn(self): conn = self._pool.getconn() try: yield conn finally: self._pool.putconn(conn) # ------------------------------------------------------------------ # Serialization helpers # ------------------------------------------------------------------ def _row_to_document(self, row: dict[str, Any]) -> Document: return Document( doc_id=row["doc_id"], doc_name=row["doc_name"], file_name=row["file_name"], object_name=row["object_name"], content_type=row["content_type"], size_bytes=row["size_bytes"], status=DocumentStatus(row["status"]), regulation_type=row["regulation_type"], version=row["version"], summary=row["summary"], summary_latency_ms=row["summary_latency_ms"], chunk_count=row["chunk_count"], parser_name=row["parser_name"], index_name=row["index_name"], error_message=row["error_message"], metadata=row["metadata"] if isinstance(row["metadata"], dict) else json.loads(row["metadata"] or "{}"), created_at=row["created_at"], updated_at=row["updated_at"], ) # ------------------------------------------------------------------ # DocumentRepository interface # ------------------------------------------------------------------ def create(self, document: Document) -> Document: sql = """ INSERT INTO documents (doc_id, doc_name, file_name, object_name, content_type, size_bytes, status, regulation_type, version, summary, summary_latency_ms, chunk_count, parser_name, index_name, error_message, metadata, created_at, updated_at) VALUES (%(doc_id)s, %(doc_name)s, %(file_name)s, %(object_name)s, %(content_type)s, %(size_bytes)s, %(status)s, %(regulation_type)s, %(version)s, %(summary)s, %(summary_latency_ms)s, %(chunk_count)s, %(parser_name)s, %(index_name)s, %(error_message)s, %(metadata)s, %(created_at)s, %(updated_at)s) ON CONFLICT (doc_id) DO NOTHING """ with self._conn() as conn: with conn.cursor() as cur: cur.execute(sql, self._to_params(document)) conn.commit() return document def update(self, document: Document) -> Document: document.updated_at = datetime.now(UTC) sql = """ UPDATE documents SET doc_name=%(doc_name)s, file_name=%(file_name)s, object_name=%(object_name)s, content_type=%(content_type)s, size_bytes=%(size_bytes)s, status=%(status)s, regulation_type=%(regulation_type)s, version=%(version)s, summary=%(summary)s, summary_latency_ms=%(summary_latency_ms)s, chunk_count=%(chunk_count)s, parser_name=%(parser_name)s, index_name=%(index_name)s, error_message=%(error_message)s, metadata=%(metadata)s, updated_at=%(updated_at)s WHERE doc_id=%(doc_id)s """ with self._conn() as conn: with conn.cursor() as cur: cur.execute(sql, self._to_params(document)) conn.commit() return document def get(self, doc_id: str) -> Document | None: sql = "SELECT * FROM documents WHERE doc_id = %s" with self._conn() as conn: with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: cur.execute(sql, (doc_id,)) row = cur.fetchone() return self._row_to_document(dict(row)) if row else None def list(self, limit: int | None = None) -> list[Document]: sql = "SELECT * FROM documents ORDER BY updated_at DESC" if limit is not None: sql += f" LIMIT {int(limit)}" with self._conn() as conn: with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: cur.execute(sql) rows = cur.fetchall() return [self._row_to_document(dict(r)) for r in rows] def delete(self, doc_id: str) -> bool: sql = "DELETE FROM documents WHERE doc_id = %s" with self._conn() as conn: with conn.cursor() as cur: cur.execute(sql, (doc_id,)) deleted = cur.rowcount > 0 conn.commit() return deleted def update_status( self, doc_id: str, status: DocumentStatus, *, error_message: str = "", chunk_count: int | None = None, summary: str | None = None, summary_latency_ms: int | None = None, parser_name: str | None = None, index_name: str | None = None, metadata: dict | None = None, ) -> Document | None: document = self.get(doc_id) if not document: return None document.status = status document.error_message = error_message if chunk_count is not None: document.chunk_count = chunk_count if summary is not None: document.summary = summary if summary_latency_ms is not None: document.summary_latency_ms = summary_latency_ms if parser_name is not None: document.parser_name = parser_name if index_name is not None: document.index_name = index_name if metadata: document.metadata.update(metadata) return self.update(document) # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _to_params(self, document: Document) -> dict[str, Any]: return { "doc_id": document.doc_id, "doc_name": document.doc_name, "file_name": document.file_name, "object_name": document.object_name, "content_type": document.content_type, "size_bytes": document.size_bytes, "status": document.status.value, "regulation_type": document.regulation_type, "version": document.version, "summary": document.summary, "summary_latency_ms": document.summary_latency_ms, "chunk_count": document.chunk_count, "parser_name": document.parser_name, "index_name": document.index_name, "error_message": document.error_message, "metadata": json.dumps(document.metadata, ensure_ascii=False), "created_at": document.created_at, "updated_at": document.updated_at, }