159 lines
7.5 KiB
Python
159 lines
7.5 KiB
Python
import httpx
|
|
import time
|
|
import json
|
|
from typing import Dict, Any, List, Optional
|
|
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
|
import logging
|
|
|
|
from .model import RetrievalResponse
|
|
from ..config import get_config
|
|
from .clients import AzureSearchClient, RetrievalAPIError
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AgenticRetrieval:
|
|
"""Azure AI Search client for retrieval tools"""
|
|
|
|
def __init__(self):
|
|
self.config = get_config()
|
|
self.search_client = AzureSearchClient()
|
|
|
|
async def __aenter__(self):
|
|
await self.search_client.__aenter__()
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
await self.search_client.__aexit__(exc_type, exc_val, exc_tb)
|
|
|
|
async def retrieve_standard_regulation(
|
|
self,
|
|
query: str,
|
|
conversation_history: str = "",
|
|
**kwargs
|
|
) -> RetrievalResponse:
|
|
"""Search standard/regulation attributes"""
|
|
start_time = time.time()
|
|
|
|
# Use the new Azure AI Search approach
|
|
index_name = self.config.retrieval.index.standard_regulation_index
|
|
vector_fields = "full_metadata_vector"
|
|
select_fields = "id, func_uuid, title, publisher, document_category, document_code, x_Standard_Regulation_Id, x_Attachment_Type, x_Standard_Title_CN, x_Standard_Title_EN, x_Standard_Published_State, x_Standard_Drafting_Status, x_Standard_Published_State_EN, x_Standard_Drafting_Status_EN, x_Standard_Range, x_Standard_Kind, x_Standard_No, x_Standard_Technical_Committee, x_Standard_Vehicle_Type, x_Standard_Power_Type, x_Standard_CCS, x_Standard_ICS, x_Standard_Published_Date, x_Standard_Effective_Date, x_Regulation_Status, x_Regulation_Status_EN, x_Regulation_Title_CN, x_Regulation_Title_EN, x_Regulation_Document_No, x_Regulation_Issued_Date, x_Classification, x_Work_Group, x_Reference_Standard, x_Replaced_by, x_Refer_To, update_time, status"
|
|
search_fields = "title, publisher, document_category, document_code, x_Standard_Regulation_Id, x_Attachment_Type, x_Standard_Title_CN, x_Standard_Title_EN, x_Standard_Published_State, x_Standard_Drafting_Status, x_Standard_Published_State_EN, x_Standard_Drafting_Status_EN, x_Standard_Range, x_Standard_Kind, x_Standard_No, x_Standard_Technical_Committee, x_Standard_Vehicle_Type, x_Standard_Power_Type, x_Standard_CCS, x_Standard_ICS, x_Standard_Published_Date, x_Standard_Effective_Date, x_Regulation_Status, x_Regulation_Status_EN, x_Regulation_Title_CN, x_Regulation_Title_EN, x_Regulation_Document_No, x_Regulation_Issued_Date, x_Classification, x_Work_Group, x_Reference_Standard, x_Replaced_by, x_Refer_To, update_time, status"
|
|
|
|
top_k = kwargs.get("top_k", 10)
|
|
score_threshold = kwargs.get("score_threshold", 1.5)
|
|
|
|
try:
|
|
response_data = await self.search_client.search_azure_ai(
|
|
index_name=index_name,
|
|
search_text=query,
|
|
vector_fields=vector_fields,
|
|
select_fields=select_fields,
|
|
search_fields=search_fields,
|
|
top_k=top_k,
|
|
score_threshold=score_threshold
|
|
)
|
|
|
|
results = response_data.get("value", [])
|
|
|
|
took_ms = int((time.time() - start_time) * 1000)
|
|
return RetrievalResponse(
|
|
results=results,
|
|
took_ms=took_ms,
|
|
total_count=len(results)
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"retrieve_standard_regulation failed: {e}")
|
|
raise
|
|
|
|
async def retrieve_doc_chunk_standard_regulation(
|
|
self,
|
|
query: str,
|
|
conversation_history: str = "",
|
|
**kwargs
|
|
) -> RetrievalResponse:
|
|
"""Search standard/regulation document chunks"""
|
|
start_time = time.time()
|
|
|
|
# Use the new Azure AI Search approach
|
|
index_name = self.config.retrieval.index.chunk_index
|
|
vector_fields = "contentVector, full_metadata_vector"
|
|
select_fields = "content, title, full_headers, document_code, document_category, publisher, x_Regulation_Title_CN, x_Regulation_Title_EN, x_Standard_Title_CN, x_Standard_Title_EN, x_Standard_Kind, x_Standard_CCS, x_Standard_ICS, x_Standard_Vehicle_Type, x_Standard_Power_Type, id, metadata, func_uuid, filepath, x_Standard_Regulation_Id"
|
|
search_fields = "content, title, full_headers, document_code, document_category, publisher, x_Regulation_Title_CN, x_Regulation_Title_EN, x_Standard_Title_CN, x_Standard_Title_EN, x_Standard_Kind, x_Standard_CCS, x_Standard_ICS, x_Standard_Vehicle_Type, x_Standard_Power_Type"
|
|
filter_query = "(document_category eq 'Standard' or document_category eq 'Regulation') and (status eq '已发布') and (x_Standard_Published_State_EN eq 'Effective' or x_Standard_Published_State_EN eq 'Publication' or x_Standard_Published_State_EN eq 'Implementation' or x_Regulation_Status_EN eq 'Publication' or x_Regulation_Status_EN eq 'Implementation') and (x_Attachment_Type eq '标准附件(PUBLISHED_STANDARDS)' or x_Attachment_Type eq '已发布法规附件(ISSUED_REGULATION)')"
|
|
|
|
top_k = kwargs.get("top_k", 10)
|
|
score_threshold = kwargs.get("score_threshold", 1.5)
|
|
|
|
try:
|
|
response_data = await self.search_client.search_azure_ai(
|
|
index_name=index_name,
|
|
search_text=query,
|
|
vector_fields=vector_fields,
|
|
select_fields=select_fields,
|
|
search_fields=search_fields,
|
|
filter_query=filter_query,
|
|
top_k=top_k,
|
|
score_threshold=score_threshold
|
|
)
|
|
|
|
results = response_data.get("value", [])
|
|
|
|
took_ms = int((time.time() - start_time) * 1000)
|
|
return RetrievalResponse(
|
|
results=results,
|
|
took_ms=took_ms,
|
|
total_count=len(results)
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"retrieve_doc_chunk_standard_regulation failed: {e}")
|
|
raise
|
|
|
|
|
|
async def retrieve_doc_chunk_user_manual(
|
|
self,
|
|
query: str,
|
|
conversation_history: str = "",
|
|
**kwargs
|
|
) -> RetrievalResponse:
|
|
"""Search CATOnline system user manual document chunks"""
|
|
start_time = time.time()
|
|
|
|
# Use the new Azure AI Search approach
|
|
index_name = self.config.retrieval.index.chunk_user_manual_index
|
|
vector_fields = "contentVector"
|
|
select_fields = "content, title, full_headers"
|
|
search_fields = "content, title, full_headers"
|
|
|
|
top_k = kwargs.get("top_k", 10)
|
|
score_threshold = kwargs.get("score_threshold", 1.5)
|
|
|
|
try:
|
|
response_data = await self.search_client.search_azure_ai(
|
|
index_name=index_name,
|
|
search_text=query,
|
|
vector_fields=vector_fields,
|
|
select_fields=select_fields,
|
|
search_fields=search_fields,
|
|
top_k=top_k,
|
|
score_threshold=score_threshold
|
|
)
|
|
|
|
results = response_data.get("value", [])
|
|
|
|
took_ms = int((time.time() - start_time) * 1000)
|
|
return RetrievalResponse(
|
|
results=results,
|
|
took_ms=took_ms,
|
|
total_count=len(results)
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"retrieve_doc_chunk_user_manual failed: {e}")
|
|
raise
|
|
|
|
|