update for 1. 优化 2.中英切换
This commit is contained in:
280
backend/app/infrastructure/compliance/repository.py
Normal file
280
backend/app/infrastructure/compliance/repository.py
Normal file
@@ -0,0 +1,280 @@
|
||||
# backend/app/infrastructure/compliance/repository.py
|
||||
"""PostgreSQL-backed compliance analysis repository.
|
||||
|
||||
Follows the same psycopg2 pattern as PostgresDocumentRepository:
|
||||
ThreadedConnectionPool + RealDictCursor for reads, _ensure_schema on init.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from loguru import logger
|
||||
|
||||
from app.domain.compliance.ports import (
|
||||
AnalysisRecord,
|
||||
ComplianceRepository,
|
||||
FindingRecord,
|
||||
)
|
||||
|
||||
|
||||
class PostgresComplianceRepository(ComplianceRepository):
|
||||
"""Stores compliance analyses, findings, and finding chat messages in PostgreSQL."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
user: str,
|
||||
password: str,
|
||||
dbname: str,
|
||||
minconn: int = 1,
|
||||
maxconn: int = 5,
|
||||
) -> None:
|
||||
self._pool = psycopg2.pool.ThreadedConnectionPool(
|
||||
minconn=minconn,
|
||||
maxconn=maxconn,
|
||||
host=host,
|
||||
port=port,
|
||||
user=user,
|
||||
password=password,
|
||||
dbname=dbname,
|
||||
)
|
||||
self._ensure_schema()
|
||||
|
||||
@contextmanager
|
||||
def _conn(self):
|
||||
conn = self._pool.getconn()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
self._pool.putconn(conn)
|
||||
|
||||
def _ensure_schema(self) -> None:
|
||||
"""Create tables if they do not exist."""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS compliance_analyses (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
created_by VARCHAR(255),
|
||||
doc_name VARCHAR(500),
|
||||
standard_name VARCHAR(500),
|
||||
risk_score INTEGER,
|
||||
conclusion TEXT,
|
||||
actions JSONB,
|
||||
para_text TEXT,
|
||||
highlight_terms JSONB
|
||||
);
|
||||
""")
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS compliance_findings (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
analysis_id UUID NOT NULL REFERENCES compliance_analyses(id) ON DELETE CASCADE,
|
||||
seq INTEGER NOT NULL,
|
||||
title VARCHAR(500),
|
||||
description TEXT,
|
||||
status VARCHAR(50),
|
||||
clause_ref VARCHAR(200)
|
||||
);
|
||||
""")
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS finding_chat_messages (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
analysis_id UUID NOT NULL REFERENCES compliance_analyses(id) ON DELETE CASCADE,
|
||||
finding_id UUID NOT NULL REFERENCES compliance_findings(id) ON DELETE CASCADE,
|
||||
role VARCHAR(20) NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
def save_analysis(self, record: AnalysisRecord) -> str:
|
||||
"""Insert analysis + findings; return the new analysis UUID."""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO compliance_analyses
|
||||
(created_by, doc_name, standard_name, risk_score,
|
||||
conclusion, actions, para_text, highlight_terms)
|
||||
VALUES
|
||||
(%(created_by)s, %(doc_name)s, %(standard_name)s, %(risk_score)s,
|
||||
%(conclusion)s, %(actions)s, %(para_text)s, %(highlight_terms)s)
|
||||
RETURNING id
|
||||
""",
|
||||
{
|
||||
"created_by": record.created_by,
|
||||
"doc_name": record.doc_name,
|
||||
"standard_name": record.standard_name,
|
||||
"risk_score": record.risk_score,
|
||||
"conclusion": record.conclusion,
|
||||
"actions": json.dumps(record.actions, ensure_ascii=False),
|
||||
"para_text": record.para_text,
|
||||
"highlight_terms": json.dumps(record.highlight_terms, ensure_ascii=False),
|
||||
},
|
||||
)
|
||||
row = cur.fetchone()
|
||||
analysis_id = str(row["id"])
|
||||
|
||||
if record.findings:
|
||||
with conn.cursor() as cur:
|
||||
for f in record.findings:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO compliance_findings
|
||||
(analysis_id, seq, title, description, status, clause_ref)
|
||||
VALUES
|
||||
(%(analysis_id)s, %(seq)s, %(title)s, %(desc)s, %(status)s, %(clause_ref)s)
|
||||
""",
|
||||
{
|
||||
"analysis_id": analysis_id,
|
||||
"seq": f.seq,
|
||||
"title": f.title,
|
||||
"desc": f.description,
|
||||
"status": f.status,
|
||||
"clause_ref": f.clause_ref,
|
||||
},
|
||||
)
|
||||
conn.commit()
|
||||
return analysis_id
|
||||
|
||||
def list_analyses(self, limit: int = 50, offset: int = 0) -> list[AnalysisRecord]:
|
||||
"""Return analyses without nested findings, ordered newest first."""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, created_at, created_by, doc_name, standard_name,
|
||||
risk_score, conclusion, actions, para_text, highlight_terms
|
||||
FROM compliance_analyses
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %(limit)s OFFSET %(offset)s
|
||||
""",
|
||||
{"limit": limit, "offset": offset},
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [self._row_to_record(dict(r)) for r in rows]
|
||||
|
||||
def get_analysis(self, analysis_id: str) -> Optional[AnalysisRecord]:
|
||||
"""Return analysis with nested findings list."""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(
|
||||
"SELECT * FROM compliance_analyses WHERE id = %(id)s",
|
||||
{"id": analysis_id},
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
record = self._row_to_record(dict(row))
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, analysis_id, seq, title, description, status, clause_ref
|
||||
FROM compliance_findings
|
||||
WHERE analysis_id = %(id)s
|
||||
ORDER BY seq
|
||||
""",
|
||||
{"id": analysis_id},
|
||||
)
|
||||
findings = [
|
||||
FindingRecord(
|
||||
id=str(r["id"]),
|
||||
analysis_id=str(r["analysis_id"]),
|
||||
seq=r["seq"],
|
||||
title=r["title"] or "",
|
||||
description=r["description"] or "",
|
||||
status=r["status"] or "ok",
|
||||
clause_ref=r["clause_ref"],
|
||||
)
|
||||
for r in cur.fetchall()
|
||||
]
|
||||
record.findings = findings
|
||||
return record
|
||||
|
||||
def delete_analysis(self, analysis_id: str) -> None:
|
||||
"""Delete analysis; findings and chat messages cascade automatically."""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"DELETE FROM compliance_analyses WHERE id = %(id)s",
|
||||
{"id": analysis_id},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def save_message(self, analysis_id: str, finding_id: str, role: str, content: str) -> str:
|
||||
"""Persist a chat message; return its UUID."""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO finding_chat_messages
|
||||
(analysis_id, finding_id, role, content)
|
||||
VALUES
|
||||
(%(analysis_id)s, %(finding_id)s, %(role)s, %(content)s)
|
||||
RETURNING id
|
||||
""",
|
||||
{
|
||||
"analysis_id": analysis_id,
|
||||
"finding_id": finding_id,
|
||||
"role": role,
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
row = cur.fetchone()
|
||||
conn.commit()
|
||||
return str(row["id"])
|
||||
|
||||
def get_messages(self, finding_id: str) -> list[dict]:
|
||||
"""Return messages for a finding, oldest first."""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, role, content, created_at
|
||||
FROM finding_chat_messages
|
||||
WHERE finding_id = %(finding_id)s
|
||||
ORDER BY created_at ASC
|
||||
""",
|
||||
{"finding_id": finding_id},
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [
|
||||
{
|
||||
"id": str(r["id"]),
|
||||
"role": r["role"],
|
||||
"content": r["content"],
|
||||
"created_at": r["created_at"].isoformat() if r["created_at"] else "",
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
def _row_to_record(self, row: dict) -> AnalysisRecord:
|
||||
"""Convert a RealDictCursor row to an AnalysisRecord (no findings)."""
|
||||
actions = row.get("actions") or []
|
||||
if isinstance(actions, str):
|
||||
actions = json.loads(actions)
|
||||
highlight_terms = row.get("highlight_terms") or []
|
||||
if isinstance(highlight_terms, str):
|
||||
highlight_terms = json.loads(highlight_terms)
|
||||
return AnalysisRecord(
|
||||
id=str(row["id"]),
|
||||
created_at=row["created_at"] if isinstance(row["created_at"], datetime) else datetime.utcnow(),
|
||||
created_by=row.get("created_by"),
|
||||
doc_name=row.get("doc_name") or "",
|
||||
standard_name=row.get("standard_name") or "",
|
||||
risk_score=int(row.get("risk_score") or 0),
|
||||
conclusion=row.get("conclusion") or "",
|
||||
actions=actions,
|
||||
para_text=row.get("para_text") or "",
|
||||
highlight_terms=highlight_terms,
|
||||
findings=[],
|
||||
)
|
||||
Reference in New Issue
Block a user