229 lines
9.0 KiB
Python
229 lines
9.0 KiB
Python
"""Implement infrastructure support for postgres document repository."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from contextlib import contextmanager
|
|
from datetime import UTC, datetime
|
|
from typing import Any
|
|
|
|
import psycopg2
|
|
import psycopg2.extras
|
|
from psycopg2.pool import ThreadedConnectionPool
|
|
|
|
from app.config.settings import settings
|
|
from app.domain.documents import Document, DocumentRepository, DocumentStatus
|
|
|
|
_CREATE_TABLE = """
|
|
CREATE TABLE IF NOT EXISTS documents (
|
|
doc_id VARCHAR(128) PRIMARY KEY,
|
|
doc_name VARCHAR(512) NOT NULL DEFAULT '',
|
|
file_name VARCHAR(512) NOT NULL DEFAULT '',
|
|
object_name VARCHAR(1024) NOT NULL DEFAULT '',
|
|
content_type VARCHAR(128) NOT NULL DEFAULT '',
|
|
size_bytes BIGINT NOT NULL DEFAULT 0,
|
|
status VARCHAR(32) NOT NULL DEFAULT 'pending',
|
|
regulation_type VARCHAR(128) NOT NULL DEFAULT '',
|
|
version VARCHAR(64) NOT NULL DEFAULT '',
|
|
summary TEXT NOT NULL DEFAULT '',
|
|
summary_latency_ms INTEGER NOT NULL DEFAULT 0,
|
|
chunk_count INTEGER NOT NULL DEFAULT 0,
|
|
parser_name VARCHAR(128) NOT NULL DEFAULT '',
|
|
index_name VARCHAR(128) NOT NULL DEFAULT '',
|
|
error_message TEXT NOT NULL DEFAULT '',
|
|
metadata JSONB NOT NULL DEFAULT '{}',
|
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
|
);
|
|
"""
|
|
|
|
_COLUMNS = (
|
|
"doc_id", "doc_name", "file_name", "object_name", "content_type",
|
|
"size_bytes", "status", "regulation_type", "version", "summary",
|
|
"summary_latency_ms", "chunk_count", "parser_name", "index_name",
|
|
"error_message", "metadata", "created_at", "updated_at",
|
|
)
|
|
|
|
|
|
class PostgresDocumentRepository(DocumentRepository):
|
|
"""DocumentRepository implementation backed by PostgreSQL."""
|
|
|
|
def __init__(self) -> None:
|
|
self._pool = ThreadedConnectionPool(
|
|
minconn=1,
|
|
maxconn=5,
|
|
host=settings.postgres_host,
|
|
port=settings.postgres_port,
|
|
user=settings.postgres_user,
|
|
password=settings.postgres_password,
|
|
dbname=settings.postgres_db,
|
|
)
|
|
self._ensure_schema()
|
|
|
|
def _ensure_schema(self) -> None:
|
|
with self._conn() as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(_CREATE_TABLE)
|
|
conn.commit()
|
|
|
|
@contextmanager
|
|
def _conn(self):
|
|
conn = self._pool.getconn()
|
|
try:
|
|
yield conn
|
|
finally:
|
|
self._pool.putconn(conn)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Serialization helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _row_to_document(self, row: dict[str, Any]) -> Document:
|
|
return Document(
|
|
doc_id=row["doc_id"],
|
|
doc_name=row["doc_name"],
|
|
file_name=row["file_name"],
|
|
object_name=row["object_name"],
|
|
content_type=row["content_type"],
|
|
size_bytes=row["size_bytes"],
|
|
status=DocumentStatus(row["status"]),
|
|
regulation_type=row["regulation_type"],
|
|
version=row["version"],
|
|
summary=row["summary"],
|
|
summary_latency_ms=row["summary_latency_ms"],
|
|
chunk_count=row["chunk_count"],
|
|
parser_name=row["parser_name"],
|
|
index_name=row["index_name"],
|
|
error_message=row["error_message"],
|
|
metadata=row["metadata"] if isinstance(row["metadata"], dict) else json.loads(row["metadata"] or "{}"),
|
|
created_at=row["created_at"],
|
|
updated_at=row["updated_at"],
|
|
)
|
|
|
|
# ------------------------------------------------------------------
|
|
# DocumentRepository interface
|
|
# ------------------------------------------------------------------
|
|
|
|
def create(self, document: Document) -> Document:
|
|
sql = """
|
|
INSERT INTO documents
|
|
(doc_id, doc_name, file_name, object_name, content_type, size_bytes,
|
|
status, regulation_type, version, summary, summary_latency_ms,
|
|
chunk_count, parser_name, index_name, error_message, metadata,
|
|
created_at, updated_at)
|
|
VALUES
|
|
(%(doc_id)s, %(doc_name)s, %(file_name)s, %(object_name)s, %(content_type)s,
|
|
%(size_bytes)s, %(status)s, %(regulation_type)s, %(version)s, %(summary)s,
|
|
%(summary_latency_ms)s, %(chunk_count)s, %(parser_name)s, %(index_name)s,
|
|
%(error_message)s, %(metadata)s, %(created_at)s, %(updated_at)s)
|
|
ON CONFLICT (doc_id) DO NOTHING
|
|
"""
|
|
with self._conn() as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql, self._to_params(document))
|
|
conn.commit()
|
|
return document
|
|
|
|
def update(self, document: Document) -> Document:
|
|
document.updated_at = datetime.now(UTC)
|
|
sql = """
|
|
UPDATE documents SET
|
|
doc_name=%(doc_name)s, file_name=%(file_name)s, object_name=%(object_name)s,
|
|
content_type=%(content_type)s, size_bytes=%(size_bytes)s, status=%(status)s,
|
|
regulation_type=%(regulation_type)s, version=%(version)s, summary=%(summary)s,
|
|
summary_latency_ms=%(summary_latency_ms)s, chunk_count=%(chunk_count)s,
|
|
parser_name=%(parser_name)s, index_name=%(index_name)s,
|
|
error_message=%(error_message)s, metadata=%(metadata)s, updated_at=%(updated_at)s
|
|
WHERE doc_id=%(doc_id)s
|
|
"""
|
|
with self._conn() as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql, self._to_params(document))
|
|
conn.commit()
|
|
return document
|
|
|
|
def get(self, doc_id: str) -> Document | None:
|
|
sql = "SELECT * FROM documents WHERE doc_id = %s"
|
|
with self._conn() as conn:
|
|
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
|
cur.execute(sql, (doc_id,))
|
|
row = cur.fetchone()
|
|
return self._row_to_document(dict(row)) if row else None
|
|
|
|
def list(self, limit: int | None = None) -> list[Document]:
|
|
sql = "SELECT * FROM documents ORDER BY updated_at DESC"
|
|
if limit is not None:
|
|
sql += f" LIMIT {int(limit)}"
|
|
with self._conn() as conn:
|
|
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
|
cur.execute(sql)
|
|
rows = cur.fetchall()
|
|
return [self._row_to_document(dict(r)) for r in rows]
|
|
|
|
def delete(self, doc_id: str) -> bool:
|
|
sql = "DELETE FROM documents WHERE doc_id = %s"
|
|
with self._conn() as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql, (doc_id,))
|
|
deleted = cur.rowcount > 0
|
|
conn.commit()
|
|
return deleted
|
|
|
|
def update_status(
|
|
self,
|
|
doc_id: str,
|
|
status: DocumentStatus,
|
|
*,
|
|
error_message: str = "",
|
|
chunk_count: int | None = None,
|
|
summary: str | None = None,
|
|
summary_latency_ms: int | None = None,
|
|
parser_name: str | None = None,
|
|
index_name: str | None = None,
|
|
metadata: dict | None = None,
|
|
) -> Document | None:
|
|
document = self.get(doc_id)
|
|
if not document:
|
|
return None
|
|
document.status = status
|
|
document.error_message = error_message
|
|
if chunk_count is not None:
|
|
document.chunk_count = chunk_count
|
|
if summary is not None:
|
|
document.summary = summary
|
|
if summary_latency_ms is not None:
|
|
document.summary_latency_ms = summary_latency_ms
|
|
if parser_name is not None:
|
|
document.parser_name = parser_name
|
|
if index_name is not None:
|
|
document.index_name = index_name
|
|
if metadata:
|
|
document.metadata.update(metadata)
|
|
return self.update(document)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Internal helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _to_params(self, document: Document) -> dict[str, Any]:
|
|
return {
|
|
"doc_id": document.doc_id,
|
|
"doc_name": document.doc_name,
|
|
"file_name": document.file_name,
|
|
"object_name": document.object_name,
|
|
"content_type": document.content_type,
|
|
"size_bytes": document.size_bytes,
|
|
"status": document.status.value,
|
|
"regulation_type": document.regulation_type,
|
|
"version": document.version,
|
|
"summary": document.summary,
|
|
"summary_latency_ms": document.summary_latency_ms,
|
|
"chunk_count": document.chunk_count,
|
|
"parser_name": document.parser_name,
|
|
"index_name": document.index_name,
|
|
"error_message": document.error_message,
|
|
"metadata": json.dumps(document.metadata, ensure_ascii=False),
|
|
"created_at": document.created_at,
|
|
"updated_at": document.updated_at,
|
|
}
|