Files
catonline_ai/vw-agentic-rag/tests/unit/test_message_trimmer.py
2025-09-26 17:15:54 +08:00

195 lines
7.8 KiB
Python

"""
Test conversation history trimming functionality.
"""
import pytest
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
from service.graph.message_trimmer import ConversationTrimmer, create_conversation_trimmer
def test_conversation_trimmer_basic():
"""Test basic trimming functionality."""
trimmer = ConversationTrimmer(max_context_length=50) # Very low limit for testing (85% = 42 tokens)
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="What is the capital of France?"),
AIMessage(content="The capital of France is Paris."),
HumanMessage(content="What about Germany?"),
AIMessage(content="The capital of Germany is Berlin."),
HumanMessage(content="And Italy?"),
AIMessage(content="The capital of Italy is Rome."),
]
# Should trigger trimming due to low token limit
assert trimmer.should_trim(messages)
trimmed = trimmer.trim_conversation_history(messages)
# Should preserve system message and keep recent messages
assert len(trimmed) < len(messages)
assert isinstance(trimmed[0], SystemMessage)
assert "helpful assistant" in trimmed[0].content
def test_conversation_trimmer_no_trim_needed():
"""Test when no trimming is needed."""
trimmer = ConversationTrimmer(max_context_length=100000) # High limit
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hello"),
AIMessage(content="Hi there!"),
]
# Should not need trimming
assert not trimmer.should_trim(messages)
trimmed = trimmer.trim_conversation_history(messages)
# Should return same messages
assert len(trimmed) == len(messages)
def test_conversation_trimmer_fallback():
"""Test fallback trimming logic."""
trimmer = ConversationTrimmer(max_context_length=100)
# Create many messages to trigger fallback
messages = [SystemMessage(content="System")] + [
HumanMessage(content=f"Message {i}") for i in range(50)
]
trimmed = trimmer._fallback_trim(messages, max_messages=5)
# Should keep system message + 4 recent messages
assert len(trimmed) == 5
assert isinstance(trimmed[0], SystemMessage)
def test_create_conversation_trimmer():
"""Test trimmer factory function."""
# Test with explicit configuration
custom_trimmer = create_conversation_trimmer(max_context_length=128000)
assert custom_trimmer.max_context_length == 128000
assert isinstance(custom_trimmer, ConversationTrimmer)
assert custom_trimmer.preserve_system is True
def test_multi_round_tool_optimization():
"""Test multi-round tool call optimization."""
trimmer = ConversationTrimmer(max_context_length=100000) # High limit to focus on optimization
# Create a multi-round tool calling scenario
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Search for food safety and cosmetics testing standards."),
# First tool round
AIMessage(content="I'll search for information.", tool_calls=[
{"id": "call_1", "name": "search_tool", "args": {"query": "food safety"}}
]),
ToolMessage(content='{"results": ["Large food safety result data..." * 100]}', tool_call_id="call_1", name="search_tool"),
# Second tool round
AIMessage(content="Let me search for more specific information.", tool_calls=[
{"id": "call_2", "name": "search_tool", "args": {"query": "cosmetics testing"}}
]),
ToolMessage(content='{"results": ["Large cosmetics testing data..." * 100]}', tool_call_id="call_2", name="search_tool"),
# Third tool round (most recent)
AIMessage(content="Now let me get equipment information.", tool_calls=[
{"id": "call_3", "name": "search_tool", "args": {"query": "testing equipment"}}
]),
ToolMessage(content='{"results": ["Testing equipment info..." * 50]}', tool_call_id="call_3", name="search_tool"),
]
# Test tool round identification
tool_rounds = trimmer._identify_tool_rounds(messages)
assert len(tool_rounds) == 3 # Should identify 3 tool rounds
# Test optimization
optimized = trimmer._optimize_multi_round_tool_calls(messages)
# Should preserve: SystemMessage + HumanMessage + latest tool round (AI + Tool messages)
expected_length = 1 + 1 + 2 # system + human + (latest AI + latest Tool)
assert len(optimized) == expected_length
# Check preserved messages
assert isinstance(optimized[0], SystemMessage)
assert isinstance(optimized[1], HumanMessage)
assert isinstance(optimized[2], AIMessage)
assert optimized[2].tool_calls[0]["id"] == "call_3" # Should be the latest tool call
assert isinstance(optimized[3], ToolMessage)
assert optimized[3].tool_call_id == "call_3" # Should be the latest tool result
def test_multi_round_optimization_single_round():
"""Test that single tool round is not optimized."""
trimmer = ConversationTrimmer(max_context_length=100000)
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Search for something."),
AIMessage(content="I'll search.", tool_calls=[{"id": "call_1", "name": "search_tool", "args": {}}]),
ToolMessage(content='{"results": ["data"]}', tool_call_id="call_1", name="search_tool"),
]
optimized = trimmer._optimize_multi_round_tool_calls(messages)
# Should return all messages unchanged for single round
assert len(optimized) == len(messages)
def test_multi_round_optimization_no_tools():
"""Test that conversations without tools are not optimized."""
trimmer = ConversationTrimmer(max_context_length=100000)
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hello"),
AIMessage(content="Hi there!"),
HumanMessage(content="How are you?"),
AIMessage(content="I'm doing well, thanks!"),
]
optimized = trimmer._optimize_multi_round_tool_calls(messages)
# Should return all messages unchanged
assert len(optimized) == len(messages)
def test_full_trimming_with_optimization():
"""Test complete trimming process with multi-round optimization."""
trimmer = ConversationTrimmer(max_context_length=50) # Very low limit
# Create messages that will trigger both optimization and regular trimming
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Search for food safety and cosmetics testing standards."),
# First tool round (should be removed by optimization)
AIMessage(content="I'll search for information.", tool_calls=[
{"id": "call_1", "name": "search_tool", "args": {"query": "food safety"}}
]),
ToolMessage(content='{"results": ["Large food safety result data..." * 100]}', tool_call_id="call_1", name="search_tool"),
# Second tool round (most recent, should be preserved)
AIMessage(content="Let me search for more information.", tool_calls=[
{"id": "call_2", "name": "search_tool", "args": {"query": "cosmetics testing"}}
]),
ToolMessage(content='{"results": ["Large cosmetics testing data..." * 100]}', tool_call_id="call_2", name="search_tool"),
]
trimmed = trimmer.trim_conversation_history(messages)
# Should apply optimization first, then regular trimming if needed
assert len(trimmed) <= len(messages)
# System message should always be preserved
assert any(isinstance(msg, SystemMessage) for msg in trimmed)
if __name__ == "__main__":
pytest.main([__file__])