This commit is contained in:
2025-09-26 17:15:54 +08:00
commit db0e5965ec
211 changed files with 40437 additions and 0 deletions

View File

@@ -0,0 +1 @@
# Empty __init__.py files to make packages

View File

@@ -0,0 +1,181 @@
"""
Azure AI Search client utilities for retrieval operations.
Contains shared functionality for interacting with Azure AI Search and embedding services.
"""
import httpx
import logging
from typing import Dict, Any, List, Optional
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from ..config import get_config
logger = logging.getLogger(__name__)
class RetrievalAPIError(Exception):
"""Custom exception for retrieval API errors"""
pass
class AzureSearchClient:
"""Shared Azure AI Search client for embedding and search operations"""
def __init__(self):
self.config = get_config()
self.search_endpoint = self.config.retrieval.endpoint
self.api_key = self.config.retrieval.api_key
self.api_version = self.config.retrieval.api_version
self.semantic_configuration = self.config.retrieval.semantic_configuration
self.embedding_client = httpx.AsyncClient(timeout=30.0)
self.search_client = httpx.AsyncClient(timeout=30.0)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.embedding_client.aclose()
await self.search_client.aclose()
async def get_embedding(self, text: str) -> List[float]:
"""Get embedding vector for text using the configured embedding service"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.config.retrieval.embedding.api_key}"
}
payload = {
"input": text,
"model": self.config.retrieval.embedding.model
}
try:
req_url = f"{self.config.retrieval.embedding.base_url}/embeddings"
if self.config.retrieval.embedding.api_version:
req_url += f"?api-version={self.config.retrieval.embedding.api_version}"
response = await self.embedding_client.post(req_url, json=payload, headers=headers)
response.raise_for_status()
result = response.json()
return result["data"][0]["embedding"]
except Exception as e:
logger.error(f"Failed to get embedding: {e}")
raise RetrievalAPIError(f"Embedding generation failed: {str(e)}")
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((httpx.HTTPStatusError, httpx.TimeoutException))
)
async def search_azure_ai(
self,
index_name: str,
search_text: str,
vector_fields: str,
select_fields: str,
search_fields: str,
filter_query: Optional[str] = None,
top_k: int = 10,
score_threshold: float = 1.5
) -> Dict[str, Any]:
"""Make hybrid search request to Azure AI Search with semantic ranking"""
# Get embedding vector for the query
query_vector = await self.get_embedding(search_text)
# Build vector queries based on the vector fields
vector_queries = []
for field in vector_fields.split(","):
field = field.strip()
vector_queries.append({
"kind": "vector",
"vector": query_vector,
"fields": field,
"k": top_k
})
# Build the search request payload
search_payload = {
"search": search_text,
"select": select_fields,
"searchFields": search_fields,
"top": top_k,
"queryType": "semantic",
"semanticConfiguration": self.semantic_configuration,
"vectorQueries": vector_queries
}
if filter_query:
search_payload["filter"] = filter_query
headers = {
"Content-Type": "application/json",
"api-key": self.api_key
}
search_url = f"{self.search_endpoint}/indexes/{index_name}/docs/search"
try:
response = await self.search_client.post(
search_url,
json=search_payload,
headers=headers,
params={"api-version": self.api_version}
)
response.raise_for_status()
result = response.json()
# Filter results by reranker score and add order numbers
filtered_results = []
for i, item in enumerate(result.get("value", [])):
reranker_score = item.get("@search.rerankerScore", 0)
if reranker_score >= score_threshold:
# Add order number
item["@order_num"] = i + 1
# Normalize the result (removes unwanted fields and empty values)
normalized_item = normalize_search_result(item)
filtered_results.append(normalized_item)
return {"value": filtered_results}
except httpx.HTTPStatusError as e:
logger.error(f"Azure AI Search HTTP error {e.response.status_code}: {e.response.text}")
raise RetrievalAPIError(f"Azure AI Search request failed: {e.response.status_code}")
except httpx.TimeoutException:
logger.error("Azure AI Search request timeout")
raise RetrievalAPIError("Azure AI Search request timeout")
except Exception as e:
logger.error(f"Azure AI Search unexpected error: {e}")
raise RetrievalAPIError(f"Azure AI Search unexpected error: {str(e)}")
def normalize_search_result(raw_result: Dict[str, Any]) -> Dict[str, Any]:
"""
Normalize raw Azure AI Search result to clean dynamic structure
Args:
raw_result: Raw result from Azure AI Search
Returns:
Cleaned and normalized result dictionary
"""
# Fields to remove if they exist (belt and suspenders approach)
fields_to_remove = {
"@search.score",
"@search.rerankerScore",
"@search.captions",
"@subquery_id"
}
# Create a copy and remove unwanted fields
result = raw_result.copy()
for field in fields_to_remove:
result.pop(field, None)
# Remove empty fields (None, empty string, empty list, empty dict)
result = {
key: value for key, value in result.items()
if value is not None and value != "" and value != [] and value != {}
}
return result

View File

@@ -0,0 +1,58 @@
import logging
import time
from ..config import get_config
from service.retrieval.clients import AzureSearchClient
from service.retrieval.model import RetrievalResponse
logger = logging.getLogger(__name__)
class GenericChunkRetrieval:
def __init__(self)->None:
self.config = get_config()
self.search_client = AzureSearchClient()
async def retrieve_doc_chunk(
self,
query: str,
conversation_history: str = "",
**kwargs
) -> RetrievalResponse:
"""Search CATOnline system user manual document chunks"""
start_time = time.time()
# Use the new Azure AI Search approach
index_name = self.config.retrieval.index.chunk_user_manual_index
vector_fields = "contentVector"
select_fields = "content, title, full_headers"
search_fields = "content, title, full_headers"
top_k = kwargs.get("top_k", 10)
score_threshold = kwargs.get("score_threshold", 1.5)
try:
response_data = await self.search_client.search_azure_ai(
index_name=index_name,
search_text=query,
vector_fields=vector_fields,
select_fields=select_fields,
search_fields=search_fields,
top_k=top_k,
score_threshold=score_threshold
)
results = response_data.get("value", [])
took_ms = int((time.time() - start_time) * 1000)
return RetrievalResponse(
results=results,
took_ms=took_ms,
total_count=len(results)
)
except Exception as e:
logger.error(f"retrieve_doc_chunk_user_manual failed: {e}")
raise

View File

@@ -0,0 +1,11 @@
from typing import Any, Optional
from pydantic import BaseModel
class RetrievalResponse(BaseModel):
"""Simple response container for tool results"""
results: list[dict[str, Any]]
took_ms: Optional[int] = None
total_count: Optional[int] = None

View File

@@ -0,0 +1,158 @@
import httpx
import time
import json
from typing import Dict, Any, List, Optional
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import logging
from .model import RetrievalResponse
from ..config import get_config
from .clients import AzureSearchClient, RetrievalAPIError
logger = logging.getLogger(__name__)
class AgenticRetrieval:
"""Azure AI Search client for retrieval tools"""
def __init__(self):
self.config = get_config()
self.search_client = AzureSearchClient()
async def __aenter__(self):
await self.search_client.__aenter__()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.search_client.__aexit__(exc_type, exc_val, exc_tb)
async def retrieve_standard_regulation(
self,
query: str,
conversation_history: str = "",
**kwargs
) -> RetrievalResponse:
"""Search standard/regulation attributes"""
start_time = time.time()
# Use the new Azure AI Search approach
index_name = self.config.retrieval.index.standard_regulation_index
vector_fields = "full_metadata_vector"
select_fields = "id, func_uuid, title, publisher, document_category, document_code, x_Standard_Regulation_Id, x_Attachment_Type, x_Standard_Title_CN, x_Standard_Title_EN, x_Standard_Published_State, x_Standard_Drafting_Status, x_Standard_Published_State_EN, x_Standard_Drafting_Status_EN, x_Standard_Range, x_Standard_Kind, x_Standard_No, x_Standard_Technical_Committee, x_Standard_Vehicle_Type, x_Standard_Power_Type, x_Standard_CCS, x_Standard_ICS, x_Standard_Published_Date, x_Standard_Effective_Date, x_Regulation_Status, x_Regulation_Status_EN, x_Regulation_Title_CN, x_Regulation_Title_EN, x_Regulation_Document_No, x_Regulation_Issued_Date, x_Classification, x_Work_Group, x_Reference_Standard, x_Replaced_by, x_Refer_To, update_time, status"
search_fields = "title, publisher, document_category, document_code, x_Standard_Regulation_Id, x_Attachment_Type, x_Standard_Title_CN, x_Standard_Title_EN, x_Standard_Published_State, x_Standard_Drafting_Status, x_Standard_Published_State_EN, x_Standard_Drafting_Status_EN, x_Standard_Range, x_Standard_Kind, x_Standard_No, x_Standard_Technical_Committee, x_Standard_Vehicle_Type, x_Standard_Power_Type, x_Standard_CCS, x_Standard_ICS, x_Standard_Published_Date, x_Standard_Effective_Date, x_Regulation_Status, x_Regulation_Status_EN, x_Regulation_Title_CN, x_Regulation_Title_EN, x_Regulation_Document_No, x_Regulation_Issued_Date, x_Classification, x_Work_Group, x_Reference_Standard, x_Replaced_by, x_Refer_To, update_time, status"
top_k = kwargs.get("top_k", 10)
score_threshold = kwargs.get("score_threshold", 1.5)
try:
response_data = await self.search_client.search_azure_ai(
index_name=index_name,
search_text=query,
vector_fields=vector_fields,
select_fields=select_fields,
search_fields=search_fields,
top_k=top_k,
score_threshold=score_threshold
)
results = response_data.get("value", [])
took_ms = int((time.time() - start_time) * 1000)
return RetrievalResponse(
results=results,
took_ms=took_ms,
total_count=len(results)
)
except Exception as e:
logger.error(f"retrieve_standard_regulation failed: {e}")
raise
async def retrieve_doc_chunk_standard_regulation(
self,
query: str,
conversation_history: str = "",
**kwargs
) -> RetrievalResponse:
"""Search standard/regulation document chunks"""
start_time = time.time()
# Use the new Azure AI Search approach
index_name = self.config.retrieval.index.chunk_index
vector_fields = "contentVector, full_metadata_vector"
select_fields = "content, title, full_headers, document_code, document_category, publisher, x_Regulation_Title_CN, x_Regulation_Title_EN, x_Standard_Title_CN, x_Standard_Title_EN, x_Standard_Kind, x_Standard_CCS, x_Standard_ICS, x_Standard_Vehicle_Type, x_Standard_Power_Type, id, metadata, func_uuid, filepath, x_Standard_Regulation_Id"
search_fields = "content, title, full_headers, document_code, document_category, publisher, x_Regulation_Title_CN, x_Regulation_Title_EN, x_Standard_Title_CN, x_Standard_Title_EN, x_Standard_Kind, x_Standard_CCS, x_Standard_ICS, x_Standard_Vehicle_Type, x_Standard_Power_Type"
filter_query = "(document_category eq 'Standard' or document_category eq 'Regulation') and (status eq '已发布') and (x_Standard_Published_State_EN eq 'Effective' or x_Standard_Published_State_EN eq 'Publication' or x_Standard_Published_State_EN eq 'Implementation' or x_Regulation_Status_EN eq 'Publication' or x_Regulation_Status_EN eq 'Implementation') and (x_Attachment_Type eq '标准附件(PUBLISHED_STANDARDS)' or x_Attachment_Type eq '已发布法规附件(ISSUED_REGULATION)')"
top_k = kwargs.get("top_k", 10)
score_threshold = kwargs.get("score_threshold", 1.5)
try:
response_data = await self.search_client.search_azure_ai(
index_name=index_name,
search_text=query,
vector_fields=vector_fields,
select_fields=select_fields,
search_fields=search_fields,
filter_query=filter_query,
top_k=top_k,
score_threshold=score_threshold
)
results = response_data.get("value", [])
took_ms = int((time.time() - start_time) * 1000)
return RetrievalResponse(
results=results,
took_ms=took_ms,
total_count=len(results)
)
except Exception as e:
logger.error(f"retrieve_doc_chunk_standard_regulation failed: {e}")
raise
async def retrieve_doc_chunk_user_manual(
self,
query: str,
conversation_history: str = "",
**kwargs
) -> RetrievalResponse:
"""Search CATOnline system user manual document chunks"""
start_time = time.time()
# Use the new Azure AI Search approach
index_name = self.config.retrieval.index.chunk_user_manual_index
vector_fields = "contentVector"
select_fields = "content, title, full_headers"
search_fields = "content, title, full_headers"
top_k = kwargs.get("top_k", 10)
score_threshold = kwargs.get("score_threshold", 1.5)
try:
response_data = await self.search_client.search_azure_ai(
index_name=index_name,
search_text=query,
vector_fields=vector_fields,
select_fields=select_fields,
search_fields=search_fields,
top_k=top_k,
score_threshold=score_threshold
)
results = response_data.get("value", [])
took_ms = int((time.time() - start_time) * 1000)
return RetrievalResponse(
results=results,
took_ms=took_ms,
total_count=len(results)
)
except Exception as e:
logger.error(f"retrieve_doc_chunk_user_manual failed: {e}")
raise