fix 文档管理模块 & 法规对话模块
This commit is contained in:
@@ -289,7 +289,7 @@ def build_vector_chunks(
|
||||
{
|
||||
"doc_id": doc_id,
|
||||
"doc_title": doc_title,
|
||||
"chunk_id": f"chunk-{chunk_index}",
|
||||
"chunk_id": f"{doc_id}-chunk-{chunk_index}",
|
||||
"chunk_index": chunk_index,
|
||||
"semantic_id": block["semantic_id"],
|
||||
"chunk_type": block["block_type"],
|
||||
|
||||
@@ -75,6 +75,15 @@ class JsonDocumentRepository(DocumentRepository):
|
||||
documents.sort(key=lambda item: item.updated_at, reverse=True)
|
||||
return documents[:limit] if limit is not None else documents
|
||||
|
||||
def delete(self, doc_id: str) -> bool:
|
||||
"""Delete a document record."""
|
||||
payload = self._load()
|
||||
if doc_id not in payload:
|
||||
return False
|
||||
del payload[doc_id]
|
||||
self._save(payload)
|
||||
return True
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
doc_id: str,
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
"""Implement infrastructure support for postgres document repository."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
from psycopg2.pool import ThreadedConnectionPool
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.domain.documents import Document, DocumentRepository, DocumentStatus
|
||||
|
||||
_CREATE_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS documents (
|
||||
doc_id VARCHAR(128) PRIMARY KEY,
|
||||
doc_name VARCHAR(512) NOT NULL DEFAULT '',
|
||||
file_name VARCHAR(512) NOT NULL DEFAULT '',
|
||||
object_name VARCHAR(1024) NOT NULL DEFAULT '',
|
||||
content_type VARCHAR(128) NOT NULL DEFAULT '',
|
||||
size_bytes BIGINT NOT NULL DEFAULT 0,
|
||||
status VARCHAR(32) NOT NULL DEFAULT 'pending',
|
||||
regulation_type VARCHAR(128) NOT NULL DEFAULT '',
|
||||
version VARCHAR(64) NOT NULL DEFAULT '',
|
||||
summary TEXT NOT NULL DEFAULT '',
|
||||
summary_latency_ms INTEGER NOT NULL DEFAULT 0,
|
||||
chunk_count INTEGER NOT NULL DEFAULT 0,
|
||||
parser_name VARCHAR(128) NOT NULL DEFAULT '',
|
||||
index_name VARCHAR(128) NOT NULL DEFAULT '',
|
||||
error_message TEXT NOT NULL DEFAULT '',
|
||||
metadata JSONB NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
|
||||
_COLUMNS = (
|
||||
"doc_id", "doc_name", "file_name", "object_name", "content_type",
|
||||
"size_bytes", "status", "regulation_type", "version", "summary",
|
||||
"summary_latency_ms", "chunk_count", "parser_name", "index_name",
|
||||
"error_message", "metadata", "created_at", "updated_at",
|
||||
)
|
||||
|
||||
|
||||
class PostgresDocumentRepository(DocumentRepository):
|
||||
"""DocumentRepository implementation backed by PostgreSQL."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pool = ThreadedConnectionPool(
|
||||
minconn=1,
|
||||
maxconn=5,
|
||||
host=settings.postgres_host,
|
||||
port=settings.postgres_port,
|
||||
user=settings.postgres_user,
|
||||
password=settings.postgres_password,
|
||||
dbname=settings.postgres_db,
|
||||
)
|
||||
self._ensure_schema()
|
||||
|
||||
def _ensure_schema(self) -> None:
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(_CREATE_TABLE)
|
||||
conn.commit()
|
||||
|
||||
@contextmanager
|
||||
def _conn(self):
|
||||
conn = self._pool.getconn()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
self._pool.putconn(conn)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Serialization helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _row_to_document(self, row: dict[str, Any]) -> Document:
|
||||
return Document(
|
||||
doc_id=row["doc_id"],
|
||||
doc_name=row["doc_name"],
|
||||
file_name=row["file_name"],
|
||||
object_name=row["object_name"],
|
||||
content_type=row["content_type"],
|
||||
size_bytes=row["size_bytes"],
|
||||
status=DocumentStatus(row["status"]),
|
||||
regulation_type=row["regulation_type"],
|
||||
version=row["version"],
|
||||
summary=row["summary"],
|
||||
summary_latency_ms=row["summary_latency_ms"],
|
||||
chunk_count=row["chunk_count"],
|
||||
parser_name=row["parser_name"],
|
||||
index_name=row["index_name"],
|
||||
error_message=row["error_message"],
|
||||
metadata=row["metadata"] if isinstance(row["metadata"], dict) else json.loads(row["metadata"] or "{}"),
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# DocumentRepository interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create(self, document: Document) -> Document:
|
||||
sql = """
|
||||
INSERT INTO documents
|
||||
(doc_id, doc_name, file_name, object_name, content_type, size_bytes,
|
||||
status, regulation_type, version, summary, summary_latency_ms,
|
||||
chunk_count, parser_name, index_name, error_message, metadata,
|
||||
created_at, updated_at)
|
||||
VALUES
|
||||
(%(doc_id)s, %(doc_name)s, %(file_name)s, %(object_name)s, %(content_type)s,
|
||||
%(size_bytes)s, %(status)s, %(regulation_type)s, %(version)s, %(summary)s,
|
||||
%(summary_latency_ms)s, %(chunk_count)s, %(parser_name)s, %(index_name)s,
|
||||
%(error_message)s, %(metadata)s, %(created_at)s, %(updated_at)s)
|
||||
ON CONFLICT (doc_id) DO NOTHING
|
||||
"""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, self._to_params(document))
|
||||
conn.commit()
|
||||
return document
|
||||
|
||||
def update(self, document: Document) -> Document:
|
||||
document.updated_at = datetime.now(UTC)
|
||||
sql = """
|
||||
UPDATE documents SET
|
||||
doc_name=%(doc_name)s, file_name=%(file_name)s, object_name=%(object_name)s,
|
||||
content_type=%(content_type)s, size_bytes=%(size_bytes)s, status=%(status)s,
|
||||
regulation_type=%(regulation_type)s, version=%(version)s, summary=%(summary)s,
|
||||
summary_latency_ms=%(summary_latency_ms)s, chunk_count=%(chunk_count)s,
|
||||
parser_name=%(parser_name)s, index_name=%(index_name)s,
|
||||
error_message=%(error_message)s, metadata=%(metadata)s, updated_at=%(updated_at)s
|
||||
WHERE doc_id=%(doc_id)s
|
||||
"""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, self._to_params(document))
|
||||
conn.commit()
|
||||
return document
|
||||
|
||||
def get(self, doc_id: str) -> Document | None:
|
||||
sql = "SELECT * FROM documents WHERE doc_id = %s"
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(sql, (doc_id,))
|
||||
row = cur.fetchone()
|
||||
return self._row_to_document(dict(row)) if row else None
|
||||
|
||||
def list(self, limit: int | None = None) -> list[Document]:
|
||||
sql = "SELECT * FROM documents ORDER BY updated_at DESC"
|
||||
if limit is not None:
|
||||
sql += f" LIMIT {int(limit)}"
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(sql)
|
||||
rows = cur.fetchall()
|
||||
return [self._row_to_document(dict(r)) for r in rows]
|
||||
|
||||
def delete(self, doc_id: str) -> bool:
|
||||
sql = "DELETE FROM documents WHERE doc_id = %s"
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (doc_id,))
|
||||
deleted = cur.rowcount > 0
|
||||
conn.commit()
|
||||
return deleted
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
doc_id: str,
|
||||
status: DocumentStatus,
|
||||
*,
|
||||
error_message: str = "",
|
||||
chunk_count: int | None = None,
|
||||
summary: str | None = None,
|
||||
summary_latency_ms: int | None = None,
|
||||
parser_name: str | None = None,
|
||||
index_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> Document | None:
|
||||
document = self.get(doc_id)
|
||||
if not document:
|
||||
return None
|
||||
document.status = status
|
||||
document.error_message = error_message
|
||||
if chunk_count is not None:
|
||||
document.chunk_count = chunk_count
|
||||
if summary is not None:
|
||||
document.summary = summary
|
||||
if summary_latency_ms is not None:
|
||||
document.summary_latency_ms = summary_latency_ms
|
||||
if parser_name is not None:
|
||||
document.parser_name = parser_name
|
||||
if index_name is not None:
|
||||
document.index_name = index_name
|
||||
if metadata:
|
||||
document.metadata.update(metadata)
|
||||
return self.update(document)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _to_params(self, document: Document) -> dict[str, Any]:
|
||||
return {
|
||||
"doc_id": document.doc_id,
|
||||
"doc_name": document.doc_name,
|
||||
"file_name": document.file_name,
|
||||
"object_name": document.object_name,
|
||||
"content_type": document.content_type,
|
||||
"size_bytes": document.size_bytes,
|
||||
"status": document.status.value,
|
||||
"regulation_type": document.regulation_type,
|
||||
"version": document.version,
|
||||
"summary": document.summary,
|
||||
"summary_latency_ms": document.summary_latency_ms,
|
||||
"chunk_count": document.chunk_count,
|
||||
"parser_name": document.parser_name,
|
||||
"index_name": document.index_name,
|
||||
"error_message": document.error_message,
|
||||
"metadata": json.dumps(document.metadata, ensure_ascii=False),
|
||||
"created_at": document.created_at,
|
||||
"updated_at": document.updated_at,
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
"""Implement infrastructure support for postgres parse artifact store."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
from psycopg2.pool import ThreadedConnectionPool
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.domain.documents import ParseArtifactStore
|
||||
|
||||
_CREATE_STRUCTURE_NODES = """
|
||||
CREATE TABLE IF NOT EXISTS structure_nodes (
|
||||
id SERIAL PRIMARY KEY,
|
||||
doc_id VARCHAR(128) NOT NULL,
|
||||
unique_id VARCHAR(128),
|
||||
page INTEGER NOT NULL DEFAULT 0,
|
||||
idx INTEGER NOT NULL DEFAULT 0,
|
||||
level INTEGER NOT NULL DEFAULT 0,
|
||||
title TEXT NOT NULL DEFAULT '',
|
||||
type VARCHAR(64),
|
||||
sub_type VARCHAR(64),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT fk_sn_doc FOREIGN KEY (doc_id) REFERENCES documents(doc_id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_structure_nodes_doc_id ON structure_nodes(doc_id);
|
||||
"""
|
||||
|
||||
_CREATE_SEMANTIC_BLOCKS = """
|
||||
CREATE TABLE IF NOT EXISTS semantic_blocks (
|
||||
id SERIAL PRIMARY KEY,
|
||||
doc_id VARCHAR(128) NOT NULL,
|
||||
semantic_id VARCHAR(128) NOT NULL,
|
||||
block_type VARCHAR(64) NOT NULL DEFAULT '',
|
||||
page_start INTEGER NOT NULL DEFAULT 0,
|
||||
page_end INTEGER NOT NULL DEFAULT 0,
|
||||
section_path JSONB NOT NULL DEFAULT '[]',
|
||||
section_level INTEGER NOT NULL DEFAULT 0,
|
||||
section_title VARCHAR(512) NOT NULL DEFAULT '',
|
||||
source_ids JSONB NOT NULL DEFAULT '[]',
|
||||
text TEXT NOT NULL DEFAULT '',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT fk_sb_doc FOREIGN KEY (doc_id) REFERENCES documents(doc_id) ON DELETE CASCADE,
|
||||
CONSTRAINT uq_semantic_blocks UNIQUE (doc_id, semantic_id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_semantic_blocks_doc_id ON semantic_blocks(doc_id);
|
||||
"""
|
||||
|
||||
|
||||
class PostgresParseArtifactStore(ParseArtifactStore):
|
||||
"""ParseArtifactStore implementation backed by PostgreSQL.
|
||||
|
||||
Requires the `documents` table to exist first (created by PostgresDocumentRepository).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pool = ThreadedConnectionPool(
|
||||
minconn=1,
|
||||
maxconn=5,
|
||||
host=settings.postgres_host,
|
||||
port=settings.postgres_port,
|
||||
user=settings.postgres_user,
|
||||
password=settings.postgres_password,
|
||||
dbname=settings.postgres_db,
|
||||
)
|
||||
self._ensure_schema()
|
||||
|
||||
def _ensure_schema(self) -> None:
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(_CREATE_STRUCTURE_NODES)
|
||||
cur.execute(_CREATE_SEMANTIC_BLOCKS)
|
||||
conn.commit()
|
||||
|
||||
@contextmanager
|
||||
def _conn(self):
|
||||
conn = self._pool.getconn()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
self._pool.putconn(conn)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# ParseArtifactStore interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def save(
|
||||
self,
|
||||
doc_id: str,
|
||||
structure_nodes: list[dict],
|
||||
semantic_blocks: list[dict],
|
||||
) -> None:
|
||||
"""Persist structure nodes and semantic blocks, replacing any existing records."""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
# Delete existing records first to keep save idempotent.
|
||||
cur.execute("DELETE FROM structure_nodes WHERE doc_id = %s", (doc_id,))
|
||||
cur.execute("DELETE FROM semantic_blocks WHERE doc_id = %s", (doc_id,))
|
||||
|
||||
if structure_nodes:
|
||||
psycopg2.extras.execute_values(
|
||||
cur,
|
||||
"""
|
||||
INSERT INTO structure_nodes
|
||||
(doc_id, unique_id, page, idx, level, title, type, sub_type)
|
||||
VALUES %s
|
||||
""",
|
||||
[
|
||||
(
|
||||
doc_id,
|
||||
node.get("unique_id"),
|
||||
int(node.get("page", 0) or 0),
|
||||
int(node.get("index", 0) or 0),
|
||||
int(node.get("level", 0) or 0),
|
||||
str(node.get("title", "")),
|
||||
node.get("type"),
|
||||
node.get("sub_type"),
|
||||
)
|
||||
for node in structure_nodes
|
||||
],
|
||||
)
|
||||
|
||||
if semantic_blocks:
|
||||
psycopg2.extras.execute_values(
|
||||
cur,
|
||||
"""
|
||||
INSERT INTO semantic_blocks
|
||||
(doc_id, semantic_id, block_type, page_start, page_end,
|
||||
section_path, section_level, section_title, source_ids, text)
|
||||
VALUES %s
|
||||
""",
|
||||
[
|
||||
(
|
||||
doc_id,
|
||||
block.get("semantic_id", ""),
|
||||
block.get("block_type", ""),
|
||||
int(block.get("page_start", 0) or 0),
|
||||
int(block.get("page_end", 0) or 0),
|
||||
json.dumps(block.get("section_path", []), ensure_ascii=False),
|
||||
int(block.get("section_level", 0) or 0),
|
||||
str(block.get("section_title", "")),
|
||||
json.dumps(block.get("source_ids", []), ensure_ascii=False),
|
||||
str(block.get("text", "")),
|
||||
)
|
||||
for block in semantic_blocks
|
||||
],
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def delete(self, doc_id: str) -> None:
|
||||
"""Remove all parse artifacts for a document (ON DELETE CASCADE handles child rows)."""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("DELETE FROM structure_nodes WHERE doc_id = %s", (doc_id,))
|
||||
cur.execute("DELETE FROM semantic_blocks WHERE doc_id = %s", (doc_id,))
|
||||
conn.commit()
|
||||
|
||||
def get_semantic_blocks(self, doc_id: str) -> list[dict[str, Any]]:
|
||||
"""Return all semantic blocks for a document ordered by id."""
|
||||
sql = """
|
||||
SELECT semantic_id, block_type, page_start, page_end,
|
||||
section_path, section_level, section_title, source_ids, text
|
||||
FROM semantic_blocks
|
||||
WHERE doc_id = %s
|
||||
ORDER BY id
|
||||
"""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(sql, (doc_id,))
|
||||
rows = cur.fetchall()
|
||||
results = []
|
||||
for row in rows:
|
||||
item = dict(row)
|
||||
for key in ("section_path", "source_ids"):
|
||||
if isinstance(item[key], str):
|
||||
item[key] = json.loads(item[key])
|
||||
results.append(item)
|
||||
return results
|
||||
|
||||
def get_structure_nodes(self, doc_id: str) -> list[dict[str, Any]]:
|
||||
"""Return all structure nodes for a document ordered by idx."""
|
||||
sql = """
|
||||
SELECT unique_id, page, idx, level, title, type, sub_type
|
||||
FROM structure_nodes
|
||||
WHERE doc_id = %s
|
||||
ORDER BY idx
|
||||
"""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(sql, (doc_id,))
|
||||
rows = cur.fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
@@ -0,0 +1,84 @@
|
||||
"""Implement cross-encoder reranking via an OpenAI-compatible reranker API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
from loguru import logger
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.domain.retrieval import Reranker, RetrievedChunk
|
||||
|
||||
|
||||
class OpenAICompatibleReranker(Reranker):
|
||||
"""Call a TEI / Cohere-style reranker endpoint to re-score retrieved chunks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | None = None,
|
||||
model: str | None = None,
|
||||
api_key: str | None = None,
|
||||
timeout: int = 30,
|
||||
) -> None:
|
||||
self._base_url = (base_url or settings.reranker_base_url).rstrip("/")
|
||||
self._model = model or settings.reranker_model
|
||||
self._api_key = api_key or settings.reranker_api_key
|
||||
self._timeout = timeout
|
||||
|
||||
def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int) -> list[RetrievedChunk]:
|
||||
"""Return up to top_k chunks re-sorted by cross-encoder score."""
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
texts = [chunk.content for chunk in chunks]
|
||||
start = time.time()
|
||||
try:
|
||||
scores = self._call_reranker(query, texts)
|
||||
except Exception as exc:
|
||||
logger.warning("Reranker call failed ({}), falling back to original order: {}", type(exc).__name__, exc)
|
||||
return chunks[:top_k]
|
||||
|
||||
elapsed_ms = int((time.time() - start) * 1000)
|
||||
logger.debug("Reranker scored {} chunks in {}ms", len(chunks), elapsed_ms)
|
||||
|
||||
ranked = sorted(
|
||||
[(score, chunk) for score, chunk in zip(scores, chunks)],
|
||||
key=lambda x: x[0],
|
||||
reverse=True,
|
||||
)
|
||||
result = []
|
||||
for score, chunk in ranked[:top_k]:
|
||||
chunk.score = float(score)
|
||||
result.append(chunk)
|
||||
return result
|
||||
|
||||
def _call_reranker(self, query: str, texts: list[str]) -> list[float]:
|
||||
"""Call the reranker API and return a score per text."""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self._api_key:
|
||||
headers["Authorization"] = f"Bearer {self._api_key}"
|
||||
|
||||
# Try TEI format first: POST /rerank
|
||||
payload = {"query": query, "texts": texts, "raw_scores": False, "return_text": False}
|
||||
url = f"{self._base_url}/rerank"
|
||||
resp = requests.post(url, json=payload, headers=headers, timeout=self._timeout)
|
||||
|
||||
if resp.status_code == 404:
|
||||
# Fall back to Cohere / OpenAI-style: POST /v1/rerank
|
||||
payload_v1 = {"model": self._model, "query": query, "documents": texts}
|
||||
url = f"{self._base_url}/v1/rerank"
|
||||
resp = requests.post(url, json=payload_v1, headers=headers, timeout=self._timeout)
|
||||
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# TEI response: list of {"index": N, "score": F}
|
||||
if isinstance(data, list):
|
||||
ordered = sorted(data, key=lambda x: x["index"])
|
||||
return [float(item["score"]) for item in ordered]
|
||||
|
||||
# Cohere/OpenAI response: {"results": [{"index": N, "relevance_score": F}]}
|
||||
results = data.get("results", [])
|
||||
ordered = sorted(results, key=lambda x: x["index"])
|
||||
return [float(item.get("relevance_score", item.get("score", 0))) for item in ordered]
|
||||
@@ -100,14 +100,42 @@ class MilvusVectorIndex(VectorIndex):
|
||||
result = self.collection.delete(f'doc_id == "{doc_id}"')
|
||||
return len(result.primary_keys)
|
||||
|
||||
def _parse_filters(self, filters: str | None) -> str | None:
|
||||
"""Parse filter string into Milvus expression."""
|
||||
if not filters or not filters.strip():
|
||||
return None
|
||||
|
||||
filters = filters.strip()
|
||||
|
||||
# Check if already a Milvus expression (contains operators)
|
||||
if any(op in filters for op in ["==", "!=", "in", "not in", ">", "<", ">=", "<=", "and", "or"]):
|
||||
return filters
|
||||
|
||||
# Parse simple regulation_type filter
|
||||
# Support: "GB" or "GB,UN-ECE" or "GB, UN-ECE"
|
||||
types = [t.strip() for t in filters.split(",") if t.strip()]
|
||||
|
||||
if not types:
|
||||
return None
|
||||
|
||||
if len(types) == 1:
|
||||
# Single value: regulation_type == "GB"
|
||||
return f'regulation_type == "{types[0]}"'
|
||||
else:
|
||||
# Multiple values: regulation_type in ["GB", "UN-ECE"]
|
||||
quoted_types = [f'"{t}"' for t in types]
|
||||
return f'regulation_type in [{", ".join(quoted_types)}]'
|
||||
|
||||
def search(self, query_vector: list[float], top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
||||
"""Handle search for the Milvus Vector Index instance."""
|
||||
milvus_expr = self._parse_filters(filters)
|
||||
|
||||
results = self.collection.search(
|
||||
data=[query_vector],
|
||||
anns_field="embedding",
|
||||
param={"metric_type": "COSINE", "params": {"nprobe": settings.milvus_nprobe}},
|
||||
limit=top_k,
|
||||
filter=filters,
|
||||
expr=milvus_expr,
|
||||
output_fields=[
|
||||
"doc_id",
|
||||
"doc_name",
|
||||
@@ -145,6 +173,49 @@ class MilvusVectorIndex(VectorIndex):
|
||||
)
|
||||
return payload
|
||||
|
||||
def count_by_document(self) -> dict[str, int]:
|
||||
"""Return doc_id -> chunk count from Milvus."""
|
||||
try:
|
||||
rows = self.collection.query(expr="doc_id != \"\"", output_fields=["doc_id"])
|
||||
except Exception:
|
||||
return {}
|
||||
counts: dict[str, int] = {}
|
||||
for row in rows:
|
||||
doc_id = row.get("doc_id", "")
|
||||
if doc_id:
|
||||
counts[doc_id] = counts.get(doc_id, 0) + 1
|
||||
return counts
|
||||
|
||||
def list_document_metadata(self) -> list[dict]:
|
||||
"""Return one metadata row per document from Milvus (single query, no embeddings)."""
|
||||
try:
|
||||
rows = self.collection.query(
|
||||
expr="doc_id != \"\"",
|
||||
output_fields=["doc_id", "doc_name", "regulation_type", "version"],
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
seen: dict[str, dict] = {}
|
||||
counts: dict[str, int] = {}
|
||||
for row in rows:
|
||||
doc_id = row.get("doc_id", "")
|
||||
if not doc_id:
|
||||
continue
|
||||
counts[doc_id] = counts.get(doc_id, 0) + 1
|
||||
if doc_id not in seen:
|
||||
seen[doc_id] = {
|
||||
"doc_id": doc_id,
|
||||
"doc_name": row.get("doc_name", ""),
|
||||
"regulation_type": row.get("regulation_type", ""),
|
||||
"version": row.get("version", ""),
|
||||
}
|
||||
|
||||
return [
|
||||
{**meta, "chunk_count": counts[meta["doc_id"]]}
|
||||
for meta in seen.values()
|
||||
]
|
||||
|
||||
def health(self) -> dict:
|
||||
"""Handle health for the Milvus Vector Index instance."""
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user