150 lines
5.0 KiB
Python
150 lines
5.0 KiB
Python
|
|
"""BM25 sparse retriever backed by the existing Milvus collection.
|
|||
|
|
|
|||
|
|
Falls back to no-op if rank_bm25 or jieba is not installed.
|
|||
|
|
The index is built lazily on first query and can be refreshed after
|
|||
|
|
new documents are ingested by calling refresh().
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import logging
|
|||
|
|
import threading
|
|||
|
|
from typing import TYPE_CHECKING
|
|||
|
|
|
|||
|
|
from app.domain.retrieval import RetrievedChunk
|
|||
|
|
|
|||
|
|
if TYPE_CHECKING:
|
|||
|
|
from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
from rank_bm25 import BM25Okapi
|
|||
|
|
import jieba
|
|||
|
|
|
|||
|
|
jieba.setLogLevel(logging.WARNING)
|
|||
|
|
_BM25_AVAILABLE = True
|
|||
|
|
except ImportError:
|
|||
|
|
_BM25_AVAILABLE = False
|
|||
|
|
logger.warning("rank_bm25 or jieba not installed – BM25 retrieval disabled. Run: pip install rank-bm25 jieba")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _tokenize(text: str) -> list[str]:
|
|||
|
|
"""Tokenize Chinese/mixed text with jieba."""
|
|||
|
|
return list(jieba.cut(text))
|
|||
|
|
|
|||
|
|
|
|||
|
|
class BM25Retriever:
|
|||
|
|
"""Sparse BM25 retriever that indexes chunks stored in Milvus.
|
|||
|
|
|
|||
|
|
Thread-safe: the index is protected by a lock so concurrent requests
|
|||
|
|
during the initial build do not race.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, vector_index: "MilvusVectorIndex") -> None:
|
|||
|
|
self._vector_index = vector_index
|
|||
|
|
self._lock = threading.Lock()
|
|||
|
|
self._index: "BM25Okapi | None" = None
|
|||
|
|
self._chunks: list[RetrievedChunk] = []
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def available(self) -> bool:
|
|||
|
|
return _BM25_AVAILABLE
|
|||
|
|
|
|||
|
|
def _load_all_chunks(self) -> list[RetrievedChunk]:
|
|||
|
|
"""Fetch all chunk records from Milvus (no vectors needed)."""
|
|||
|
|
try:
|
|||
|
|
rows = self._vector_index.collection.query(
|
|||
|
|
expr='doc_id != ""',
|
|||
|
|
output_fields=["id", "doc_id", "doc_name", "content", "section_title", "page_number"],
|
|||
|
|
limit=16384,
|
|||
|
|
)
|
|||
|
|
except Exception:
|
|||
|
|
logger.exception("BM25Retriever: failed to fetch chunks from Milvus")
|
|||
|
|
return []
|
|||
|
|
return [
|
|||
|
|
RetrievedChunk(
|
|||
|
|
chunk_id=str(row.get("id", "")),
|
|||
|
|
doc_id=str(row.get("doc_id", "")),
|
|||
|
|
doc_name=str(row.get("doc_name", "")),
|
|||
|
|
content=str(row.get("content", "")),
|
|||
|
|
score=0.0,
|
|||
|
|
section_title=str(row.get("section_title", "")),
|
|||
|
|
page_number=int(row.get("page_number") or 0),
|
|||
|
|
metadata={},
|
|||
|
|
)
|
|||
|
|
for row in rows
|
|||
|
|
if row.get("content")
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
def _ensure_built(self) -> None:
|
|||
|
|
if self._index is not None:
|
|||
|
|
return
|
|||
|
|
with self._lock:
|
|||
|
|
if self._index is not None:
|
|||
|
|
return
|
|||
|
|
self._build()
|
|||
|
|
|
|||
|
|
def _build(self) -> None:
|
|||
|
|
logger.info("BM25Retriever: building index …")
|
|||
|
|
chunks = self._load_all_chunks()
|
|||
|
|
if not chunks:
|
|||
|
|
logger.warning("BM25Retriever: no chunks found, index is empty")
|
|||
|
|
self._chunks = []
|
|||
|
|
self._index = BM25Okapi([[]])
|
|||
|
|
return
|
|||
|
|
tokenized = [_tokenize(c.content) for c in chunks]
|
|||
|
|
self._chunks = chunks
|
|||
|
|
self._index = BM25Okapi(tokenized)
|
|||
|
|
logger.info("BM25Retriever: index built with %d chunks", len(chunks))
|
|||
|
|
|
|||
|
|
def refresh(self) -> None:
|
|||
|
|
"""Rebuild the index (call after new documents are ingested)."""
|
|||
|
|
with self._lock:
|
|||
|
|
self._index = None
|
|||
|
|
self._chunks = []
|
|||
|
|
self._build()
|
|||
|
|
|
|||
|
|
def retrieve(self, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
|||
|
|
"""Return top_k chunks ranked by BM25 score."""
|
|||
|
|
if not _BM25_AVAILABLE:
|
|||
|
|
return []
|
|||
|
|
self._ensure_built()
|
|||
|
|
if not self._chunks:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
tokens = _tokenize(query)
|
|||
|
|
scores = self._index.get_scores(tokens) # type: ignore[union-attr]
|
|||
|
|
|
|||
|
|
# Pair and sort descending by score
|
|||
|
|
ranked = sorted(
|
|||
|
|
((float(scores[i]), self._chunks[i]) for i in range(len(self._chunks))),
|
|||
|
|
key=lambda x: x[0],
|
|||
|
|
reverse=True,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
results: list[RetrievedChunk] = []
|
|||
|
|
for score, chunk in ranked[: top_k * 2]:
|
|||
|
|
if score <= 0:
|
|||
|
|
break
|
|||
|
|
# Apply simple regulation_type filter if provided
|
|||
|
|
if filters and chunk.metadata.get("regulation_type"):
|
|||
|
|
types = [t.strip() for t in filters.split(",")]
|
|||
|
|
if chunk.metadata.get("regulation_type") not in types:
|
|||
|
|
continue
|
|||
|
|
results.append(
|
|||
|
|
RetrievedChunk(
|
|||
|
|
chunk_id=chunk.chunk_id,
|
|||
|
|
doc_id=chunk.doc_id,
|
|||
|
|
doc_name=chunk.doc_name,
|
|||
|
|
content=chunk.content,
|
|||
|
|
score=score,
|
|||
|
|
section_title=chunk.section_title,
|
|||
|
|
page_number=chunk.page_number,
|
|||
|
|
metadata=chunk.metadata,
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
if len(results) >= top_k:
|
|||
|
|
break
|
|||
|
|
return results
|