fix 文档管理模块 & 法规对话模块

This commit is contained in:
2026-05-20 23:34:08 +08:00
parent c22b03dc07
commit b065d55c86
39 changed files with 1671 additions and 540 deletions

View File

@@ -289,7 +289,7 @@ def build_vector_chunks(
{
"doc_id": doc_id,
"doc_title": doc_title,
"chunk_id": f"chunk-{chunk_index}",
"chunk_id": f"{doc_id}-chunk-{chunk_index}",
"chunk_index": chunk_index,
"semantic_id": block["semantic_id"],
"chunk_type": block["block_type"],

View File

@@ -75,6 +75,15 @@ class JsonDocumentRepository(DocumentRepository):
documents.sort(key=lambda item: item.updated_at, reverse=True)
return documents[:limit] if limit is not None else documents
def delete(self, doc_id: str) -> bool:
"""Delete a document record."""
payload = self._load()
if doc_id not in payload:
return False
del payload[doc_id]
self._save(payload)
return True
def update_status(
self,
doc_id: str,

View File

@@ -0,0 +1,228 @@
"""Implement infrastructure support for postgres document repository."""
from __future__ import annotations
import json
from contextlib import contextmanager
from datetime import UTC, datetime
from typing import Any
import psycopg2
import psycopg2.extras
from psycopg2.pool import ThreadedConnectionPool
from app.config.settings import settings
from app.domain.documents import Document, DocumentRepository, DocumentStatus
_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS documents (
doc_id VARCHAR(128) PRIMARY KEY,
doc_name VARCHAR(512) NOT NULL DEFAULT '',
file_name VARCHAR(512) NOT NULL DEFAULT '',
object_name VARCHAR(1024) NOT NULL DEFAULT '',
content_type VARCHAR(128) NOT NULL DEFAULT '',
size_bytes BIGINT NOT NULL DEFAULT 0,
status VARCHAR(32) NOT NULL DEFAULT 'pending',
regulation_type VARCHAR(128) NOT NULL DEFAULT '',
version VARCHAR(64) NOT NULL DEFAULT '',
summary TEXT NOT NULL DEFAULT '',
summary_latency_ms INTEGER NOT NULL DEFAULT 0,
chunk_count INTEGER NOT NULL DEFAULT 0,
parser_name VARCHAR(128) NOT NULL DEFAULT '',
index_name VARCHAR(128) NOT NULL DEFAULT '',
error_message TEXT NOT NULL DEFAULT '',
metadata JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_COLUMNS = (
"doc_id", "doc_name", "file_name", "object_name", "content_type",
"size_bytes", "status", "regulation_type", "version", "summary",
"summary_latency_ms", "chunk_count", "parser_name", "index_name",
"error_message", "metadata", "created_at", "updated_at",
)
class PostgresDocumentRepository(DocumentRepository):
"""DocumentRepository implementation backed by PostgreSQL."""
def __init__(self) -> None:
self._pool = ThreadedConnectionPool(
minconn=1,
maxconn=5,
host=settings.postgres_host,
port=settings.postgres_port,
user=settings.postgres_user,
password=settings.postgres_password,
dbname=settings.postgres_db,
)
self._ensure_schema()
def _ensure_schema(self) -> None:
with self._conn() as conn:
with conn.cursor() as cur:
cur.execute(_CREATE_TABLE)
conn.commit()
@contextmanager
def _conn(self):
conn = self._pool.getconn()
try:
yield conn
finally:
self._pool.putconn(conn)
# ------------------------------------------------------------------
# Serialization helpers
# ------------------------------------------------------------------
def _row_to_document(self, row: dict[str, Any]) -> Document:
return Document(
doc_id=row["doc_id"],
doc_name=row["doc_name"],
file_name=row["file_name"],
object_name=row["object_name"],
content_type=row["content_type"],
size_bytes=row["size_bytes"],
status=DocumentStatus(row["status"]),
regulation_type=row["regulation_type"],
version=row["version"],
summary=row["summary"],
summary_latency_ms=row["summary_latency_ms"],
chunk_count=row["chunk_count"],
parser_name=row["parser_name"],
index_name=row["index_name"],
error_message=row["error_message"],
metadata=row["metadata"] if isinstance(row["metadata"], dict) else json.loads(row["metadata"] or "{}"),
created_at=row["created_at"],
updated_at=row["updated_at"],
)
# ------------------------------------------------------------------
# DocumentRepository interface
# ------------------------------------------------------------------
def create(self, document: Document) -> Document:
sql = """
INSERT INTO documents
(doc_id, doc_name, file_name, object_name, content_type, size_bytes,
status, regulation_type, version, summary, summary_latency_ms,
chunk_count, parser_name, index_name, error_message, metadata,
created_at, updated_at)
VALUES
(%(doc_id)s, %(doc_name)s, %(file_name)s, %(object_name)s, %(content_type)s,
%(size_bytes)s, %(status)s, %(regulation_type)s, %(version)s, %(summary)s,
%(summary_latency_ms)s, %(chunk_count)s, %(parser_name)s, %(index_name)s,
%(error_message)s, %(metadata)s, %(created_at)s, %(updated_at)s)
ON CONFLICT (doc_id) DO NOTHING
"""
with self._conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, self._to_params(document))
conn.commit()
return document
def update(self, document: Document) -> Document:
document.updated_at = datetime.now(UTC)
sql = """
UPDATE documents SET
doc_name=%(doc_name)s, file_name=%(file_name)s, object_name=%(object_name)s,
content_type=%(content_type)s, size_bytes=%(size_bytes)s, status=%(status)s,
regulation_type=%(regulation_type)s, version=%(version)s, summary=%(summary)s,
summary_latency_ms=%(summary_latency_ms)s, chunk_count=%(chunk_count)s,
parser_name=%(parser_name)s, index_name=%(index_name)s,
error_message=%(error_message)s, metadata=%(metadata)s, updated_at=%(updated_at)s
WHERE doc_id=%(doc_id)s
"""
with self._conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, self._to_params(document))
conn.commit()
return document
def get(self, doc_id: str) -> Document | None:
sql = "SELECT * FROM documents WHERE doc_id = %s"
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql, (doc_id,))
row = cur.fetchone()
return self._row_to_document(dict(row)) if row else None
def list(self, limit: int | None = None) -> list[Document]:
sql = "SELECT * FROM documents ORDER BY updated_at DESC"
if limit is not None:
sql += f" LIMIT {int(limit)}"
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql)
rows = cur.fetchall()
return [self._row_to_document(dict(r)) for r in rows]
def delete(self, doc_id: str) -> bool:
sql = "DELETE FROM documents WHERE doc_id = %s"
with self._conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, (doc_id,))
deleted = cur.rowcount > 0
conn.commit()
return deleted
def update_status(
self,
doc_id: str,
status: DocumentStatus,
*,
error_message: str = "",
chunk_count: int | None = None,
summary: str | None = None,
summary_latency_ms: int | None = None,
parser_name: str | None = None,
index_name: str | None = None,
metadata: dict | None = None,
) -> Document | None:
document = self.get(doc_id)
if not document:
return None
document.status = status
document.error_message = error_message
if chunk_count is not None:
document.chunk_count = chunk_count
if summary is not None:
document.summary = summary
if summary_latency_ms is not None:
document.summary_latency_ms = summary_latency_ms
if parser_name is not None:
document.parser_name = parser_name
if index_name is not None:
document.index_name = index_name
if metadata:
document.metadata.update(metadata)
return self.update(document)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _to_params(self, document: Document) -> dict[str, Any]:
return {
"doc_id": document.doc_id,
"doc_name": document.doc_name,
"file_name": document.file_name,
"object_name": document.object_name,
"content_type": document.content_type,
"size_bytes": document.size_bytes,
"status": document.status.value,
"regulation_type": document.regulation_type,
"version": document.version,
"summary": document.summary,
"summary_latency_ms": document.summary_latency_ms,
"chunk_count": document.chunk_count,
"parser_name": document.parser_name,
"index_name": document.index_name,
"error_message": document.error_message,
"metadata": json.dumps(document.metadata, ensure_ascii=False),
"created_at": document.created_at,
"updated_at": document.updated_at,
}

