Files
AIRegulation-DocAnalysis/backend/app/infrastructure/vectorstore/bm25_retriever.py
ash66 30c7bda389 Refactor document handling and update Milvus collection settings
- Removed multiple failed document entries from `documents.json`.
- Added a new document entry with updated metadata and changed the index name to `regulations_dense_1024_v2`.
- Updated architecture documentation to reflect changes in the Milvus collection name.
- Adjusted requirements by removing the sqlalchemy dependency.
- Modified test cases to align with new document structure and naming conventions.
- Introduced a new test file for Milvus vector index runtime recovery and error handling.
- Updated assertions in various test files to ensure compatibility with the new schema.
2026-05-26 20:21:31 +08:00

184 lines
6.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""BM25 sparse retriever backed by the existing Milvus collection.
Falls back to no-op if rank_bm25 or jieba is not installed.
The index is built lazily on first query and can be refreshed after
new documents are ingested by calling refresh().
"""
from __future__ import annotations
import logging
import threading
from typing import TYPE_CHECKING
from app.domain.retrieval import RetrievedChunk
if TYPE_CHECKING:
from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex
logger = logging.getLogger(__name__)
try:
from rank_bm25 import BM25Okapi
import jieba
jieba.setLogLevel(logging.WARNING)
_BM25_AVAILABLE = True
except ImportError:
_BM25_AVAILABLE = False
logger.warning("rank_bm25 or jieba not installed BM25 retrieval disabled. Run: pip install rank-bm25 jieba")
def _tokenize(text: str) -> list[str]:
"""Tokenize Chinese/mixed text with jieba."""
return list(jieba.cut(text))
class BM25Retriever:
"""Sparse BM25 retriever that indexes chunks stored in Milvus.
Thread-safe: the index is protected by a lock so concurrent requests
during the initial build do not race.
"""
def __init__(self, vector_index: "MilvusVectorIndex") -> None:
self._vector_index = vector_index
self._lock = threading.Lock()
self._index: "BM25Okapi | None" = None
self._chunks: list[RetrievedChunk] = []
@property
def available(self) -> bool:
return _BM25_AVAILABLE
def _load_all_chunks(self) -> list[RetrievedChunk]:
"""Fetch all chunk records from Milvus (no vectors needed)."""
try:
rows = self._vector_index.collection.query(
expr='doc_id != ""',
output_fields=[
"id",
"chunk_id",
"doc_id",
"doc_title",
"text",
"chunk_type",
"section_title",
"page_start",
"page_end",
"section_level",
"chunk_index",
"piece_index",
"metadata_json",
],
limit=16384,
)
except Exception:
logger.exception("BM25Retriever: failed to fetch chunks from Milvus")
return []
return [
RetrievedChunk(
chunk_id=str(row.get("chunk_id") or row.get("id", "")),
doc_id=str(row.get("doc_id", "")),
doc_title=str(row.get("doc_title", "")),
text=str(row.get("text", "")),
score=0.0,
chunk_type=str(row.get("chunk_type", "")),
section_title=str(row.get("section_title", "")),
page_start=int(row.get("page_start") or 0),
page_end=int(row.get("page_end") or 0),
section_level=int(row.get("section_level") or 0),
chunk_index=int(row.get("chunk_index") or 0),
piece_index=int(row.get("piece_index") or 0),
metadata=self._parse_metadata_json(row.get("metadata_json", "")),
)
for row in rows
if row.get("text")
]
def _parse_metadata_json(self, raw_metadata: str) -> dict:
"""Parse metadata_json into a dict for BM25-side filtering."""
if not raw_metadata:
return {}
try:
return dict(__import__("json").loads(raw_metadata))
except Exception:
return {}
def _ensure_built(self) -> None:
if self._index is not None:
return
with self._lock:
if self._index is not None:
return
self._build()
def _build(self) -> None:
logger.info("BM25Retriever: building index …")
chunks = self._load_all_chunks()
if not chunks:
logger.warning("BM25Retriever: no chunks found, index is empty")
self._chunks = []
self._index = BM25Okapi([[]])
return
tokenized = [_tokenize(c.text) for c in chunks]
self._chunks = chunks
self._index = BM25Okapi(tokenized)
logger.info("BM25Retriever: index built with %d chunks", len(chunks))
def refresh(self) -> None:
"""Rebuild the index (call after new documents are ingested)."""
with self._lock:
self._index = None
self._chunks = []
self._build()
def retrieve(self, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
"""Return top_k chunks ranked by BM25 score."""
if not _BM25_AVAILABLE:
return []
self._ensure_built()
if not self._chunks:
return []
tokens = _tokenize(query)
scores = self._index.get_scores(tokens) # type: ignore[union-attr]
# Pair and sort descending by score
ranked = sorted(
((float(scores[i]), self._chunks[i]) for i in range(len(self._chunks))),
key=lambda x: x[0],
reverse=True,
)
results: list[RetrievedChunk] = []
for score, chunk in ranked[: top_k * 2]:
if score <= 0:
break
if filters:
normalized_filter = filters.replace("doc_name", "doc_title").strip()
if normalized_filter.startswith('doc_title == "'):
expected_title = normalized_filter[len('doc_title == "'):-1]
if chunk.doc_title != expected_title:
continue
results.append(
RetrievedChunk(
chunk_id=chunk.chunk_id,
doc_id=chunk.doc_id,
doc_title=chunk.doc_title,
text=chunk.text,
score=score,
chunk_type=chunk.chunk_type,
section_title=chunk.section_title,
page_start=chunk.page_start,
page_end=chunk.page_end,
section_level=chunk.section_level,
chunk_index=chunk.chunk_index,
piece_index=chunk.piece_index,
metadata=chunk.metadata,
)
)
if len(results) >= top_k:
break
return results