Refactor document handling and update Milvus collection settings
- Removed multiple failed document entries from `documents.json`. - Added a new document entry with updated metadata and changed the index name to `regulations_dense_1024_v2`. - Updated architecture documentation to reflect changes in the Milvus collection name. - Adjusted requirements by removing the sqlalchemy dependency. - Modified test cases to align with new document structure and naming conventions. - Introduced a new test file for Milvus vector index runtime recovery and error handling. - Updated assertions in various test files to ensure compatibility with the new schema.
This commit is contained in:
@@ -45,10 +45,10 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
|
||||
context_tokens = 0
|
||||
for idx, chunk in enumerate(retrieved_chunks, start=1):
|
||||
block = (
|
||||
f"[{idx}] 文档: {chunk.doc_name}\n"
|
||||
f"[{idx}] 文档: {chunk.doc_title}\n"
|
||||
f"章节: {chunk.section_title or '未标注'}\n"
|
||||
f"页码: {chunk.page_number}\n"
|
||||
f"内容: {chunk.content}"
|
||||
f"页码: {chunk.page_start}" + (f"-{chunk.page_end}" if chunk.page_end and chunk.page_end != chunk.page_start else "") + "\n"
|
||||
f"内容: {chunk.text}"
|
||||
)
|
||||
block_tokens = self._estimate_tokens(block)
|
||||
if context_tokens + block_tokens > settings.rag_max_context_tokens:
|
||||
@@ -73,10 +73,10 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
|
||||
return False
|
||||
estimated_total_tokens = sum(
|
||||
self._estimate_tokens(
|
||||
f"[{idx}] 文档: {chunk.doc_name}\n"
|
||||
f"[{idx}] 文档: {chunk.doc_title}\n"
|
||||
f"章节: {chunk.section_title or '未标注'}\n"
|
||||
f"页码: {chunk.page_number}\n"
|
||||
f"内容: {chunk.content}"
|
||||
f"页码: {chunk.page_start}" + (f"-{chunk.page_end}" if chunk.page_end and chunk.page_end != chunk.page_start else "") + "\n"
|
||||
f"内容: {chunk.text}"
|
||||
)
|
||||
for idx, chunk in enumerate(retrieved_chunks, start=1)
|
||||
)
|
||||
@@ -87,12 +87,17 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
|
||||
return [
|
||||
AnswerSource(
|
||||
doc_id=chunk.doc_id,
|
||||
doc_name=chunk.doc_name,
|
||||
doc_title=chunk.doc_title,
|
||||
chunk_id=chunk.chunk_id,
|
||||
chunk_type=chunk.chunk_type,
|
||||
section_title=chunk.section_title,
|
||||
page_number=chunk.page_number,
|
||||
page_start=chunk.page_start,
|
||||
page_end=chunk.page_end,
|
||||
section_level=chunk.section_level,
|
||||
chunk_index=chunk.chunk_index,
|
||||
piece_index=chunk.piece_index,
|
||||
score=chunk.score,
|
||||
content=chunk.content,
|
||||
text=chunk.text,
|
||||
metadata=chunk.metadata,
|
||||
)
|
||||
for chunk in chunks
|
||||
|
||||
@@ -10,6 +10,7 @@ class LocalRegulationChunkBuilder(ChunkBuilder):
|
||||
"""Adapt the existing markdown chunker to the new chunk builder port."""
|
||||
|
||||
def __init__(self, *, chunk_size: int = 512, chunk_overlap: int = 50) -> None:
|
||||
"""Initialize the local markdown chunk builder."""
|
||||
self.chunker = RegulationChunker(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
@@ -22,6 +23,7 @@ class LocalRegulationChunkBuilder(ChunkBuilder):
|
||||
regulation_type: str,
|
||||
version: str,
|
||||
) -> list[Chunk]:
|
||||
"""Build migrated chunk objects from the legacy markdown chunker output."""
|
||||
markdown_text = parsed_document.raw_text.strip()
|
||||
if not markdown_text:
|
||||
return []
|
||||
@@ -50,16 +52,18 @@ class LocalRegulationChunkBuilder(ChunkBuilder):
|
||||
Chunk(
|
||||
chunk_id=item.metadata.chunk_id,
|
||||
doc_id=parsed_document.doc_id,
|
||||
doc_name=parsed_document.doc_name,
|
||||
content=item.content,
|
||||
doc_title=parsed_document.doc_name,
|
||||
text=item.content,
|
||||
embedding_text=item.content,
|
||||
chunk_type="local_markdown_chunk",
|
||||
section_title=item.metadata.section_title or item.metadata.section_number,
|
||||
section_path=section_path,
|
||||
page_number=item.metadata.page_number,
|
||||
page_start=item.metadata.page_number,
|
||||
page_end=item.metadata.page_number,
|
||||
section_level=len(section_path),
|
||||
regulation_type=regulation_type,
|
||||
version=version,
|
||||
semantic_id=item.metadata.clause_number,
|
||||
block_type="local_markdown_chunk",
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -19,29 +19,35 @@ class AliyunVectorChunkBuilder(ChunkBuilder):
|
||||
"""Handle build for the Aliyun Vector Chunk Builder instance."""
|
||||
chunks: list[Chunk] = []
|
||||
for index, item in enumerate(parsed_document.vector_chunks):
|
||||
content = item.get("content") or item.get("text") or ""
|
||||
embedding_text = item.get("embedding_text") or content
|
||||
text = item.get("text") or ""
|
||||
embedding_text = item.get("embedding_text") or text
|
||||
if not embedding_text.strip():
|
||||
continue
|
||||
section_path = item.get("section_path") or []
|
||||
section_title = item.get("section_title") or (section_path[-1] if section_path else "")
|
||||
page_number = item.get("page_start") or item.get("page") or 0
|
||||
chunk_id = item.get("chunk_id") or f"{parsed_document.doc_id}-chunk-{index}"
|
||||
metadata = {k: v for k, v in item.items() if k not in {"content", "embedding_text"}}
|
||||
metadata = dict(item)
|
||||
metadata["regulation_type"] = regulation_type
|
||||
metadata["version"] = version
|
||||
chunks.append(
|
||||
Chunk(
|
||||
chunk_id=str(chunk_id),
|
||||
doc_id=parsed_document.doc_id,
|
||||
doc_name=parsed_document.doc_name,
|
||||
content=content,
|
||||
doc_title=str(item.get("doc_title") or parsed_document.doc_name),
|
||||
text=text,
|
||||
embedding_text=embedding_text,
|
||||
chunk_type=str(item.get("chunk_type", item.get("block_type", ""))),
|
||||
chunk_index=int(item.get("chunk_index") or 0),
|
||||
piece_index=int(item.get("piece_index") or 0),
|
||||
page_start=int(item.get("page_start") or 0),
|
||||
page_end=int(item.get("page_end") or 0),
|
||||
section_title=section_title,
|
||||
section_path=section_path,
|
||||
page_number=int(page_number or 0),
|
||||
section_level=int(item.get("section_level") or len(section_path)),
|
||||
source_ids=[str(v) for v in item.get("source_ids", [])],
|
||||
regulation_type=regulation_type,
|
||||
version=version,
|
||||
semantic_id=item.get("semantic_id", ""),
|
||||
block_type=item.get("block_type", ""),
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -56,7 +56,21 @@ class BM25Retriever:
|
||||
try:
|
||||
rows = self._vector_index.collection.query(
|
||||
expr='doc_id != ""',
|
||||
output_fields=["id", "doc_id", "doc_name", "content", "section_title", "page_number"],
|
||||
output_fields=[
|
||||
"id",
|
||||
"chunk_id",
|
||||
"doc_id",
|
||||
"doc_title",
|
||||
"text",
|
||||
"chunk_type",
|
||||
"section_title",
|
||||
"page_start",
|
||||
"page_end",
|
||||
"section_level",
|
||||
"chunk_index",
|
||||
"piece_index",
|
||||
"metadata_json",
|
||||
],
|
||||
limit=16384,
|
||||
)
|
||||
except Exception:
|
||||
@@ -64,19 +78,33 @@ class BM25Retriever:
|
||||
return []
|
||||
return [
|
||||
RetrievedChunk(
|
||||
chunk_id=str(row.get("id", "")),
|
||||
chunk_id=str(row.get("chunk_id") or row.get("id", "")),
|
||||
doc_id=str(row.get("doc_id", "")),
|
||||
doc_name=str(row.get("doc_name", "")),
|
||||
content=str(row.get("content", "")),
|
||||
doc_title=str(row.get("doc_title", "")),
|
||||
text=str(row.get("text", "")),
|
||||
score=0.0,
|
||||
chunk_type=str(row.get("chunk_type", "")),
|
||||
section_title=str(row.get("section_title", "")),
|
||||
page_number=int(row.get("page_number") or 0),
|
||||
metadata={},
|
||||
page_start=int(row.get("page_start") or 0),
|
||||
page_end=int(row.get("page_end") or 0),
|
||||
section_level=int(row.get("section_level") or 0),
|
||||
chunk_index=int(row.get("chunk_index") or 0),
|
||||
piece_index=int(row.get("piece_index") or 0),
|
||||
metadata=self._parse_metadata_json(row.get("metadata_json", "")),
|
||||
)
|
||||
for row in rows
|
||||
if row.get("content")
|
||||
if row.get("text")
|
||||
]
|
||||
|
||||
def _parse_metadata_json(self, raw_metadata: str) -> dict:
|
||||
"""Parse metadata_json into a dict for BM25-side filtering."""
|
||||
if not raw_metadata:
|
||||
return {}
|
||||
try:
|
||||
return dict(__import__("json").loads(raw_metadata))
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
def _ensure_built(self) -> None:
|
||||
if self._index is not None:
|
||||
return
|
||||
@@ -93,7 +121,7 @@ class BM25Retriever:
|
||||
self._chunks = []
|
||||
self._index = BM25Okapi([[]])
|
||||
return
|
||||
tokenized = [_tokenize(c.content) for c in chunks]
|
||||
tokenized = [_tokenize(c.text) for c in chunks]
|
||||
self._chunks = chunks
|
||||
self._index = BM25Okapi(tokenized)
|
||||
logger.info("BM25Retriever: index built with %d chunks", len(chunks))
|
||||
@@ -127,20 +155,26 @@ class BM25Retriever:
|
||||
for score, chunk in ranked[: top_k * 2]:
|
||||
if score <= 0:
|
||||
break
|
||||
# Apply simple regulation_type filter if provided
|
||||
if filters and chunk.metadata.get("regulation_type"):
|
||||
types = [t.strip() for t in filters.split(",")]
|
||||
if chunk.metadata.get("regulation_type") not in types:
|
||||
continue
|
||||
if filters:
|
||||
normalized_filter = filters.replace("doc_name", "doc_title").strip()
|
||||
if normalized_filter.startswith('doc_title == "'):
|
||||
expected_title = normalized_filter[len('doc_title == "'):-1]
|
||||
if chunk.doc_title != expected_title:
|
||||
continue
|
||||
results.append(
|
||||
RetrievedChunk(
|
||||
chunk_id=chunk.chunk_id,
|
||||
doc_id=chunk.doc_id,
|
||||
doc_name=chunk.doc_name,
|
||||
content=chunk.content,
|
||||
doc_title=chunk.doc_title,
|
||||
text=chunk.text,
|
||||
score=score,
|
||||
chunk_type=chunk.chunk_type,
|
||||
section_title=chunk.section_title,
|
||||
page_number=chunk.page_number,
|
||||
page_start=chunk.page_start,
|
||||
page_end=chunk.page_end,
|
||||
section_level=chunk.section_level,
|
||||
chunk_index=chunk.chunk_index,
|
||||
piece_index=chunk.piece_index,
|
||||
metadata=chunk.metadata,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -31,7 +31,7 @@ class OpenAICompatibleReranker(Reranker):
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
texts = [chunk.content for chunk in chunks]
|
||||
texts = [chunk.text for chunk in chunks]
|
||||
start = time.time()
|
||||
try:
|
||||
scores = self._call_reranker(query, texts)
|
||||
|
||||
@@ -4,57 +4,150 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Iterable
|
||||
|
||||
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, utility
|
||||
from loguru import logger
|
||||
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusException, connections, utility
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.domain.documents import Chunk
|
||||
from app.domain.retrieval import RetrievedChunk, VectorIndex
|
||||
from app.shared.errors import VectorStoreSchemaError
|
||||
# Keep adapter behavior explicit so integration details remain easy to audit.
|
||||
|
||||
|
||||
_REQUIRED_SCHEMA_FIELDS = (
|
||||
"doc_id",
|
||||
"doc_title",
|
||||
"chunk_id",
|
||||
"text",
|
||||
"embedding",
|
||||
"section_title",
|
||||
"metadata_json",
|
||||
)
|
||||
_SCHEMA_RECOVERY_TOKENS = (
|
||||
"field doc_title not exist",
|
||||
"field text not exist",
|
||||
"field embedding not exist",
|
||||
"collection not loaded",
|
||||
"can't find collection",
|
||||
"not found[collection",
|
||||
)
|
||||
|
||||
|
||||
|
||||
class MilvusVectorIndex(VectorIndex):
|
||||
"""Provide the Milvus Vector Index index implementation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the Milvus Vector Index instance."""
|
||||
self.collection_name = settings.milvus_collection
|
||||
self.db_name = settings.milvus_db_name
|
||||
self.host = settings.milvus_host
|
||||
self.port = settings.milvus_port
|
||||
# Use an adapter-specific alias so this index never reuses unrelated global Milvus state.
|
||||
self.alias = f"vector-index::{self.host}:{self.port}/{self.db_name}/{self.collection_name}"
|
||||
self._connect()
|
||||
self.collection = self._bind_collection()
|
||||
|
||||
def _connect(self, *, refresh: bool = False) -> None:
|
||||
"""Establish the Milvus connection for this adapter."""
|
||||
if refresh:
|
||||
try:
|
||||
connections.disconnect(self.alias)
|
||||
except Exception:
|
||||
# Best-effort disconnect keeps refresh idempotent when no alias is active yet.
|
||||
pass
|
||||
connections.connect(
|
||||
alias="default",
|
||||
host=settings.milvus_host,
|
||||
port=settings.milvus_port,
|
||||
alias=self.alias,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
db_name=self.db_name,
|
||||
)
|
||||
self.collection = self._ensure_collection()
|
||||
|
||||
def _schema_field_names(self, collection: Collection) -> list[str]:
|
||||
"""Return the field names exposed by the bound Milvus collection."""
|
||||
return [field.name for field in collection.schema.fields]
|
||||
|
||||
def _raise_schema_error(self, *, message: str, actual_fields: Iterable[str]) -> None:
|
||||
"""Raise a typed schema error for the active collection."""
|
||||
raise VectorStoreSchemaError(
|
||||
message=message,
|
||||
host=self.host,
|
||||
db_name=self.db_name,
|
||||
collection_name=self.collection_name,
|
||||
expected_fields=list(_REQUIRED_SCHEMA_FIELDS),
|
||||
actual_fields=list(actual_fields),
|
||||
)
|
||||
|
||||
def _validate_schema(self, collection: Collection) -> None:
|
||||
"""Ensure the collection schema matches the dense-only adapter contract."""
|
||||
actual_fields = self._schema_field_names(collection)
|
||||
missing_fields = [field_name for field_name in _REQUIRED_SCHEMA_FIELDS if field_name not in actual_fields]
|
||||
if missing_fields:
|
||||
self._raise_schema_error(
|
||||
message=f"Milvus collection schema mismatch; missing required fields: {missing_fields}",
|
||||
actual_fields=actual_fields,
|
||||
)
|
||||
|
||||
def _log_collection_binding(self, collection: Collection, *, event: str) -> None:
|
||||
"""Record the bound collection details for runtime diagnostics."""
|
||||
try:
|
||||
num_entities = collection.num_entities
|
||||
except Exception:
|
||||
num_entities = "unknown"
|
||||
logger.info(
|
||||
"Milvus binding {} alias={} host={} db={} collection={} fields={} num_entities={}",
|
||||
event,
|
||||
self.alias,
|
||||
self.host,
|
||||
self.db_name,
|
||||
self.collection_name,
|
||||
self._schema_field_names(collection),
|
||||
num_entities,
|
||||
)
|
||||
|
||||
def _bind_collection(self, *, force_refresh: bool = False) -> Collection:
|
||||
"""Bind and validate the configured Milvus collection."""
|
||||
if force_refresh:
|
||||
self._connect(refresh=True)
|
||||
collection = self._ensure_collection()
|
||||
self._validate_schema(collection)
|
||||
self._log_collection_binding(collection, event="refreshed" if force_refresh else "initialized")
|
||||
return collection
|
||||
|
||||
def _ensure_collection(self) -> Collection:
|
||||
"""Handle ensure collection for this module for the Milvus Vector Index instance."""
|
||||
if utility.has_collection(self.collection_name):
|
||||
collection = Collection(self.collection_name)
|
||||
if utility.has_collection(self.collection_name, using=self.alias):
|
||||
collection = Collection(self.collection_name, using=self.alias)
|
||||
collection.load()
|
||||
return collection
|
||||
schema = CollectionSchema(
|
||||
fields=[
|
||||
FieldSchema(name="id", dtype=DataType.VARCHAR, max_length=128, is_primary=True, auto_id=False),
|
||||
FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=64),
|
||||
FieldSchema(name="doc_name", dtype=DataType.VARCHAR, max_length=256),
|
||||
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="doc_title", dtype=DataType.VARCHAR, max_length=256),
|
||||
FieldSchema(name="chunk_id", dtype=DataType.VARCHAR, max_length=128),
|
||||
FieldSchema(name="chunk_index", dtype=DataType.INT64),
|
||||
FieldSchema(name="piece_index", dtype=DataType.INT64),
|
||||
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="embedding_text", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=settings.embedding_dim),
|
||||
FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512),
|
||||
FieldSchema(name="section_path", dtype=DataType.VARCHAR, max_length=4096),
|
||||
FieldSchema(name="page_number", dtype=DataType.INT64),
|
||||
FieldSchema(name="regulation_type", dtype=DataType.VARCHAR, max_length=128),
|
||||
FieldSchema(name="version", dtype=DataType.VARCHAR, max_length=64),
|
||||
FieldSchema(name="semantic_id", dtype=DataType.VARCHAR, max_length=128),
|
||||
FieldSchema(name="block_type", dtype=DataType.VARCHAR, max_length=64),
|
||||
FieldSchema(name="chunk_type", dtype=DataType.VARCHAR, max_length=64),
|
||||
FieldSchema(name="page_start", dtype=DataType.INT64),
|
||||
FieldSchema(name="page_end", dtype=DataType.INT64),
|
||||
FieldSchema(name="section_level", dtype=DataType.INT64),
|
||||
FieldSchema(name="source_ids", dtype=DataType.VARCHAR, max_length=4096),
|
||||
FieldSchema(name="section_path", dtype=DataType.VARCHAR, max_length=4096),
|
||||
FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512),
|
||||
FieldSchema(name="metadata_json", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="created_at", dtype=DataType.INT64),
|
||||
],
|
||||
description="Dense-only regulations index",
|
||||
enable_dynamic_field=False,
|
||||
)
|
||||
collection = Collection(name=self.collection_name, schema=schema)
|
||||
collection = Collection(name=self.collection_name, schema=schema, using=self.alias)
|
||||
collection.create_index(
|
||||
field_name="embedding",
|
||||
index_params={
|
||||
@@ -73,21 +166,34 @@ class MilvusVectorIndex(VectorIndex):
|
||||
data = []
|
||||
now = int(time.time())
|
||||
for chunk, vector in zip(chunks, vectors):
|
||||
metadata = dict(chunk.metadata)
|
||||
doc_title = str(metadata.get("doc_title", chunk.doc_title))
|
||||
text = str(metadata.get("text", chunk.text))
|
||||
embedding_text = str(metadata.get("embedding_text", chunk.embedding_text))
|
||||
page_start = int(metadata.get("page_start", 0) or 0)
|
||||
page_end = int(metadata.get("page_end", 0) or 0)
|
||||
section_path = metadata.get("section_path", chunk.section_path)
|
||||
source_ids = metadata.get("source_ids", [])
|
||||
data.append(
|
||||
{
|
||||
"id": chunk.chunk_id,
|
||||
"doc_id": chunk.doc_id,
|
||||
"doc_name": chunk.doc_name,
|
||||
"content": chunk.content[:65535],
|
||||
"doc_title": doc_title[:256],
|
||||
"chunk_id": chunk.chunk_id[:128],
|
||||
"chunk_index": int(metadata.get("chunk_index", chunk.chunk_index) or 0),
|
||||
"piece_index": int(metadata.get("piece_index", chunk.piece_index) or 0),
|
||||
"text": text[:65535],
|
||||
"embedding_text": embedding_text[:65535],
|
||||
"embedding": vector,
|
||||
"section_title": chunk.section_title[:512],
|
||||
"section_path": json.dumps(chunk.section_path, ensure_ascii=False)[:4096],
|
||||
"page_number": chunk.page_number,
|
||||
"regulation_type": chunk.regulation_type[:128],
|
||||
"version": chunk.version[:64],
|
||||
"semantic_id": chunk.semantic_id[:128],
|
||||
"block_type": chunk.block_type[:64],
|
||||
"metadata_json": json.dumps(chunk.metadata, ensure_ascii=False)[:65535],
|
||||
"semantic_id": str(metadata.get("semantic_id", chunk.semantic_id))[:128],
|
||||
"chunk_type": str(metadata.get("chunk_type", chunk.chunk_type))[:64],
|
||||
"page_start": page_start,
|
||||
"page_end": page_end,
|
||||
"section_level": int(metadata.get("section_level", chunk.section_level) or 0),
|
||||
"source_ids": json.dumps(source_ids, ensure_ascii=False)[:4096],
|
||||
"section_path": json.dumps(section_path, ensure_ascii=False)[:4096],
|
||||
"section_title": str(metadata.get("section_title", chunk.section_title))[:512],
|
||||
"metadata_json": json.dumps(metadata, ensure_ascii=False)[:65535],
|
||||
"created_at": now,
|
||||
}
|
||||
)
|
||||
@@ -107,47 +213,97 @@ class MilvusVectorIndex(VectorIndex):
|
||||
|
||||
filters = filters.strip()
|
||||
|
||||
# Normalize legacy field names so callers can keep older filter payloads.
|
||||
replacements = {
|
||||
"doc_name": "doc_title",
|
||||
"content": "text",
|
||||
"page_number": "page_start",
|
||||
"block_type": "chunk_type",
|
||||
}
|
||||
for legacy_name, new_name in replacements.items():
|
||||
filters = filters.replace(legacy_name, new_name)
|
||||
|
||||
# Check if already a Milvus expression (contains operators)
|
||||
if any(op in filters for op in ["==", "!=", "in", "not in", ">", "<", ">=", "<=", "and", "or"]):
|
||||
return filters
|
||||
|
||||
# Parse simple regulation_type filter
|
||||
# Support: "GB" or "GB,UN-ECE" or "GB, UN-ECE"
|
||||
types = [t.strip() for t in filters.split(",") if t.strip()]
|
||||
# Parse simple document-title filter.
|
||||
titles = [title.strip() for title in filters.split(",") if title.strip()]
|
||||
|
||||
if not types:
|
||||
if not titles:
|
||||
return None
|
||||
|
||||
if len(types) == 1:
|
||||
# Single value: regulation_type == "GB"
|
||||
return f'regulation_type == "{types[0]}"'
|
||||
else:
|
||||
# Multiple values: regulation_type in ["GB", "UN-ECE"]
|
||||
quoted_types = [f'"{t}"' for t in types]
|
||||
return f'regulation_type in [{", ".join(quoted_types)}]'
|
||||
if len(titles) == 1:
|
||||
return f'doc_title == "{titles[0]}"'
|
||||
|
||||
quoted_titles = [f'"{title}"' for title in titles]
|
||||
return f'doc_title in [{", ".join(quoted_titles)}]'
|
||||
|
||||
def _should_refresh_after_exception(self, exc: Exception) -> bool:
|
||||
"""Return whether the Milvus error suggests stale connection or collection state."""
|
||||
if not isinstance(exc, MilvusException):
|
||||
return False
|
||||
normalized = str(exc).lower()
|
||||
return any(token in normalized for token in _SCHEMA_RECOVERY_TOKENS)
|
||||
|
||||
def _run_with_refresh(self, operation):
|
||||
"""Run a Milvus operation and retry once after a forced reconnect when appropriate."""
|
||||
try:
|
||||
return operation()
|
||||
except VectorStoreSchemaError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if not self._should_refresh_after_exception(exc):
|
||||
raise
|
||||
logger.warning(
|
||||
"Milvus operation failed for alias={} collection={}; forcing reconnect and retry: {}",
|
||||
self.alias,
|
||||
self.collection_name,
|
||||
exc,
|
||||
)
|
||||
self.collection = self._bind_collection(force_refresh=True)
|
||||
try:
|
||||
return operation()
|
||||
except VectorStoreSchemaError:
|
||||
raise
|
||||
except Exception as retry_exc:
|
||||
if isinstance(retry_exc, MilvusException):
|
||||
self._raise_schema_error(
|
||||
message=f"Milvus operation failed after refresh: {retry_exc}",
|
||||
actual_fields=self._schema_field_names(self.collection),
|
||||
)
|
||||
raise
|
||||
|
||||
def search(self, query_vector: list[float], top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
||||
"""Handle search for the Milvus Vector Index instance."""
|
||||
milvus_expr = self._parse_filters(filters)
|
||||
|
||||
results = self.collection.search(
|
||||
data=[query_vector],
|
||||
anns_field="embedding",
|
||||
param={"metric_type": "COSINE", "params": {"nprobe": settings.milvus_nprobe}},
|
||||
limit=top_k,
|
||||
expr=milvus_expr,
|
||||
output_fields=[
|
||||
"doc_id",
|
||||
"doc_name",
|
||||
"content",
|
||||
"section_title",
|
||||
"page_number",
|
||||
"regulation_type",
|
||||
"version",
|
||||
"semantic_id",
|
||||
"block_type",
|
||||
"metadata_json",
|
||||
],
|
||||
results = self._run_with_refresh(
|
||||
lambda: self.collection.search(
|
||||
data=[query_vector],
|
||||
anns_field="embedding",
|
||||
param={"metric_type": "COSINE", "params": {"nprobe": settings.milvus_nprobe}},
|
||||
limit=top_k,
|
||||
expr=milvus_expr,
|
||||
output_fields=[
|
||||
"doc_id",
|
||||
"doc_title",
|
||||
"chunk_id",
|
||||
"chunk_index",
|
||||
"piece_index",
|
||||
"text",
|
||||
"embedding_text",
|
||||
"section_title",
|
||||
"semantic_id",
|
||||
"chunk_type",
|
||||
"page_start",
|
||||
"page_end",
|
||||
"section_level",
|
||||
"source_ids",
|
||||
"section_path",
|
||||
"metadata_json",
|
||||
],
|
||||
)
|
||||
)
|
||||
payload: list[RetrievedChunk] = []
|
||||
for hits in results:
|
||||
@@ -161,13 +317,18 @@ class MilvusVectorIndex(VectorIndex):
|
||||
metadata = {"raw_metadata": raw_metadata}
|
||||
payload.append(
|
||||
RetrievedChunk(
|
||||
chunk_id=str(hit.id),
|
||||
chunk_id=str(hit.entity.get("chunk_id", hit.id)),
|
||||
doc_id=hit.entity.get("doc_id", ""),
|
||||
doc_name=hit.entity.get("doc_name", ""),
|
||||
content=hit.entity.get("content", ""),
|
||||
doc_title=hit.entity.get("doc_title", ""),
|
||||
text=hit.entity.get("text", ""),
|
||||
score=float(hit.score),
|
||||
chunk_type=hit.entity.get("chunk_type", ""),
|
||||
section_title=hit.entity.get("section_title", ""),
|
||||
page_number=int(hit.entity.get("page_number", 0) or 0),
|
||||
page_start=int(hit.entity.get("page_start", 0) or 0),
|
||||
page_end=int(hit.entity.get("page_end", 0) or 0),
|
||||
section_level=int(hit.entity.get("section_level", 0) or 0),
|
||||
chunk_index=int(hit.entity.get("chunk_index", 0) or 0),
|
||||
piece_index=int(hit.entity.get("piece_index", 0) or 0),
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
@@ -176,7 +337,9 @@ class MilvusVectorIndex(VectorIndex):
|
||||
def count_by_document(self) -> dict[str, int]:
|
||||
"""Return doc_id -> chunk count from Milvus."""
|
||||
try:
|
||||
rows = self.collection.query(expr="doc_id != \"\"", output_fields=["doc_id"])
|
||||
rows = self._run_with_refresh(
|
||||
lambda: self.collection.query(expr="doc_id != \"\"", output_fields=["doc_id", "doc_title"])
|
||||
)
|
||||
except Exception:
|
||||
return {}
|
||||
counts: dict[str, int] = {}
|
||||
@@ -189,9 +352,11 @@ class MilvusVectorIndex(VectorIndex):
|
||||
def list_document_metadata(self) -> list[dict]:
|
||||
"""Return one metadata row per document from Milvus (single query, no embeddings)."""
|
||||
try:
|
||||
rows = self.collection.query(
|
||||
expr="doc_id != \"\"",
|
||||
output_fields=["doc_id", "doc_name", "regulation_type", "version"],
|
||||
rows = self._run_with_refresh(
|
||||
lambda: self.collection.query(
|
||||
expr="doc_id != \"\"",
|
||||
output_fields=["doc_id", "doc_title", "metadata_json"],
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
@@ -204,15 +369,26 @@ class MilvusVectorIndex(VectorIndex):
|
||||
continue
|
||||
counts[doc_id] = counts.get(doc_id, 0) + 1
|
||||
if doc_id not in seen:
|
||||
metadata: dict[str, object] = {}
|
||||
raw_metadata = row.get("metadata_json", "")
|
||||
if raw_metadata:
|
||||
try:
|
||||
metadata = json.loads(raw_metadata)
|
||||
except json.JSONDecodeError:
|
||||
metadata = {}
|
||||
seen[doc_id] = {
|
||||
"doc_id": doc_id,
|
||||
"doc_name": row.get("doc_name", ""),
|
||||
"regulation_type": row.get("regulation_type", ""),
|
||||
"version": row.get("version", ""),
|
||||
"doc_title": row.get("doc_title", ""),
|
||||
"regulation_type": str(metadata.get("regulation_type", "")),
|
||||
"version": str(metadata.get("version", "")),
|
||||
}
|
||||
|
||||
return [
|
||||
{**meta, "chunk_count": counts[meta["doc_id"]]}
|
||||
{
|
||||
**meta,
|
||||
"doc_name": meta.get("doc_title", ""),
|
||||
"chunk_count": counts[meta["doc_id"]],
|
||||
}
|
||||
for meta in seen.values()
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user