init
This commit is contained in:
200
vw-agentic-rag/tests/unit/test_trimming_fix.py
Normal file
200
vw-agentic-rag/tests/unit/test_trimming_fix.py
Normal file
@@ -0,0 +1,200 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试修改后的修剪逻辑:验证只在 tool_rounds=0 时修剪,其他时候跳过修剪
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from service.graph.state import AgentState
|
||||
from service.graph.graph import call_model
|
||||
from service.graph.user_manual_rag import user_manual_agent_node
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trimming_only_at_tool_rounds_zero():
|
||||
"""测试修剪只在 tool_rounds=0 时发生"""
|
||||
|
||||
# 创建一个大的消息列表来触发修剪
|
||||
large_messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="请搜索信息"),
|
||||
]
|
||||
|
||||
# 添加大量内容来触发修剪
|
||||
for i in range(10):
|
||||
large_messages.extend([
|
||||
AIMessage(content=f"搜索{i}", tool_calls=[
|
||||
{"id": f"call_{i}", "name": "search", "args": {"query": f"test{i}"}}
|
||||
]),
|
||||
ToolMessage(content="很长的搜索结果内容 " * 500, tool_call_id=f"call_{i}", name="search"),
|
||||
])
|
||||
|
||||
# Mock trimmer to track calls
|
||||
with patch('service.graph.graph.create_conversation_trimmer') as mock_trimmer_factory:
|
||||
mock_trimmer = MagicMock()
|
||||
mock_trimmer.should_trim.return_value = True # 总是返回需要修剪
|
||||
mock_trimmer.trim_conversation_history.return_value = large_messages[:3] # 返回修剪后的消息
|
||||
mock_trimmer_factory.return_value = mock_trimmer
|
||||
|
||||
# Mock LLM client
|
||||
with patch('service.graph.graph.LLMClient') as mock_llm_factory:
|
||||
mock_llm = MagicMock()
|
||||
# 设置异步方法的返回值
|
||||
async def mock_ainvoke_with_tools(*args, **kwargs):
|
||||
return AIMessage(content="最终答案")
|
||||
|
||||
async def mock_astream(*args, **kwargs):
|
||||
for token in ["最", "终", "答", "案"]:
|
||||
yield token
|
||||
|
||||
mock_llm.ainvoke_with_tools = mock_ainvoke_with_tools
|
||||
mock_llm.astream = mock_astream
|
||||
mock_llm.bind_tools = MagicMock()
|
||||
mock_llm_factory.return_value = mock_llm
|
||||
|
||||
# 测试 tool_rounds=0 的情况(应该修剪)
|
||||
from service.graph.state import AgentState
|
||||
state_round_0 = AgentState(
|
||||
messages=large_messages,
|
||||
session_id="test",
|
||||
intent=None,
|
||||
tool_results=[],
|
||||
final_answer="",
|
||||
tool_rounds=0, # 第一轮
|
||||
max_tool_rounds=3,
|
||||
max_tool_rounds_user_manual=2
|
||||
)
|
||||
|
||||
result = await call_model(state_round_0)
|
||||
|
||||
# 验证修剪器被调用
|
||||
mock_trimmer.should_trim.assert_called()
|
||||
mock_trimmer.trim_conversation_history.assert_called()
|
||||
|
||||
# 重置 mock
|
||||
mock_trimmer.reset_mock()
|
||||
|
||||
# 测试 tool_rounds>0 的情况(不应该修剪)
|
||||
from service.graph.state import AgentState
|
||||
state_round_1 = AgentState(
|
||||
messages=large_messages,
|
||||
session_id="test",
|
||||
intent=None,
|
||||
tool_results=[],
|
||||
final_answer="",
|
||||
tool_rounds=1, # 第二轮
|
||||
max_tool_rounds=3,
|
||||
max_tool_rounds_user_manual=2
|
||||
)
|
||||
|
||||
result = await call_model(state_round_1)
|
||||
|
||||
# 验证修剪器没有被调用
|
||||
mock_trimmer.should_trim.assert_not_called()
|
||||
mock_trimmer.trim_conversation_history.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_manual_agent_trimming_behavior():
|
||||
"""测试 user_manual_agent_node 的修剪行为"""
|
||||
|
||||
# 创建大的消息列表
|
||||
large_messages = [
|
||||
SystemMessage(content="You are a helpful user manual assistant."),
|
||||
HumanMessage(content="请帮我查找系统使用说明"),
|
||||
]
|
||||
|
||||
# 添加大量内容
|
||||
for i in range(8):
|
||||
large_messages.extend([
|
||||
AIMessage(content=f"查找{i}", tool_calls=[
|
||||
{"id": f"manual_call_{i}", "name": "retrieve_system_usermanual", "args": {"query": f"功能{i}"}}
|
||||
]),
|
||||
ToolMessage(content="很长的手册内容 " * 400, tool_call_id=f"manual_call_{i}", name="retrieve_system_usermanual"),
|
||||
])
|
||||
|
||||
# Mock dependencies
|
||||
with patch('service.graph.user_manual_rag.create_conversation_trimmer') as mock_trimmer_factory, \
|
||||
patch('service.graph.user_manual_rag.LLMClient') as mock_llm_factory, \
|
||||
patch('service.graph.user_manual_rag.get_user_manual_tool_schemas') as mock_tools, \
|
||||
patch('service.graph.user_manual_rag.get_cached_config') as mock_config:
|
||||
|
||||
# Setup mocks
|
||||
mock_trimmer = MagicMock()
|
||||
mock_trimmer.should_trim.return_value = True
|
||||
mock_trimmer.trim_conversation_history.return_value = large_messages[:3]
|
||||
mock_trimmer_factory.return_value = mock_trimmer
|
||||
|
||||
mock_llm = MagicMock()
|
||||
# 设置异步方法的返回值
|
||||
async def mock_ainvoke_with_tools(*args, **kwargs):
|
||||
return AIMessage(content="用户手册答案")
|
||||
|
||||
async def mock_astream(*args, **kwargs):
|
||||
for token in ["答", "案"]:
|
||||
yield token
|
||||
|
||||
mock_llm.ainvoke_with_tools = mock_ainvoke_with_tools
|
||||
mock_llm.astream = mock_astream
|
||||
mock_llm.bind_tools = MagicMock()
|
||||
mock_llm_factory.return_value = mock_llm
|
||||
|
||||
mock_tools.return_value = []
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.app.max_tool_rounds_user_manual = 2
|
||||
mock_app_config.get_rag_prompts.return_value = {
|
||||
"user_manual_prompt": "System prompt for user manual: {conversation_history} {context_content} {current_query}"
|
||||
}
|
||||
mock_config.return_value = mock_app_config
|
||||
|
||||
# 测试 tool_rounds=0(应该修剪)
|
||||
from service.graph.state import AgentState
|
||||
state_round_0 = AgentState(
|
||||
messages=large_messages,
|
||||
session_id="test",
|
||||
intent=None,
|
||||
tool_results=[],
|
||||
final_answer="",
|
||||
tool_rounds=0,
|
||||
max_tool_rounds=3,
|
||||
max_tool_rounds_user_manual=2
|
||||
)
|
||||
|
||||
result = await user_manual_agent_node(state_round_0)
|
||||
|
||||
# 验证修剪器被调用
|
||||
mock_trimmer.should_trim.assert_called()
|
||||
mock_trimmer.trim_conversation_history.assert_called()
|
||||
|
||||
# 重置 mock
|
||||
mock_trimmer.reset_mock()
|
||||
|
||||
# 测试 tool_rounds=1(不应该修剪)
|
||||
state_round_1 = AgentState(
|
||||
messages=large_messages,
|
||||
session_id="test",
|
||||
intent=None,
|
||||
tool_results=[],
|
||||
final_answer="",
|
||||
tool_rounds=1,
|
||||
max_tool_rounds=3,
|
||||
max_tool_rounds_user_manual=2
|
||||
)
|
||||
|
||||
result = await user_manual_agent_node(state_round_1)
|
||||
|
||||
# 验证修剪器没有被调用
|
||||
mock_trimmer.should_trim.assert_not_called()
|
||||
mock_trimmer.trim_conversation_history.assert_not_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user