Refactor document handling and update Milvus collection settings

- Removed multiple failed document entries from `documents.json`.
- Added a new document entry with updated metadata and changed the index name to `regulations_dense_1024_v2`.
- Updated architecture documentation to reflect changes in the Milvus collection name.
- Adjusted requirements by removing the sqlalchemy dependency.
- Modified test cases to align with new document structure and naming conventions.
- Introduced a new test file for Milvus vector index runtime recovery and error handling.
- Updated assertions in various test files to ensure compatibility with the new schema.
This commit is contained in:
ash66
2026-05-26 20:21:31 +08:00
parent fec22a3a2c
commit 30c7bda389
42 changed files with 7482 additions and 569 deletions

View File

@@ -56,7 +56,21 @@ class BM25Retriever:
try:
rows = self._vector_index.collection.query(
expr='doc_id != ""',
output_fields=["id", "doc_id", "doc_name", "content", "section_title", "page_number"],
output_fields=[
"id",
"chunk_id",
"doc_id",
"doc_title",
"text",
"chunk_type",
"section_title",
"page_start",
"page_end",
"section_level",
"chunk_index",
"piece_index",
"metadata_json",
],
limit=16384,
)
except Exception:
@@ -64,19 +78,33 @@ class BM25Retriever:
return []
return [
RetrievedChunk(
chunk_id=str(row.get("id", "")),
chunk_id=str(row.get("chunk_id") or row.get("id", "")),
doc_id=str(row.get("doc_id", "")),
doc_name=str(row.get("doc_name", "")),
content=str(row.get("content", "")),
doc_title=str(row.get("doc_title", "")),
text=str(row.get("text", "")),
score=0.0,
chunk_type=str(row.get("chunk_type", "")),
section_title=str(row.get("section_title", "")),
page_number=int(row.get("page_number") or 0),
metadata={},
page_start=int(row.get("page_start") or 0),
page_end=int(row.get("page_end") or 0),
section_level=int(row.get("section_level") or 0),
chunk_index=int(row.get("chunk_index") or 0),
piece_index=int(row.get("piece_index") or 0),
metadata=self._parse_metadata_json(row.get("metadata_json", "")),
)
for row in rows
if row.get("content")
if row.get("text")
]
def _parse_metadata_json(self, raw_metadata: str) -> dict:
"""Parse metadata_json into a dict for BM25-side filtering."""
if not raw_metadata:
return {}
try:
return dict(__import__("json").loads(raw_metadata))
except Exception:
return {}
def _ensure_built(self) -> None:
if self._index is not None:
return
@@ -93,7 +121,7 @@ class BM25Retriever:
self._chunks = []
self._index = BM25Okapi([[]])
return
tokenized = [_tokenize(c.content) for c in chunks]
tokenized = [_tokenize(c.text) for c in chunks]
self._chunks = chunks
self._index = BM25Okapi(tokenized)
logger.info("BM25Retriever: index built with %d chunks", len(chunks))
@@ -127,20 +155,26 @@ class BM25Retriever:
for score, chunk in ranked[: top_k * 2]:
if score <= 0:
break
# Apply simple regulation_type filter if provided
if filters and chunk.metadata.get("regulation_type"):
types = [t.strip() for t in filters.split(",")]
if chunk.metadata.get("regulation_type") not in types:
continue
if filters:
normalized_filter = filters.replace("doc_name", "doc_title").strip()
if normalized_filter.startswith('doc_title == "'):
expected_title = normalized_filter[len('doc_title == "'):-1]
if chunk.doc_title != expected_title:
continue
results.append(
RetrievedChunk(
chunk_id=chunk.chunk_id,
doc_id=chunk.doc_id,
doc_name=chunk.doc_name,
content=chunk.content,
doc_title=chunk.doc_title,
text=chunk.text,
score=score,
chunk_type=chunk.chunk_type,
section_title=chunk.section_title,
page_number=chunk.page_number,
page_start=chunk.page_start,
page_end=chunk.page_end,
section_level=chunk.section_level,
chunk_index=chunk.chunk_index,
piece_index=chunk.piece_index,
metadata=chunk.metadata,
)
)

View File

@@ -31,7 +31,7 @@ class OpenAICompatibleReranker(Reranker):
if not chunks:
return []
texts = [chunk.content for chunk in chunks]
texts = [chunk.text for chunk in chunks]
start = time.time()
try:
scores = self._call_reranker(query, texts)

View File

