""" Test conversation history trimming functionality. """ import pytest from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage from service.graph.message_trimmer import ConversationTrimmer, create_conversation_trimmer def test_conversation_trimmer_basic(): """Test basic trimming functionality.""" trimmer = ConversationTrimmer(max_context_length=50) # Very low limit for testing (85% = 42 tokens) messages = [ SystemMessage(content="You are a helpful assistant."), HumanMessage(content="What is the capital of France?"), AIMessage(content="The capital of France is Paris."), HumanMessage(content="What about Germany?"), AIMessage(content="The capital of Germany is Berlin."), HumanMessage(content="And Italy?"), AIMessage(content="The capital of Italy is Rome."), ] # Should trigger trimming due to low token limit assert trimmer.should_trim(messages) trimmed = trimmer.trim_conversation_history(messages) # Should preserve system message and keep recent messages assert len(trimmed) < len(messages) assert isinstance(trimmed[0], SystemMessage) assert "helpful assistant" in trimmed[0].content def test_conversation_trimmer_no_trim_needed(): """Test when no trimming is needed.""" trimmer = ConversationTrimmer(max_context_length=100000) # High limit messages = [ SystemMessage(content="You are a helpful assistant."), HumanMessage(content="Hello"), AIMessage(content="Hi there!"), ] # Should not need trimming assert not trimmer.should_trim(messages) trimmed = trimmer.trim_conversation_history(messages) # Should return same messages assert len(trimmed) == len(messages) def test_conversation_trimmer_fallback(): """Test fallback trimming logic.""" trimmer = ConversationTrimmer(max_context_length=100) # Create many messages to trigger fallback messages = [SystemMessage(content="System")] + [ HumanMessage(content=f"Message {i}") for i in range(50) ] trimmed = trimmer._fallback_trim(messages, max_messages=5) # Should keep system message + 4 recent messages assert len(trimmed) == 5 assert isinstance(trimmed[0], SystemMessage) def test_create_conversation_trimmer(): """Test trimmer factory function.""" # Test with explicit configuration custom_trimmer = create_conversation_trimmer(max_context_length=128000) assert custom_trimmer.max_context_length == 128000 assert isinstance(custom_trimmer, ConversationTrimmer) assert custom_trimmer.preserve_system is True def test_multi_round_tool_optimization(): """Test multi-round tool call optimization.""" trimmer = ConversationTrimmer(max_context_length=100000) # High limit to focus on optimization # Create a multi-round tool calling scenario messages = [ SystemMessage(content="You are a helpful assistant."), HumanMessage(content="Search for food safety and cosmetics testing standards."), # First tool round AIMessage(content="I'll search for information.", tool_calls=[ {"id": "call_1", "name": "search_tool", "args": {"query": "food safety"}} ]), ToolMessage(content='{"results": ["Large food safety result data..." * 100]}', tool_call_id="call_1", name="search_tool"), # Second tool round AIMessage(content="Let me search for more specific information.", tool_calls=[ {"id": "call_2", "name": "search_tool", "args": {"query": "cosmetics testing"}} ]), ToolMessage(content='{"results": ["Large cosmetics testing data..." * 100]}', tool_call_id="call_2", name="search_tool"), # Third tool round (most recent) AIMessage(content="Now let me get equipment information.", tool_calls=[ {"id": "call_3", "name": "search_tool", "args": {"query": "testing equipment"}} ]), ToolMessage(content='{"results": ["Testing equipment info..." * 50]}', tool_call_id="call_3", name="search_tool"), ] # Test tool round identification tool_rounds = trimmer._identify_tool_rounds(messages) assert len(tool_rounds) == 3 # Should identify 3 tool rounds # Test optimization optimized = trimmer._optimize_multi_round_tool_calls(messages) # Should preserve: SystemMessage + HumanMessage + latest tool round (AI + Tool messages) expected_length = 1 + 1 + 2 # system + human + (latest AI + latest Tool) assert len(optimized) == expected_length # Check preserved messages assert isinstance(optimized[0], SystemMessage) assert isinstance(optimized[1], HumanMessage) assert isinstance(optimized[2], AIMessage) assert optimized[2].tool_calls[0]["id"] == "call_3" # Should be the latest tool call assert isinstance(optimized[3], ToolMessage) assert optimized[3].tool_call_id == "call_3" # Should be the latest tool result def test_multi_round_optimization_single_round(): """Test that single tool round is not optimized.""" trimmer = ConversationTrimmer(max_context_length=100000) messages = [ SystemMessage(content="You are a helpful assistant."), HumanMessage(content="Search for something."), AIMessage(content="I'll search.", tool_calls=[{"id": "call_1", "name": "search_tool", "args": {}}]), ToolMessage(content='{"results": ["data"]}', tool_call_id="call_1", name="search_tool"), ] optimized = trimmer._optimize_multi_round_tool_calls(messages) # Should return all messages unchanged for single round assert len(optimized) == len(messages) def test_multi_round_optimization_no_tools(): """Test that conversations without tools are not optimized.""" trimmer = ConversationTrimmer(max_context_length=100000) messages = [ SystemMessage(content="You are a helpful assistant."), HumanMessage(content="Hello"), AIMessage(content="Hi there!"), HumanMessage(content="How are you?"), AIMessage(content="I'm doing well, thanks!"), ] optimized = trimmer._optimize_multi_round_tool_calls(messages) # Should return all messages unchanged assert len(optimized) == len(messages) def test_full_trimming_with_optimization(): """Test complete trimming process with multi-round optimization.""" trimmer = ConversationTrimmer(max_context_length=50) # Very low limit # Create messages that will trigger both optimization and regular trimming messages = [ SystemMessage(content="You are a helpful assistant."), HumanMessage(content="Search for food safety and cosmetics testing standards."), # First tool round (should be removed by optimization) AIMessage(content="I'll search for information.", tool_calls=[ {"id": "call_1", "name": "search_tool", "args": {"query": "food safety"}} ]), ToolMessage(content='{"results": ["Large food safety result data..." * 100]}', tool_call_id="call_1", name="search_tool"), # Second tool round (most recent, should be preserved) AIMessage(content="Let me search for more information.", tool_calls=[ {"id": "call_2", "name": "search_tool", "args": {"query": "cosmetics testing"}} ]), ToolMessage(content='{"results": ["Large cosmetics testing data..." * 100]}', tool_call_id="call_2", name="search_tool"), ] trimmed = trimmer.trim_conversation_history(messages) # Should apply optimization first, then regular trimming if needed assert len(trimmed) <= len(messages) # System message should always be preserved assert any(isinstance(msg, SystemMessage) for msg in trimmed) if __name__ == "__main__": pytest.main([__file__])