Files
catonline_ai/vw-agentic-rag/tests/unit/test_aggressive_trimming.py

115 lines
4.8 KiB
Python
Raw Normal View History

2025-09-26 17:15:54 +08:00
#!/usr/bin/env python3
"""
测试新的积极修剪策略即使token数很少也要修剪历史工具调用结果
"""
import pytest
from service.graph.message_trimmer import ConversationTrimmer
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
from langchain_core.messages.utils import count_tokens_approximately
def test_aggressive_tool_history_trimming():
"""测试积极的工具历史修剪策略"""
# 直接创建修剪器,避免配置依赖
trimmer = ConversationTrimmer(max_context_length=100000)
# 创建包含多轮工具调用的对话token数很少
messages = [
SystemMessage(content='You are a helpful assistant.'),
# 历史对话轮次1
HumanMessage(content='搜索汽车标准'),
AIMessage(content='搜索中', tool_calls=[{'id': 'call_1', 'name': 'search', 'args': {'query': '汽车标准'}}]),
ToolMessage(content='历史结果1', tool_call_id='call_1', name='search'),
AIMessage(content='汽车标准信息'),
# 历史对话轮次2
HumanMessage(content='搜索电池标准'),
AIMessage(content='搜索中', tool_calls=[{'id': 'call_2', 'name': 'search', 'args': {'query': '电池标准'}}]),
ToolMessage(content='历史结果2', tool_call_id='call_2', name='search'),
AIMessage(content='电池标准信息'),
# 新的用户查询(触发修剪的时机)
HumanMessage(content='搜索安全标准'),
]
# 验证token数很少远低于阈值
token_count = count_tokens_approximately(messages)
assert token_count < 1000, f"Token count should be low, got {token_count}"
assert token_count < trimmer.history_token_limit, "Token count should be well below limit"
# 验证识别到多个工具轮次
tool_rounds = trimmer._identify_tool_rounds(messages)
assert len(tool_rounds) == 2, f"Should identify 2 tool rounds, got {len(tool_rounds)}"
# 验证触发修剪(因为有多个工具轮次)
should_trim = trimmer.should_trim(messages)
assert should_trim, "Should trigger trimming due to multiple tool rounds"
# 执行修剪
trimmed = trimmer.trim_conversation_history(messages)
# 验证修剪效果
assert len(trimmed) < len(messages), "Should have fewer messages after trimming"
# 验证保留了系统消息和初始查询
assert isinstance(trimmed[0], SystemMessage), "Should preserve system message"
assert isinstance(trimmed[1], HumanMessage), "Should preserve initial human message"
# 验证只保留了最新轮次的工具调用结果
tool_messages = [msg for msg in trimmed if isinstance(msg, ToolMessage)]
assert len(tool_messages) == 1, f"Should only keep 1 tool message, got {len(tool_messages)}"
assert tool_messages[0].content == '历史结果2', "Should keep the most recent tool result"
def test_single_tool_round_no_trimming():
"""测试单轮工具调用不触发修剪"""
trimmer = ConversationTrimmer(max_context_length=100000)
# 只有一轮工具调用的对话
messages = [
SystemMessage(content='You are a helpful assistant.'),
HumanMessage(content='搜索信息'),
AIMessage(content='搜索中', tool_calls=[{'id': 'call_1', 'name': 'search', 'args': {'query': '信息'}}]),
ToolMessage(content='搜索结果', tool_call_id='call_1', name='search'),
AIMessage(content='这是搜索到的信息'),
HumanMessage(content='新的问题'),
]
# 验证只有一个工具轮次
tool_rounds = trimmer._identify_tool_rounds(messages)
assert len(tool_rounds) == 1, f"Should identify 1 tool round, got {len(tool_rounds)}"
# 验证不触发修剪因为只有一个工具轮次且token数不高
should_trim = trimmer.should_trim(messages)
assert not should_trim, "Should not trigger trimming for single tool round with low tokens"
def test_no_tool_rounds_no_trimming():
"""测试没有工具调用的对话不触发修剪"""
trimmer = ConversationTrimmer(max_context_length=100000)
# 没有工具调用的对话
messages = [
SystemMessage(content='You are a helpful assistant.'),
HumanMessage(content='Hello'),
AIMessage(content='Hi there!'),
HumanMessage(content='How are you?'),
AIMessage(content='I am doing well, thank you!'),
]
# 验证没有工具轮次
tool_rounds = trimmer._identify_tool_rounds(messages)
assert len(tool_rounds) == 0, f"Should identify 0 tool rounds, got {len(tool_rounds)}"
# 验证不触发修剪
should_trim = trimmer.should_trim(messages)
assert not should_trim, "Should not trigger trimming without tool rounds"
if __name__ == "__main__":
pytest.main([__file__, "-v"])