Files
catonline_ai/vw-agentic-rag/tests/unit/test_retrieval.py
2025-09-26 17:15:54 +08:00

353 lines
16 KiB
Python

"""
Unit tests for agentic retrieval tools.
Tests the HTTP wrapper tools that interface with the retrieval API.
"""
import pytest
from unittest.mock import AsyncMock, patch, Mock
import httpx
from service.retrieval.retrieval import AgenticRetrieval, RetrievalResponse
from service.retrieval.clients import RetrievalAPIError, normalize_search_result
class TestAgenticRetrieval:
"""Test the agentic retrieval HTTP client."""
@pytest.fixture
def retrieval_client(self):
"""Create retrieval client for testing."""
with patch('service.retrieval.retrieval.get_config') as mock_config:
# Mock the new config structure
mock_config.return_value.retrieval.endpoint = "https://test-search.search.azure.com"
mock_config.return_value.retrieval.api_key = "test-key"
mock_config.return_value.retrieval.api_version = "2024-11-01-preview"
mock_config.return_value.retrieval.semantic_configuration = "default"
# Mock embedding config
mock_config.return_value.retrieval.embedding.base_url = "http://test-embedding"
mock_config.return_value.retrieval.embedding.api_key = "test-embedding-key"
mock_config.return_value.retrieval.embedding.model = "test-embedding-model"
mock_config.return_value.retrieval.embedding.dimension = 1536
# Mock index config
mock_config.return_value.retrieval.index.standard_regulation_index = "index-catonline-standard-regulation-v2-prd"
mock_config.return_value.retrieval.index.chunk_index = "index-catonline-chunk-v2-prd"
mock_config.return_value.retrieval.index.chunk_user_manual_index = "index-cat-usermanual-chunk-prd"
return AgenticRetrieval()
@pytest.mark.asyncio
async def test_retrieve_standard_regulation_success(self, retrieval_client):
"""Test successful standards regulation retrieval."""
mock_response_data = {
"value": [
{
"id": "test_id_1",
"title": "ISO 26262-1:2018",
"document_code": "ISO26262-1",
"document_category": "Standard",
"@order_num": 1 # Should be preserved
},
{
"id": "test_id_2",
"title": "ISO 26262-3:2018",
"document_code": "ISO26262-3",
"document_category": "Standard",
"@order_num": 2
}
]
}
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value=mock_response_data) as mock_search:
result = await retrieval_client.retrieve_standard_regulation("ISO 26262")
# Verify search call
mock_search.assert_called_once()
call_args = mock_search.call_args[1] # keyword arguments
assert call_args['search_text'] == "ISO 26262"
assert call_args['index_name'] == "index-catonline-standard-regulation-v2-prd"
assert call_args['top_k'] == 10
# Verify response structure
assert isinstance(result, RetrievalResponse)
assert len(result.results) == 2
assert result.total_count == 2
assert result.took_ms is not None
# Verify normalization
first_result = result.results[0]
assert first_result['id'] == "test_id_1"
assert first_result['title'] == "ISO 26262-1:2018"
# Verify order number is preserved
assert "@order_num" in first_result
assert "content" not in first_result # Should not be included for standards
@pytest.mark.asyncio
async def test_retrieve_doc_chunk_success(self, retrieval_client):
"""Test successful document chunk retrieval."""
mock_response_data = {
"value": [
{
"id": "chunk_id_1",
"title": "Functional Safety Requirements",
"content": "Detailed content about functional safety...",
"document_code": "ISO26262-1",
"@order_num": 1 # Should be preserved
}
]
}
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value=mock_response_data) as mock_search:
result = await retrieval_client.retrieve_doc_chunk_standard_regulation(
"functional safety",
conversation_history="Previous question about ISO 26262"
)
# Verify search call
mock_search.assert_called_once()
call_args = mock_search.call_args[1] # keyword arguments
assert call_args['search_text'] == "functional safety"
assert call_args['index_name'] == "index-catonline-chunk-v2-prd"
assert "filter_query" in call_args # Should have document filter
# Verify response
assert isinstance(result, RetrievalResponse)
assert len(result.results) == 1
# Verify content is included for chunks
first_result = result.results[0]
assert "content" in first_result
assert first_result['content'] == "Detailed content about functional safety..."
# Verify order number is preserved
assert "@order_num" in first_result
@pytest.mark.asyncio
async def test_http_error_handling(self, retrieval_client):
"""Test HTTP error handling."""
with patch.object(retrieval_client.search_client, 'search_azure_ai') as mock_search:
mock_search.side_effect = RetrievalAPIError("API request failed: 404")
with pytest.raises(RetrievalAPIError) as exc_info:
await retrieval_client.retrieve_standard_regulation("nonexistent")
assert "404" in str(exc_info.value)
@pytest.mark.asyncio
async def test_timeout_handling(self, retrieval_client):
"""Test timeout handling."""
with patch.object(retrieval_client.search_client, 'search_azure_ai') as mock_search:
mock_search.side_effect = RetrievalAPIError("Request timeout")
with pytest.raises(RetrievalAPIError) as exc_info:
await retrieval_client.retrieve_standard_regulation("test query")
assert "timeout" in str(exc_info.value)
def test_normalize_result_removes_unwanted_fields(self, retrieval_client):
"""Test result normalization removes unwanted fields."""
from service.retrieval.clients import normalize_search_result
raw_result = {
"id": "test_id",
"title": "Test Title",
"content": "Test content",
"@search.score": 0.95, # Should be removed
"@search.rerankerScore": 2.5, # Should be removed
"@search.captions": [], # Should be removed
"@subquery_id": 1, # Should be removed
"empty_field": "", # Should be removed
"null_field": None, # Should be removed
"empty_list": [], # Should be removed
"empty_dict": {}, # Should be removed
"valid_field": "valid_value"
}
# Test normalization
normalized = normalize_search_result(raw_result)
# Verify unwanted fields are removed
assert "@search.score" not in normalized
assert "@search.rerankerScore" not in normalized
assert "@search.captions" not in normalized
assert "@subquery_id" not in normalized
assert "empty_field" not in normalized
assert "null_field" not in normalized
assert "empty_list" not in normalized
assert "empty_dict" not in normalized
# Verify valid fields are preserved (including content)
assert normalized["id"] == "test_id"
assert normalized["title"] == "Test Title"
assert normalized["content"] == "Test content" # Content is now always preserved
assert normalized["valid_field"] == "valid_value"
def test_normalize_result_includes_content_when_requested(self, retrieval_client):
"""Test result normalization always includes content."""
from service.retrieval.clients import normalize_search_result
raw_result = {
"id": "test_id",
"title": "Test Title",
"content": "Test content"
}
# Test normalization (content is always preserved now)
normalized = normalize_search_result(raw_result)
assert "content" in normalized
assert normalized["content"] == "Test content"
@pytest.mark.asyncio
async def test_empty_results_handling(self, retrieval_client):
"""Test handling of empty search results."""
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"value": []}):
result = await retrieval_client.retrieve_standard_regulation("no results query")
assert isinstance(result, RetrievalResponse)
assert result.results == []
assert result.total_count == 0
@pytest.mark.asyncio
async def test_malformed_response_handling(self, retrieval_client):
"""Test handling of malformed API responses."""
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"unexpected": "format"}):
result = await retrieval_client.retrieve_standard_regulation("test query")
# Should handle gracefully and return empty results
assert isinstance(result, RetrievalResponse)
assert result.results == []
assert result.total_count == 0
@pytest.mark.asyncio
async def test_kwargs_override_payload(self, retrieval_client):
"""Test that kwargs can override default values."""
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"value": []}) as mock_search:
await retrieval_client.retrieve_standard_regulation(
"test query",
top_k=5,
score_threshold=2.0
)
call_args = mock_search.call_args[1] # keyword arguments
assert call_args['top_k'] == 5 # Override default 10
assert call_args['score_threshold'] == 2.0 # Override default 1.5
class TestRetrievalTools:
"""Test the LangGraph tool decorators."""
@pytest.mark.asyncio
async def test_retrieve_standard_regulation_tool(self):
"""Test the @tool decorated function for standards search."""
from service.graph.tools import retrieve_standard_regulation
mock_response = RetrievalResponse(
results=[{"title": "Test Standard", "id": "test_id"}],
took_ms=100,
total_count=1
)
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
mock_instance = AsyncMock()
mock_instance.retrieve_standard_regulation.return_value = mock_response
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
mock_retrieval_class.return_value.__aexit__.return_value = None
# Note: retrieve_standard_regulation is a LangGraph @tool, so we call it directly
result = await retrieve_standard_regulation.ainvoke({"query": "test query", "conversation_history": "test history"})
# Verify result format matches expected tool output
assert isinstance(result, dict)
assert "results" in result
@pytest.mark.asyncio
async def test_retrieve_doc_chunk_tool(self):
"""Test the @tool decorated function for document chunks search."""
from service.graph.tools import retrieve_doc_chunk_standard_regulation
mock_response = RetrievalResponse(
results=[{"title": "Test Chunk", "content": "Test content", "id": "chunk_id"}],
took_ms=150,
total_count=1
)
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
mock_instance = AsyncMock()
mock_instance.retrieve_doc_chunk_standard_regulation.return_value = mock_response
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
mock_retrieval_class.return_value.__aexit__.return_value = None
# Call the tool with proper input format
result = await retrieve_doc_chunk_standard_regulation.ainvoke({"query": "test query"})
# Verify result format
assert isinstance(result, dict)
assert "results" in result
@pytest.mark.asyncio
async def test_tool_error_handling(self):
"""Test tool error handling."""
from service.graph.tools import retrieve_standard_regulation
with patch('service.retrieval.clients.AzureSearchClient.search_azure_ai') as mock_search:
mock_search.side_effect = RetrievalAPIError("API Error")
# Tool should handle error gracefully and return error result
result = await retrieve_standard_regulation.ainvoke({"query": "test query"})
# Should return error information in result
assert isinstance(result, dict)
assert "error" in result
assert "API Error" in result["error"]
class TestRetrievalIntegration:
"""Integration-style tests for retrieval functionality."""
@pytest.mark.asyncio
async def test_full_retrieval_workflow(self):
"""Test complete retrieval workflow with both tools."""
# Mock both retrieval calls
standards_response = RetrievalResponse(
results=[
{"id": "std_1", "title": "ISO 26262-1", "document_category": "Standard"}
],
took_ms=100,
total_count=1
)
chunks_response = RetrievalResponse(
results=[
{"id": "chunk_1", "title": "Safety Requirements", "content": "Detailed safety requirements..."}
],
took_ms=150,
total_count=1
)
with patch('service.retrieval.clients.AzureSearchClient.search_azure_ai') as mock_search:
# Set up mock to return different responses for different calls
def mock_response_side_effect(*args, **kwargs):
if kwargs.get('index_name') == 'index-catonline-standard-regulation-v2-prd':
return {
"value": [{"id": "std_1", "title": "ISO 26262-1", "document_category": "Standard", "@order_num": 1}]
}
else:
return {
"value": [{"id": "chunk_1", "title": "Safety Requirements", "content": "Detailed safety requirements...", "@order_num": 1}]
}
mock_search.side_effect = mock_response_side_effect
# Import and test both tools
from service.graph.tools import retrieve_standard_regulation, retrieve_doc_chunk_standard_regulation
# Test standards search
std_result = await retrieve_standard_regulation.ainvoke({"query": "ISO 26262"})
assert isinstance(std_result, dict)
assert "results" in std_result
assert len(std_result["results"]) == 1
# Test chunks search
chunk_result = await retrieve_doc_chunk_standard_regulation.ainvoke({"query": "safety requirements"})
assert isinstance(chunk_result, dict)
assert "results" in chunk_result
assert len(chunk_result["results"]) == 1