2026-05-18 16:32:42 +08:00
|
|
|
"""Implement infrastructure support for openai compatible embedding provider."""
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
import httpx
|
|
|
|
|
|
|
|
|
|
from app.config.settings import settings
|
|
|
|
|
from app.domain.retrieval import EmbeddingProvider
|
|
|
|
|
# Keep adapter behavior explicit so integration details remain easy to audit.
|
|
|
|
|
|
2026-05-18 22:30:28 +08:00
|
|
|
EMBEDDING_BATCH_SIZE = 8
|
|
|
|
|
|
2026-05-18 16:32:42 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpenAICompatibleEmbeddingProvider(EmbeddingProvider):
|
|
|
|
|
"""Provide the Open A I Compatible Embedding Provider provider."""
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
"""Initialize the Open A I Compatible Embedding Provider instance."""
|
|
|
|
|
self.base_url = settings.embedding_base_url.rstrip("/")
|
|
|
|
|
self.api_key = (
|
|
|
|
|
settings.embedding_api_key
|
|
|
|
|
or os.getenv("OPENAI_API_KEY", "")
|
|
|
|
|
or os.getenv("QWEN_API_KEY", "")
|
|
|
|
|
or os.getenv("DEEPSEEK_API_KEY", "")
|
|
|
|
|
)
|
|
|
|
|
self.model = settings.embedding_model
|
|
|
|
|
self.timeout = settings.embedding_timeout_seconds
|
|
|
|
|
self.dimension = settings.embedding_dim
|
|
|
|
|
|
2026-05-18 22:30:28 +08:00
|
|
|
def _raise_for_status(self, response: httpx.Response, *, batch_size: int) -> None:
|
|
|
|
|
"""Raise a detailed error so upstream gateway failures are easier to diagnose."""
|
|
|
|
|
try:
|
|
|
|
|
response.raise_for_status()
|
|
|
|
|
except httpx.HTTPStatusError as exc:
|
|
|
|
|
response_preview = response.text[:500].strip()
|
|
|
|
|
detail = (
|
|
|
|
|
f"Embedding request failed for model={self.model}, batch_size={batch_size}, "
|
|
|
|
|
f"status={response.status_code}, url={response.request.url}, response={response_preview}"
|
|
|
|
|
)
|
|
|
|
|
raise httpx.HTTPStatusError(detail, request=exc.request, response=exc.response) from exc
|
|
|
|
|
|
2026-05-18 16:32:42 +08:00
|
|
|
def _request(self, texts: list[str]) -> list[list[float]]:
|
|
|
|
|
"""Handle request for this module for the Open A I Compatible Embedding Provider instance."""
|
|
|
|
|
if not self.api_key:
|
|
|
|
|
raise ValueError("缺少 EMBEDDING_API_KEY / OPENAI_API_KEY")
|
|
|
|
|
response = httpx.post(
|
|
|
|
|
f"{self.base_url}/embeddings",
|
|
|
|
|
headers={
|
|
|
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
|
|
|
"Content-Type": "application/json",
|
|
|
|
|
},
|
|
|
|
|
json={"model": self.model, "input": texts},
|
|
|
|
|
timeout=self.timeout,
|
|
|
|
|
)
|
2026-05-18 22:30:28 +08:00
|
|
|
self._raise_for_status(response, batch_size=len(texts))
|
2026-05-18 16:32:42 +08:00
|
|
|
data = response.json()
|
|
|
|
|
vectors = [item["embedding"] for item in sorted(data.get("data", []), key=lambda item: item["index"])]
|
|
|
|
|
if any(len(vector) != self.dimension for vector in vectors):
|
|
|
|
|
raise ValueError(f"embedding 维度不匹配,期望 {self.dimension}")
|
|
|
|
|
return vectors
|
|
|
|
|
|
|
|
|
|
def embed_texts(self, texts: list[str]) -> list[list[float]]:
|
|
|
|
|
"""Embed texts for the Open A I Compatible Embedding Provider instance."""
|
|
|
|
|
if not texts:
|
|
|
|
|
return []
|
2026-05-18 22:30:28 +08:00
|
|
|
vectors: list[list[float]] = []
|
|
|
|
|
# Batch requests conservatively because some gateways reject larger embedding payloads.
|
|
|
|
|
for start in range(0, len(texts), EMBEDDING_BATCH_SIZE):
|
|
|
|
|
batch = texts[start:start + EMBEDDING_BATCH_SIZE]
|
|
|
|
|
vectors.extend(self._request(batch))
|
|
|
|
|
return vectors
|
2026-05-18 16:32:42 +08:00
|
|
|
|
|
|
|
|
def embed_query(self, text: str) -> list[float]:
|
|
|
|
|
"""Embed query for the Open A I Compatible Embedding Provider instance."""
|
|
|
|
|
vectors = self._request([text])
|
|
|
|
|
return vectors[0]
|