""" Azure AI Search client utilities for retrieval operations. Contains shared functionality for interacting with Azure AI Search and embedding services. """ import httpx import logging from typing import Dict, Any, List, Optional from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type from ..config import get_config logger = logging.getLogger(__name__) class RetrievalAPIError(Exception): """Custom exception for retrieval API errors""" pass class AzureSearchClient: """Shared Azure AI Search client for embedding and search operations""" def __init__(self): self.config = get_config() self.search_endpoint = self.config.retrieval.endpoint self.api_key = self.config.retrieval.api_key self.api_version = self.config.retrieval.api_version self.semantic_configuration = self.config.retrieval.semantic_configuration self.embedding_client = httpx.AsyncClient(timeout=30.0) self.search_client = httpx.AsyncClient(timeout=30.0) async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.embedding_client.aclose() await self.search_client.aclose() async def get_embedding(self, text: str) -> List[float]: """Get embedding vector for text using the configured embedding service""" headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.config.retrieval.embedding.api_key}" } payload = { "input": text, "model": self.config.retrieval.embedding.model } try: req_url = f"{self.config.retrieval.embedding.base_url}/embeddings" if self.config.retrieval.embedding.api_version: req_url += f"?api-version={self.config.retrieval.embedding.api_version}" response = await self.embedding_client.post(req_url, json=payload, headers=headers) response.raise_for_status() result = response.json() return result["data"][0]["embedding"] except Exception as e: logger.error(f"Failed to get embedding: {e}") raise RetrievalAPIError(f"Embedding generation failed: {str(e)}") @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((httpx.HTTPStatusError, httpx.TimeoutException)) ) async def search_azure_ai( self, index_name: str, search_text: str, vector_fields: str, select_fields: str, search_fields: str, filter_query: Optional[str] = None, top_k: int = 10, score_threshold: float = 1.5 ) -> Dict[str, Any]: """Make hybrid search request to Azure AI Search with semantic ranking""" # Get embedding vector for the query query_vector = await self.get_embedding(search_text) # Build vector queries based on the vector fields vector_queries = [] for field in vector_fields.split(","): field = field.strip() vector_queries.append({ "kind": "vector", "vector": query_vector, "fields": field, "k": top_k }) # Build the search request payload search_payload = { "search": search_text, "select": select_fields, "searchFields": search_fields, "top": top_k, "queryType": "semantic", "semanticConfiguration": self.semantic_configuration, "vectorQueries": vector_queries } if filter_query: search_payload["filter"] = filter_query headers = { "Content-Type": "application/json", "api-key": self.api_key } search_url = f"{self.search_endpoint}/indexes/{index_name}/docs/search" try: response = await self.search_client.post( search_url, json=search_payload, headers=headers, params={"api-version": self.api_version} ) response.raise_for_status() result = response.json() # Filter results by reranker score and add order numbers filtered_results = [] for i, item in enumerate(result.get("value", [])): reranker_score = item.get("@search.rerankerScore", 0) if reranker_score >= score_threshold: # Add order number item["@order_num"] = i + 1 # Normalize the result (removes unwanted fields and empty values) normalized_item = normalize_search_result(item) filtered_results.append(normalized_item) return {"value": filtered_results} except httpx.HTTPStatusError as e: logger.error(f"Azure AI Search HTTP error {e.response.status_code}: {e.response.text}") raise RetrievalAPIError(f"Azure AI Search request failed: {e.response.status_code}") except httpx.TimeoutException: logger.error("Azure AI Search request timeout") raise RetrievalAPIError("Azure AI Search request timeout") except Exception as e: logger.error(f"Azure AI Search unexpected error: {e}") raise RetrievalAPIError(f"Azure AI Search unexpected error: {str(e)}") def normalize_search_result(raw_result: Dict[str, Any]) -> Dict[str, Any]: """ Normalize raw Azure AI Search result to clean dynamic structure Args: raw_result: Raw result from Azure AI Search Returns: Cleaned and normalized result dictionary """ # Fields to remove if they exist (belt and suspenders approach) fields_to_remove = { "@search.score", "@search.rerankerScore", "@search.captions", "@subquery_id" } # Create a copy and remove unwanted fields result = raw_result.copy() for field in fields_to_remove: result.pop(field, None) # Remove empty fields (None, empty string, empty list, empty dict) result = { key: value for key, value in result.items() if value is not None and value != "" and value != [] and value != {} } return result