115 lines
4.8 KiB
Python
115 lines
4.8 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
测试新的积极修剪策略:即使token数很少,也要修剪历史工具调用结果
|
||
"""
|
||
import pytest
|
||
from service.graph.message_trimmer import ConversationTrimmer
|
||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
|
||
from langchain_core.messages.utils import count_tokens_approximately
|
||
|
||
|
||
def test_aggressive_tool_history_trimming():
|
||
"""测试积极的工具历史修剪策略"""
|
||
|
||
# 直接创建修剪器,避免配置依赖
|
||
trimmer = ConversationTrimmer(max_context_length=100000)
|
||
|
||
# 创建包含多轮工具调用的对话(token数很少)
|
||
messages = [
|
||
SystemMessage(content='You are a helpful assistant.'),
|
||
|
||
# 历史对话轮次1
|
||
HumanMessage(content='搜索汽车标准'),
|
||
AIMessage(content='搜索中', tool_calls=[{'id': 'call_1', 'name': 'search', 'args': {'query': '汽车标准'}}]),
|
||
ToolMessage(content='历史结果1', tool_call_id='call_1', name='search'),
|
||
AIMessage(content='汽车标准信息'),
|
||
|
||
# 历史对话轮次2
|
||
HumanMessage(content='搜索电池标准'),
|
||
AIMessage(content='搜索中', tool_calls=[{'id': 'call_2', 'name': 'search', 'args': {'query': '电池标准'}}]),
|
||
ToolMessage(content='历史结果2', tool_call_id='call_2', name='search'),
|
||
AIMessage(content='电池标准信息'),
|
||
|
||
# 新的用户查询(触发修剪的时机)
|
||
HumanMessage(content='搜索安全标准'),
|
||
]
|
||
|
||
# 验证token数很少,远低于阈值
|
||
token_count = count_tokens_approximately(messages)
|
||
assert token_count < 1000, f"Token count should be low, got {token_count}"
|
||
assert token_count < trimmer.history_token_limit, "Token count should be well below limit"
|
||
|
||
# 验证识别到多个工具轮次
|
||
tool_rounds = trimmer._identify_tool_rounds(messages)
|
||
assert len(tool_rounds) == 2, f"Should identify 2 tool rounds, got {len(tool_rounds)}"
|
||
|
||
# 验证触发修剪(因为有多个工具轮次)
|
||
should_trim = trimmer.should_trim(messages)
|
||
assert should_trim, "Should trigger trimming due to multiple tool rounds"
|
||
|
||
# 执行修剪
|
||
trimmed = trimmer.trim_conversation_history(messages)
|
||
|
||
# 验证修剪效果
|
||
assert len(trimmed) < len(messages), "Should have fewer messages after trimming"
|
||
|
||
# 验证保留了系统消息和初始查询
|
||
assert isinstance(trimmed[0], SystemMessage), "Should preserve system message"
|
||
assert isinstance(trimmed[1], HumanMessage), "Should preserve initial human message"
|
||
|
||
# 验证只保留了最新轮次的工具调用结果
|
||
tool_messages = [msg for msg in trimmed if isinstance(msg, ToolMessage)]
|
||
assert len(tool_messages) == 1, f"Should only keep 1 tool message, got {len(tool_messages)}"
|
||
assert tool_messages[0].content == '历史结果2', "Should keep the most recent tool result"
|
||
|
||
|
||
def test_single_tool_round_no_trimming():
|
||
"""测试单轮工具调用不触发修剪"""
|
||
|
||
trimmer = ConversationTrimmer(max_context_length=100000)
|
||
|
||
# 只有一轮工具调用的对话
|
||
messages = [
|
||
SystemMessage(content='You are a helpful assistant.'),
|
||
HumanMessage(content='搜索信息'),
|
||
AIMessage(content='搜索中', tool_calls=[{'id': 'call_1', 'name': 'search', 'args': {'query': '信息'}}]),
|
||
ToolMessage(content='搜索结果', tool_call_id='call_1', name='search'),
|
||
AIMessage(content='这是搜索到的信息'),
|
||
HumanMessage(content='新的问题'),
|
||
]
|
||
|
||
# 验证只有一个工具轮次
|
||
tool_rounds = trimmer._identify_tool_rounds(messages)
|
||
assert len(tool_rounds) == 1, f"Should identify 1 tool round, got {len(tool_rounds)}"
|
||
|
||
# 验证不触发修剪(因为只有一个工具轮次且token数不高)
|
||
should_trim = trimmer.should_trim(messages)
|
||
assert not should_trim, "Should not trigger trimming for single tool round with low tokens"
|
||
|
||
|
||
def test_no_tool_rounds_no_trimming():
|
||
"""测试没有工具调用的对话不触发修剪"""
|
||
|
||
trimmer = ConversationTrimmer(max_context_length=100000)
|
||
|
||
# 没有工具调用的对话
|
||
messages = [
|
||
SystemMessage(content='You are a helpful assistant.'),
|
||
HumanMessage(content='Hello'),
|
||
AIMessage(content='Hi there!'),
|
||
HumanMessage(content='How are you?'),
|
||
AIMessage(content='I am doing well, thank you!'),
|
||
]
|
||
|
||
# 验证没有工具轮次
|
||
tool_rounds = trimmer._identify_tool_rounds(messages)
|
||
assert len(tool_rounds) == 0, f"Should identify 0 tool rounds, got {len(tool_rounds)}"
|
||
|
||
# 验证不触发修剪
|
||
should_trim = trimmer.should_trim(messages)
|
||
assert not should_trim, "Should not trigger trimming without tool rounds"
|
||
|
||
|
||
if __name__ == "__main__":
|
||
pytest.main([__file__, "-v"])
|