View File

@@ -0,0 +1,196 @@
"""Implement infrastructure support for postgres parse artifact store."""
from __future__ import annotations
import json
from contextlib import contextmanager
from typing import Any
import psycopg2
import psycopg2.extras
from psycopg2.pool import ThreadedConnectionPool
from app.config.settings import settings
from app.domain.documents import ParseArtifactStore
_CREATE_STRUCTURE_NODES = """
CREATE TABLE IF NOT EXISTS structure_nodes (
id SERIAL PRIMARY KEY,
doc_id VARCHAR(128) NOT NULL,
unique_id VARCHAR(128),
page INTEGER NOT NULL DEFAULT 0,
idx INTEGER NOT NULL DEFAULT 0,
level INTEGER NOT NULL DEFAULT 0,
title TEXT NOT NULL DEFAULT '',
type VARCHAR(64),
sub_type VARCHAR(64),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT fk_sn_doc FOREIGN KEY (doc_id) REFERENCES documents(doc_id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_structure_nodes_doc_id ON structure_nodes(doc_id);
"""
_CREATE_SEMANTIC_BLOCKS = """
CREATE TABLE IF NOT EXISTS semantic_blocks (
id SERIAL PRIMARY KEY,
doc_id VARCHAR(128) NOT NULL,
semantic_id VARCHAR(128) NOT NULL,
block_type VARCHAR(64) NOT NULL DEFAULT '',
page_start INTEGER NOT NULL DEFAULT 0,
page_end INTEGER NOT NULL DEFAULT 0,
section_path JSONB NOT NULL DEFAULT '[]',
section_level INTEGER NOT NULL DEFAULT 0,
section_title VARCHAR(512) NOT NULL DEFAULT '',
source_ids JSONB NOT NULL DEFAULT '[]',
text TEXT NOT NULL DEFAULT '',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT fk_sb_doc FOREIGN KEY (doc_id) REFERENCES documents(doc_id) ON DELETE CASCADE,
CONSTRAINT uq_semantic_blocks UNIQUE (doc_id, semantic_id)
);
CREATE INDEX IF NOT EXISTS idx_semantic_blocks_doc_id ON semantic_blocks(doc_id);
"""
class PostgresParseArtifactStore(ParseArtifactStore):
"""ParseArtifactStore implementation backed by PostgreSQL.
Requires the `documents` table to exist first (created by PostgresDocumentRepository).
"""
def __init__(self) -> None:
self._pool = ThreadedConnectionPool(
minconn=1,
maxconn=5,
host=settings.postgres_host,
port=settings.postgres_port,
user=settings.postgres_user,
password=settings.postgres_password,
dbname=settings.postgres_db,
)
self._ensure_schema()
def _ensure_schema(self) -> None:
with self._conn() as conn:
with conn.cursor() as cur:
cur.execute(_CREATE_STRUCTURE_NODES)
cur.execute(_CREATE_SEMANTIC_BLOCKS)
conn.commit()
@contextmanager
def _conn(self):
conn = self._pool.getconn()
try:
yield conn
finally:
self._pool.putconn(conn)
# ------------------------------------------------------------------
# ParseArtifactStore interface
# ------------------------------------------------------------------
def save(
self,
doc_id: str,
structure_nodes: list[dict],
semantic_blocks: list[dict],
) -> None:
"""Persist structure nodes and semantic blocks, replacing any existing records."""
with self._conn() as conn:
with conn.cursor() as cur:
# Delete existing records first to keep save idempotent.
cur.execute("DELETE FROM structure_nodes WHERE doc_id = %s", (doc_id,))
cur.execute("DELETE FROM semantic_blocks WHERE doc_id = %s", (doc_id,))
if structure_nodes:
psycopg2.extras.execute_values(
cur,
"""
INSERT INTO structure_nodes
(doc_id, unique_id, page, idx, level, title, type, sub_type)
VALUES %s
""",
[
(
doc_id,
node.get("unique_id"),
int(node.get("page", 0) or 0),
int(node.get("index", 0) or 0),
int(node.get("level", 0) or 0),
str(node.get("title", "")),
node.get("type"),
node.get("sub_type"),
)
for node in structure_nodes
],
)
if semantic_blocks:
psycopg2.extras.execute_values(
cur,
"""
INSERT INTO semantic_blocks
(doc_id, semantic_id, block_type, page_start, page_end,
section_path, section_level, section_title, source_ids, text)
VALUES %s
""",
[
(
doc_id,
block.get("semantic_id", ""),
block.get("block_type", ""),
int(block.get("page_start", 0) or 0),
int(block.get("page_end", 0) or 0),
json.dumps(block.get("section_path", []), ensure_ascii=False),
int(block.get("section_level", 0) or 0),
str(block.get("section_title", "")),
json.dumps(block.get("source_ids", []), ensure_ascii=False),
str(block.get("text", "")),
)
for block in semantic_blocks
],
)
conn.commit()
def delete(self, doc_id: str) -> None:
"""Remove all parse artifacts for a document (ON DELETE CASCADE handles child rows)."""
with self._conn() as conn:
with conn.cursor() as cur:
cur.execute("DELETE FROM structure_nodes WHERE doc_id = %s", (doc_id,))
cur.execute("DELETE FROM semantic_blocks WHERE doc_id = %s", (doc_id,))
conn.commit()
def get_semantic_blocks(self, doc_id: str) -> list[dict[str, Any]]:
"""Return all semantic blocks for a document ordered by id."""
sql = """
SELECT semantic_id, block_type, page_start, page_end,
section_path, section_level, section_title, source_ids, text
FROM semantic_blocks
WHERE doc_id = %s
ORDER BY id
"""
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql, (doc_id,))
rows = cur.fetchall()
results = []
for row in rows:
item = dict(row)
for key in ("section_path", "source_ids"):
if isinstance(item[key], str):
item[key] = json.loads(item[key])
results.append(item)
return results
def get_structure_nodes(self, doc_id: str) -> list[dict[str, Any]]:
"""Return all structure nodes for a document ordered by idx."""
sql = """
SELECT unique_id, page, idx, level, title, type, sub_type
FROM structure_nodes
WHERE doc_id = %s
ORDER BY idx
"""
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql, (doc_id,))
rows = cur.fetchall()
return [dict(r) for r in rows]

