Files
catonline_ai/vw-agentic-rag/service/retrieval/clients.py
2025-09-26 17:15:54 +08:00

182 lines
6.4 KiB
Python

"""
Azure AI Search client utilities for retrieval operations.
Contains shared functionality for interacting with Azure AI Search and embedding services.
"""
import httpx
import logging
from typing import Dict, Any, List, Optional
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from ..config import get_config
logger = logging.getLogger(__name__)
class RetrievalAPIError(Exception):
"""Custom exception for retrieval API errors"""
pass
class AzureSearchClient:
"""Shared Azure AI Search client for embedding and search operations"""
def __init__(self):
self.config = get_config()
self.search_endpoint = self.config.retrieval.endpoint
self.api_key = self.config.retrieval.api_key
self.api_version = self.config.retrieval.api_version
self.semantic_configuration = self.config.retrieval.semantic_configuration
self.embedding_client = httpx.AsyncClient(timeout=30.0)
self.search_client = httpx.AsyncClient(timeout=30.0)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.embedding_client.aclose()
await self.search_client.aclose()
async def get_embedding(self, text: str) -> List[float]:
"""Get embedding vector for text using the configured embedding service"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.config.retrieval.embedding.api_key}"
}
payload = {
"input": text,
"model": self.config.retrieval.embedding.model
}
try:
req_url = f"{self.config.retrieval.embedding.base_url}/embeddings"
if self.config.retrieval.embedding.api_version:
req_url += f"?api-version={self.config.retrieval.embedding.api_version}"
response = await self.embedding_client.post(req_url, json=payload, headers=headers)
response.raise_for_status()
result = response.json()
return result["data"][0]["embedding"]
except Exception as e:
logger.error(f"Failed to get embedding: {e}")
raise RetrievalAPIError(f"Embedding generation failed: {str(e)}")
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((httpx.HTTPStatusError, httpx.TimeoutException))
)
async def search_azure_ai(
self,
index_name: str,
search_text: str,
vector_fields: str,
select_fields: str,
search_fields: str,
filter_query: Optional[str] = None,
top_k: int = 10,
score_threshold: float = 1.5
) -> Dict[str, Any]:
"""Make hybrid search request to Azure AI Search with semantic ranking"""
# Get embedding vector for the query
query_vector = await self.get_embedding(search_text)
# Build vector queries based on the vector fields
vector_queries = []
for field in vector_fields.split(","):
field = field.strip()
vector_queries.append({
"kind": "vector",
"vector": query_vector,
"fields": field,
"k": top_k
})
# Build the search request payload
search_payload = {
"search": search_text,
"select": select_fields,
"searchFields": search_fields,
"top": top_k,
"queryType": "semantic",
"semanticConfiguration": self.semantic_configuration,
"vectorQueries": vector_queries
}
if filter_query:
search_payload["filter"] = filter_query
headers = {
"Content-Type": "application/json",
"api-key": self.api_key
}
search_url = f"{self.search_endpoint}/indexes/{index_name}/docs/search"
try:
response = await self.search_client.post(
search_url,
json=search_payload,
headers=headers,
params={"api-version": self.api_version}
)
response.raise_for_status()
result = response.json()
# Filter results by reranker score and add order numbers
filtered_results = []
for i, item in enumerate(result.get("value", [])):
reranker_score = item.get("@search.rerankerScore", 0)
if reranker_score >= score_threshold:
# Add order number
item["@order_num"] = i + 1
# Normalize the result (removes unwanted fields and empty values)
normalized_item = normalize_search_result(item)
filtered_results.append(normalized_item)
return {"value": filtered_results}
except httpx.HTTPStatusError as e:
logger.error(f"Azure AI Search HTTP error {e.response.status_code}: {e.response.text}")
raise RetrievalAPIError(f"Azure AI Search request failed: {e.response.status_code}")
except httpx.TimeoutException:
logger.error("Azure AI Search request timeout")
raise RetrievalAPIError("Azure AI Search request timeout")
except Exception as e:
logger.error(f"Azure AI Search unexpected error: {e}")
raise RetrievalAPIError(f"Azure AI Search unexpected error: {str(e)}")
def normalize_search_result(raw_result: Dict[str, Any]) -> Dict[str, Any]:
"""
Normalize raw Azure AI Search result to clean dynamic structure
Args:
raw_result: Raw result from Azure AI Search
Returns:
Cleaned and normalized result dictionary
"""
# Fields to remove if they exist (belt and suspenders approach)
fields_to_remove = {
"@search.score",
"@search.rerankerScore",
"@search.captions",
"@subquery_id"
}
# Create a copy and remove unwanted fields
result = raw_result.copy()
for field in fields_to_remove:
result.pop(field, None)
# Remove empty fields (None, empty string, empty list, empty dict)
result = {
key: value for key, value in result.items()
if value is not None and value != "" and value != [] and value != {}
}
return result