fix 文档管理模块 & 法规对话模块
This commit is contained in:
@@ -0,0 +1,84 @@
|
||||
"""Implement cross-encoder reranking via an OpenAI-compatible reranker API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
from loguru import logger
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.domain.retrieval import Reranker, RetrievedChunk
|
||||
|
||||
|
||||
class OpenAICompatibleReranker(Reranker):
|
||||
"""Call a TEI / Cohere-style reranker endpoint to re-score retrieved chunks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | None = None,
|
||||
model: str | None = None,
|
||||
api_key: str | None = None,
|
||||
timeout: int = 30,
|
||||
) -> None:
|
||||
self._base_url = (base_url or settings.reranker_base_url).rstrip("/")
|
||||
self._model = model or settings.reranker_model
|
||||
self._api_key = api_key or settings.reranker_api_key
|
||||
self._timeout = timeout
|
||||
|
||||
def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int) -> list[RetrievedChunk]:
|
||||
"""Return up to top_k chunks re-sorted by cross-encoder score."""
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
texts = [chunk.content for chunk in chunks]
|
||||
start = time.time()
|
||||
try:
|
||||
scores = self._call_reranker(query, texts)
|
||||
except Exception as exc:
|
||||
logger.warning("Reranker call failed ({}), falling back to original order: {}", type(exc).__name__, exc)
|
||||
return chunks[:top_k]
|
||||
|
||||
elapsed_ms = int((time.time() - start) * 1000)
|
||||
logger.debug("Reranker scored {} chunks in {}ms", len(chunks), elapsed_ms)
|
||||
|
||||
ranked = sorted(
|
||||
[(score, chunk) for score, chunk in zip(scores, chunks)],
|
||||
key=lambda x: x[0],
|
||||
reverse=True,
|
||||
)
|
||||
result = []
|
||||
for score, chunk in ranked[:top_k]:
|
||||
chunk.score = float(score)
|
||||
result.append(chunk)
|
||||
return result
|
||||
|
||||
def _call_reranker(self, query: str, texts: list[str]) -> list[float]:
|
||||
"""Call the reranker API and return a score per text."""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self._api_key:
|
||||
headers["Authorization"] = f"Bearer {self._api_key}"
|
||||
|
||||
# Try TEI format first: POST /rerank
|
||||
payload = {"query": query, "texts": texts, "raw_scores": False, "return_text": False}
|
||||
url = f"{self._base_url}/rerank"
|
||||
resp = requests.post(url, json=payload, headers=headers, timeout=self._timeout)
|
||||
|
||||
if resp.status_code == 404:
|
||||
# Fall back to Cohere / OpenAI-style: POST /v1/rerank
|
||||
payload_v1 = {"model": self._model, "query": query, "documents": texts}
|
||||
url = f"{self._base_url}/v1/rerank"
|
||||
resp = requests.post(url, json=payload_v1, headers=headers, timeout=self._timeout)
|
||||
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# TEI response: list of {"index": N, "score": F}
|
||||
if isinstance(data, list):
|
||||
ordered = sorted(data, key=lambda x: x["index"])
|
||||
return [float(item["score"]) for item in ordered]
|
||||
|
||||
# Cohere/OpenAI response: {"results": [{"index": N, "relevance_score": F}]}
|
||||
results = data.get("results", [])
|
||||
ordered = sorted(results, key=lambda x: x["index"])
|
||||
return [float(item.get("relevance_score", item.get("score", 0))) for item in ordered]
|
||||
Reference in New Issue
Block a user