init
This commit is contained in:
95
vw-agentic-rag/tests/unit/test_user_manual_tool.py
Normal file
95
vw-agentic-rag/tests/unit/test_user_manual_tool.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Unit test for the new retrieve_system_usermanual tool
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
import asyncio
|
||||
|
||||
from service.graph.user_manual_tools import retrieve_system_usermanual, user_manual_tools, get_user_manual_tool_schemas, get_user_manual_tools_by_name
|
||||
from service.graph.tools import get_tool_schemas, get_tools_by_name
|
||||
from service.retrieval.retrieval import RetrievalResponse
|
||||
|
||||
|
||||
class TestRetrieveSystemUsermanualTool:
|
||||
"""Test the new user manual retrieval tool"""
|
||||
|
||||
def test_tool_in_tools_list(self):
|
||||
"""Test that the new tool is in the tools list"""
|
||||
tool_names = [tool.name for tool in user_manual_tools]
|
||||
assert "retrieve_system_usermanual" in tool_names
|
||||
assert len(user_manual_tools) == 1 # Should have 1 user manual tool
|
||||
|
||||
def test_tool_schemas_generation(self):
|
||||
"""Test that tool schemas are generated correctly"""
|
||||
schemas = get_user_manual_tool_schemas()
|
||||
tool_names = [schema["function"]["name"] for schema in schemas]
|
||||
assert "retrieve_system_usermanual" in tool_names
|
||||
|
||||
# Find the user manual tool schema
|
||||
user_manual_schema = next(
|
||||
schema for schema in schemas
|
||||
if schema["function"]["name"] == "retrieve_system_usermanual"
|
||||
)
|
||||
|
||||
assert user_manual_schema["function"]["description"] == "Search for document content chunks of user manual of this system(CATOnline)"
|
||||
assert "query" in user_manual_schema["function"]["parameters"]["properties"]
|
||||
assert user_manual_schema["function"]["parameters"]["required"] == ["query"]
|
||||
|
||||
def test_tools_by_name_mapping(self):
|
||||
"""Test that tools_by_name mapping includes the new tool"""
|
||||
tools_mapping = get_user_manual_tools_by_name()
|
||||
assert "retrieve_system_usermanual" in tools_mapping
|
||||
assert tools_mapping["retrieve_system_usermanual"] == retrieve_system_usermanual
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_system_usermanual_success(self):
|
||||
"""Test successful user manual retrieval"""
|
||||
|
||||
mock_response = RetrievalResponse(
|
||||
results=[
|
||||
{"title": "User Manual Chapter 1", "content": "How to use the system", "id": "manual_1"}
|
||||
],
|
||||
took_ms=120,
|
||||
total_count=1
|
||||
)
|
||||
|
||||
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.retrieve_doc_chunk_user_manual.return_value = mock_response
|
||||
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
|
||||
mock_retrieval_class.return_value.__aexit__.return_value = None
|
||||
|
||||
# Test the tool
|
||||
result = await retrieve_system_usermanual.ainvoke({"query": "how to use system"})
|
||||
|
||||
# Verify result format
|
||||
assert isinstance(result, dict)
|
||||
assert result["tool_name"] == "retrieve_system_usermanual"
|
||||
assert result["results_count"] == 1
|
||||
assert len(result["results"]) == 1
|
||||
assert result["results"][0]["title"] == "User Manual Chapter 1"
|
||||
assert result["took_ms"] == 120
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_system_usermanual_error(self):
|
||||
"""Test error handling in user manual retrieval"""
|
||||
|
||||
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.retrieve_doc_chunk_user_manual.side_effect = Exception("Search API Error")
|
||||
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
|
||||
mock_retrieval_class.return_value.__aexit__.return_value = None
|
||||
|
||||
# Test error handling
|
||||
result = await retrieve_system_usermanual.ainvoke({"query": "test query"})
|
||||
|
||||
# Should return error information
|
||||
assert isinstance(result, dict)
|
||||
assert "error" in result
|
||||
assert "Search API Error" in result["error"]
|
||||
assert result["results_count"] == 0
|
||||
assert result["results"] == []
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user