201 lines
7.4 KiB
Python
201 lines
7.4 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
测试修改后的修剪逻辑:验证只在 tool_rounds=0 时修剪,其他时候跳过修剪
|
||
"""
|
||
import pytest
|
||
from unittest.mock import patch, MagicMock
|
||
import asyncio
|
||
from datetime import datetime
|
||
|
||
import sys
|
||
import os
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from service.graph.state import AgentState
|
||
from service.graph.graph import call_model
|
||
from service.graph.user_manual_rag import user_manual_agent_node
|
||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_trimming_only_at_tool_rounds_zero():
|
||
"""测试修剪只在 tool_rounds=0 时发生"""
|
||
|
||
# 创建一个大的消息列表来触发修剪
|
||
large_messages = [
|
||
SystemMessage(content="You are a helpful assistant."),
|
||
HumanMessage(content="请搜索信息"),
|
||
]
|
||
|
||
# 添加大量内容来触发修剪
|
||
for i in range(10):
|
||
large_messages.extend([
|
||
AIMessage(content=f"搜索{i}", tool_calls=[
|
||
{"id": f"call_{i}", "name": "search", "args": {"query": f"test{i}"}}
|
||
]),
|
||
ToolMessage(content="很长的搜索结果内容 " * 500, tool_call_id=f"call_{i}", name="search"),
|
||
])
|
||
|
||
# Mock trimmer to track calls
|
||
with patch('service.graph.graph.create_conversation_trimmer') as mock_trimmer_factory:
|
||
mock_trimmer = MagicMock()
|
||
mock_trimmer.should_trim.return_value = True # 总是返回需要修剪
|
||
mock_trimmer.trim_conversation_history.return_value = large_messages[:3] # 返回修剪后的消息
|
||
mock_trimmer_factory.return_value = mock_trimmer
|
||
|
||
# Mock LLM client
|
||
with patch('service.graph.graph.LLMClient') as mock_llm_factory:
|
||
mock_llm = MagicMock()
|
||
# 设置异步方法的返回值
|
||
async def mock_ainvoke_with_tools(*args, **kwargs):
|
||
return AIMessage(content="最终答案")
|
||
|
||
async def mock_astream(*args, **kwargs):
|
||
for token in ["最", "终", "答", "案"]:
|
||
yield token
|
||
|
||
mock_llm.ainvoke_with_tools = mock_ainvoke_with_tools
|
||
mock_llm.astream = mock_astream
|
||
mock_llm.bind_tools = MagicMock()
|
||
mock_llm_factory.return_value = mock_llm
|
||
|
||
# 测试 tool_rounds=0 的情况(应该修剪)
|
||
from service.graph.state import AgentState
|
||
state_round_0 = AgentState(
|
||
messages=large_messages,
|
||
session_id="test",
|
||
intent=None,
|
||
tool_results=[],
|
||
final_answer="",
|
||
tool_rounds=0, # 第一轮
|
||
max_tool_rounds=3,
|
||
max_tool_rounds_user_manual=2
|
||
)
|
||
|
||
result = await call_model(state_round_0)
|
||
|
||
# 验证修剪器被调用
|
||
mock_trimmer.should_trim.assert_called()
|
||
mock_trimmer.trim_conversation_history.assert_called()
|
||
|
||
# 重置 mock
|
||
mock_trimmer.reset_mock()
|
||
|
||
# 测试 tool_rounds>0 的情况(不应该修剪)
|
||
from service.graph.state import AgentState
|
||
state_round_1 = AgentState(
|
||
messages=large_messages,
|
||
session_id="test",
|
||
intent=None,
|
||
tool_results=[],
|
||
final_answer="",
|
||
tool_rounds=1, # 第二轮
|
||
max_tool_rounds=3,
|
||
max_tool_rounds_user_manual=2
|
||
)
|
||
|
||
result = await call_model(state_round_1)
|
||
|
||
# 验证修剪器没有被调用
|
||
mock_trimmer.should_trim.assert_not_called()
|
||
mock_trimmer.trim_conversation_history.assert_not_called()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_user_manual_agent_trimming_behavior():
|
||
"""测试 user_manual_agent_node 的修剪行为"""
|
||
|
||
# 创建大的消息列表
|
||
large_messages = [
|
||
SystemMessage(content="You are a helpful user manual assistant."),
|
||
HumanMessage(content="请帮我查找系统使用说明"),
|
||
]
|
||
|
||
# 添加大量内容
|
||
for i in range(8):
|
||
large_messages.extend([
|
||
AIMessage(content=f"查找{i}", tool_calls=[
|
||
{"id": f"manual_call_{i}", "name": "retrieve_system_usermanual", "args": {"query": f"功能{i}"}}
|
||
]),
|
||
ToolMessage(content="很长的手册内容 " * 400, tool_call_id=f"manual_call_{i}", name="retrieve_system_usermanual"),
|
||
])
|
||
|
||
# Mock dependencies
|
||
with patch('service.graph.user_manual_rag.create_conversation_trimmer') as mock_trimmer_factory, \
|
||
patch('service.graph.user_manual_rag.LLMClient') as mock_llm_factory, \
|
||
patch('service.graph.user_manual_rag.get_user_manual_tool_schemas') as mock_tools, \
|
||
patch('service.graph.user_manual_rag.get_cached_config') as mock_config:
|
||
|
||
# Setup mocks
|
||
mock_trimmer = MagicMock()
|
||
mock_trimmer.should_trim.return_value = True
|
||
mock_trimmer.trim_conversation_history.return_value = large_messages[:3]
|
||
mock_trimmer_factory.return_value = mock_trimmer
|
||
|
||
mock_llm = MagicMock()
|
||
# 设置异步方法的返回值
|
||
async def mock_ainvoke_with_tools(*args, **kwargs):
|
||
return AIMessage(content="用户手册答案")
|
||
|
||
async def mock_astream(*args, **kwargs):
|
||
for token in ["答", "案"]:
|
||
yield token
|
||
|
||
mock_llm.ainvoke_with_tools = mock_ainvoke_with_tools
|
||
mock_llm.astream = mock_astream
|
||
mock_llm.bind_tools = MagicMock()
|
||
mock_llm_factory.return_value = mock_llm
|
||
|
||
mock_tools.return_value = []
|
||
|
||
mock_app_config = MagicMock()
|
||
mock_app_config.app.max_tool_rounds_user_manual = 2
|
||
mock_app_config.get_rag_prompts.return_value = {
|
||
"user_manual_prompt": "System prompt for user manual: {conversation_history} {context_content} {current_query}"
|
||
}
|
||
mock_config.return_value = mock_app_config
|
||
|
||
# 测试 tool_rounds=0(应该修剪)
|
||
from service.graph.state import AgentState
|
||
state_round_0 = AgentState(
|
||
messages=large_messages,
|
||
session_id="test",
|
||
intent=None,
|
||
tool_results=[],
|
||
final_answer="",
|
||
tool_rounds=0,
|
||
max_tool_rounds=3,
|
||
max_tool_rounds_user_manual=2
|
||
)
|
||
|
||
result = await user_manual_agent_node(state_round_0)
|
||
|
||
# 验证修剪器被调用
|
||
mock_trimmer.should_trim.assert_called()
|
||
mock_trimmer.trim_conversation_history.assert_called()
|
||
|
||
# 重置 mock
|
||
mock_trimmer.reset_mock()
|
||
|
||
# 测试 tool_rounds=1(不应该修剪)
|
||
state_round_1 = AgentState(
|
||
messages=large_messages,
|
||
session_id="test",
|
||
intent=None,
|
||
tool_results=[],
|
||
final_answer="",
|
||
tool_rounds=1,
|
||
max_tool_rounds=3,
|
||
max_tool_rounds_user_manual=2
|
||
)
|
||
|
||
result = await user_manual_agent_node(state_round_1)
|
||
|
||
# 验证修剪器没有被调用
|
||
mock_trimmer.should_trim.assert_not_called()
|
||
mock_trimmer.trim_conversation_history.assert_not_called()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
pytest.main([__file__, "-v"])
|