Files
catonline_ai/vw-agentic-rag/tests/unit/test_user_manual_tool.py
2025-09-26 17:15:54 +08:00

96 lines
4.1 KiB
Python

"""
Unit test for the new retrieve_system_usermanual tool
"""
import pytest
from unittest.mock import AsyncMock, patch
import asyncio
from service.graph.user_manual_tools import retrieve_system_usermanual, user_manual_tools, get_user_manual_tool_schemas, get_user_manual_tools_by_name
from service.graph.tools import get_tool_schemas, get_tools_by_name
from service.retrieval.retrieval import RetrievalResponse
class TestRetrieveSystemUsermanualTool:
"""Test the new user manual retrieval tool"""
def test_tool_in_tools_list(self):
"""Test that the new tool is in the tools list"""
tool_names = [tool.name for tool in user_manual_tools]
assert "retrieve_system_usermanual" in tool_names
assert len(user_manual_tools) == 1 # Should have 1 user manual tool
def test_tool_schemas_generation(self):
"""Test that tool schemas are generated correctly"""
schemas = get_user_manual_tool_schemas()
tool_names = [schema["function"]["name"] for schema in schemas]
assert "retrieve_system_usermanual" in tool_names
# Find the user manual tool schema
user_manual_schema = next(
schema for schema in schemas
if schema["function"]["name"] == "retrieve_system_usermanual"
)
assert user_manual_schema["function"]["description"] == "Search for document content chunks of user manual of this system(CATOnline)"
assert "query" in user_manual_schema["function"]["parameters"]["properties"]
assert user_manual_schema["function"]["parameters"]["required"] == ["query"]
def test_tools_by_name_mapping(self):
"""Test that tools_by_name mapping includes the new tool"""
tools_mapping = get_user_manual_tools_by_name()
assert "retrieve_system_usermanual" in tools_mapping
assert tools_mapping["retrieve_system_usermanual"] == retrieve_system_usermanual
@pytest.mark.asyncio
async def test_retrieve_system_usermanual_success(self):
"""Test successful user manual retrieval"""
mock_response = RetrievalResponse(
results=[
{"title": "User Manual Chapter 1", "content": "How to use the system", "id": "manual_1"}
],
took_ms=120,
total_count=1
)
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
mock_instance = AsyncMock()
mock_instance.retrieve_doc_chunk_user_manual.return_value = mock_response
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
mock_retrieval_class.return_value.__aexit__.return_value = None
# Test the tool
result = await retrieve_system_usermanual.ainvoke({"query": "how to use system"})
# Verify result format
assert isinstance(result, dict)
assert result["tool_name"] == "retrieve_system_usermanual"
assert result["results_count"] == 1
assert len(result["results"]) == 1
assert result["results"][0]["title"] == "User Manual Chapter 1"
assert result["took_ms"] == 120
@pytest.mark.asyncio
async def test_retrieve_system_usermanual_error(self):
"""Test error handling in user manual retrieval"""
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
mock_instance = AsyncMock()
mock_instance.retrieve_doc_chunk_user_manual.side_effect = Exception("Search API Error")
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
mock_retrieval_class.return_value.__aexit__.return_value = None
# Test error handling
result = await retrieve_system_usermanual.ainvoke({"query": "test query"})
# Should return error information
assert isinstance(result, dict)
assert "error" in result
assert "Search API Error" in result["error"]
assert result["results_count"] == 0
assert result["results"] == []
if __name__ == "__main__":
pytest.main([__file__, "-v"])