195 lines
7.8 KiB
Python
195 lines
7.8 KiB
Python
"""
|
|
Test conversation history trimming functionality.
|
|
"""
|
|
import pytest
|
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
|
|
|
|
from service.graph.message_trimmer import ConversationTrimmer, create_conversation_trimmer
|
|
|
|
|
|
def test_conversation_trimmer_basic():
|
|
"""Test basic trimming functionality."""
|
|
trimmer = ConversationTrimmer(max_context_length=50) # Very low limit for testing (85% = 42 tokens)
|
|
|
|
messages = [
|
|
SystemMessage(content="You are a helpful assistant."),
|
|
HumanMessage(content="What is the capital of France?"),
|
|
AIMessage(content="The capital of France is Paris."),
|
|
HumanMessage(content="What about Germany?"),
|
|
AIMessage(content="The capital of Germany is Berlin."),
|
|
HumanMessage(content="And Italy?"),
|
|
AIMessage(content="The capital of Italy is Rome."),
|
|
]
|
|
|
|
# Should trigger trimming due to low token limit
|
|
assert trimmer.should_trim(messages)
|
|
|
|
trimmed = trimmer.trim_conversation_history(messages)
|
|
|
|
# Should preserve system message and keep recent messages
|
|
assert len(trimmed) < len(messages)
|
|
assert isinstance(trimmed[0], SystemMessage)
|
|
assert "helpful assistant" in trimmed[0].content
|
|
|
|
|
|
def test_conversation_trimmer_no_trim_needed():
|
|
"""Test when no trimming is needed."""
|
|
trimmer = ConversationTrimmer(max_context_length=100000) # High limit
|
|
|
|
messages = [
|
|
SystemMessage(content="You are a helpful assistant."),
|
|
HumanMessage(content="Hello"),
|
|
AIMessage(content="Hi there!"),
|
|
]
|
|
|
|
# Should not need trimming
|
|
assert not trimmer.should_trim(messages)
|
|
|
|
trimmed = trimmer.trim_conversation_history(messages)
|
|
|
|
# Should return same messages
|
|
assert len(trimmed) == len(messages)
|
|
|
|
|
|
def test_conversation_trimmer_fallback():
|
|
"""Test fallback trimming logic."""
|
|
trimmer = ConversationTrimmer(max_context_length=100)
|
|
|
|
# Create many messages to trigger fallback
|
|
messages = [SystemMessage(content="System")] + [
|
|
HumanMessage(content=f"Message {i}") for i in range(50)
|
|
]
|
|
|
|
trimmed = trimmer._fallback_trim(messages, max_messages=5)
|
|
|
|
# Should keep system message + 4 recent messages
|
|
assert len(trimmed) == 5
|
|
assert isinstance(trimmed[0], SystemMessage)
|
|
|
|
|
|
def test_create_conversation_trimmer():
|
|
"""Test trimmer factory function."""
|
|
# Test with explicit configuration
|
|
custom_trimmer = create_conversation_trimmer(max_context_length=128000)
|
|
assert custom_trimmer.max_context_length == 128000
|
|
assert isinstance(custom_trimmer, ConversationTrimmer)
|
|
assert custom_trimmer.preserve_system is True
|
|
|
|
|
|
def test_multi_round_tool_optimization():
|
|
"""Test multi-round tool call optimization."""
|
|
trimmer = ConversationTrimmer(max_context_length=100000) # High limit to focus on optimization
|
|
|
|
# Create a multi-round tool calling scenario
|
|
messages = [
|
|
SystemMessage(content="You are a helpful assistant."),
|
|
HumanMessage(content="Search for food safety and cosmetics testing standards."),
|
|
|
|
# First tool round
|
|
AIMessage(content="I'll search for information.", tool_calls=[
|
|
{"id": "call_1", "name": "search_tool", "args": {"query": "food safety"}}
|
|
]),
|
|
ToolMessage(content='{"results": ["Large food safety result data..." * 100]}', tool_call_id="call_1", name="search_tool"),
|
|
|
|
# Second tool round
|
|
AIMessage(content="Let me search for more specific information.", tool_calls=[
|
|
{"id": "call_2", "name": "search_tool", "args": {"query": "cosmetics testing"}}
|
|
]),
|
|
ToolMessage(content='{"results": ["Large cosmetics testing data..." * 100]}', tool_call_id="call_2", name="search_tool"),
|
|
|
|
# Third tool round (most recent)
|
|
AIMessage(content="Now let me get equipment information.", tool_calls=[
|
|
{"id": "call_3", "name": "search_tool", "args": {"query": "testing equipment"}}
|
|
]),
|
|
ToolMessage(content='{"results": ["Testing equipment info..." * 50]}', tool_call_id="call_3", name="search_tool"),
|
|
]
|
|
|
|
# Test tool round identification
|
|
tool_rounds = trimmer._identify_tool_rounds(messages)
|
|
assert len(tool_rounds) == 3 # Should identify 3 tool rounds
|
|
|
|
# Test optimization
|
|
optimized = trimmer._optimize_multi_round_tool_calls(messages)
|
|
|
|
# Should preserve: SystemMessage + HumanMessage + latest tool round (AI + Tool messages)
|
|
expected_length = 1 + 1 + 2 # system + human + (latest AI + latest Tool)
|
|
assert len(optimized) == expected_length
|
|
|
|
# Check preserved messages
|
|
assert isinstance(optimized[0], SystemMessage)
|
|
assert isinstance(optimized[1], HumanMessage)
|
|
assert isinstance(optimized[2], AIMessage)
|
|
assert optimized[2].tool_calls[0]["id"] == "call_3" # Should be the latest tool call
|
|
assert isinstance(optimized[3], ToolMessage)
|
|
assert optimized[3].tool_call_id == "call_3" # Should be the latest tool result
|
|
|
|
|
|
def test_multi_round_optimization_single_round():
|
|
"""Test that single tool round is not optimized."""
|
|
trimmer = ConversationTrimmer(max_context_length=100000)
|
|
|
|
messages = [
|
|
SystemMessage(content="You are a helpful assistant."),
|
|
HumanMessage(content="Search for something."),
|
|
AIMessage(content="I'll search.", tool_calls=[{"id": "call_1", "name": "search_tool", "args": {}}]),
|
|
ToolMessage(content='{"results": ["data"]}', tool_call_id="call_1", name="search_tool"),
|
|
]
|
|
|
|
optimized = trimmer._optimize_multi_round_tool_calls(messages)
|
|
|
|
# Should return all messages unchanged for single round
|
|
assert len(optimized) == len(messages)
|
|
|
|
|
|
def test_multi_round_optimization_no_tools():
|
|
"""Test that conversations without tools are not optimized."""
|
|
trimmer = ConversationTrimmer(max_context_length=100000)
|
|
|
|
messages = [
|
|
SystemMessage(content="You are a helpful assistant."),
|
|
HumanMessage(content="Hello"),
|
|
AIMessage(content="Hi there!"),
|
|
HumanMessage(content="How are you?"),
|
|
AIMessage(content="I'm doing well, thanks!"),
|
|
]
|
|
|
|
optimized = trimmer._optimize_multi_round_tool_calls(messages)
|
|
|
|
# Should return all messages unchanged
|
|
assert len(optimized) == len(messages)
|
|
|
|
|
|
def test_full_trimming_with_optimization():
|
|
"""Test complete trimming process with multi-round optimization."""
|
|
trimmer = ConversationTrimmer(max_context_length=50) # Very low limit
|
|
|
|
# Create messages that will trigger both optimization and regular trimming
|
|
messages = [
|
|
SystemMessage(content="You are a helpful assistant."),
|
|
HumanMessage(content="Search for food safety and cosmetics testing standards."),
|
|
|
|
# First tool round (should be removed by optimization)
|
|
AIMessage(content="I'll search for information.", tool_calls=[
|
|
{"id": "call_1", "name": "search_tool", "args": {"query": "food safety"}}
|
|
]),
|
|
ToolMessage(content='{"results": ["Large food safety result data..." * 100]}', tool_call_id="call_1", name="search_tool"),
|
|
|
|
# Second tool round (most recent, should be preserved)
|
|
AIMessage(content="Let me search for more information.", tool_calls=[
|
|
{"id": "call_2", "name": "search_tool", "args": {"query": "cosmetics testing"}}
|
|
]),
|
|
ToolMessage(content='{"results": ["Large cosmetics testing data..." * 100]}', tool_call_id="call_2", name="search_tool"),
|
|
]
|
|
|
|
trimmed = trimmer.trim_conversation_history(messages)
|
|
|
|
# Should apply optimization first, then regular trimming if needed
|
|
assert len(trimmed) <= len(messages)
|
|
|
|
# System message should always be preserved
|
|
assert any(isinstance(msg, SystemMessage) for msg in trimmed)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|