Files
AIRegulation-DocAnalysis/backend/app/infrastructure/vectorstore/bm25_retriever.py

184 lines
6.3 KiB
Python
Raw Normal View History

2026-05-21 23:20:39 +08:00
"""BM25 sparse retriever backed by the existing Milvus collection.
Falls back to no-op if rank_bm25 or jieba is not installed.
The index is built lazily on first query and can be refreshed after
new documents are ingested by calling refresh().
"""
from __future__ import annotations
import logging
import threading
from typing import TYPE_CHECKING
from app.domain.retrieval import RetrievedChunk
if TYPE_CHECKING:
from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex
logger = logging.getLogger(__name__)
try:
from rank_bm25 import BM25Okapi
import jieba
jieba.setLogLevel(logging.WARNING)
_BM25_AVAILABLE = True
except ImportError:
_BM25_AVAILABLE = False
logger.warning("rank_bm25 or jieba not installed BM25 retrieval disabled. Run: pip install rank-bm25 jieba")
def _tokenize(text: str) -> list[str]:
"""Tokenize Chinese/mixed text with jieba."""
return list(jieba.cut(text))
class BM25Retriever:
"""Sparse BM25 retriever that indexes chunks stored in Milvus.
Thread-safe: the index is protected by a lock so concurrent requests
during the initial build do not race.
"""
def __init__(self, vector_index: "MilvusVectorIndex") -> None:
self._vector_index = vector_index
self._lock = threading.Lock()
self._index: "BM25Okapi | None" = None
self._chunks: list[RetrievedChunk] = []
@property
def available(self) -> bool:
return _BM25_AVAILABLE
def _load_all_chunks(self) -> list[RetrievedChunk]:
"""Fetch all chunk records from Milvus (no vectors needed)."""
try:
rows = self._vector_index.collection.query(
expr='doc_id != ""',
output_fields=[
"id",
"chunk_id",
"doc_id",
"doc_title",
"text",
"chunk_type",
"section_title",
"page_start",
"page_end",
"section_level",
"chunk_index",
"piece_index",
"metadata_json",
],
2026-05-21 23:20:39 +08:00
limit=16384,
)
except Exception:
logger.exception("BM25Retriever: failed to fetch chunks from Milvus")
return []
return [
RetrievedChunk(
chunk_id=str(row.get("chunk_id") or row.get("id", "")),
2026-05-21 23:20:39 +08:00
doc_id=str(row.get("doc_id", "")),
doc_title=str(row.get("doc_title", "")),
text=str(row.get("text", "")),
2026-05-21 23:20:39 +08:00
score=0.0,
chunk_type=str(row.get("chunk_type", "")),
2026-05-21 23:20:39 +08:00
section_title=str(row.get("section_title", "")),
page_start=int(row.get("page_start") or 0),
page_end=int(row.get("page_end") or 0),
section_level=int(row.get("section_level") or 0),
chunk_index=int(row.get("chunk_index") or 0),
piece_index=int(row.get("piece_index") or 0),
metadata=self._parse_metadata_json(row.get("metadata_json", "")),
2026-05-21 23:20:39 +08:00
)
for row in rows
if row.get("text")
2026-05-21 23:20:39 +08:00
]
def _parse_metadata_json(self, raw_metadata: str) -> dict:
"""Parse metadata_json into a dict for BM25-side filtering."""
if not raw_metadata:
return {}
try:
return dict(__import__("json").loads(raw_metadata))
except Exception:
return {}
2026-05-21 23:20:39 +08:00
def _ensure_built(self) -> None:
if self._index is not None:
return
with self._lock:
if self._index is not None:
return
self._build()
def _build(self) -> None:
logger.info("BM25Retriever: building index …")
chunks = self._load_all_chunks()
if not chunks:
logger.warning("BM25Retriever: no chunks found, index is empty")
self._chunks = []
self._index = BM25Okapi([[]])
return
tokenized = [_tokenize(c.text) for c in chunks]
2026-05-21 23:20:39 +08:00
self._chunks = chunks
self._index = BM25Okapi(tokenized)
logger.info("BM25Retriever: index built with %d chunks", len(chunks))
def refresh(self) -> None:
"""Rebuild the index (call after new documents are ingested)."""
with self._lock:
self._index = None
self._chunks = []
self._build()
def retrieve(self, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
"""Return top_k chunks ranked by BM25 score."""
if not _BM25_AVAILABLE:
return []
self._ensure_built()
if not self._chunks:
return []
tokens = _tokenize(query)
scores = self._index.get_scores(tokens) # type: ignore[union-attr]
# Pair and sort descending by score
ranked = sorted(
((float(scores[i]), self._chunks[i]) for i in range(len(self._chunks))),
key=lambda x: x[0],
reverse=True,
)
results: list[RetrievedChunk] = []
for score, chunk in ranked[: top_k * 2]:
if score <= 0:
break
if filters:
normalized_filter = filters.replace("doc_name", "doc_title").strip()
if normalized_filter.startswith('doc_title == "'):
expected_title = normalized_filter[len('doc_title == "'):-1]
if chunk.doc_title != expected_title:
continue
2026-05-21 23:20:39 +08:00
results.append(
RetrievedChunk(
chunk_id=chunk.chunk_id,
doc_id=chunk.doc_id,
doc_title=chunk.doc_title,
text=chunk.text,
2026-05-21 23:20:39 +08:00
score=score,
chunk_type=chunk.chunk_type,
2026-05-21 23:20:39 +08:00
section_title=chunk.section_title,
page_start=chunk.page_start,
page_end=chunk.page_end,
section_level=chunk.section_level,
chunk_index=chunk.chunk_index,
piece_index=chunk.piece_index,
2026-05-21 23:20:39 +08:00
metadata=chunk.metadata,
)
)
if len(results) >= top_k:
break
return results