85 lines
3.2 KiB
Python
85 lines
3.2 KiB
Python
"""Implement cross-encoder reranking via an OpenAI-compatible reranker API."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
|
|
import requests
|
|
from loguru import logger
|
|
|
|
from app.config.settings import settings
|
|
from app.domain.retrieval import Reranker, RetrievedChunk
|
|
|
|
|
|
class OpenAICompatibleReranker(Reranker):
|
|
"""Call a TEI / Cohere-style reranker endpoint to re-score retrieved chunks."""
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: str | None = None,
|
|
model: str | None = None,
|
|
api_key: str | None = None,
|
|
timeout: int = 30,
|
|
) -> None:
|
|
self._base_url = (base_url or settings.reranker_base_url).rstrip("/")
|
|
self._model = model or settings.reranker_model
|
|
self._api_key = api_key or settings.reranker_api_key
|
|
self._timeout = timeout
|
|
|
|
def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int) -> list[RetrievedChunk]:
|
|
"""Return up to top_k chunks re-sorted by cross-encoder score."""
|
|
if not chunks:
|
|
return []
|
|
|
|
texts = [chunk.content for chunk in chunks]
|
|
start = time.time()
|
|
try:
|
|
scores = self._call_reranker(query, texts)
|
|
except Exception as exc:
|
|
logger.warning("Reranker call failed ({}), falling back to original order: {}", type(exc).__name__, exc)
|
|
return chunks[:top_k]
|
|
|
|
elapsed_ms = int((time.time() - start) * 1000)
|
|
logger.debug("Reranker scored {} chunks in {}ms", len(chunks), elapsed_ms)
|
|
|
|
ranked = sorted(
|
|
[(score, chunk) for score, chunk in zip(scores, chunks)],
|
|
key=lambda x: x[0],
|
|
reverse=True,
|
|
)
|
|
result = []
|
|
for score, chunk in ranked[:top_k]:
|
|
chunk.score = float(score)
|
|
result.append(chunk)
|
|
return result
|
|
|
|
def _call_reranker(self, query: str, texts: list[str]) -> list[float]:
|
|
"""Call the reranker API and return a score per text."""
|
|
headers = {"Content-Type": "application/json"}
|
|
if self._api_key:
|
|
headers["Authorization"] = f"Bearer {self._api_key}"
|
|
|
|
# Try TEI format first: POST /rerank
|
|
payload = {"query": query, "texts": texts, "raw_scores": False, "return_text": False}
|
|
url = f"{self._base_url}/rerank"
|
|
resp = requests.post(url, json=payload, headers=headers, timeout=self._timeout)
|
|
|
|
if resp.status_code == 404:
|
|
# Fall back to Cohere / OpenAI-style: POST /v1/rerank
|
|
payload_v1 = {"model": self._model, "query": query, "documents": texts}
|
|
url = f"{self._base_url}/v1/rerank"
|
|
resp = requests.post(url, json=payload_v1, headers=headers, timeout=self._timeout)
|
|
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
|
|
# TEI response: list of {"index": N, "score": F}
|
|
if isinstance(data, list):
|
|
ordered = sorted(data, key=lambda x: x["index"])
|
|
return [float(item["score"]) for item in ordered]
|
|
|
|
# Cohere/OpenAI response: {"results": [{"index": N, "relevance_score": F}]}
|
|
results = data.get("results", [])
|
|
ordered = sorted(results, key=lambda x: x["index"])
|
|
return [float(item.get("relevance_score", item.get("score", 0))) for item in ordered]
|