353 lines
16 KiB
Python
353 lines
16 KiB
Python
"""
|
|
Unit tests for agentic retrieval tools.
|
|
Tests the HTTP wrapper tools that interface with the retrieval API.
|
|
"""
|
|
import pytest
|
|
from unittest.mock import AsyncMock, patch, Mock
|
|
import httpx
|
|
|
|
from service.retrieval.retrieval import AgenticRetrieval, RetrievalResponse
|
|
from service.retrieval.clients import RetrievalAPIError, normalize_search_result
|
|
|
|
|
|
class TestAgenticRetrieval:
|
|
"""Test the agentic retrieval HTTP client."""
|
|
|
|
@pytest.fixture
|
|
def retrieval_client(self):
|
|
"""Create retrieval client for testing."""
|
|
with patch('service.retrieval.retrieval.get_config') as mock_config:
|
|
# Mock the new config structure
|
|
mock_config.return_value.retrieval.endpoint = "https://test-search.search.azure.com"
|
|
mock_config.return_value.retrieval.api_key = "test-key"
|
|
mock_config.return_value.retrieval.api_version = "2024-11-01-preview"
|
|
mock_config.return_value.retrieval.semantic_configuration = "default"
|
|
|
|
# Mock embedding config
|
|
mock_config.return_value.retrieval.embedding.base_url = "http://test-embedding"
|
|
mock_config.return_value.retrieval.embedding.api_key = "test-embedding-key"
|
|
mock_config.return_value.retrieval.embedding.model = "test-embedding-model"
|
|
mock_config.return_value.retrieval.embedding.dimension = 1536
|
|
|
|
# Mock index config
|
|
mock_config.return_value.retrieval.index.standard_regulation_index = "index-catonline-standard-regulation-v2-prd"
|
|
mock_config.return_value.retrieval.index.chunk_index = "index-catonline-chunk-v2-prd"
|
|
mock_config.return_value.retrieval.index.chunk_user_manual_index = "index-cat-usermanual-chunk-prd"
|
|
|
|
return AgenticRetrieval()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_standard_regulation_success(self, retrieval_client):
|
|
"""Test successful standards regulation retrieval."""
|
|
mock_response_data = {
|
|
"value": [
|
|
{
|
|
"id": "test_id_1",
|
|
"title": "ISO 26262-1:2018",
|
|
"document_code": "ISO26262-1",
|
|
"document_category": "Standard",
|
|
"@order_num": 1 # Should be preserved
|
|
},
|
|
{
|
|
"id": "test_id_2",
|
|
"title": "ISO 26262-3:2018",
|
|
"document_code": "ISO26262-3",
|
|
"document_category": "Standard",
|
|
"@order_num": 2
|
|
}
|
|
]
|
|
}
|
|
|
|
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value=mock_response_data) as mock_search:
|
|
result = await retrieval_client.retrieve_standard_regulation("ISO 26262")
|
|
|
|
# Verify search call
|
|
mock_search.assert_called_once()
|
|
call_args = mock_search.call_args[1] # keyword arguments
|
|
assert call_args['search_text'] == "ISO 26262"
|
|
assert call_args['index_name'] == "index-catonline-standard-regulation-v2-prd"
|
|
assert call_args['top_k'] == 10
|
|
|
|
# Verify response structure
|
|
assert isinstance(result, RetrievalResponse)
|
|
assert len(result.results) == 2
|
|
assert result.total_count == 2
|
|
assert result.took_ms is not None
|
|
|
|
# Verify normalization
|
|
first_result = result.results[0]
|
|
assert first_result['id'] == "test_id_1"
|
|
assert first_result['title'] == "ISO 26262-1:2018"
|
|
# Verify order number is preserved
|
|
assert "@order_num" in first_result
|
|
assert "content" not in first_result # Should not be included for standards
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_doc_chunk_success(self, retrieval_client):
|
|
"""Test successful document chunk retrieval."""
|
|
mock_response_data = {
|
|
"value": [
|
|
{
|
|
"id": "chunk_id_1",
|
|
"title": "Functional Safety Requirements",
|
|
"content": "Detailed content about functional safety...",
|
|
"document_code": "ISO26262-1",
|
|
"@order_num": 1 # Should be preserved
|
|
}
|
|
]
|
|
}
|
|
|
|
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value=mock_response_data) as mock_search:
|
|
result = await retrieval_client.retrieve_doc_chunk_standard_regulation(
|
|
"functional safety",
|
|
conversation_history="Previous question about ISO 26262"
|
|
)
|
|
|
|
# Verify search call
|
|
mock_search.assert_called_once()
|
|
call_args = mock_search.call_args[1] # keyword arguments
|
|
assert call_args['search_text'] == "functional safety"
|
|
assert call_args['index_name'] == "index-catonline-chunk-v2-prd"
|
|
assert "filter_query" in call_args # Should have document filter
|
|
|
|
# Verify response
|
|
assert isinstance(result, RetrievalResponse)
|
|
assert len(result.results) == 1
|
|
|
|
# Verify content is included for chunks
|
|
first_result = result.results[0]
|
|
assert "content" in first_result
|
|
assert first_result['content'] == "Detailed content about functional safety..."
|
|
# Verify order number is preserved
|
|
assert "@order_num" in first_result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_http_error_handling(self, retrieval_client):
|
|
"""Test HTTP error handling."""
|
|
with patch.object(retrieval_client.search_client, 'search_azure_ai') as mock_search:
|
|
mock_search.side_effect = RetrievalAPIError("API request failed: 404")
|
|
|
|
with pytest.raises(RetrievalAPIError) as exc_info:
|
|
await retrieval_client.retrieve_standard_regulation("nonexistent")
|
|
|
|
assert "404" in str(exc_info.value)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_timeout_handling(self, retrieval_client):
|
|
"""Test timeout handling."""
|
|
with patch.object(retrieval_client.search_client, 'search_azure_ai') as mock_search:
|
|
mock_search.side_effect = RetrievalAPIError("Request timeout")
|
|
|
|
with pytest.raises(RetrievalAPIError) as exc_info:
|
|
await retrieval_client.retrieve_standard_regulation("test query")
|
|
|
|
assert "timeout" in str(exc_info.value)
|
|
|
|
def test_normalize_result_removes_unwanted_fields(self, retrieval_client):
|
|
"""Test result normalization removes unwanted fields."""
|
|
from service.retrieval.clients import normalize_search_result
|
|
|
|
raw_result = {
|
|
"id": "test_id",
|
|
"title": "Test Title",
|
|
"content": "Test content",
|
|
"@search.score": 0.95, # Should be removed
|
|
"@search.rerankerScore": 2.5, # Should be removed
|
|
"@search.captions": [], # Should be removed
|
|
"@subquery_id": 1, # Should be removed
|
|
"empty_field": "", # Should be removed
|
|
"null_field": None, # Should be removed
|
|
"empty_list": [], # Should be removed
|
|
"empty_dict": {}, # Should be removed
|
|
"valid_field": "valid_value"
|
|
}
|
|
|
|
# Test normalization
|
|
normalized = normalize_search_result(raw_result)
|
|
|
|
# Verify unwanted fields are removed
|
|
assert "@search.score" not in normalized
|
|
assert "@search.rerankerScore" not in normalized
|
|
assert "@search.captions" not in normalized
|
|
assert "@subquery_id" not in normalized
|
|
assert "empty_field" not in normalized
|
|
assert "null_field" not in normalized
|
|
assert "empty_list" not in normalized
|
|
assert "empty_dict" not in normalized
|
|
|
|
# Verify valid fields are preserved (including content)
|
|
assert normalized["id"] == "test_id"
|
|
assert normalized["title"] == "Test Title"
|
|
assert normalized["content"] == "Test content" # Content is now always preserved
|
|
assert normalized["valid_field"] == "valid_value"
|
|
|
|
def test_normalize_result_includes_content_when_requested(self, retrieval_client):
|
|
"""Test result normalization always includes content."""
|
|
from service.retrieval.clients import normalize_search_result
|
|
|
|
raw_result = {
|
|
"id": "test_id",
|
|
"title": "Test Title",
|
|
"content": "Test content"
|
|
}
|
|
|
|
# Test normalization (content is always preserved now)
|
|
normalized = normalize_search_result(raw_result)
|
|
assert "content" in normalized
|
|
assert normalized["content"] == "Test content"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_results_handling(self, retrieval_client):
|
|
"""Test handling of empty search results."""
|
|
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"value": []}):
|
|
result = await retrieval_client.retrieve_standard_regulation("no results query")
|
|
|
|
assert isinstance(result, RetrievalResponse)
|
|
assert result.results == []
|
|
assert result.total_count == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_malformed_response_handling(self, retrieval_client):
|
|
"""Test handling of malformed API responses."""
|
|
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"unexpected": "format"}):
|
|
result = await retrieval_client.retrieve_standard_regulation("test query")
|
|
|
|
# Should handle gracefully and return empty results
|
|
assert isinstance(result, RetrievalResponse)
|
|
assert result.results == []
|
|
assert result.total_count == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_kwargs_override_payload(self, retrieval_client):
|
|
"""Test that kwargs can override default values."""
|
|
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"value": []}) as mock_search:
|
|
await retrieval_client.retrieve_standard_regulation(
|
|
"test query",
|
|
top_k=5,
|
|
score_threshold=2.0
|
|
)
|
|
|
|
call_args = mock_search.call_args[1] # keyword arguments
|
|
assert call_args['top_k'] == 5 # Override default 10
|
|
assert call_args['score_threshold'] == 2.0 # Override default 1.5
|
|
|
|
|
|
class TestRetrievalTools:
|
|
"""Test the LangGraph tool decorators."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_standard_regulation_tool(self):
|
|
"""Test the @tool decorated function for standards search."""
|
|
from service.graph.tools import retrieve_standard_regulation
|
|
|
|
mock_response = RetrievalResponse(
|
|
results=[{"title": "Test Standard", "id": "test_id"}],
|
|
took_ms=100,
|
|
total_count=1
|
|
)
|
|
|
|
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
|
|
mock_instance = AsyncMock()
|
|
mock_instance.retrieve_standard_regulation.return_value = mock_response
|
|
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
|
|
mock_retrieval_class.return_value.__aexit__.return_value = None
|
|
|
|
# Note: retrieve_standard_regulation is a LangGraph @tool, so we call it directly
|
|
result = await retrieve_standard_regulation.ainvoke({"query": "test query", "conversation_history": "test history"})
|
|
|
|
# Verify result format matches expected tool output
|
|
assert isinstance(result, dict)
|
|
assert "results" in result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_doc_chunk_tool(self):
|
|
"""Test the @tool decorated function for document chunks search."""
|
|
from service.graph.tools import retrieve_doc_chunk_standard_regulation
|
|
|
|
mock_response = RetrievalResponse(
|
|
results=[{"title": "Test Chunk", "content": "Test content", "id": "chunk_id"}],
|
|
took_ms=150,
|
|
total_count=1
|
|
)
|
|
|
|
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
|
|
mock_instance = AsyncMock()
|
|
mock_instance.retrieve_doc_chunk_standard_regulation.return_value = mock_response
|
|
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
|
|
mock_retrieval_class.return_value.__aexit__.return_value = None
|
|
|
|
# Call the tool with proper input format
|
|
result = await retrieve_doc_chunk_standard_regulation.ainvoke({"query": "test query"})
|
|
|
|
# Verify result format
|
|
assert isinstance(result, dict)
|
|
assert "results" in result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_error_handling(self):
|
|
"""Test tool error handling."""
|
|
from service.graph.tools import retrieve_standard_regulation
|
|
|
|
with patch('service.retrieval.clients.AzureSearchClient.search_azure_ai') as mock_search:
|
|
mock_search.side_effect = RetrievalAPIError("API Error")
|
|
|
|
# Tool should handle error gracefully and return error result
|
|
result = await retrieve_standard_regulation.ainvoke({"query": "test query"})
|
|
|
|
# Should return error information in result
|
|
assert isinstance(result, dict)
|
|
assert "error" in result
|
|
assert "API Error" in result["error"]
|
|
|
|
|
|
class TestRetrievalIntegration:
|
|
"""Integration-style tests for retrieval functionality."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_full_retrieval_workflow(self):
|
|
"""Test complete retrieval workflow with both tools."""
|
|
# Mock both retrieval calls
|
|
standards_response = RetrievalResponse(
|
|
results=[
|
|
{"id": "std_1", "title": "ISO 26262-1", "document_category": "Standard"}
|
|
],
|
|
took_ms=100,
|
|
total_count=1
|
|
)
|
|
|
|
chunks_response = RetrievalResponse(
|
|
results=[
|
|
{"id": "chunk_1", "title": "Safety Requirements", "content": "Detailed safety requirements..."}
|
|
],
|
|
took_ms=150,
|
|
total_count=1
|
|
)
|
|
|
|
with patch('service.retrieval.clients.AzureSearchClient.search_azure_ai') as mock_search:
|
|
# Set up mock to return different responses for different calls
|
|
def mock_response_side_effect(*args, **kwargs):
|
|
if kwargs.get('index_name') == 'index-catonline-standard-regulation-v2-prd':
|
|
return {
|
|
"value": [{"id": "std_1", "title": "ISO 26262-1", "document_category": "Standard", "@order_num": 1}]
|
|
}
|
|
else:
|
|
return {
|
|
"value": [{"id": "chunk_1", "title": "Safety Requirements", "content": "Detailed safety requirements...", "@order_num": 1}]
|
|
}
|
|
|
|
mock_search.side_effect = mock_response_side_effect
|
|
|
|
# Import and test both tools
|
|
from service.graph.tools import retrieve_standard_regulation, retrieve_doc_chunk_standard_regulation
|
|
|
|
# Test standards search
|
|
std_result = await retrieve_standard_regulation.ainvoke({"query": "ISO 26262"})
|
|
assert isinstance(std_result, dict)
|
|
assert "results" in std_result
|
|
assert len(std_result["results"]) == 1
|
|
|
|
# Test chunks search
|
|
chunk_result = await retrieve_doc_chunk_standard_regulation.ainvoke({"query": "safety requirements"})
|
|
assert isinstance(chunk_result, dict)
|
|
assert "results" in chunk_result
|
|
assert len(chunk_result["results"]) == 1 |