init
This commit is contained in:
1
vw-agentic-rag/service/retrieval/__init__.py
Normal file
1
vw-agentic-rag/service/retrieval/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Empty __init__.py files to make packages
|
||||
181
vw-agentic-rag/service/retrieval/clients.py
Normal file
181
vw-agentic-rag/service/retrieval/clients.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
Azure AI Search client utilities for retrieval operations.
|
||||
Contains shared functionality for interacting with Azure AI Search and embedding services.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
|
||||
from ..config import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetrievalAPIError(Exception):
|
||||
"""Custom exception for retrieval API errors"""
|
||||
pass
|
||||
|
||||
|
||||
class AzureSearchClient:
|
||||
"""Shared Azure AI Search client for embedding and search operations"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = get_config()
|
||||
self.search_endpoint = self.config.retrieval.endpoint
|
||||
self.api_key = self.config.retrieval.api_key
|
||||
self.api_version = self.config.retrieval.api_version
|
||||
self.semantic_configuration = self.config.retrieval.semantic_configuration
|
||||
self.embedding_client = httpx.AsyncClient(timeout=30.0)
|
||||
self.search_client = httpx.AsyncClient(timeout=30.0)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.embedding_client.aclose()
|
||||
await self.search_client.aclose()
|
||||
|
||||
async def get_embedding(self, text: str) -> List[float]:
|
||||
"""Get embedding vector for text using the configured embedding service"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.config.retrieval.embedding.api_key}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"input": text,
|
||||
"model": self.config.retrieval.embedding.model
|
||||
}
|
||||
|
||||
try:
|
||||
req_url = f"{self.config.retrieval.embedding.base_url}/embeddings"
|
||||
if self.config.retrieval.embedding.api_version:
|
||||
req_url += f"?api-version={self.config.retrieval.embedding.api_version}"
|
||||
|
||||
response = await self.embedding_client.post(req_url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return result["data"][0]["embedding"]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding: {e}")
|
||||
raise RetrievalAPIError(f"Embedding generation failed: {str(e)}")
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((httpx.HTTPStatusError, httpx.TimeoutException))
|
||||
)
|
||||
async def search_azure_ai(
|
||||
self,
|
||||
index_name: str,
|
||||
search_text: str,
|
||||
vector_fields: str,
|
||||
select_fields: str,
|
||||
search_fields: str,
|
||||
filter_query: Optional[str] = None,
|
||||
top_k: int = 10,
|
||||
score_threshold: float = 1.5
|
||||
) -> Dict[str, Any]:
|
||||
"""Make hybrid search request to Azure AI Search with semantic ranking"""
|
||||
|
||||
# Get embedding vector for the query
|
||||
query_vector = await self.get_embedding(search_text)
|
||||
|
||||
# Build vector queries based on the vector fields
|
||||
vector_queries = []
|
||||
for field in vector_fields.split(","):
|
||||
field = field.strip()
|
||||
vector_queries.append({
|
||||
"kind": "vector",
|
||||
"vector": query_vector,
|
||||
"fields": field,
|
||||
"k": top_k
|
||||
})
|
||||
|
||||
# Build the search request payload
|
||||
search_payload = {
|
||||
"search": search_text,
|
||||
"select": select_fields,
|
||||
"searchFields": search_fields,
|
||||
"top": top_k,
|
||||
"queryType": "semantic",
|
||||
"semanticConfiguration": self.semantic_configuration,
|
||||
"vectorQueries": vector_queries
|
||||
}
|
||||
|
||||
if filter_query:
|
||||
search_payload["filter"] = filter_query
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.api_key
|
||||
}
|
||||
|
||||
search_url = f"{self.search_endpoint}/indexes/{index_name}/docs/search"
|
||||
|
||||
try:
|
||||
response = await self.search_client.post(
|
||||
search_url,
|
||||
json=search_payload,
|
||||
headers=headers,
|
||||
params={"api-version": self.api_version}
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
# Filter results by reranker score and add order numbers
|
||||
filtered_results = []
|
||||
for i, item in enumerate(result.get("value", [])):
|
||||
reranker_score = item.get("@search.rerankerScore", 0)
|
||||
if reranker_score >= score_threshold:
|
||||
# Add order number
|
||||
item["@order_num"] = i + 1
|
||||
# Normalize the result (removes unwanted fields and empty values)
|
||||
normalized_item = normalize_search_result(item)
|
||||
filtered_results.append(normalized_item)
|
||||
|
||||
return {"value": filtered_results}
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Azure AI Search HTTP error {e.response.status_code}: {e.response.text}")
|
||||
raise RetrievalAPIError(f"Azure AI Search request failed: {e.response.status_code}")
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Azure AI Search request timeout")
|
||||
raise RetrievalAPIError("Azure AI Search request timeout")
|
||||
except Exception as e:
|
||||
logger.error(f"Azure AI Search unexpected error: {e}")
|
||||
raise RetrievalAPIError(f"Azure AI Search unexpected error: {str(e)}")
|
||||
|
||||
|
||||
def normalize_search_result(raw_result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Normalize raw Azure AI Search result to clean dynamic structure
|
||||
|
||||
Args:
|
||||
raw_result: Raw result from Azure AI Search
|
||||
|
||||
Returns:
|
||||
Cleaned and normalized result dictionary
|
||||
"""
|
||||
# Fields to remove if they exist (belt and suspenders approach)
|
||||
fields_to_remove = {
|
||||
"@search.score",
|
||||
"@search.rerankerScore",
|
||||
"@search.captions",
|
||||
"@subquery_id"
|
||||
}
|
||||
|
||||
# Create a copy and remove unwanted fields
|
||||
result = raw_result.copy()
|
||||
for field in fields_to_remove:
|
||||
result.pop(field, None)
|
||||
|
||||
# Remove empty fields (None, empty string, empty list, empty dict)
|
||||
result = {
|
||||
key: value for key, value in result.items()
|
||||
if value is not None and value != "" and value != [] and value != {}
|
||||
}
|
||||
|
||||
return result
|
||||
58
vw-agentic-rag/service/retrieval/generic_chunk_retrieval.py
Normal file
58
vw-agentic-rag/service/retrieval/generic_chunk_retrieval.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import logging
|
||||
|
||||
|
||||
import time
|
||||
|
||||
from ..config import get_config
|
||||
from service.retrieval.clients import AzureSearchClient
|
||||
from service.retrieval.model import RetrievalResponse
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenericChunkRetrieval:
|
||||
def __init__(self)->None:
|
||||
self.config = get_config()
|
||||
self.search_client = AzureSearchClient()
|
||||
|
||||
async def retrieve_doc_chunk(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: str = "",
|
||||
**kwargs
|
||||
) -> RetrievalResponse:
|
||||
"""Search CATOnline system user manual document chunks"""
|
||||
start_time = time.time()
|
||||
|
||||
# Use the new Azure AI Search approach
|
||||
index_name = self.config.retrieval.index.chunk_user_manual_index
|
||||
vector_fields = "contentVector"
|
||||
select_fields = "content, title, full_headers"
|
||||
search_fields = "content, title, full_headers"
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
score_threshold = kwargs.get("score_threshold", 1.5)
|
||||
|
||||
try:
|
||||
response_data = await self.search_client.search_azure_ai(
|
||||
index_name=index_name,
|
||||
search_text=query,
|
||||
vector_fields=vector_fields,
|
||||
select_fields=select_fields,
|
||||
search_fields=search_fields,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold
|
||||
)
|
||||
|
||||
results = response_data.get("value", [])
|
||||
|
||||
took_ms = int((time.time() - start_time) * 1000)
|
||||
return RetrievalResponse(
|
||||
results=results,
|
||||
took_ms=took_ms,
|
||||
total_count=len(results)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"retrieve_doc_chunk_user_manual failed: {e}")
|
||||
raise
|
||||
11
vw-agentic-rag/service/retrieval/model.py
Normal file
11
vw-agentic-rag/service/retrieval/model.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
||||
class RetrievalResponse(BaseModel):
|
||||
"""Simple response container for tool results"""
|
||||
results: list[dict[str, Any]]
|
||||
took_ms: Optional[int] = None
|
||||
total_count: Optional[int] = None
|
||||
158
vw-agentic-rag/service/retrieval/retrieval.py
Normal file
158
vw-agentic-rag/service/retrieval/retrieval.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import httpx
|
||||
import time
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
import logging
|
||||
|
||||
from .model import RetrievalResponse
|
||||
from ..config import get_config
|
||||
from .clients import AzureSearchClient, RetrievalAPIError
|
||||
|
||||
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgenticRetrieval:
|
||||
"""Azure AI Search client for retrieval tools"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = get_config()
|
||||
self.search_client = AzureSearchClient()
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.search_client.__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.search_client.__aexit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
async def retrieve_standard_regulation(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: str = "",
|
||||
**kwargs
|
||||
) -> RetrievalResponse:
|
||||
"""Search standard/regulation attributes"""
|
||||
start_time = time.time()
|
||||
|
||||
# Use the new Azure AI Search approach
|
||||
index_name = self.config.retrieval.index.standard_regulation_index
|
||||
vector_fields = "full_metadata_vector"
|
||||
select_fields = "id, func_uuid, title, publisher, document_category, document_code, x_Standard_Regulation_Id, x_Attachment_Type, x_Standard_Title_CN, x_Standard_Title_EN, x_Standard_Published_State, x_Standard_Drafting_Status, x_Standard_Published_State_EN, x_Standard_Drafting_Status_EN, x_Standard_Range, x_Standard_Kind, x_Standard_No, x_Standard_Technical_Committee, x_Standard_Vehicle_Type, x_Standard_Power_Type, x_Standard_CCS, x_Standard_ICS, x_Standard_Published_Date, x_Standard_Effective_Date, x_Regulation_Status, x_Regulation_Status_EN, x_Regulation_Title_CN, x_Regulation_Title_EN, x_Regulation_Document_No, x_Regulation_Issued_Date, x_Classification, x_Work_Group, x_Reference_Standard, x_Replaced_by, x_Refer_To, update_time, status"
|
||||
search_fields = "title, publisher, document_category, document_code, x_Standard_Regulation_Id, x_Attachment_Type, x_Standard_Title_CN, x_Standard_Title_EN, x_Standard_Published_State, x_Standard_Drafting_Status, x_Standard_Published_State_EN, x_Standard_Drafting_Status_EN, x_Standard_Range, x_Standard_Kind, x_Standard_No, x_Standard_Technical_Committee, x_Standard_Vehicle_Type, x_Standard_Power_Type, x_Standard_CCS, x_Standard_ICS, x_Standard_Published_Date, x_Standard_Effective_Date, x_Regulation_Status, x_Regulation_Status_EN, x_Regulation_Title_CN, x_Regulation_Title_EN, x_Regulation_Document_No, x_Regulation_Issued_Date, x_Classification, x_Work_Group, x_Reference_Standard, x_Replaced_by, x_Refer_To, update_time, status"
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
score_threshold = kwargs.get("score_threshold", 1.5)
|
||||
|
||||
try:
|
||||
response_data = await self.search_client.search_azure_ai(
|
||||
index_name=index_name,
|
||||
search_text=query,
|
||||
vector_fields=vector_fields,
|
||||
select_fields=select_fields,
|
||||
search_fields=search_fields,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold
|
||||
)
|
||||
|
||||
results = response_data.get("value", [])
|
||||
|
||||
took_ms = int((time.time() - start_time) * 1000)
|
||||
return RetrievalResponse(
|
||||
results=results,
|
||||
took_ms=took_ms,
|
||||
total_count=len(results)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"retrieve_standard_regulation failed: {e}")
|
||||
raise
|
||||
|
||||
async def retrieve_doc_chunk_standard_regulation(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: str = "",
|
||||
**kwargs
|
||||
) -> RetrievalResponse:
|
||||
"""Search standard/regulation document chunks"""
|
||||
start_time = time.time()
|
||||
|
||||
# Use the new Azure AI Search approach
|
||||
index_name = self.config.retrieval.index.chunk_index
|
||||
vector_fields = "contentVector, full_metadata_vector"
|
||||
select_fields = "content, title, full_headers, document_code, document_category, publisher, x_Regulation_Title_CN, x_Regulation_Title_EN, x_Standard_Title_CN, x_Standard_Title_EN, x_Standard_Kind, x_Standard_CCS, x_Standard_ICS, x_Standard_Vehicle_Type, x_Standard_Power_Type, id, metadata, func_uuid, filepath, x_Standard_Regulation_Id"
|
||||
search_fields = "content, title, full_headers, document_code, document_category, publisher, x_Regulation_Title_CN, x_Regulation_Title_EN, x_Standard_Title_CN, x_Standard_Title_EN, x_Standard_Kind, x_Standard_CCS, x_Standard_ICS, x_Standard_Vehicle_Type, x_Standard_Power_Type"
|
||||
filter_query = "(document_category eq 'Standard' or document_category eq 'Regulation') and (status eq '已发布') and (x_Standard_Published_State_EN eq 'Effective' or x_Standard_Published_State_EN eq 'Publication' or x_Standard_Published_State_EN eq 'Implementation' or x_Regulation_Status_EN eq 'Publication' or x_Regulation_Status_EN eq 'Implementation') and (x_Attachment_Type eq '标准附件(PUBLISHED_STANDARDS)' or x_Attachment_Type eq '已发布法规附件(ISSUED_REGULATION)')"
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
score_threshold = kwargs.get("score_threshold", 1.5)
|
||||
|
||||
try:
|
||||
response_data = await self.search_client.search_azure_ai(
|
||||
index_name=index_name,
|
||||
search_text=query,
|
||||
vector_fields=vector_fields,
|
||||
select_fields=select_fields,
|
||||
search_fields=search_fields,
|
||||
filter_query=filter_query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold
|
||||
)
|
||||
|
||||
results = response_data.get("value", [])
|
||||
|
||||
took_ms = int((time.time() - start_time) * 1000)
|
||||
return RetrievalResponse(
|
||||
results=results,
|
||||
took_ms=took_ms,
|
||||
total_count=len(results)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"retrieve_doc_chunk_standard_regulation failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def retrieve_doc_chunk_user_manual(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: str = "",
|
||||
**kwargs
|
||||
) -> RetrievalResponse:
|
||||
"""Search CATOnline system user manual document chunks"""
|
||||
start_time = time.time()
|
||||
|
||||
# Use the new Azure AI Search approach
|
||||
index_name = self.config.retrieval.index.chunk_user_manual_index
|
||||
vector_fields = "contentVector"
|
||||
select_fields = "content, title, full_headers"
|
||||
search_fields = "content, title, full_headers"
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
score_threshold = kwargs.get("score_threshold", 1.5)
|
||||
|
||||
try:
|
||||
response_data = await self.search_client.search_azure_ai(
|
||||
index_name=index_name,
|
||||
search_text=query,
|
||||
vector_fields=vector_fields,
|
||||
select_fields=select_fields,
|
||||
search_fields=search_fields,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold
|
||||
)
|
||||
|
||||
results = response_data.get("value", [])
|
||||
|
||||
took_ms = int((time.time() - start_time) * 1000)
|
||||
return RetrievalResponse(
|
||||
results=results,
|
||||
took_ms=took_ms,
|
||||
total_count=len(results)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"retrieve_doc_chunk_user_manual failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
Reference in New Issue
Block a user