"""BM25 sparse retriever backed by the existing Milvus collection. Falls back to no-op if rank_bm25 or jieba is not installed. The index is built lazily on first query and can be refreshed after new documents are ingested by calling refresh(). """ from __future__ import annotations import logging import threading from typing import TYPE_CHECKING from app.domain.retrieval import RetrievedChunk if TYPE_CHECKING: from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex logger = logging.getLogger(__name__) try: from rank_bm25 import BM25Okapi import jieba jieba.setLogLevel(logging.WARNING) _BM25_AVAILABLE = True except ImportError: _BM25_AVAILABLE = False logger.warning("rank_bm25 or jieba not installed – BM25 retrieval disabled. Run: pip install rank-bm25 jieba") def _tokenize(text: str) -> list[str]: """Tokenize Chinese/mixed text with jieba.""" return list(jieba.cut(text)) class BM25Retriever: """Sparse BM25 retriever that indexes chunks stored in Milvus. Thread-safe: the index is protected by a lock so concurrent requests during the initial build do not race. """ def __init__(self, vector_index: "MilvusVectorIndex") -> None: self._vector_index = vector_index self._lock = threading.Lock() self._index: "BM25Okapi | None" = None self._chunks: list[RetrievedChunk] = [] @property def available(self) -> bool: return _BM25_AVAILABLE def _load_all_chunks(self) -> list[RetrievedChunk]: """Fetch all chunk records from Milvus (no vectors needed).""" try: rows = self._vector_index.collection.query( expr='doc_id != ""', output_fields=[ "id", "chunk_id", "doc_id", "doc_title", "text", "chunk_type", "section_title", "page_start", "page_end", "section_level", "chunk_index", "piece_index", "metadata_json", ], limit=16384, ) except Exception: logger.exception("BM25Retriever: failed to fetch chunks from Milvus") return [] return [ RetrievedChunk( chunk_id=str(row.get("chunk_id") or row.get("id", "")), doc_id=str(row.get("doc_id", "")), doc_title=str(row.get("doc_title", "")), text=str(row.get("text", "")), score=0.0, chunk_type=str(row.get("chunk_type", "")), section_title=str(row.get("section_title", "")), page_start=int(row.get("page_start") or 0), page_end=int(row.get("page_end") or 0), section_level=int(row.get("section_level") or 0), chunk_index=int(row.get("chunk_index") or 0), piece_index=int(row.get("piece_index") or 0), metadata=self._parse_metadata_json(row.get("metadata_json", "")), ) for row in rows if row.get("text") ] def _parse_metadata_json(self, raw_metadata: str) -> dict: """Parse metadata_json into a dict for BM25-side filtering.""" if not raw_metadata: return {} try: return dict(__import__("json").loads(raw_metadata)) except Exception: return {} def _ensure_built(self) -> None: if self._index is not None: return with self._lock: if self._index is not None: return self._build() def _build(self) -> None: logger.info("BM25Retriever: building index …") chunks = self._load_all_chunks() if not chunks: logger.warning("BM25Retriever: no chunks found, index is empty") self._chunks = [] self._index = BM25Okapi([[]]) return tokenized = [_tokenize(c.text) for c in chunks] self._chunks = chunks self._index = BM25Okapi(tokenized) logger.info("BM25Retriever: index built with %d chunks", len(chunks)) def refresh(self) -> None: """Rebuild the index (call after new documents are ingested).""" with self._lock: self._index = None self._chunks = [] self._build() def retrieve(self, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]: """Return top_k chunks ranked by BM25 score.""" if not _BM25_AVAILABLE: return [] self._ensure_built() if not self._chunks: return [] tokens = _tokenize(query) scores = self._index.get_scores(tokens) # type: ignore[union-attr] # Pair and sort descending by score ranked = sorted( ((float(scores[i]), self._chunks[i]) for i in range(len(self._chunks))), key=lambda x: x[0], reverse=True, ) results: list[RetrievedChunk] = [] for score, chunk in ranked[: top_k * 2]: if score <= 0: break if filters: normalized_filter = filters.replace("doc_name", "doc_title").strip() if normalized_filter.startswith('doc_title == "'): expected_title = normalized_filter[len('doc_title == "'):-1] if chunk.doc_title != expected_title: continue results.append( RetrievedChunk( chunk_id=chunk.chunk_id, doc_id=chunk.doc_id, doc_title=chunk.doc_title, text=chunk.text, score=score, chunk_type=chunk.chunk_type, section_title=chunk.section_title, page_start=chunk.page_start, page_end=chunk.page_end, section_level=chunk.section_level, chunk_index=chunk.chunk_index, piece_index=chunk.piece_index, metadata=chunk.metadata, ) ) if len(results) >= top_k: break return results