fix 文档管理模块 & 法规对话模块
This commit is contained in:
@@ -0,0 +1,84 @@
|
||||
"""Implement cross-encoder reranking via an OpenAI-compatible reranker API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
from loguru import logger
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.domain.retrieval import Reranker, RetrievedChunk
|
||||
|
||||
|
||||
class OpenAICompatibleReranker(Reranker):
|
||||
"""Call a TEI / Cohere-style reranker endpoint to re-score retrieved chunks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | None = None,
|
||||
model: str | None = None,
|
||||
api_key: str | None = None,
|
||||
timeout: int = 30,
|
||||
) -> None:
|
||||
self._base_url = (base_url or settings.reranker_base_url).rstrip("/")
|
||||
self._model = model or settings.reranker_model
|
||||
self._api_key = api_key or settings.reranker_api_key
|
||||
self._timeout = timeout
|
||||
|
||||
def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int) -> list[RetrievedChunk]:
|
||||
"""Return up to top_k chunks re-sorted by cross-encoder score."""
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
texts = [chunk.content for chunk in chunks]
|
||||
start = time.time()
|
||||
try:
|
||||
scores = self._call_reranker(query, texts)
|
||||
except Exception as exc:
|
||||
logger.warning("Reranker call failed ({}), falling back to original order: {}", type(exc).__name__, exc)
|
||||
return chunks[:top_k]
|
||||
|
||||
elapsed_ms = int((time.time() - start) * 1000)
|
||||
logger.debug("Reranker scored {} chunks in {}ms", len(chunks), elapsed_ms)
|
||||
|
||||
ranked = sorted(
|
||||
[(score, chunk) for score, chunk in zip(scores, chunks)],
|
||||
key=lambda x: x[0],
|
||||
reverse=True,
|
||||
)
|
||||
result = []
|
||||
for score, chunk in ranked[:top_k]:
|
||||
chunk.score = float(score)
|
||||
result.append(chunk)
|
||||
return result
|
||||
|
||||
def _call_reranker(self, query: str, texts: list[str]) -> list[float]:
|
||||
"""Call the reranker API and return a score per text."""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self._api_key:
|
||||
headers["Authorization"] = f"Bearer {self._api_key}"
|
||||
|
||||
# Try TEI format first: POST /rerank
|
||||
payload = {"query": query, "texts": texts, "raw_scores": False, "return_text": False}
|
||||
url = f"{self._base_url}/rerank"
|
||||
resp = requests.post(url, json=payload, headers=headers, timeout=self._timeout)
|
||||
|
||||
if resp.status_code == 404:
|
||||
# Fall back to Cohere / OpenAI-style: POST /v1/rerank
|
||||
payload_v1 = {"model": self._model, "query": query, "documents": texts}
|
||||
url = f"{self._base_url}/v1/rerank"
|
||||
resp = requests.post(url, json=payload_v1, headers=headers, timeout=self._timeout)
|
||||
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# TEI response: list of {"index": N, "score": F}
|
||||
if isinstance(data, list):
|
||||
ordered = sorted(data, key=lambda x: x["index"])
|
||||
return [float(item["score"]) for item in ordered]
|
||||
|
||||
# Cohere/OpenAI response: {"results": [{"index": N, "relevance_score": F}]}
|
||||
results = data.get("results", [])
|
||||
ordered = sorted(results, key=lambda x: x["index"])
|
||||
return [float(item.get("relevance_score", item.get("score", 0))) for item in ordered]
|
||||
@@ -100,14 +100,42 @@ class MilvusVectorIndex(VectorIndex):
|
||||
result = self.collection.delete(f'doc_id == "{doc_id}"')
|
||||
return len(result.primary_keys)
|
||||
|
||||
def _parse_filters(self, filters: str | None) -> str | None:
|
||||
"""Parse filter string into Milvus expression."""
|
||||
if not filters or not filters.strip():
|
||||
return None
|
||||
|
||||
filters = filters.strip()
|
||||
|
||||
# Check if already a Milvus expression (contains operators)
|
||||
if any(op in filters for op in ["==", "!=", "in", "not in", ">", "<", ">=", "<=", "and", "or"]):
|
||||
return filters
|
||||
|
||||
# Parse simple regulation_type filter
|
||||
# Support: "GB" or "GB,UN-ECE" or "GB, UN-ECE"
|
||||
types = [t.strip() for t in filters.split(",") if t.strip()]
|
||||
|
||||
if not types:
|
||||
return None
|
||||
|
||||
if len(types) == 1:
|
||||
# Single value: regulation_type == "GB"
|
||||
return f'regulation_type == "{types[0]}"'
|
||||
else:
|
||||
# Multiple values: regulation_type in ["GB", "UN-ECE"]
|
||||
quoted_types = [f'"{t}"' for t in types]
|
||||
return f'regulation_type in [{", ".join(quoted_types)}]'
|
||||
|
||||
def search(self, query_vector: list[float], top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
||||
"""Handle search for the Milvus Vector Index instance."""
|
||||
milvus_expr = self._parse_filters(filters)
|
||||
|
||||
results = self.collection.search(
|
||||
data=[query_vector],
|
||||
anns_field="embedding",
|
||||
param={"metric_type": "COSINE", "params": {"nprobe": settings.milvus_nprobe}},
|
||||
limit=top_k,
|
||||
filter=filters,
|
||||
expr=milvus_expr,
|
||||
output_fields=[
|
||||
"doc_id",
|
||||
"doc_name",
|
||||
@@ -145,6 +173,49 @@ class MilvusVectorIndex(VectorIndex):
|
||||
)
|
||||
return payload
|
||||
|
||||
def count_by_document(self) -> dict[str, int]:
|
||||
"""Return doc_id -> chunk count from Milvus."""
|
||||
try:
|
||||
rows = self.collection.query(expr="doc_id != \"\"", output_fields=["doc_id"])
|
||||
except Exception:
|
||||
return {}
|
||||
counts: dict[str, int] = {}
|
||||
for row in rows:
|
||||
doc_id = row.get("doc_id", "")
|
||||
if doc_id:
|
||||
counts[doc_id] = counts.get(doc_id, 0) + 1
|
||||
return counts
|
||||
|
||||
def list_document_metadata(self) -> list[dict]:
|
||||
"""Return one metadata row per document from Milvus (single query, no embeddings)."""
|
||||
try:
|
||||
rows = self.collection.query(
|
||||
expr="doc_id != \"\"",
|
||||
output_fields=["doc_id", "doc_name", "regulation_type", "version"],
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
seen: dict[str, dict] = {}
|
||||
counts: dict[str, int] = {}
|
||||
for row in rows:
|
||||
doc_id = row.get("doc_id", "")
|
||||
if not doc_id:
|
||||
continue
|
||||
counts[doc_id] = counts.get(doc_id, 0) + 1
|
||||
if doc_id not in seen:
|
||||
seen[doc_id] = {
|
||||
"doc_id": doc_id,
|
||||
"doc_name": row.get("doc_name", ""),
|
||||
"regulation_type": row.get("regulation_type", ""),
|
||||
"version": row.get("version", ""),
|
||||
}
|
||||
|
||||
return [
|
||||
{**meta, "chunk_count": counts[meta["doc_id"]]}
|
||||
for meta in seen.values()
|
||||
]
|
||||
|
||||
def health(self) -> dict:
|
||||
"""Handle health for the Milvus Vector Index instance."""
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user