96 lines
4.1 KiB
Python
96 lines
4.1 KiB
Python
"""
|
|
Unit test for the new retrieve_system_usermanual tool
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, patch
|
|
import asyncio
|
|
|
|
from service.graph.user_manual_tools import retrieve_system_usermanual, user_manual_tools, get_user_manual_tool_schemas, get_user_manual_tools_by_name
|
|
from service.graph.tools import get_tool_schemas, get_tools_by_name
|
|
from service.retrieval.retrieval import RetrievalResponse
|
|
|
|
|
|
class TestRetrieveSystemUsermanualTool:
|
|
"""Test the new user manual retrieval tool"""
|
|
|
|
def test_tool_in_tools_list(self):
|
|
"""Test that the new tool is in the tools list"""
|
|
tool_names = [tool.name for tool in user_manual_tools]
|
|
assert "retrieve_system_usermanual" in tool_names
|
|
assert len(user_manual_tools) == 1 # Should have 1 user manual tool
|
|
|
|
def test_tool_schemas_generation(self):
|
|
"""Test that tool schemas are generated correctly"""
|
|
schemas = get_user_manual_tool_schemas()
|
|
tool_names = [schema["function"]["name"] for schema in schemas]
|
|
assert "retrieve_system_usermanual" in tool_names
|
|
|
|
# Find the user manual tool schema
|
|
user_manual_schema = next(
|
|
schema for schema in schemas
|
|
if schema["function"]["name"] == "retrieve_system_usermanual"
|
|
)
|
|
|
|
assert user_manual_schema["function"]["description"] == "Search for document content chunks of user manual of this system(CATOnline)"
|
|
assert "query" in user_manual_schema["function"]["parameters"]["properties"]
|
|
assert user_manual_schema["function"]["parameters"]["required"] == ["query"]
|
|
|
|
def test_tools_by_name_mapping(self):
|
|
"""Test that tools_by_name mapping includes the new tool"""
|
|
tools_mapping = get_user_manual_tools_by_name()
|
|
assert "retrieve_system_usermanual" in tools_mapping
|
|
assert tools_mapping["retrieve_system_usermanual"] == retrieve_system_usermanual
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_system_usermanual_success(self):
|
|
"""Test successful user manual retrieval"""
|
|
|
|
mock_response = RetrievalResponse(
|
|
results=[
|
|
{"title": "User Manual Chapter 1", "content": "How to use the system", "id": "manual_1"}
|
|
],
|
|
took_ms=120,
|
|
total_count=1
|
|
)
|
|
|
|
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
|
|
mock_instance = AsyncMock()
|
|
mock_instance.retrieve_doc_chunk_user_manual.return_value = mock_response
|
|
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
|
|
mock_retrieval_class.return_value.__aexit__.return_value = None
|
|
|
|
# Test the tool
|
|
result = await retrieve_system_usermanual.ainvoke({"query": "how to use system"})
|
|
|
|
# Verify result format
|
|
assert isinstance(result, dict)
|
|
assert result["tool_name"] == "retrieve_system_usermanual"
|
|
assert result["results_count"] == 1
|
|
assert len(result["results"]) == 1
|
|
assert result["results"][0]["title"] == "User Manual Chapter 1"
|
|
assert result["took_ms"] == 120
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_system_usermanual_error(self):
|
|
"""Test error handling in user manual retrieval"""
|
|
|
|
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
|
|
mock_instance = AsyncMock()
|
|
mock_instance.retrieve_doc_chunk_user_manual.side_effect = Exception("Search API Error")
|
|
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
|
|
mock_retrieval_class.return_value.__aexit__.return_value = None
|
|
|
|
# Test error handling
|
|
result = await retrieve_system_usermanual.ainvoke({"query": "test query"})
|
|
|
|
# Should return error information
|
|
assert isinstance(result, dict)
|
|
assert "error" in result
|
|
assert "Search API Error" in result["error"]
|
|
assert result["results_count"] == 0
|
|
assert result["results"] == []
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|