View File

@@ -0,0 +1,84 @@
"""Implement cross-encoder reranking via an OpenAI-compatible reranker API."""
from __future__ import annotations
import time
import requests
from loguru import logger
from app.config.settings import settings
from app.domain.retrieval import Reranker, RetrievedChunk
class OpenAICompatibleReranker(Reranker):
"""Call a TEI / Cohere-style reranker endpoint to re-score retrieved chunks."""
def __init__(
self,
base_url: str | None = None,
model: str | None = None,
api_key: str | None = None,
timeout: int = 30,
) -> None:
self._base_url = (base_url or settings.reranker_base_url).rstrip("/")
self._model = model or settings.reranker_model
self._api_key = api_key or settings.reranker_api_key
self._timeout = timeout
def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int) -> list[RetrievedChunk]:
"""Return up to top_k chunks re-sorted by cross-encoder score."""
if not chunks:
return []
texts = [chunk.content for chunk in chunks]
start = time.time()
try:
scores = self._call_reranker(query, texts)
except Exception as exc:
logger.warning("Reranker call failed ({}), falling back to original order: {}", type(exc).__name__, exc)
return chunks[:top_k]
elapsed_ms = int((time.time() - start) * 1000)
logger.debug("Reranker scored {} chunks in {}ms", len(chunks), elapsed_ms)
ranked = sorted(
[(score, chunk) for score, chunk in zip(scores, chunks)],
key=lambda x: x[0],
reverse=True,
)
result = []
for score, chunk in ranked[:top_k]:
chunk.score = float(score)
result.append(chunk)
return result
def _call_reranker(self, query: str, texts: list[str]) -> list[float]:
"""Call the reranker API and return a score per text."""
headers = {"Content-Type": "application/json"}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
# Try TEI format first: POST /rerank
payload = {"query": query, "texts": texts, "raw_scores": False, "return_text": False}
url = f"{self._base_url}/rerank"
resp = requests.post(url, json=payload, headers=headers, timeout=self._timeout)
if resp.status_code == 404:
# Fall back to Cohere / OpenAI-style: POST /v1/rerank
payload_v1 = {"model": self._model, "query": query, "documents": texts}
url = f"{self._base_url}/v1/rerank"
resp = requests.post(url, json=payload_v1, headers=headers, timeout=self._timeout)
resp.raise_for_status()
data = resp.json()
# TEI response: list of {"index": N, "score": F}
if isinstance(data, list):
ordered = sorted(data, key=lambda x: x["index"])
return [float(item["score"]) for item in ordered]
# Cohere/OpenAI response: {"results": [{"index": N, "relevance_score": F}]}
results = data.get("results", [])
ordered = sorted(results, key=lambda x: x["index"])
return [float(item.get("relevance_score", item.get("score", 0))) for item in ordered]

