""" Unit tests for agentic retrieval tools. Tests the HTTP wrapper tools that interface with the retrieval API. """ import pytest from unittest.mock import AsyncMock, patch, Mock import httpx from service.retrieval.retrieval import AgenticRetrieval, RetrievalResponse from service.retrieval.clients import RetrievalAPIError, normalize_search_result class TestAgenticRetrieval: """Test the agentic retrieval HTTP client.""" @pytest.fixture def retrieval_client(self): """Create retrieval client for testing.""" with patch('service.retrieval.retrieval.get_config') as mock_config: # Mock the new config structure mock_config.return_value.retrieval.endpoint = "https://test-search.search.azure.com" mock_config.return_value.retrieval.api_key = "test-key" mock_config.return_value.retrieval.api_version = "2024-11-01-preview" mock_config.return_value.retrieval.semantic_configuration = "default" # Mock embedding config mock_config.return_value.retrieval.embedding.base_url = "http://test-embedding" mock_config.return_value.retrieval.embedding.api_key = "test-embedding-key" mock_config.return_value.retrieval.embedding.model = "test-embedding-model" mock_config.return_value.retrieval.embedding.dimension = 1536 # Mock index config mock_config.return_value.retrieval.index.standard_regulation_index = "index-catonline-standard-regulation-v2-prd" mock_config.return_value.retrieval.index.chunk_index = "index-catonline-chunk-v2-prd" mock_config.return_value.retrieval.index.chunk_user_manual_index = "index-cat-usermanual-chunk-prd" return AgenticRetrieval() @pytest.mark.asyncio async def test_retrieve_standard_regulation_success(self, retrieval_client): """Test successful standards regulation retrieval.""" mock_response_data = { "value": [ { "id": "test_id_1", "title": "ISO 26262-1:2018", "document_code": "ISO26262-1", "document_category": "Standard", "@order_num": 1 # Should be preserved }, { "id": "test_id_2", "title": "ISO 26262-3:2018", "document_code": "ISO26262-3", "document_category": "Standard", "@order_num": 2 } ] } with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value=mock_response_data) as mock_search: result = await retrieval_client.retrieve_standard_regulation("ISO 26262") # Verify search call mock_search.assert_called_once() call_args = mock_search.call_args[1] # keyword arguments assert call_args['search_text'] == "ISO 26262" assert call_args['index_name'] == "index-catonline-standard-regulation-v2-prd" assert call_args['top_k'] == 10 # Verify response structure assert isinstance(result, RetrievalResponse) assert len(result.results) == 2 assert result.total_count == 2 assert result.took_ms is not None # Verify normalization first_result = result.results[0] assert first_result['id'] == "test_id_1" assert first_result['title'] == "ISO 26262-1:2018" # Verify order number is preserved assert "@order_num" in first_result assert "content" not in first_result # Should not be included for standards @pytest.mark.asyncio async def test_retrieve_doc_chunk_success(self, retrieval_client): """Test successful document chunk retrieval.""" mock_response_data = { "value": [ { "id": "chunk_id_1", "title": "Functional Safety Requirements", "content": "Detailed content about functional safety...", "document_code": "ISO26262-1", "@order_num": 1 # Should be preserved } ] } with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value=mock_response_data) as mock_search: result = await retrieval_client.retrieve_doc_chunk_standard_regulation( "functional safety", conversation_history="Previous question about ISO 26262" ) # Verify search call mock_search.assert_called_once() call_args = mock_search.call_args[1] # keyword arguments assert call_args['search_text'] == "functional safety" assert call_args['index_name'] == "index-catonline-chunk-v2-prd" assert "filter_query" in call_args # Should have document filter # Verify response assert isinstance(result, RetrievalResponse) assert len(result.results) == 1 # Verify content is included for chunks first_result = result.results[0] assert "content" in first_result assert first_result['content'] == "Detailed content about functional safety..." # Verify order number is preserved assert "@order_num" in first_result @pytest.mark.asyncio async def test_http_error_handling(self, retrieval_client): """Test HTTP error handling.""" with patch.object(retrieval_client.search_client, 'search_azure_ai') as mock_search: mock_search.side_effect = RetrievalAPIError("API request failed: 404") with pytest.raises(RetrievalAPIError) as exc_info: await retrieval_client.retrieve_standard_regulation("nonexistent") assert "404" in str(exc_info.value) @pytest.mark.asyncio async def test_timeout_handling(self, retrieval_client): """Test timeout handling.""" with patch.object(retrieval_client.search_client, 'search_azure_ai') as mock_search: mock_search.side_effect = RetrievalAPIError("Request timeout") with pytest.raises(RetrievalAPIError) as exc_info: await retrieval_client.retrieve_standard_regulation("test query") assert "timeout" in str(exc_info.value) def test_normalize_result_removes_unwanted_fields(self, retrieval_client): """Test result normalization removes unwanted fields.""" from service.retrieval.clients import normalize_search_result raw_result = { "id": "test_id", "title": "Test Title", "content": "Test content", "@search.score": 0.95, # Should be removed "@search.rerankerScore": 2.5, # Should be removed "@search.captions": [], # Should be removed "@subquery_id": 1, # Should be removed "empty_field": "", # Should be removed "null_field": None, # Should be removed "empty_list": [], # Should be removed "empty_dict": {}, # Should be removed "valid_field": "valid_value" } # Test normalization normalized = normalize_search_result(raw_result) # Verify unwanted fields are removed assert "@search.score" not in normalized assert "@search.rerankerScore" not in normalized assert "@search.captions" not in normalized assert "@subquery_id" not in normalized assert "empty_field" not in normalized assert "null_field" not in normalized assert "empty_list" not in normalized assert "empty_dict" not in normalized # Verify valid fields are preserved (including content) assert normalized["id"] == "test_id" assert normalized["title"] == "Test Title" assert normalized["content"] == "Test content" # Content is now always preserved assert normalized["valid_field"] == "valid_value" def test_normalize_result_includes_content_when_requested(self, retrieval_client): """Test result normalization always includes content.""" from service.retrieval.clients import normalize_search_result raw_result = { "id": "test_id", "title": "Test Title", "content": "Test content" } # Test normalization (content is always preserved now) normalized = normalize_search_result(raw_result) assert "content" in normalized assert normalized["content"] == "Test content" @pytest.mark.asyncio async def test_empty_results_handling(self, retrieval_client): """Test handling of empty search results.""" with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"value": []}): result = await retrieval_client.retrieve_standard_regulation("no results query") assert isinstance(result, RetrievalResponse) assert result.results == [] assert result.total_count == 0 @pytest.mark.asyncio async def test_malformed_response_handling(self, retrieval_client): """Test handling of malformed API responses.""" with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"unexpected": "format"}): result = await retrieval_client.retrieve_standard_regulation("test query") # Should handle gracefully and return empty results assert isinstance(result, RetrievalResponse) assert result.results == [] assert result.total_count == 0 @pytest.mark.asyncio async def test_kwargs_override_payload(self, retrieval_client): """Test that kwargs can override default values.""" with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"value": []}) as mock_search: await retrieval_client.retrieve_standard_regulation( "test query", top_k=5, score_threshold=2.0 ) call_args = mock_search.call_args[1] # keyword arguments assert call_args['top_k'] == 5 # Override default 10 assert call_args['score_threshold'] == 2.0 # Override default 1.5 class TestRetrievalTools: """Test the LangGraph tool decorators.""" @pytest.mark.asyncio async def test_retrieve_standard_regulation_tool(self): """Test the @tool decorated function for standards search.""" from service.graph.tools import retrieve_standard_regulation mock_response = RetrievalResponse( results=[{"title": "Test Standard", "id": "test_id"}], took_ms=100, total_count=1 ) with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class: mock_instance = AsyncMock() mock_instance.retrieve_standard_regulation.return_value = mock_response mock_retrieval_class.return_value.__aenter__.return_value = mock_instance mock_retrieval_class.return_value.__aexit__.return_value = None # Note: retrieve_standard_regulation is a LangGraph @tool, so we call it directly result = await retrieve_standard_regulation.ainvoke({"query": "test query", "conversation_history": "test history"}) # Verify result format matches expected tool output assert isinstance(result, dict) assert "results" in result @pytest.mark.asyncio async def test_retrieve_doc_chunk_tool(self): """Test the @tool decorated function for document chunks search.""" from service.graph.tools import retrieve_doc_chunk_standard_regulation mock_response = RetrievalResponse( results=[{"title": "Test Chunk", "content": "Test content", "id": "chunk_id"}], took_ms=150, total_count=1 ) with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class: mock_instance = AsyncMock() mock_instance.retrieve_doc_chunk_standard_regulation.return_value = mock_response mock_retrieval_class.return_value.__aenter__.return_value = mock_instance mock_retrieval_class.return_value.__aexit__.return_value = None # Call the tool with proper input format result = await retrieve_doc_chunk_standard_regulation.ainvoke({"query": "test query"}) # Verify result format assert isinstance(result, dict) assert "results" in result @pytest.mark.asyncio async def test_tool_error_handling(self): """Test tool error handling.""" from service.graph.tools import retrieve_standard_regulation with patch('service.retrieval.clients.AzureSearchClient.search_azure_ai') as mock_search: mock_search.side_effect = RetrievalAPIError("API Error") # Tool should handle error gracefully and return error result result = await retrieve_standard_regulation.ainvoke({"query": "test query"}) # Should return error information in result assert isinstance(result, dict) assert "error" in result assert "API Error" in result["error"] class TestRetrievalIntegration: """Integration-style tests for retrieval functionality.""" @pytest.mark.asyncio async def test_full_retrieval_workflow(self): """Test complete retrieval workflow with both tools.""" # Mock both retrieval calls standards_response = RetrievalResponse( results=[ {"id": "std_1", "title": "ISO 26262-1", "document_category": "Standard"} ], took_ms=100, total_count=1 ) chunks_response = RetrievalResponse( results=[ {"id": "chunk_1", "title": "Safety Requirements", "content": "Detailed safety requirements..."} ], took_ms=150, total_count=1 ) with patch('service.retrieval.clients.AzureSearchClient.search_azure_ai') as mock_search: # Set up mock to return different responses for different calls def mock_response_side_effect(*args, **kwargs): if kwargs.get('index_name') == 'index-catonline-standard-regulation-v2-prd': return { "value": [{"id": "std_1", "title": "ISO 26262-1", "document_category": "Standard", "@order_num": 1}] } else: return { "value": [{"id": "chunk_1", "title": "Safety Requirements", "content": "Detailed safety requirements...", "@order_num": 1}] } mock_search.side_effect = mock_response_side_effect # Import and test both tools from service.graph.tools import retrieve_standard_regulation, retrieve_doc_chunk_standard_regulation # Test standards search std_result = await retrieve_standard_regulation.ainvoke({"query": "ISO 26262"}) assert isinstance(std_result, dict) assert "results" in std_result assert len(std_result["results"]) == 1 # Test chunks search chunk_result = await retrieve_doc_chunk_standard_regulation.ainvoke({"query": "safety requirements"}) assert isinstance(chunk_result, dict) assert "results" in chunk_result assert len(chunk_result["results"]) == 1