Files
AIRegulation-DocAnalysis/backend/app/infrastructure/vectorstore/cross_encoder_reranker.py
ash66 30c7bda389 Refactor document handling and update Milvus collection settings
- Removed multiple failed document entries from `documents.json`.
- Added a new document entry with updated metadata and changed the index name to `regulations_dense_1024_v2`.
- Updated architecture documentation to reflect changes in the Milvus collection name.
- Adjusted requirements by removing the sqlalchemy dependency.
- Modified test cases to align with new document structure and naming conventions.
- Introduced a new test file for Milvus vector index runtime recovery and error handling.
- Updated assertions in various test files to ensure compatibility with the new schema.
2026-05-26 20:21:31 +08:00

85 lines
3.2 KiB
Python

"""Implement cross-encoder reranking via an OpenAI-compatible reranker API."""
from __future__ import annotations
import time
import requests
from loguru import logger
from app.config.settings import settings
from app.domain.retrieval import Reranker, RetrievedChunk
class OpenAICompatibleReranker(Reranker):
"""Call a TEI / Cohere-style reranker endpoint to re-score retrieved chunks."""
def __init__(
self,
base_url: str | None = None,
model: str | None = None,
api_key: str | None = None,
timeout: int = 30,
) -> None:
self._base_url = (base_url or settings.reranker_base_url).rstrip("/")
self._model = model or settings.reranker_model
self._api_key = api_key or settings.reranker_api_key
self._timeout = timeout
def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int) -> list[RetrievedChunk]:
"""Return up to top_k chunks re-sorted by cross-encoder score."""
if not chunks:
return []
texts = [chunk.text for chunk in chunks]
start = time.time()
try:
scores = self._call_reranker(query, texts)
except Exception as exc:
logger.warning("Reranker call failed ({}), falling back to original order: {}", type(exc).__name__, exc)
return chunks[:top_k]
elapsed_ms = int((time.time() - start) * 1000)
logger.debug("Reranker scored {} chunks in {}ms", len(chunks), elapsed_ms)
ranked = sorted(
[(score, chunk) for score, chunk in zip(scores, chunks)],
key=lambda x: x[0],
reverse=True,
)
result = []
for score, chunk in ranked[:top_k]:
chunk.score = float(score)
result.append(chunk)
return result
def _call_reranker(self, query: str, texts: list[str]) -> list[float]:
"""Call the reranker API and return a score per text."""
headers = {"Content-Type": "application/json"}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
# Try TEI format first: POST /rerank
payload = {"query": query, "texts": texts, "raw_scores": False, "return_text": False}
url = f"{self._base_url}/rerank"
resp = requests.post(url, json=payload, headers=headers, timeout=self._timeout)
if resp.status_code == 404:
# Fall back to Cohere / OpenAI-style: POST /v1/rerank
payload_v1 = {"model": self._model, "query": query, "documents": texts}
url = f"{self._base_url}/v1/rerank"
resp = requests.post(url, json=payload_v1, headers=headers, timeout=self._timeout)
resp.raise_for_status()
data = resp.json()
# TEI response: list of {"index": N, "score": F}
if isinstance(data, list):
ordered = sorted(data, key=lambda x: x["index"])
return [float(item["score"]) for item in ordered]
# Cohere/OpenAI response: {"results": [{"index": N, "relevance_score": F}]}
results = data.get("results", [])
ordered = sorted(results, key=lambda x: x["index"])
return [float(item.get("relevance_score", item.get("score", 0))) for item in ordered]