Files
catonline_ai/vw-agentic-rag/tests/unit/test_trimming_fix.py
2025-09-26 17:15:54 +08:00

201 lines
7.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
测试修改后的修剪逻辑:验证只在 tool_rounds=0 时修剪,其他时候跳过修剪
"""
import pytest
from unittest.mock import patch, MagicMock
import asyncio
from datetime import datetime
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from service.graph.state import AgentState
from service.graph.graph import call_model
from service.graph.user_manual_rag import user_manual_agent_node
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
@pytest.mark.asyncio
async def test_trimming_only_at_tool_rounds_zero():
"""测试修剪只在 tool_rounds=0 时发生"""
# 创建一个大的消息列表来触发修剪
large_messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="请搜索信息"),
]
# 添加大量内容来触发修剪
for i in range(10):
large_messages.extend([
AIMessage(content=f"搜索{i}", tool_calls=[
{"id": f"call_{i}", "name": "search", "args": {"query": f"test{i}"}}
]),
ToolMessage(content="很长的搜索结果内容 " * 500, tool_call_id=f"call_{i}", name="search"),
])
# Mock trimmer to track calls
with patch('service.graph.graph.create_conversation_trimmer') as mock_trimmer_factory:
mock_trimmer = MagicMock()
mock_trimmer.should_trim.return_value = True # 总是返回需要修剪
mock_trimmer.trim_conversation_history.return_value = large_messages[:3] # 返回修剪后的消息
mock_trimmer_factory.return_value = mock_trimmer
# Mock LLM client
with patch('service.graph.graph.LLMClient') as mock_llm_factory:
mock_llm = MagicMock()
# 设置异步方法的返回值
async def mock_ainvoke_with_tools(*args, **kwargs):
return AIMessage(content="最终答案")
async def mock_astream(*args, **kwargs):
for token in ["", "", "", ""]:
yield token
mock_llm.ainvoke_with_tools = mock_ainvoke_with_tools
mock_llm.astream = mock_astream
mock_llm.bind_tools = MagicMock()
mock_llm_factory.return_value = mock_llm
# 测试 tool_rounds=0 的情况(应该修剪)
from service.graph.state import AgentState
state_round_0 = AgentState(
messages=large_messages,
session_id="test",
intent=None,
tool_results=[],
final_answer="",
tool_rounds=0, # 第一轮
max_tool_rounds=3,
max_tool_rounds_user_manual=2
)
result = await call_model(state_round_0)
# 验证修剪器被调用
mock_trimmer.should_trim.assert_called()
mock_trimmer.trim_conversation_history.assert_called()
# 重置 mock
mock_trimmer.reset_mock()
# 测试 tool_rounds>0 的情况(不应该修剪)
from service.graph.state import AgentState
state_round_1 = AgentState(
messages=large_messages,
session_id="test",
intent=None,
tool_results=[],
final_answer="",
tool_rounds=1, # 第二轮
max_tool_rounds=3,
max_tool_rounds_user_manual=2
)
result = await call_model(state_round_1)
# 验证修剪器没有被调用
mock_trimmer.should_trim.assert_not_called()
mock_trimmer.trim_conversation_history.assert_not_called()
@pytest.mark.asyncio
async def test_user_manual_agent_trimming_behavior():
"""测试 user_manual_agent_node 的修剪行为"""
# 创建大的消息列表
large_messages = [
SystemMessage(content="You are a helpful user manual assistant."),
HumanMessage(content="请帮我查找系统使用说明"),
]
# 添加大量内容
for i in range(8):
large_messages.extend([
AIMessage(content=f"查找{i}", tool_calls=[
{"id": f"manual_call_{i}", "name": "retrieve_system_usermanual", "args": {"query": f"功能{i}"}}
]),
ToolMessage(content="很长的手册内容 " * 400, tool_call_id=f"manual_call_{i}", name="retrieve_system_usermanual"),
])
# Mock dependencies
with patch('service.graph.user_manual_rag.create_conversation_trimmer') as mock_trimmer_factory, \
patch('service.graph.user_manual_rag.LLMClient') as mock_llm_factory, \
patch('service.graph.user_manual_rag.get_user_manual_tool_schemas') as mock_tools, \
patch('service.graph.user_manual_rag.get_cached_config') as mock_config:
# Setup mocks
mock_trimmer = MagicMock()
mock_trimmer.should_trim.return_value = True
mock_trimmer.trim_conversation_history.return_value = large_messages[:3]
mock_trimmer_factory.return_value = mock_trimmer
mock_llm = MagicMock()
# 设置异步方法的返回值
async def mock_ainvoke_with_tools(*args, **kwargs):
return AIMessage(content="用户手册答案")
async def mock_astream(*args, **kwargs):
for token in ["", ""]:
yield token
mock_llm.ainvoke_with_tools = mock_ainvoke_with_tools
mock_llm.astream = mock_astream
mock_llm.bind_tools = MagicMock()
mock_llm_factory.return_value = mock_llm
mock_tools.return_value = []
mock_app_config = MagicMock()
mock_app_config.app.max_tool_rounds_user_manual = 2
mock_app_config.get_rag_prompts.return_value = {
"user_manual_prompt": "System prompt for user manual: {conversation_history} {context_content} {current_query}"
}
mock_config.return_value = mock_app_config
# 测试 tool_rounds=0应该修剪
from service.graph.state import AgentState
state_round_0 = AgentState(
messages=large_messages,
session_id="test",
intent=None,
tool_results=[],
final_answer="",
tool_rounds=0,
max_tool_rounds=3,
max_tool_rounds_user_manual=2
)
result = await user_manual_agent_node(state_round_0)
# 验证修剪器被调用
mock_trimmer.should_trim.assert_called()
mock_trimmer.trim_conversation_history.assert_called()
# 重置 mock
mock_trimmer.reset_mock()
# 测试 tool_rounds=1不应该修剪
state_round_1 = AgentState(
messages=large_messages,
session_id="test",
intent=None,
tool_results=[],
final_answer="",
tool_rounds=1,
max_tool_rounds=3,
max_tool_rounds_user_manual=2
)
result = await user_manual_agent_node(state_round_1)
# 验证修剪器没有被调用
mock_trimmer.should_trim.assert_not_called()
mock_trimmer.trim_conversation_history.assert_not_called()
if __name__ == "__main__":
pytest.main([__file__, "-v"])