This commit is contained in:
2025-09-26 17:15:54 +08:00
commit db0e5965ec
211 changed files with 40437 additions and 0 deletions

View File

@@ -0,0 +1,200 @@
#!/usr/bin/env python3
"""
测试修改后的修剪逻辑:验证只在 tool_rounds=0 时修剪,其他时候跳过修剪
"""
import pytest
from unittest.mock import patch, MagicMock
import asyncio
from datetime import datetime
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from service.graph.state import AgentState
from service.graph.graph import call_model
from service.graph.user_manual_rag import user_manual_agent_node
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
@pytest.mark.asyncio
async def test_trimming_only_at_tool_rounds_zero():
"""测试修剪只在 tool_rounds=0 时发生"""
# 创建一个大的消息列表来触发修剪
large_messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="请搜索信息"),
]
# 添加大量内容来触发修剪
for i in range(10):
large_messages.extend([
AIMessage(content=f"搜索{i}", tool_calls=[
{"id": f"call_{i}", "name": "search", "args": {"query": f"test{i}"}}
]),
ToolMessage(content="很长的搜索结果内容 " * 500, tool_call_id=f"call_{i}", name="search"),
])
# Mock trimmer to track calls
with patch('service.graph.graph.create_conversation_trimmer') as mock_trimmer_factory:
mock_trimmer = MagicMock()
mock_trimmer.should_trim.return_value = True # 总是返回需要修剪
mock_trimmer.trim_conversation_history.return_value = large_messages[:3] # 返回修剪后的消息
mock_trimmer_factory.return_value = mock_trimmer
# Mock LLM client
with patch('service.graph.graph.LLMClient') as mock_llm_factory:
mock_llm = MagicMock()
# 设置异步方法的返回值
async def mock_ainvoke_with_tools(*args, **kwargs):
return AIMessage(content="最终答案")
async def mock_astream(*args, **kwargs):
for token in ["", "", "", ""]:
yield token
mock_llm.ainvoke_with_tools = mock_ainvoke_with_tools
mock_llm.astream = mock_astream
mock_llm.bind_tools = MagicMock()
mock_llm_factory.return_value = mock_llm
# 测试 tool_rounds=0 的情况(应该修剪)
from service.graph.state import AgentState
state_round_0 = AgentState(
messages=large_messages,
session_id="test",
intent=None,
tool_results=[],
final_answer="",
tool_rounds=0, # 第一轮
max_tool_rounds=3,
max_tool_rounds_user_manual=2
)
result = await call_model(state_round_0)
# 验证修剪器被调用
mock_trimmer.should_trim.assert_called()
mock_trimmer.trim_conversation_history.assert_called()
# 重置 mock
mock_trimmer.reset_mock()
# 测试 tool_rounds>0 的情况(不应该修剪)
from service.graph.state import AgentState
state_round_1 = AgentState(
messages=large_messages,
session_id="test",
intent=None,
tool_results=[],
final_answer="",
tool_rounds=1, # 第二轮
max_tool_rounds=3,
max_tool_rounds_user_manual=2
)
result = await call_model(state_round_1)
# 验证修剪器没有被调用
mock_trimmer.should_trim.assert_not_called()
mock_trimmer.trim_conversation_history.assert_not_called()
@pytest.mark.asyncio
async def test_user_manual_agent_trimming_behavior():
"""测试 user_manual_agent_node 的修剪行为"""
# 创建大的消息列表
large_messages = [
SystemMessage(content="You are a helpful user manual assistant."),
HumanMessage(content="请帮我查找系统使用说明"),
]
# 添加大量内容
for i in range(8):
large_messages.extend([
AIMessage(content=f"查找{i}", tool_calls=[
{"id": f"manual_call_{i}", "name": "retrieve_system_usermanual", "args": {"query": f"功能{i}"}}
]),
ToolMessage(content="很长的手册内容 " * 400, tool_call_id=f"manual_call_{i}", name="retrieve_system_usermanual"),
])
# Mock dependencies
with patch('service.graph.user_manual_rag.create_conversation_trimmer') as mock_trimmer_factory, \
patch('service.graph.user_manual_rag.LLMClient') as mock_llm_factory, \
patch('service.graph.user_manual_rag.get_user_manual_tool_schemas') as mock_tools, \
patch('service.graph.user_manual_rag.get_cached_config') as mock_config:
# Setup mocks
mock_trimmer = MagicMock()
mock_trimmer.should_trim.return_value = True
mock_trimmer.trim_conversation_history.return_value = large_messages[:3]
mock_trimmer_factory.return_value = mock_trimmer
mock_llm = MagicMock()
# 设置异步方法的返回值
async def mock_ainvoke_with_tools(*args, **kwargs):
return AIMessage(content="用户手册答案")
async def mock_astream(*args, **kwargs):
for token in ["", ""]:
yield token
mock_llm.ainvoke_with_tools = mock_ainvoke_with_tools
mock_llm.astream = mock_astream
mock_llm.bind_tools = MagicMock()
mock_llm_factory.return_value = mock_llm
mock_tools.return_value = []
mock_app_config = MagicMock()
mock_app_config.app.max_tool_rounds_user_manual = 2
mock_app_config.get_rag_prompts.return_value = {
"user_manual_prompt": "System prompt for user manual: {conversation_history} {context_content} {current_query}"
}
mock_config.return_value = mock_app_config
# 测试 tool_rounds=0应该修剪
from service.graph.state import AgentState
state_round_0 = AgentState(
messages=large_messages,
session_id="test",
intent=None,
tool_results=[],
final_answer="",
tool_rounds=0,
max_tool_rounds=3,
max_tool_rounds_user_manual=2
)
result = await user_manual_agent_node(state_round_0)
# 验证修剪器被调用
mock_trimmer.should_trim.assert_called()
mock_trimmer.trim_conversation_history.assert_called()
# 重置 mock
mock_trimmer.reset_mock()
# 测试 tool_rounds=1不应该修剪
state_round_1 = AgentState(
messages=large_messages,
session_id="test",
intent=None,
tool_results=[],
final_answer="",
tool_rounds=1,
max_tool_rounds=3,
max_tool_rounds_user_manual=2
)
result = await user_manual_agent_node(state_round_1)
# 验证修剪器没有被调用
mock_trimmer.should_trim.assert_not_called()
mock_trimmer.trim_conversation_history.assert_not_called()
if __name__ == "__main__":
pytest.main([__file__, "-v"])