"""Implement infrastructure support for openai compatible embedding provider.""" from __future__ import annotations import os import httpx from app.config.settings import settings from app.domain.retrieval import EmbeddingProvider # Keep adapter behavior explicit so integration details remain easy to audit. EMBEDDING_BATCH_SIZE = 8 class OpenAICompatibleEmbeddingProvider(EmbeddingProvider): """Provide the Open A I Compatible Embedding Provider provider.""" def __init__(self) -> None: """Initialize the Open A I Compatible Embedding Provider instance.""" self.base_url = settings.embedding_base_url.rstrip("/") self.api_key = ( settings.embedding_api_key or os.getenv("OPENAI_API_KEY", "") or os.getenv("QWEN_API_KEY", "") or os.getenv("DEEPSEEK_API_KEY", "") ) self.model = settings.embedding_model self.timeout = settings.embedding_timeout_seconds self.dimension = settings.embedding_dim def _raise_for_status(self, response: httpx.Response, *, batch_size: int) -> None: """Raise a detailed error so upstream gateway failures are easier to diagnose.""" try: response.raise_for_status() except httpx.HTTPStatusError as exc: response_preview = response.text[:500].strip() detail = ( f"Embedding request failed for model={self.model}, batch_size={batch_size}, " f"status={response.status_code}, url={response.request.url}, response={response_preview}" ) raise httpx.HTTPStatusError(detail, request=exc.request, response=exc.response) from exc def _request(self, texts: list[str]) -> list[list[float]]: """Handle request for this module for the Open A I Compatible Embedding Provider instance.""" if not self.api_key: raise ValueError("缺少 EMBEDDING_API_KEY / OPENAI_API_KEY") response = httpx.post( f"{self.base_url}/embeddings", headers={ "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", }, json={"model": self.model, "input": texts}, timeout=self.timeout, ) self._raise_for_status(response, batch_size=len(texts)) data = response.json() vectors = [item["embedding"] for item in sorted(data.get("data", []), key=lambda item: item["index"])] if any(len(vector) != self.dimension for vector in vectors): raise ValueError(f"embedding 维度不匹配,期望 {self.dimension}") return vectors def embed_texts(self, texts: list[str]) -> list[list[float]]: """Embed texts for the Open A I Compatible Embedding Provider instance.""" if not texts: return [] vectors: list[list[float]] = [] # Batch requests conservatively because some gateways reject larger embedding payloads. for start in range(0, len(texts), EMBEDDING_BATCH_SIZE): batch = texts[start:start + EMBEDDING_BATCH_SIZE] vectors.extend(self._request(batch)) return vectors def embed_query(self, text: str) -> list[float]: """Embed query for the Open A I Compatible Embedding Provider instance.""" vectors = self._request([text]) return vectors[0]