- Removed multiple failed document entries from `documents.json`. - Added a new document entry with updated metadata and changed the index name to `regulations_dense_1024_v2`. - Updated architecture documentation to reflect changes in the Milvus collection name. - Adjusted requirements by removing the sqlalchemy dependency. - Modified test cases to align with new document structure and naming conventions. - Introduced a new test file for Milvus vector index runtime recovery and error handling. - Updated assertions in various test files to ensure compatibility with the new schema.
184 lines
6.3 KiB
Python
184 lines
6.3 KiB
Python
"""BM25 sparse retriever backed by the existing Milvus collection.
|
||
|
||
Falls back to no-op if rank_bm25 or jieba is not installed.
|
||
The index is built lazily on first query and can be refreshed after
|
||
new documents are ingested by calling refresh().
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
import threading
|
||
from typing import TYPE_CHECKING
|
||
|
||
from app.domain.retrieval import RetrievedChunk
|
||
|
||
if TYPE_CHECKING:
|
||
from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
try:
|
||
from rank_bm25 import BM25Okapi
|
||
import jieba
|
||
|
||
jieba.setLogLevel(logging.WARNING)
|
||
_BM25_AVAILABLE = True
|
||
except ImportError:
|
||
_BM25_AVAILABLE = False
|
||
logger.warning("rank_bm25 or jieba not installed – BM25 retrieval disabled. Run: pip install rank-bm25 jieba")
|
||
|
||
|
||
def _tokenize(text: str) -> list[str]:
|
||
"""Tokenize Chinese/mixed text with jieba."""
|
||
return list(jieba.cut(text))
|
||
|
||
|
||
class BM25Retriever:
|
||
"""Sparse BM25 retriever that indexes chunks stored in Milvus.
|
||
|
||
Thread-safe: the index is protected by a lock so concurrent requests
|
||
during the initial build do not race.
|
||
"""
|
||
|
||
def __init__(self, vector_index: "MilvusVectorIndex") -> None:
|
||
self._vector_index = vector_index
|
||
self._lock = threading.Lock()
|
||
self._index: "BM25Okapi | None" = None
|
||
self._chunks: list[RetrievedChunk] = []
|
||
|
||
@property
|
||
def available(self) -> bool:
|
||
return _BM25_AVAILABLE
|
||
|
||
def _load_all_chunks(self) -> list[RetrievedChunk]:
|
||
"""Fetch all chunk records from Milvus (no vectors needed)."""
|
||
try:
|
||
rows = self._vector_index.collection.query(
|
||
expr='doc_id != ""',
|
||
output_fields=[
|
||
"id",
|
||
"chunk_id",
|
||
"doc_id",
|
||
"doc_title",
|
||
"text",
|
||
"chunk_type",
|
||
"section_title",
|
||
"page_start",
|
||
"page_end",
|
||
"section_level",
|
||
"chunk_index",
|
||
"piece_index",
|
||
"metadata_json",
|
||
],
|
||
limit=16384,
|
||
)
|
||
except Exception:
|
||
logger.exception("BM25Retriever: failed to fetch chunks from Milvus")
|
||
return []
|
||
return [
|
||
RetrievedChunk(
|
||
chunk_id=str(row.get("chunk_id") or row.get("id", "")),
|
||
doc_id=str(row.get("doc_id", "")),
|
||
doc_title=str(row.get("doc_title", "")),
|
||
text=str(row.get("text", "")),
|
||
score=0.0,
|
||
chunk_type=str(row.get("chunk_type", "")),
|
||
section_title=str(row.get("section_title", "")),
|
||
page_start=int(row.get("page_start") or 0),
|
||
page_end=int(row.get("page_end") or 0),
|
||
section_level=int(row.get("section_level") or 0),
|
||
chunk_index=int(row.get("chunk_index") or 0),
|
||
piece_index=int(row.get("piece_index") or 0),
|
||
metadata=self._parse_metadata_json(row.get("metadata_json", "")),
|
||
)
|
||
for row in rows
|
||
if row.get("text")
|
||
]
|
||
|
||
def _parse_metadata_json(self, raw_metadata: str) -> dict:
|
||
"""Parse metadata_json into a dict for BM25-side filtering."""
|
||
if not raw_metadata:
|
||
return {}
|
||
try:
|
||
return dict(__import__("json").loads(raw_metadata))
|
||
except Exception:
|
||
return {}
|
||
|
||
def _ensure_built(self) -> None:
|
||
if self._index is not None:
|
||
return
|
||
with self._lock:
|
||
if self._index is not None:
|
||
return
|
||
self._build()
|
||
|
||
def _build(self) -> None:
|
||
logger.info("BM25Retriever: building index …")
|
||
chunks = self._load_all_chunks()
|
||
if not chunks:
|
||
logger.warning("BM25Retriever: no chunks found, index is empty")
|
||
self._chunks = []
|
||
self._index = BM25Okapi([[]])
|
||
return
|
||
tokenized = [_tokenize(c.text) for c in chunks]
|
||
self._chunks = chunks
|
||
self._index = BM25Okapi(tokenized)
|
||
logger.info("BM25Retriever: index built with %d chunks", len(chunks))
|
||
|
||
def refresh(self) -> None:
|
||
"""Rebuild the index (call after new documents are ingested)."""
|
||
with self._lock:
|
||
self._index = None
|
||
self._chunks = []
|
||
self._build()
|
||
|
||
def retrieve(self, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
||
"""Return top_k chunks ranked by BM25 score."""
|
||
if not _BM25_AVAILABLE:
|
||
return []
|
||
self._ensure_built()
|
||
if not self._chunks:
|
||
return []
|
||
|
||
tokens = _tokenize(query)
|
||
scores = self._index.get_scores(tokens) # type: ignore[union-attr]
|
||
|
||
# Pair and sort descending by score
|
||
ranked = sorted(
|
||
((float(scores[i]), self._chunks[i]) for i in range(len(self._chunks))),
|
||
key=lambda x: x[0],
|
||
reverse=True,
|
||
)
|
||
|
||
results: list[RetrievedChunk] = []
|
||
for score, chunk in ranked[: top_k * 2]:
|
||
if score <= 0:
|
||
break
|
||
if filters:
|
||
normalized_filter = filters.replace("doc_name", "doc_title").strip()
|
||
if normalized_filter.startswith('doc_title == "'):
|
||
expected_title = normalized_filter[len('doc_title == "'):-1]
|
||
if chunk.doc_title != expected_title:
|
||
continue
|
||
results.append(
|
||
RetrievedChunk(
|
||
chunk_id=chunk.chunk_id,
|
||
doc_id=chunk.doc_id,
|
||
doc_title=chunk.doc_title,
|
||
text=chunk.text,
|
||
score=score,
|
||
chunk_type=chunk.chunk_type,
|
||
section_title=chunk.section_title,
|
||
page_start=chunk.page_start,
|
||
page_end=chunk.page_end,
|
||
section_level=chunk.section_level,
|
||
chunk_index=chunk.chunk_index,
|
||
piece_index=chunk.piece_index,
|
||
metadata=chunk.metadata,
|
||
)
|
||
)
|
||
if len(results) >= top_k:
|
||
break
|
||
return results
|