#!/usr/bin/env python3 """ 测试修改后的修剪逻辑:验证只在 tool_rounds=0 时修剪,其他时候跳过修剪 """ import pytest from unittest.mock import patch, MagicMock import asyncio from datetime import datetime import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from service.graph.state import AgentState from service.graph.graph import call_model from service.graph.user_manual_rag import user_manual_agent_node from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage @pytest.mark.asyncio async def test_trimming_only_at_tool_rounds_zero(): """测试修剪只在 tool_rounds=0 时发生""" # 创建一个大的消息列表来触发修剪 large_messages = [ SystemMessage(content="You are a helpful assistant."), HumanMessage(content="请搜索信息"), ] # 添加大量内容来触发修剪 for i in range(10): large_messages.extend([ AIMessage(content=f"搜索{i}", tool_calls=[ {"id": f"call_{i}", "name": "search", "args": {"query": f"test{i}"}} ]), ToolMessage(content="很长的搜索结果内容 " * 500, tool_call_id=f"call_{i}", name="search"), ]) # Mock trimmer to track calls with patch('service.graph.graph.create_conversation_trimmer') as mock_trimmer_factory: mock_trimmer = MagicMock() mock_trimmer.should_trim.return_value = True # 总是返回需要修剪 mock_trimmer.trim_conversation_history.return_value = large_messages[:3] # 返回修剪后的消息 mock_trimmer_factory.return_value = mock_trimmer # Mock LLM client with patch('service.graph.graph.LLMClient') as mock_llm_factory: mock_llm = MagicMock() # 设置异步方法的返回值 async def mock_ainvoke_with_tools(*args, **kwargs): return AIMessage(content="最终答案") async def mock_astream(*args, **kwargs): for token in ["最", "终", "答", "案"]: yield token mock_llm.ainvoke_with_tools = mock_ainvoke_with_tools mock_llm.astream = mock_astream mock_llm.bind_tools = MagicMock() mock_llm_factory.return_value = mock_llm # 测试 tool_rounds=0 的情况(应该修剪) from service.graph.state import AgentState state_round_0 = AgentState( messages=large_messages, session_id="test", intent=None, tool_results=[], final_answer="", tool_rounds=0, # 第一轮 max_tool_rounds=3, max_tool_rounds_user_manual=2 ) result = await call_model(state_round_0) # 验证修剪器被调用 mock_trimmer.should_trim.assert_called() mock_trimmer.trim_conversation_history.assert_called() # 重置 mock mock_trimmer.reset_mock() # 测试 tool_rounds>0 的情况(不应该修剪) from service.graph.state import AgentState state_round_1 = AgentState( messages=large_messages, session_id="test", intent=None, tool_results=[], final_answer="", tool_rounds=1, # 第二轮 max_tool_rounds=3, max_tool_rounds_user_manual=2 ) result = await call_model(state_round_1) # 验证修剪器没有被调用 mock_trimmer.should_trim.assert_not_called() mock_trimmer.trim_conversation_history.assert_not_called() @pytest.mark.asyncio async def test_user_manual_agent_trimming_behavior(): """测试 user_manual_agent_node 的修剪行为""" # 创建大的消息列表 large_messages = [ SystemMessage(content="You are a helpful user manual assistant."), HumanMessage(content="请帮我查找系统使用说明"), ] # 添加大量内容 for i in range(8): large_messages.extend([ AIMessage(content=f"查找{i}", tool_calls=[ {"id": f"manual_call_{i}", "name": "retrieve_system_usermanual", "args": {"query": f"功能{i}"}} ]), ToolMessage(content="很长的手册内容 " * 400, tool_call_id=f"manual_call_{i}", name="retrieve_system_usermanual"), ]) # Mock dependencies with patch('service.graph.user_manual_rag.create_conversation_trimmer') as mock_trimmer_factory, \ patch('service.graph.user_manual_rag.LLMClient') as mock_llm_factory, \ patch('service.graph.user_manual_rag.get_user_manual_tool_schemas') as mock_tools, \ patch('service.graph.user_manual_rag.get_cached_config') as mock_config: # Setup mocks mock_trimmer = MagicMock() mock_trimmer.should_trim.return_value = True mock_trimmer.trim_conversation_history.return_value = large_messages[:3] mock_trimmer_factory.return_value = mock_trimmer mock_llm = MagicMock() # 设置异步方法的返回值 async def mock_ainvoke_with_tools(*args, **kwargs): return AIMessage(content="用户手册答案") async def mock_astream(*args, **kwargs): for token in ["答", "案"]: yield token mock_llm.ainvoke_with_tools = mock_ainvoke_with_tools mock_llm.astream = mock_astream mock_llm.bind_tools = MagicMock() mock_llm_factory.return_value = mock_llm mock_tools.return_value = [] mock_app_config = MagicMock() mock_app_config.app.max_tool_rounds_user_manual = 2 mock_app_config.get_rag_prompts.return_value = { "user_manual_prompt": "System prompt for user manual: {conversation_history} {context_content} {current_query}" } mock_config.return_value = mock_app_config # 测试 tool_rounds=0(应该修剪) from service.graph.state import AgentState state_round_0 = AgentState( messages=large_messages, session_id="test", intent=None, tool_results=[], final_answer="", tool_rounds=0, max_tool_rounds=3, max_tool_rounds_user_manual=2 ) result = await user_manual_agent_node(state_round_0) # 验证修剪器被调用 mock_trimmer.should_trim.assert_called() mock_trimmer.trim_conversation_history.assert_called() # 重置 mock mock_trimmer.reset_mock() # 测试 tool_rounds=1(不应该修剪) state_round_1 = AgentState( messages=large_messages, session_id="test", intent=None, tool_results=[], final_answer="", tool_rounds=1, max_tool_rounds=3, max_tool_rounds_user_manual=2 ) result = await user_manual_agent_node(state_round_1) # 验证修剪器没有被调用 mock_trimmer.should_trim.assert_not_called() mock_trimmer.trim_conversation_history.assert_not_called() if __name__ == "__main__": pytest.main([__file__, "-v"])