View File

@@ -100,14 +100,42 @@ class MilvusVectorIndex(VectorIndex):
result = self.collection.delete(f'doc_id == "{doc_id}"')
return len(result.primary_keys)
def _parse_filters(self, filters: str | None) -> str | None:
"""Parse filter string into Milvus expression."""
if not filters or not filters.strip():
return None
filters = filters.strip()
# Check if already a Milvus expression (contains operators)
if any(op in filters for op in ["==", "!=", "in", "not in", ">", "<", ">=", "<=", "and", "or"]):
return filters
# Parse simple regulation_type filter
# Support: "GB" or "GB,UN-ECE" or "GB, UN-ECE"
types = [t.strip() for t in filters.split(",") if t.strip()]
if not types:
return None
if len(types) == 1:
# Single value: regulation_type == "GB"
return f'regulation_type == "{types[0]}"'
else:
# Multiple values: regulation_type in ["GB", "UN-ECE"]
quoted_types = [f'"{t}"' for t in types]
return f'regulation_type in [{", ".join(quoted_types)}]'
def search(self, query_vector: list[float], top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
"""Handle search for the Milvus Vector Index instance."""
milvus_expr = self._parse_filters(filters)
results = self.collection.search(
data=[query_vector],
anns_field="embedding",
param={"metric_type": "COSINE", "params": {"nprobe": settings.milvus_nprobe}},
limit=top_k,
filter=filters,
expr=milvus_expr,
output_fields=[
"doc_id",
"doc_name",
@@ -145,6 +173,49 @@ class MilvusVectorIndex(VectorIndex):
)
return payload
def count_by_document(self) -> dict[str, int]:
"""Return doc_id -> chunk count from Milvus."""
try:
rows = self.collection.query(expr="doc_id != \"\"", output_fields=["doc_id"])
except Exception:
return {}
counts: dict[str, int] = {}
for row in rows:
doc_id = row.get("doc_id", "")
if doc_id:
counts[doc_id] = counts.get(doc_id, 0) + 1
return counts
def list_document_metadata(self) -> list[dict]:
"""Return one metadata row per document from Milvus (single query, no embeddings)."""
try:
rows = self.collection.query(
expr="doc_id != \"\"",
output_fields=["doc_id", "doc_name", "regulation_type", "version"],
)
except Exception:
return []
seen: dict[str, dict] = {}
counts: dict[str, int] = {}
for row in rows:
doc_id = row.get("doc_id", "")
if not doc_id:
continue
counts[doc_id] = counts.get(doc_id, 0) + 1
if doc_id not in seen:
seen[doc_id] = {
"doc_id": doc_id,
"doc_name": row.get("doc_name", ""),
"regulation_type": row.get("regulation_type", ""),
"version": row.get("version", ""),
}
return [
{**meta, "chunk_count": counts[meta["doc_id"]]}
for meta in seen.values()
]
def health(self) -> dict:
"""Handle health for the Milvus Vector Index instance."""
return {