Files
catonline_ai/vw-agentic-rag/tests/unit/test_aggressive_trimming.py
2025-09-26 17:15:54 +08:00

115 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
测试新的积极修剪策略即使token数很少也要修剪历史工具调用结果
"""
import pytest
from service.graph.message_trimmer import ConversationTrimmer
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
from langchain_core.messages.utils import count_tokens_approximately
def test_aggressive_tool_history_trimming():
"""测试积极的工具历史修剪策略"""
# 直接创建修剪器,避免配置依赖
trimmer = ConversationTrimmer(max_context_length=100000)
# 创建包含多轮工具调用的对话token数很少
messages = [
SystemMessage(content='You are a helpful assistant.'),
# 历史对话轮次1
HumanMessage(content='搜索汽车标准'),
AIMessage(content='搜索中', tool_calls=[{'id': 'call_1', 'name': 'search', 'args': {'query': '汽车标准'}}]),
ToolMessage(content='历史结果1', tool_call_id='call_1', name='search'),
AIMessage(content='汽车标准信息'),
# 历史对话轮次2
HumanMessage(content='搜索电池标准'),
AIMessage(content='搜索中', tool_calls=[{'id': 'call_2', 'name': 'search', 'args': {'query': '电池标准'}}]),
ToolMessage(content='历史结果2', tool_call_id='call_2', name='search'),
AIMessage(content='电池标准信息'),
# 新的用户查询(触发修剪的时机)
HumanMessage(content='搜索安全标准'),
]
# 验证token数很少远低于阈值
token_count = count_tokens_approximately(messages)
assert token_count < 1000, f"Token count should be low, got {token_count}"
assert token_count < trimmer.history_token_limit, "Token count should be well below limit"
# 验证识别到多个工具轮次
tool_rounds = trimmer._identify_tool_rounds(messages)
assert len(tool_rounds) == 2, f"Should identify 2 tool rounds, got {len(tool_rounds)}"
# 验证触发修剪(因为有多个工具轮次)
should_trim = trimmer.should_trim(messages)
assert should_trim, "Should trigger trimming due to multiple tool rounds"
# 执行修剪
trimmed = trimmer.trim_conversation_history(messages)
# 验证修剪效果
assert len(trimmed) < len(messages), "Should have fewer messages after trimming"
# 验证保留了系统消息和初始查询
assert isinstance(trimmed[0], SystemMessage), "Should preserve system message"
assert isinstance(trimmed[1], HumanMessage), "Should preserve initial human message"
# 验证只保留了最新轮次的工具调用结果
tool_messages = [msg for msg in trimmed if isinstance(msg, ToolMessage)]
assert len(tool_messages) == 1, f"Should only keep 1 tool message, got {len(tool_messages)}"
assert tool_messages[0].content == '历史结果2', "Should keep the most recent tool result"
def test_single_tool_round_no_trimming():
"""测试单轮工具调用不触发修剪"""
trimmer = ConversationTrimmer(max_context_length=100000)
# 只有一轮工具调用的对话
messages = [
SystemMessage(content='You are a helpful assistant.'),
HumanMessage(content='搜索信息'),
AIMessage(content='搜索中', tool_calls=[{'id': 'call_1', 'name': 'search', 'args': {'query': '信息'}}]),
ToolMessage(content='搜索结果', tool_call_id='call_1', name='search'),
AIMessage(content='这是搜索到的信息'),
HumanMessage(content='新的问题'),
]
# 验证只有一个工具轮次
tool_rounds = trimmer._identify_tool_rounds(messages)
assert len(tool_rounds) == 1, f"Should identify 1 tool round, got {len(tool_rounds)}"
# 验证不触发修剪因为只有一个工具轮次且token数不高
should_trim = trimmer.should_trim(messages)
assert not should_trim, "Should not trigger trimming for single tool round with low tokens"
def test_no_tool_rounds_no_trimming():
"""测试没有工具调用的对话不触发修剪"""
trimmer = ConversationTrimmer(max_context_length=100000)
# 没有工具调用的对话
messages = [
SystemMessage(content='You are a helpful assistant.'),
HumanMessage(content='Hello'),
AIMessage(content='Hi there!'),
HumanMessage(content='How are you?'),
AIMessage(content='I am doing well, thank you!'),
]
# 验证没有工具轮次
tool_rounds = trimmer._identify_tool_rounds(messages)
assert len(tool_rounds) == 0, f"Should identify 0 tool rounds, got {len(tool_rounds)}"
# 验证不触发修剪
should_trim = trimmer.should_trim(messages)
assert not should_trim, "Should not trigger trimming without tool rounds"
if __name__ == "__main__":
pytest.main([__file__, "-v"])