@@ -4,57 +4,150 @@ from __future__ import annotations
import json
import time
from typing import Iterable
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, utility
from loguru import logger
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusException, connections, utility
from app.config.settings import settings
from app.domain.documents import Chunk
from app.domain.retrieval import RetrievedChunk, VectorIndex
from app.shared.errors import VectorStoreSchemaError
# Keep adapter behavior explicit so integration details remain easy to audit.
_REQUIRED_SCHEMA_FIELDS = (
"doc_id",
"doc_title",
"chunk_id",
"text",
"embedding",
"section_title",
"metadata_json",
)
_SCHEMA_RECOVERY_TOKENS = (
"field doc_title not exist",
"field text not exist",
"field embedding not exist",
"collection not loaded",
"can't find collection",
"not found[collection",
)
class MilvusVectorIndex(VectorIndex):
"""Provide the Milvus Vector Index index implementation."""
def __init__(self) -> None:
"""Initialize the Milvus Vector Index instance."""
self.collection_name = settings.milvus_collection
self.db_name = settings.milvus_db_name
self.host = settings.milvus_host
self.port = settings.milvus_port
# Use an adapter-specific alias so this index never reuses unrelated global Milvus state.
self.alias = f"vector-index::{self.host}:{self.port}/{self.db_name}/{self.collection_name}"
self._connect()
self.collection = self._bind_collection()
def _connect(self, *, refresh: bool = False) -> None:
"""Establish the Milvus connection for this adapter."""
if refresh:
try:
connections.disconnect(self.alias)
except Exception:
# Best-effort disconnect keeps refresh idempotent when no alias is active yet.
pass
connections.connect(
alias="default",
host=settings.milvus_host,
port=settings.milvus_port,
alias=self.alias,
host=self.host,
port=self.port,
db_name=self.db_name,
)
self.collection = self._ensure_collection()
def _schema_field_names(self, collection: Collection) -> list[str]:
"""Return the field names exposed by the bound Milvus collection."""
return [field.name for field in collection.schema.fields]
def _raise_schema_error(self, *, message: str, actual_fields: Iterable[str]) -> None:
"""Raise a typed schema error for the active collection."""
raise VectorStoreSchemaError(
message=message,
host=self.host,
db_name=self.db_name,
collection_name=self.collection_name,
expected_fields=list(_REQUIRED_SCHEMA_FIELDS),
actual_fields=list(actual_fields),
)
def _validate_schema(self, collection: Collection) -> None:
"""Ensure the collection schema matches the dense-only adapter contract."""
actual_fields = self._schema_field_names(collection)
missing_fields = [field_name for field_name in _REQUIRED_SCHEMA_FIELDS if field_name not in actual_fields]
if missing_fields:
self._raise_schema_error(
message=f"Milvus collection schema mismatch; missing required fields: {missing_fields}",
actual_fields=actual_fields,
)
def _log_collection_binding(self, collection: Collection, *, event: str) -> None:
"""Record the bound collection details for runtime diagnostics."""
try:
num_entities = collection.num_entities
except Exception:
num_entities = "unknown"
logger.info(
"Milvus binding {} alias={} host={} db={} collection={} fields={} num_entities={}",
event,
self.alias,
self.host,
self.db_name,
self.collection_name,
self._schema_field_names(collection),
num_entities,
)
def _bind_collection(self, *, force_refresh: bool = False) -> Collection:
"""Bind and validate the configured Milvus collection."""
if force_refresh:
self._connect(refresh=True)
collection = self._ensure_collection()
self._validate_schema(collection)
self._log_collection_binding(collection, event="refreshed" if force_refresh else "initialized")
return collection
def _ensure_collection(self) -> Collection:
"""Handle ensure collection for this module for the Milvus Vector Index instance."""
if utility.has_collection(self.collection_name):
collection = Collection(self.collection_name)
if utility.has_collection(self.collection_name, using=self.alias):
collection = Collection(self.collection_name, using=self.alias)
collection.load()
return collection
schema = CollectionSchema(
fields=[
FieldSchema(name="id", dtype=DataType.VARCHAR, max_length=128, is_primary=True, auto_id=False),
FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="doc_name", dtype=DataType.VARCHAR, max_length=256),
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="doc_title", dtype=DataType.VARCHAR, max_length=256),
FieldSchema(name="chunk_id", dtype=DataType.VARCHAR, max_length=128),
FieldSchema(name="chunk_index", dtype=DataType.INT64),
FieldSchema(name="piece_index", dtype=DataType.INT64),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="embedding_text", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=settings.embedding_dim),
FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512),
FieldSchema(name="section_path", dtype=DataType.VARCHAR, max_length=4096),
FieldSchema(name="page_number", dtype=DataType.INT64),
FieldSchema(name="regulation_type", dtype=DataType.VARCHAR, max_length=128),
FieldSchema(name="version", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="semantic_id", dtype=DataType.VARCHAR, max_length=128),
FieldSchema(name="block_type", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="chunk_type", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="page_start", dtype=DataType.INT64),
FieldSchema(name="page_end", dtype=DataType.INT64),
FieldSchema(name="section_level", dtype=DataType.INT64),
FieldSchema(name="source_ids", dtype=DataType.VARCHAR, max_length=4096),
FieldSchema(name="section_path", dtype=DataType.VARCHAR, max_length=4096),
FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512),
FieldSchema(name="metadata_json", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="created_at", dtype=DataType.INT64),
],
description="Dense-only regulations index",
enable_dynamic_field=False,
)
collection = Collection(name=self.collection_name, schema=schema)
collection = Collection(name=self.collection_name, schema=schema, using=self.alias)
collection.create_index(
field_name="embedding",
index_params={
@@ -73,21 +166,34 @@ class MilvusVectorIndex(VectorIndex):
data = []
now = int(time.time())
for chunk, vector in zip(chunks, vectors):
metadata = dict(chunk.metadata)
doc_title = str(metadata.get("doc_title", chunk.doc_title))
text = str(metadata.get("text", chunk.text))
embedding_text = str(metadata.get("embedding_text", chunk.embedding_text))
page_start = int(metadata.get("page_start", 0) or 0)
page_end = int(metadata.get("page_end", 0) or 0)
section_path = metadata.get("section_path", chunk.section_path)
source_ids = metadata.get("source_ids", [])
data.append(
{
"id": chunk.chunk_id,
"doc_id": chunk.doc_id,
"doc_name": chunk.doc_name,
"content": chunk.content[:65535],
"doc_title": doc_title[:256],
"chunk_id": chunk.chunk_id[:128],
"chunk_index": int(metadata.get("chunk_index", chunk.chunk_index) or 0),
"piece_index": int(metadata.get("piece_index", chunk.piece_index) or 0),
"text": text[:65535],
"embedding_text": embedding_text[:65535],
"embedding": vector,
"section_title": chunk.section_title[:512],
"section_path": json.dumps(chunk.section_path, ensure_ascii=False)[:4096],
"page_number": chunk.page_number,
"regulation_type": chunk.regulation_type[:128],
"version": chunk.version[:64],
"semantic_id": chunk.semantic_id[:128],
"block_type": chunk.block_type[:64],
"metadata_json": json.dumps(chunk.metadata, ensure_ascii=False)[:65535],
"semantic_id": str(metadata.get("semantic_id", chunk.semantic_id))[:128],
"chunk_type": str(metadata.get("chunk_type", chunk.chunk_type))[:64],
"page_start": page_start,
"page_end": page_end,
"section_level": int(metadata.get("section_level", chunk.section_level) or 0),
"source_ids": json.dumps(source_ids, ensure_ascii=False)[:4096],
"section_path": json.dumps(section_path, ensure_ascii=False)[:4096],
"section_title": str(metadata.get("section_title", chunk.section_title))[:512],
"metadata_json": json.dumps(metadata, ensure_ascii=False)[:65535],
"created_at": now,
}
)
@@ -107,47 +213,97 @@ class MilvusVectorIndex(VectorIndex):
filters = filters.strip()
# Normalize legacy field names so callers can keep older filter payloads.
replacements = {
"doc_name": "doc_title",
"content": "text",
"page_number": "page_start",
"block_type": "chunk_type",
}
for legacy_name, new_name in replacements.items():
filters = filters.replace(legacy_name, new_name)
# Check if already a Milvus expression (contains operators)
if any(op in filters for op in ["==", "!=", "in", "not in", ">", "<", ">=", "<=", "and", "or"]):
return filters
# Parse simple regulation_type filter
# Support: "GB" or "GB,UN-ECE" or "GB, UN-ECE"
types = [t.strip() for t in filters.split(",") if t.strip()]
# Parse simple document-title filter.
titles = [title.strip() for title in filters.split(",") if title.strip()]
if not types:
if not titles:
return None
if len(types) == 1:
# Single value: regulation_type == "GB"
return f'regulation_type == "{types[0]}"'
else:
# Multiple values: regulation_type in ["GB", "UN-ECE"]
quoted_types = [f'"{t}"' for t in types]
return f'regulation_type in [{", ".join(quoted_types)}]'
if len(titles) == 1:
return f'doc_title == "{titles[0]}"'
quoted_titles = [f'"{title}"' for title in titles]
return f'doc_title in [{", ".join(quoted_titles)}]'
def _should_refresh_after_exception(self, exc: Exception) -> bool:
"""Return whether the Milvus error suggests stale connection or collection state."""
if not isinstance(exc, MilvusException):
return False
normalized = str(exc).lower()
return any(token in normalized for token in _SCHEMA_RECOVERY_TOKENS)
def _run_with_refresh(self, operation):
"""Run a Milvus operation and retry once after a forced reconnect when appropriate."""
try:
return operation()
except VectorStoreSchemaError:
raise
except Exception as exc:
if not self._should_refresh_after_exception(exc):
raise
logger.warning(
"Milvus operation failed for alias={} collection={}; forcing reconnect and retry: {}",
self.alias,
self.collection_name,
exc,
)
self.collection = self._bind_collection(force_refresh=True)
try:
return operation()
except VectorStoreSchemaError:
raise
except Exception as retry_exc:
if isinstance(retry_exc, MilvusException):
self._raise_schema_error(
message=f"Milvus operation failed after refresh: {retry_exc}",
actual_fields=self._schema_field_names(self.collection),
)
raise
def search(self, query_vector: list[float], top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
"""Handle search for the Milvus Vector Index instance."""
milvus_expr = self._parse_filters(filters)
results = self.collection.search(
data=[query_vector],
anns_field="embedding",
param={"metric_type": "COSINE", "params": {"nprobe": settings.milvus_nprobe}},
limit=top_k,
expr=milvus_expr,
output_fields=[
"doc_id",
"doc_name",
"content",
"section_title",
"page_number",
"regulation_type",
"version",
"semantic_id",
"block_type",
"metadata_json",
],
results = self._run_with_refresh(
lambda: self.collection.search(
data=[query_vector],
anns_field="embedding",
param={"metric_type": "COSINE", "params": {"nprobe": settings.milvus_nprobe}},
limit=top_k,
expr=milvus_expr,
output_fields=[
"doc_id",
"doc_title",
"chunk_id",
"chunk_index",
"piece_index",
"text",
"embedding_text",
"section_title",
"semantic_id",
"chunk_type",
"page_start",
"page_end",
"section_level",
"source_ids",
"section_path",
"metadata_json",
],
)
)
payload: list[RetrievedChunk] = []
for hits in results:
@@ -161,13 +317,18 @@ class MilvusVectorIndex(VectorIndex):
metadata = {"raw_metadata": raw_metadata}
payload.append(
RetrievedChunk(
chunk_id=str(hit.id),
chunk_id=str(hit.entity.get("chunk_id", hit.id)),
doc_id=hit.entity.get("doc_id", ""),
doc_name=hit.entity.get("doc_name", ""),
content=hit.entity.get("content", ""),
doc_title=hit.entity.get("doc_title", ""),
text=hit.entity.get("text", ""),
score=float(hit.score),
chunk_type=hit.entity.get("chunk_type", ""),
section_title=hit.entity.get("section_title", ""),
page_number=int(hit.entity.get("page_number", 0) or 0),
page_start=int(hit.entity.get("page_start", 0) or 0),
page_end=int(hit.entity.get("page_end", 0) or 0),
section_level=int(hit.entity.get("section_level", 0) or 0),
chunk_index=int(hit.entity.get("chunk_index", 0) or 0),
piece_index=int(hit.entity.get("piece_index", 0) or 0),
metadata=metadata,
)
)
@@ -176,7 +337,9 @@ class MilvusVectorIndex(VectorIndex):
def count_by_document(self) -> dict[str, int]:
"""Return doc_id -> chunk count from Milvus."""
try:
rows = self.collection.query(expr="doc_id != \"\"", output_fields=["doc_id"])
rows = self._run_with_refresh(
lambda: self.collection.query(expr="doc_id != \"\"", output_fields=["doc_id", "doc_title"])
)
except Exception:
return {}
counts: dict[str, int] = {}
@@ -189,9 +352,11 @@ class MilvusVectorIndex(VectorIndex):
def list_document_metadata(self) -> list[dict]:
"""Return one metadata row per document from Milvus (single query, no embeddings)."""
try:
rows = self.collection.query(
expr="doc_id != \"\"",
output_fields=["doc_id", "doc_name", "regulation_type", "version"],
rows = self._run_with_refresh(
lambda: self.collection.query(
expr="doc_id != \"\"",
output_fields=["doc_id", "doc_title", "metadata_json"],
)
)
except Exception:
return []
@@ -204,15 +369,26 @@ class MilvusVectorIndex(VectorIndex):
continue
counts[doc_id] = counts.get(doc_id, 0) + 1
if doc_id not in seen:
metadata: dict[str, object] = {}
raw_metadata = row.get("metadata_json", "")
if raw_metadata:
try:
metadata = json.loads(raw_metadata)
except json.JSONDecodeError:
metadata = {}
seen[doc_id] = {
"doc_id": doc_id,
"doc_name": row.get("doc_name", ""),
"regulation_type": row.get("regulation_type", ""),
"version": row.get("version", ""),
"doc_title": row.get("doc_title", ""),
"regulation_type": str(metadata.get("regulation_type", "")),
"version": str(metadata.get("version", "")),
}
return [
{**meta, "chunk_count": counts[meta["doc_id"]]}
{
**meta,
"doc_name": meta.get("doc_title", ""),
"chunk_count": counts[meta["doc_id"]],
}
for meta in seen.values()
]