201 lines
7.4 KiB
Python
201 lines
7.4 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
测试修改后的修剪逻辑:验证只在 tool_rounds=0 时修剪,其他时候跳过修剪
|
|||
|
|
"""
|
|||
|
|
import pytest
|
|||
|
|
from unittest.mock import patch, MagicMock
|
|||
|
|
import asyncio
|
|||
|
|
from datetime import datetime
|
|||
|
|
|
|||
|
|
import sys
|
|||
|
|
import os
|
|||
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|||
|
|
|
|||
|
|
from service.graph.state import AgentState
|
|||
|
|
from service.graph.graph import call_model
|
|||
|
|
from service.graph.user_manual_rag import user_manual_agent_node
|
|||
|
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_trimming_only_at_tool_rounds_zero():
|
|||
|
|
"""测试修剪只在 tool_rounds=0 时发生"""
|
|||
|
|
|
|||
|
|
# 创建一个大的消息列表来触发修剪
|
|||
|
|
large_messages = [
|
|||
|
|
SystemMessage(content="You are a helpful assistant."),
|
|||
|
|
HumanMessage(content="请搜索信息"),
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# 添加大量内容来触发修剪
|
|||
|
|
for i in range(10):
|
|||
|
|
large_messages.extend([
|
|||
|
|
AIMessage(content=f"搜索{i}", tool_calls=[
|
|||
|
|
{"id": f"call_{i}", "name": "search", "args": {"query": f"test{i}"}}
|
|||
|
|
]),
|
|||
|
|
ToolMessage(content="很长的搜索结果内容 " * 500, tool_call_id=f"call_{i}", name="search"),
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
# Mock trimmer to track calls
|
|||
|
|
with patch('service.graph.graph.create_conversation_trimmer') as mock_trimmer_factory:
|
|||
|
|
mock_trimmer = MagicMock()
|
|||
|
|
mock_trimmer.should_trim.return_value = True # 总是返回需要修剪
|
|||
|
|
mock_trimmer.trim_conversation_history.return_value = large_messages[:3] # 返回修剪后的消息
|
|||
|
|
mock_trimmer_factory.return_value = mock_trimmer
|
|||
|
|
|
|||
|
|
# Mock LLM client
|
|||
|
|
with patch('service.graph.graph.LLMClient') as mock_llm_factory:
|
|||
|
|
mock_llm = MagicMock()
|
|||
|
|
# 设置异步方法的返回值
|
|||
|
|
async def mock_ainvoke_with_tools(*args, **kwargs):
|
|||
|
|
return AIMessage(content="最终答案")
|
|||
|
|
|
|||
|
|
async def mock_astream(*args, **kwargs):
|
|||
|
|
for token in ["最", "终", "答", "案"]:
|
|||
|
|
yield token
|
|||
|
|
|
|||
|
|
mock_llm.ainvoke_with_tools = mock_ainvoke_with_tools
|
|||
|
|
mock_llm.astream = mock_astream
|
|||
|
|
mock_llm.bind_tools = MagicMock()
|
|||
|
|
mock_llm_factory.return_value = mock_llm
|
|||
|
|
|
|||
|
|
# 测试 tool_rounds=0 的情况(应该修剪)
|
|||
|
|
from service.graph.state import AgentState
|
|||
|
|
state_round_0 = AgentState(
|
|||
|
|
messages=large_messages,
|
|||
|
|
session_id="test",
|
|||
|
|
intent=None,
|
|||
|
|
tool_results=[],
|
|||
|
|
final_answer="",
|
|||
|
|
tool_rounds=0, # 第一轮
|
|||
|
|
max_tool_rounds=3,
|
|||
|
|
max_tool_rounds_user_manual=2
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result = await call_model(state_round_0)
|
|||
|
|
|
|||
|
|
# 验证修剪器被调用
|
|||
|
|
mock_trimmer.should_trim.assert_called()
|
|||
|
|
mock_trimmer.trim_conversation_history.assert_called()
|
|||
|
|
|
|||
|
|
# 重置 mock
|
|||
|
|
mock_trimmer.reset_mock()
|
|||
|
|
|
|||
|
|
# 测试 tool_rounds>0 的情况(不应该修剪)
|
|||
|
|
from service.graph.state import AgentState
|
|||
|
|
state_round_1 = AgentState(
|
|||
|
|
messages=large_messages,
|
|||
|
|
session_id="test",
|
|||
|
|
intent=None,
|
|||
|
|
tool_results=[],
|
|||
|
|
final_answer="",
|
|||
|
|
tool_rounds=1, # 第二轮
|
|||
|
|
max_tool_rounds=3,
|
|||
|
|
max_tool_rounds_user_manual=2
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result = await call_model(state_round_1)
|
|||
|
|
|
|||
|
|
# 验证修剪器没有被调用
|
|||
|
|
mock_trimmer.should_trim.assert_not_called()
|
|||
|
|
mock_trimmer.trim_conversation_history.assert_not_called()
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_user_manual_agent_trimming_behavior():
|
|||
|
|
"""测试 user_manual_agent_node 的修剪行为"""
|
|||
|
|
|
|||
|
|
# 创建大的消息列表
|
|||
|
|
large_messages = [
|
|||
|
|
SystemMessage(content="You are a helpful user manual assistant."),
|
|||
|
|
HumanMessage(content="请帮我查找系统使用说明"),
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# 添加大量内容
|
|||
|
|
for i in range(8):
|
|||
|
|
large_messages.extend([
|
|||
|
|
AIMessage(content=f"查找{i}", tool_calls=[
|
|||
|
|
{"id": f"manual_call_{i}", "name": "retrieve_system_usermanual", "args": {"query": f"功能{i}"}}
|
|||
|
|
]),
|
|||
|
|
ToolMessage(content="很长的手册内容 " * 400, tool_call_id=f"manual_call_{i}", name="retrieve_system_usermanual"),
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
# Mock dependencies
|
|||
|
|
with patch('service.graph.user_manual_rag.create_conversation_trimmer') as mock_trimmer_factory, \
|
|||
|
|
patch('service.graph.user_manual_rag.LLMClient') as mock_llm_factory, \
|
|||
|
|
patch('service.graph.user_manual_rag.get_user_manual_tool_schemas') as mock_tools, \
|
|||
|
|
patch('service.graph.user_manual_rag.get_cached_config') as mock_config:
|
|||
|
|
|
|||
|
|
# Setup mocks
|
|||
|
|
mock_trimmer = MagicMock()
|
|||
|
|
mock_trimmer.should_trim.return_value = True
|
|||
|
|
mock_trimmer.trim_conversation_history.return_value = large_messages[:3]
|
|||
|
|
mock_trimmer_factory.return_value = mock_trimmer
|
|||
|
|
|
|||
|
|
mock_llm = MagicMock()
|
|||
|
|
# 设置异步方法的返回值
|
|||
|
|
async def mock_ainvoke_with_tools(*args, **kwargs):
|
|||
|
|
return AIMessage(content="用户手册答案")
|
|||
|
|
|
|||
|
|
async def mock_astream(*args, **kwargs):
|
|||
|
|
for token in ["答", "案"]:
|
|||
|
|
yield token
|
|||
|
|
|
|||
|
|
mock_llm.ainvoke_with_tools = mock_ainvoke_with_tools
|
|||
|
|
mock_llm.astream = mock_astream
|
|||
|
|
mock_llm.bind_tools = MagicMock()
|
|||
|
|
mock_llm_factory.return_value = mock_llm
|
|||
|
|
|
|||
|
|
mock_tools.return_value = []
|
|||
|
|
|
|||
|
|
mock_app_config = MagicMock()
|
|||
|
|
mock_app_config.app.max_tool_rounds_user_manual = 2
|
|||
|
|
mock_app_config.get_rag_prompts.return_value = {
|
|||
|
|
"user_manual_prompt": "System prompt for user manual: {conversation_history} {context_content} {current_query}"
|
|||
|
|
}
|
|||
|
|
mock_config.return_value = mock_app_config
|
|||
|
|
|
|||
|
|
# 测试 tool_rounds=0(应该修剪)
|
|||
|
|
from service.graph.state import AgentState
|
|||
|
|
state_round_0 = AgentState(
|
|||
|
|
messages=large_messages,
|
|||
|
|
session_id="test",
|
|||
|
|
intent=None,
|
|||
|
|
tool_results=[],
|
|||
|
|
final_answer="",
|
|||
|
|
tool_rounds=0,
|
|||
|
|
max_tool_rounds=3,
|
|||
|
|
max_tool_rounds_user_manual=2
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result = await user_manual_agent_node(state_round_0)
|
|||
|
|
|
|||
|
|
# 验证修剪器被调用
|
|||
|
|
mock_trimmer.should_trim.assert_called()
|
|||
|
|
mock_trimmer.trim_conversation_history.assert_called()
|
|||
|
|
|
|||
|
|
# 重置 mock
|
|||
|
|
mock_trimmer.reset_mock()
|
|||
|
|
|
|||
|
|
# 测试 tool_rounds=1(不应该修剪)
|
|||
|
|
state_round_1 = AgentState(
|
|||
|
|
messages=large_messages,
|
|||
|
|
session_id="test",
|
|||
|
|
intent=None,
|
|||
|
|
tool_results=[],
|
|||
|
|
final_answer="",
|
|||
|
|
tool_rounds=1,
|
|||
|
|
max_tool_rounds=3,
|
|||
|
|
max_tool_rounds_user_manual=2
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result = await user_manual_agent_node(state_round_1)
|
|||
|
|
|
|||
|
|
# 验证修剪器没有被调用
|
|||
|
|
mock_trimmer.should_trim.assert_not_called()
|
|||
|
|
mock_trimmer.trim_conversation_history.assert_not_called()
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
pytest.main([__file__, "-v"])
|