init
This commit is contained in:
194
vw-agentic-rag/tests/unit/test_message_trimmer.py
Normal file
194
vw-agentic-rag/tests/unit/test_message_trimmer.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Test conversation history trimming functionality.
|
||||
"""
|
||||
import pytest
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
|
||||
|
||||
from service.graph.message_trimmer import ConversationTrimmer, create_conversation_trimmer
|
||||
|
||||
|
||||
def test_conversation_trimmer_basic():
|
||||
"""Test basic trimming functionality."""
|
||||
trimmer = ConversationTrimmer(max_context_length=50) # Very low limit for testing (85% = 42 tokens)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="What is the capital of France?"),
|
||||
AIMessage(content="The capital of France is Paris."),
|
||||
HumanMessage(content="What about Germany?"),
|
||||
AIMessage(content="The capital of Germany is Berlin."),
|
||||
HumanMessage(content="And Italy?"),
|
||||
AIMessage(content="The capital of Italy is Rome."),
|
||||
]
|
||||
|
||||
# Should trigger trimming due to low token limit
|
||||
assert trimmer.should_trim(messages)
|
||||
|
||||
trimmed = trimmer.trim_conversation_history(messages)
|
||||
|
||||
# Should preserve system message and keep recent messages
|
||||
assert len(trimmed) < len(messages)
|
||||
assert isinstance(trimmed[0], SystemMessage)
|
||||
assert "helpful assistant" in trimmed[0].content
|
||||
|
||||
|
||||
def test_conversation_trimmer_no_trim_needed():
|
||||
"""Test when no trimming is needed."""
|
||||
trimmer = ConversationTrimmer(max_context_length=100000) # High limit
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hi there!"),
|
||||
]
|
||||
|
||||
# Should not need trimming
|
||||
assert not trimmer.should_trim(messages)
|
||||
|
||||
trimmed = trimmer.trim_conversation_history(messages)
|
||||
|
||||
# Should return same messages
|
||||
assert len(trimmed) == len(messages)
|
||||
|
||||
|
||||
def test_conversation_trimmer_fallback():
|
||||
"""Test fallback trimming logic."""
|
||||
trimmer = ConversationTrimmer(max_context_length=100)
|
||||
|
||||
# Create many messages to trigger fallback
|
||||
messages = [SystemMessage(content="System")] + [
|
||||
HumanMessage(content=f"Message {i}") for i in range(50)
|
||||
]
|
||||
|
||||
trimmed = trimmer._fallback_trim(messages, max_messages=5)
|
||||
|
||||
# Should keep system message + 4 recent messages
|
||||
assert len(trimmed) == 5
|
||||
assert isinstance(trimmed[0], SystemMessage)
|
||||
|
||||
|
||||
def test_create_conversation_trimmer():
|
||||
"""Test trimmer factory function."""
|
||||
# Test with explicit configuration
|
||||
custom_trimmer = create_conversation_trimmer(max_context_length=128000)
|
||||
assert custom_trimmer.max_context_length == 128000
|
||||
assert isinstance(custom_trimmer, ConversationTrimmer)
|
||||
assert custom_trimmer.preserve_system is True
|
||||
|
||||
|
||||
def test_multi_round_tool_optimization():
|
||||
"""Test multi-round tool call optimization."""
|
||||
trimmer = ConversationTrimmer(max_context_length=100000) # High limit to focus on optimization
|
||||
|
||||
# Create a multi-round tool calling scenario
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Search for food safety and cosmetics testing standards."),
|
||||
|
||||
# First tool round
|
||||
AIMessage(content="I'll search for information.", tool_calls=[
|
||||
{"id": "call_1", "name": "search_tool", "args": {"query": "food safety"}}
|
||||
]),
|
||||
ToolMessage(content='{"results": ["Large food safety result data..." * 100]}', tool_call_id="call_1", name="search_tool"),
|
||||
|
||||
# Second tool round
|
||||
AIMessage(content="Let me search for more specific information.", tool_calls=[
|
||||
{"id": "call_2", "name": "search_tool", "args": {"query": "cosmetics testing"}}
|
||||
]),
|
||||
ToolMessage(content='{"results": ["Large cosmetics testing data..." * 100]}', tool_call_id="call_2", name="search_tool"),
|
||||
|
||||
# Third tool round (most recent)
|
||||
AIMessage(content="Now let me get equipment information.", tool_calls=[
|
||||
{"id": "call_3", "name": "search_tool", "args": {"query": "testing equipment"}}
|
||||
]),
|
||||
ToolMessage(content='{"results": ["Testing equipment info..." * 50]}', tool_call_id="call_3", name="search_tool"),
|
||||
]
|
||||
|
||||
# Test tool round identification
|
||||
tool_rounds = trimmer._identify_tool_rounds(messages)
|
||||
assert len(tool_rounds) == 3 # Should identify 3 tool rounds
|
||||
|
||||
# Test optimization
|
||||
optimized = trimmer._optimize_multi_round_tool_calls(messages)
|
||||
|
||||
# Should preserve: SystemMessage + HumanMessage + latest tool round (AI + Tool messages)
|
||||
expected_length = 1 + 1 + 2 # system + human + (latest AI + latest Tool)
|
||||
assert len(optimized) == expected_length
|
||||
|
||||
# Check preserved messages
|
||||
assert isinstance(optimized[0], SystemMessage)
|
||||
assert isinstance(optimized[1], HumanMessage)
|
||||
assert isinstance(optimized[2], AIMessage)
|
||||
assert optimized[2].tool_calls[0]["id"] == "call_3" # Should be the latest tool call
|
||||
assert isinstance(optimized[3], ToolMessage)
|
||||
assert optimized[3].tool_call_id == "call_3" # Should be the latest tool result
|
||||
|
||||
|
||||
def test_multi_round_optimization_single_round():
|
||||
"""Test that single tool round is not optimized."""
|
||||
trimmer = ConversationTrimmer(max_context_length=100000)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Search for something."),
|
||||
AIMessage(content="I'll search.", tool_calls=[{"id": "call_1", "name": "search_tool", "args": {}}]),
|
||||
ToolMessage(content='{"results": ["data"]}', tool_call_id="call_1", name="search_tool"),
|
||||
]
|
||||
|
||||
optimized = trimmer._optimize_multi_round_tool_calls(messages)
|
||||
|
||||
# Should return all messages unchanged for single round
|
||||
assert len(optimized) == len(messages)
|
||||
|
||||
|
||||
def test_multi_round_optimization_no_tools():
|
||||
"""Test that conversations without tools are not optimized."""
|
||||
trimmer = ConversationTrimmer(max_context_length=100000)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hi there!"),
|
||||
HumanMessage(content="How are you?"),
|
||||
AIMessage(content="I'm doing well, thanks!"),
|
||||
]
|
||||
|
||||
optimized = trimmer._optimize_multi_round_tool_calls(messages)
|
||||
|
||||
# Should return all messages unchanged
|
||||
assert len(optimized) == len(messages)
|
||||
|
||||
|
||||
def test_full_trimming_with_optimization():
|
||||
"""Test complete trimming process with multi-round optimization."""
|
||||
trimmer = ConversationTrimmer(max_context_length=50) # Very low limit
|
||||
|
||||
# Create messages that will trigger both optimization and regular trimming
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Search for food safety and cosmetics testing standards."),
|
||||
|
||||
# First tool round (should be removed by optimization)
|
||||
AIMessage(content="I'll search for information.", tool_calls=[
|
||||
{"id": "call_1", "name": "search_tool", "args": {"query": "food safety"}}
|
||||
]),
|
||||
ToolMessage(content='{"results": ["Large food safety result data..." * 100]}', tool_call_id="call_1", name="search_tool"),
|
||||
|
||||
# Second tool round (most recent, should be preserved)
|
||||
AIMessage(content="Let me search for more information.", tool_calls=[
|
||||
{"id": "call_2", "name": "search_tool", "args": {"query": "cosmetics testing"}}
|
||||
]),
|
||||
ToolMessage(content='{"results": ["Large cosmetics testing data..." * 100]}', tool_call_id="call_2", name="search_tool"),
|
||||
]
|
||||
|
||||
trimmed = trimmer.trim_conversation_history(messages)
|
||||
|
||||
# Should apply optimization first, then regular trimming if needed
|
||||
assert len(trimmed) <= len(messages)
|
||||
|
||||
# System message should always be preserved
|
||||
assert any(isinstance(msg, SystemMessage) for msg in trimmed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user