init
This commit is contained in:
1
vw-agentic-rag/tests/unit/__init__.py
Normal file
1
vw-agentic-rag/tests/unit/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Empty __init__.py files to make test packages
|
||||
114
vw-agentic-rag/tests/unit/test_aggressive_trimming.py
Normal file
114
vw-agentic-rag/tests/unit/test_aggressive_trimming.py
Normal file
@@ -0,0 +1,114 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试新的积极修剪策略:即使token数很少,也要修剪历史工具调用结果
|
||||
"""
|
||||
import pytest
|
||||
from service.graph.message_trimmer import ConversationTrimmer
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
|
||||
from langchain_core.messages.utils import count_tokens_approximately
|
||||
|
||||
|
||||
def test_aggressive_tool_history_trimming():
|
||||
"""测试积极的工具历史修剪策略"""
|
||||
|
||||
# 直接创建修剪器,避免配置依赖
|
||||
trimmer = ConversationTrimmer(max_context_length=100000)
|
||||
|
||||
# 创建包含多轮工具调用的对话(token数很少)
|
||||
messages = [
|
||||
SystemMessage(content='You are a helpful assistant.'),
|
||||
|
||||
# 历史对话轮次1
|
||||
HumanMessage(content='搜索汽车标准'),
|
||||
AIMessage(content='搜索中', tool_calls=[{'id': 'call_1', 'name': 'search', 'args': {'query': '汽车标准'}}]),
|
||||
ToolMessage(content='历史结果1', tool_call_id='call_1', name='search'),
|
||||
AIMessage(content='汽车标准信息'),
|
||||
|
||||
# 历史对话轮次2
|
||||
HumanMessage(content='搜索电池标准'),
|
||||
AIMessage(content='搜索中', tool_calls=[{'id': 'call_2', 'name': 'search', 'args': {'query': '电池标准'}}]),
|
||||
ToolMessage(content='历史结果2', tool_call_id='call_2', name='search'),
|
||||
AIMessage(content='电池标准信息'),
|
||||
|
||||
# 新的用户查询(触发修剪的时机)
|
||||
HumanMessage(content='搜索安全标准'),
|
||||
]
|
||||
|
||||
# 验证token数很少,远低于阈值
|
||||
token_count = count_tokens_approximately(messages)
|
||||
assert token_count < 1000, f"Token count should be low, got {token_count}"
|
||||
assert token_count < trimmer.history_token_limit, "Token count should be well below limit"
|
||||
|
||||
# 验证识别到多个工具轮次
|
||||
tool_rounds = trimmer._identify_tool_rounds(messages)
|
||||
assert len(tool_rounds) == 2, f"Should identify 2 tool rounds, got {len(tool_rounds)}"
|
||||
|
||||
# 验证触发修剪(因为有多个工具轮次)
|
||||
should_trim = trimmer.should_trim(messages)
|
||||
assert should_trim, "Should trigger trimming due to multiple tool rounds"
|
||||
|
||||
# 执行修剪
|
||||
trimmed = trimmer.trim_conversation_history(messages)
|
||||
|
||||
# 验证修剪效果
|
||||
assert len(trimmed) < len(messages), "Should have fewer messages after trimming"
|
||||
|
||||
# 验证保留了系统消息和初始查询
|
||||
assert isinstance(trimmed[0], SystemMessage), "Should preserve system message"
|
||||
assert isinstance(trimmed[1], HumanMessage), "Should preserve initial human message"
|
||||
|
||||
# 验证只保留了最新轮次的工具调用结果
|
||||
tool_messages = [msg for msg in trimmed if isinstance(msg, ToolMessage)]
|
||||
assert len(tool_messages) == 1, f"Should only keep 1 tool message, got {len(tool_messages)}"
|
||||
assert tool_messages[0].content == '历史结果2', "Should keep the most recent tool result"
|
||||
|
||||
|
||||
def test_single_tool_round_no_trimming():
|
||||
"""测试单轮工具调用不触发修剪"""
|
||||
|
||||
trimmer = ConversationTrimmer(max_context_length=100000)
|
||||
|
||||
# 只有一轮工具调用的对话
|
||||
messages = [
|
||||
SystemMessage(content='You are a helpful assistant.'),
|
||||
HumanMessage(content='搜索信息'),
|
||||
AIMessage(content='搜索中', tool_calls=[{'id': 'call_1', 'name': 'search', 'args': {'query': '信息'}}]),
|
||||
ToolMessage(content='搜索结果', tool_call_id='call_1', name='search'),
|
||||
AIMessage(content='这是搜索到的信息'),
|
||||
HumanMessage(content='新的问题'),
|
||||
]
|
||||
|
||||
# 验证只有一个工具轮次
|
||||
tool_rounds = trimmer._identify_tool_rounds(messages)
|
||||
assert len(tool_rounds) == 1, f"Should identify 1 tool round, got {len(tool_rounds)}"
|
||||
|
||||
# 验证不触发修剪(因为只有一个工具轮次且token数不高)
|
||||
should_trim = trimmer.should_trim(messages)
|
||||
assert not should_trim, "Should not trigger trimming for single tool round with low tokens"
|
||||
|
||||
|
||||
def test_no_tool_rounds_no_trimming():
|
||||
"""测试没有工具调用的对话不触发修剪"""
|
||||
|
||||
trimmer = ConversationTrimmer(max_context_length=100000)
|
||||
|
||||
# 没有工具调用的对话
|
||||
messages = [
|
||||
SystemMessage(content='You are a helpful assistant.'),
|
||||
HumanMessage(content='Hello'),
|
||||
AIMessage(content='Hi there!'),
|
||||
HumanMessage(content='How are you?'),
|
||||
AIMessage(content='I am doing well, thank you!'),
|
||||
]
|
||||
|
||||
# 验证没有工具轮次
|
||||
tool_rounds = trimmer._identify_tool_rounds(messages)
|
||||
assert len(tool_rounds) == 0, f"Should identify 0 tool rounds, got {len(tool_rounds)}"
|
||||
|
||||
# 验证不触发修剪
|
||||
should_trim = trimmer.should_trim(messages)
|
||||
assert not should_trim, "Should not trigger trimming without tool rounds"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
143
vw-agentic-rag/tests/unit/test_assistant_ui_best_practices.py
Normal file
143
vw-agentic-rag/tests/unit/test_assistant_ui_best_practices.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
Test assistant-ui best practices implementation
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
def test_package_json_dependencies():
|
||||
"""Test that package.json has the correct assistant-ui dependencies"""
|
||||
package_json_path = os.path.join(os.path.dirname(__file__), "../../web/package.json")
|
||||
|
||||
with open(package_json_path, 'r') as f:
|
||||
package_data = json.load(f)
|
||||
|
||||
deps = package_data.get("dependencies", {})
|
||||
|
||||
# Check for essential assistant-ui packages
|
||||
assert "@assistant-ui/react" in deps, "Missing @assistant-ui/react"
|
||||
assert "@assistant-ui/react-ui" in deps, "Missing @assistant-ui/react-ui"
|
||||
assert "@assistant-ui/react-markdown" in deps, "Missing @assistant-ui/react-markdown"
|
||||
assert "@assistant-ui/react-data-stream" in deps, "Missing @assistant-ui/react-data-stream"
|
||||
|
||||
# Check versions are reasonable (not too old)
|
||||
react_version = deps["@assistant-ui/react"]
|
||||
assert "0.10" in react_version or "0.9" in react_version, f"Version too old: {react_version}"
|
||||
|
||||
print("✅ Package dependencies test passed")
|
||||
|
||||
|
||||
def test_env_configuration():
|
||||
"""Test that environment configuration files exist"""
|
||||
env_local_path = os.path.join(os.path.dirname(__file__), "../../web/.env.local")
|
||||
assert os.path.exists(env_local_path), "Missing .env.local file"
|
||||
|
||||
with open(env_local_path, 'r') as f:
|
||||
env_content = f.read()
|
||||
|
||||
assert "NEXT_PUBLIC_LANGGRAPH_API_URL" in env_content, "Missing API URL config"
|
||||
assert "NEXT_PUBLIC_LANGGRAPH_ASSISTANT_ID" in env_content, "Missing Assistant ID config"
|
||||
|
||||
print("✅ Environment configuration test passed")
|
||||
|
||||
|
||||
def test_api_route_structure():
|
||||
"""Test that API routes are properly structured"""
|
||||
# Check main chat API route exists
|
||||
chat_route_path = os.path.join(os.path.dirname(__file__), "../../web/src/app/api/chat/route.ts")
|
||||
assert os.path.exists(chat_route_path), "Missing chat API route"
|
||||
|
||||
with open(chat_route_path, 'r') as f:
|
||||
route_content = f.read()
|
||||
|
||||
# Check for essential API patterns
|
||||
assert "export async function POST" in route_content, "Missing POST handler"
|
||||
assert "Response" in route_content, "Missing Response handling"
|
||||
assert "x-vercel-ai-data-stream" in route_content, "Missing AI SDK compatibility header"
|
||||
|
||||
print("✅ API route structure test passed")
|
||||
|
||||
|
||||
def test_component_structure():
|
||||
"""Test that main components follow best practices"""
|
||||
# Check main page component
|
||||
page_path = os.path.join(os.path.dirname(__file__), "../../web/src/app/page.tsx")
|
||||
assert os.path.exists(page_path), "Missing main page component"
|
||||
|
||||
with open(page_path, 'r') as f:
|
||||
page_content = f.read()
|
||||
|
||||
# Check for key React patterns and components
|
||||
assert '"use client"' in page_content, "Missing client-side directive"
|
||||
assert "Assistant" in page_content, "Missing Assistant component"
|
||||
assert "export default function" in page_content, "Missing default function export"
|
||||
|
||||
# Check for proper structure
|
||||
assert "className=" in page_content, "Missing CSS class usage"
|
||||
assert "h-screen" in page_content or "h-full" in page_content, "Missing full height layout"
|
||||
|
||||
print("✅ Component structure test passed")
|
||||
|
||||
|
||||
def test_markdown_component():
|
||||
"""Test that markdown component is properly configured"""
|
||||
markdown_path = os.path.join(os.path.dirname(__file__), "../../web/src/components/ui/markdown-text.tsx")
|
||||
assert os.path.exists(markdown_path), "Missing markdown component"
|
||||
|
||||
with open(markdown_path, 'r') as f:
|
||||
markdown_content = f.read()
|
||||
|
||||
assert "MarkdownTextPrimitive" in markdown_content, "Missing markdown primitive"
|
||||
assert "remarkGfm" in markdown_content, "Missing GFM support"
|
||||
|
||||
print("✅ Markdown component test passed")
|
||||
|
||||
|
||||
def test_best_practices_documentation():
|
||||
"""Test that best practices documentation exists and is comprehensive"""
|
||||
docs_path = os.path.join(os.path.dirname(__file__), "../../docs/topics/ASSISTANT_UI_BEST_PRACTICES.md")
|
||||
assert os.path.exists(docs_path), "Missing best practices documentation"
|
||||
|
||||
with open(docs_path, 'r') as f:
|
||||
docs_content = f.read()
|
||||
|
||||
# Check for key sections
|
||||
assert "Assistant-UI + LangGraph + FastAPI" in docs_content, "Missing main title"
|
||||
assert "Implementation Status" in docs_content, "Missing implementation status"
|
||||
assert "Package Dependencies Updated" in docs_content, "Missing dependencies section"
|
||||
assert "Server-Side API Routes" in docs_content, "Missing API routes explanation"
|
||||
|
||||
print("✅ Best practices documentation test passed")
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all tests"""
|
||||
print("🧪 Running assistant-ui best practices validation tests...")
|
||||
|
||||
try:
|
||||
test_package_json_dependencies()
|
||||
test_env_configuration()
|
||||
test_api_route_structure()
|
||||
test_component_structure()
|
||||
test_markdown_component()
|
||||
test_best_practices_documentation()
|
||||
|
||||
print("\n🎉 All assistant-ui best practices tests passed!")
|
||||
print("✅ Your implementation follows the recommended patterns for:")
|
||||
print(" - Package dependencies and versions")
|
||||
print(" - Environment configuration")
|
||||
print(" - API route structure")
|
||||
print(" - Component composition")
|
||||
print(" - Markdown rendering")
|
||||
print(" - Documentation completeness")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
exit(0 if success else 1)
|
||||
141
vw-agentic-rag/tests/unit/test_memory.py
Normal file
141
vw-agentic-rag/tests/unit/test_memory.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import pytest
|
||||
from service.memory.store import InMemoryStore
|
||||
from service.graph.state import TurnState, Message
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_store():
|
||||
"""Create memory store for testing"""
|
||||
return InMemoryStore(ttl_days=1) # Short TTL for testing
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_state():
|
||||
"""Create sample turn state"""
|
||||
state = TurnState(session_id="test_session")
|
||||
state.messages = [
|
||||
Message(role="user", content="Hello"),
|
||||
Message(role="assistant", content="Hi there!")
|
||||
]
|
||||
return state
|
||||
|
||||
|
||||
def test_create_new_session(memory_store):
|
||||
"""Test creating a new session"""
|
||||
state = memory_store.create_new_session("new_session")
|
||||
assert state.session_id == "new_session"
|
||||
assert len(state.messages) == 0
|
||||
|
||||
# Verify it's stored
|
||||
retrieved = memory_store.get("new_session")
|
||||
assert retrieved is not None
|
||||
assert retrieved.session_id == "new_session"
|
||||
|
||||
|
||||
def test_put_and_get_state(memory_store, sample_state):
|
||||
"""Test storing and retrieving state"""
|
||||
memory_store.put("test_session", sample_state)
|
||||
|
||||
retrieved = memory_store.get("test_session")
|
||||
assert retrieved is not None
|
||||
assert retrieved.session_id == "test_session"
|
||||
assert len(retrieved.messages) == 2
|
||||
assert retrieved.messages[0].content == "Hello"
|
||||
|
||||
|
||||
def test_get_nonexistent_session(memory_store):
|
||||
"""Test getting non-existent session returns None"""
|
||||
result = memory_store.get("nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_add_message(memory_store):
|
||||
"""Test adding messages to conversation"""
|
||||
memory_store.create_new_session("test_session")
|
||||
|
||||
message = Message(role="user", content="Test message")
|
||||
memory_store.add_message("test_session", message)
|
||||
|
||||
state = memory_store.get("test_session")
|
||||
assert len(state.messages) == 1
|
||||
assert state.messages[0].content == "Test message"
|
||||
|
||||
|
||||
def test_add_message_to_nonexistent_session(memory_store):
|
||||
"""Test adding message creates new session if it doesn't exist"""
|
||||
message = Message(role="user", content="Test message")
|
||||
memory_store.add_message("new_session", message)
|
||||
|
||||
state = memory_store.get("new_session")
|
||||
assert state is not None
|
||||
assert len(state.messages) == 1
|
||||
|
||||
|
||||
def test_get_conversation_history(memory_store):
|
||||
"""Test conversation history formatting"""
|
||||
memory_store.create_new_session("test_session")
|
||||
|
||||
messages = [
|
||||
Message(role="user", content="Hello"),
|
||||
Message(role="assistant", content="Hi there!"),
|
||||
Message(role="user", content="How are you?"),
|
||||
Message(role="assistant", content="I'm doing well!")
|
||||
]
|
||||
|
||||
for msg in messages:
|
||||
memory_store.add_message("test_session", msg)
|
||||
|
||||
history = memory_store.get_conversation_history("test_session")
|
||||
|
||||
assert "User: Hello" in history
|
||||
assert "Assistant: Hi there!" in history
|
||||
assert "User: How are you?" in history
|
||||
assert "Assistant: I'm doing well!" in history
|
||||
|
||||
|
||||
def test_get_conversation_history_empty(memory_store):
|
||||
"""Test conversation history for empty session"""
|
||||
history = memory_store.get_conversation_history("nonexistent")
|
||||
assert history == ""
|
||||
|
||||
|
||||
def test_trim_messages(memory_store):
|
||||
"""Test message trimming"""
|
||||
memory_store.create_new_session("test_session")
|
||||
|
||||
# Add many messages
|
||||
for i in range(25):
|
||||
memory_store.add_message("test_session", Message(role="user", content=f"Message {i}"))
|
||||
|
||||
# Trim to 10 messages
|
||||
memory_store.trim("test_session", max_messages=10)
|
||||
|
||||
state = memory_store.get("test_session")
|
||||
assert len(state.messages) <= 10
|
||||
|
||||
|
||||
def test_trim_nonexistent_session(memory_store):
|
||||
"""Test trimming non-existent session doesn't crash"""
|
||||
memory_store.trim("nonexistent", max_messages=10)
|
||||
# Should not raise an exception
|
||||
|
||||
|
||||
def test_ttl_cleanup(memory_store):
|
||||
"""Test TTL-based cleanup"""
|
||||
# Create store with very short TTL for testing
|
||||
short_ttl_store = InMemoryStore(ttl_days=0.001) # ~1.5 minutes
|
||||
|
||||
# Add a session
|
||||
state = TurnState(session_id="test_session")
|
||||
short_ttl_store.put("test_session", state)
|
||||
|
||||
# Verify it exists
|
||||
assert short_ttl_store.get("test_session") is not None
|
||||
|
||||
# Manually expire it by manipulating internal timestamp
|
||||
short_ttl_store.store["test_session"]["last_updated"] = datetime.now() - timedelta(days=1)
|
||||
|
||||
# Try to get it - should trigger cleanup
|
||||
assert short_ttl_store.get("test_session") is None
|
||||
assert "test_session" not in short_ttl_store.store
|
||||
194
vw-agentic-rag/tests/unit/test_message_trimmer.py
Normal file
194
vw-agentic-rag/tests/unit/test_message_trimmer.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Test conversation history trimming functionality.
|
||||
"""
|
||||
import pytest
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
|
||||
|
||||
from service.graph.message_trimmer import ConversationTrimmer, create_conversation_trimmer
|
||||
|
||||
|
||||
def test_conversation_trimmer_basic():
|
||||
"""Test basic trimming functionality."""
|
||||
trimmer = ConversationTrimmer(max_context_length=50) # Very low limit for testing (85% = 42 tokens)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="What is the capital of France?"),
|
||||
AIMessage(content="The capital of France is Paris."),
|
||||
HumanMessage(content="What about Germany?"),
|
||||
AIMessage(content="The capital of Germany is Berlin."),
|
||||
HumanMessage(content="And Italy?"),
|
||||
AIMessage(content="The capital of Italy is Rome."),
|
||||
]
|
||||
|
||||
# Should trigger trimming due to low token limit
|
||||
assert trimmer.should_trim(messages)
|
||||
|
||||
trimmed = trimmer.trim_conversation_history(messages)
|
||||
|
||||
# Should preserve system message and keep recent messages
|
||||
assert len(trimmed) < len(messages)
|
||||
assert isinstance(trimmed[0], SystemMessage)
|
||||
assert "helpful assistant" in trimmed[0].content
|
||||
|
||||
|
||||
def test_conversation_trimmer_no_trim_needed():
|
||||
"""Test when no trimming is needed."""
|
||||
trimmer = ConversationTrimmer(max_context_length=100000) # High limit
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hi there!"),
|
||||
]
|
||||
|
||||
# Should not need trimming
|
||||
assert not trimmer.should_trim(messages)
|
||||
|
||||
trimmed = trimmer.trim_conversation_history(messages)
|
||||
|
||||
# Should return same messages
|
||||
assert len(trimmed) == len(messages)
|
||||
|
||||
|
||||
def test_conversation_trimmer_fallback():
|
||||
"""Test fallback trimming logic."""
|
||||
trimmer = ConversationTrimmer(max_context_length=100)
|
||||
|
||||
# Create many messages to trigger fallback
|
||||
messages = [SystemMessage(content="System")] + [
|
||||
HumanMessage(content=f"Message {i}") for i in range(50)
|
||||
]
|
||||
|
||||
trimmed = trimmer._fallback_trim(messages, max_messages=5)
|
||||
|
||||
# Should keep system message + 4 recent messages
|
||||
assert len(trimmed) == 5
|
||||
assert isinstance(trimmed[0], SystemMessage)
|
||||
|
||||
|
||||
def test_create_conversation_trimmer():
|
||||
"""Test trimmer factory function."""
|
||||
# Test with explicit configuration
|
||||
custom_trimmer = create_conversation_trimmer(max_context_length=128000)
|
||||
assert custom_trimmer.max_context_length == 128000
|
||||
assert isinstance(custom_trimmer, ConversationTrimmer)
|
||||
assert custom_trimmer.preserve_system is True
|
||||
|
||||
|
||||
def test_multi_round_tool_optimization():
|
||||
"""Test multi-round tool call optimization."""
|
||||
trimmer = ConversationTrimmer(max_context_length=100000) # High limit to focus on optimization
|
||||
|
||||
# Create a multi-round tool calling scenario
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Search for food safety and cosmetics testing standards."),
|
||||
|
||||
# First tool round
|
||||
AIMessage(content="I'll search for information.", tool_calls=[
|
||||
{"id": "call_1", "name": "search_tool", "args": {"query": "food safety"}}
|
||||
]),
|
||||
ToolMessage(content='{"results": ["Large food safety result data..." * 100]}', tool_call_id="call_1", name="search_tool"),
|
||||
|
||||
# Second tool round
|
||||
AIMessage(content="Let me search for more specific information.", tool_calls=[
|
||||
{"id": "call_2", "name": "search_tool", "args": {"query": "cosmetics testing"}}
|
||||
]),
|
||||
ToolMessage(content='{"results": ["Large cosmetics testing data..." * 100]}', tool_call_id="call_2", name="search_tool"),
|
||||
|
||||
# Third tool round (most recent)
|
||||
AIMessage(content="Now let me get equipment information.", tool_calls=[
|
||||
{"id": "call_3", "name": "search_tool", "args": {"query": "testing equipment"}}
|
||||
]),
|
||||
ToolMessage(content='{"results": ["Testing equipment info..." * 50]}', tool_call_id="call_3", name="search_tool"),
|
||||
]
|
||||
|
||||
# Test tool round identification
|
||||
tool_rounds = trimmer._identify_tool_rounds(messages)
|
||||
assert len(tool_rounds) == 3 # Should identify 3 tool rounds
|
||||
|
||||
# Test optimization
|
||||
optimized = trimmer._optimize_multi_round_tool_calls(messages)
|
||||
|
||||
# Should preserve: SystemMessage + HumanMessage + latest tool round (AI + Tool messages)
|
||||
expected_length = 1 + 1 + 2 # system + human + (latest AI + latest Tool)
|
||||
assert len(optimized) == expected_length
|
||||
|
||||
# Check preserved messages
|
||||
assert isinstance(optimized[0], SystemMessage)
|
||||
assert isinstance(optimized[1], HumanMessage)
|
||||
assert isinstance(optimized[2], AIMessage)
|
||||
assert optimized[2].tool_calls[0]["id"] == "call_3" # Should be the latest tool call
|
||||
assert isinstance(optimized[3], ToolMessage)
|
||||
assert optimized[3].tool_call_id == "call_3" # Should be the latest tool result
|
||||
|
||||
|
||||
def test_multi_round_optimization_single_round():
|
||||
"""Test that single tool round is not optimized."""
|
||||
trimmer = ConversationTrimmer(max_context_length=100000)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Search for something."),
|
||||
AIMessage(content="I'll search.", tool_calls=[{"id": "call_1", "name": "search_tool", "args": {}}]),
|
||||
ToolMessage(content='{"results": ["data"]}', tool_call_id="call_1", name="search_tool"),
|
||||
]
|
||||
|
||||
optimized = trimmer._optimize_multi_round_tool_calls(messages)
|
||||
|
||||
# Should return all messages unchanged for single round
|
||||
assert len(optimized) == len(messages)
|
||||
|
||||
|
||||
def test_multi_round_optimization_no_tools():
|
||||
"""Test that conversations without tools are not optimized."""
|
||||
trimmer = ConversationTrimmer(max_context_length=100000)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hi there!"),
|
||||
HumanMessage(content="How are you?"),
|
||||
AIMessage(content="I'm doing well, thanks!"),
|
||||
]
|
||||
|
||||
optimized = trimmer._optimize_multi_round_tool_calls(messages)
|
||||
|
||||
# Should return all messages unchanged
|
||||
assert len(optimized) == len(messages)
|
||||
|
||||
|
||||
def test_full_trimming_with_optimization():
|
||||
"""Test complete trimming process with multi-round optimization."""
|
||||
trimmer = ConversationTrimmer(max_context_length=50) # Very low limit
|
||||
|
||||
# Create messages that will trigger both optimization and regular trimming
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Search for food safety and cosmetics testing standards."),
|
||||
|
||||
# First tool round (should be removed by optimization)
|
||||
AIMessage(content="I'll search for information.", tool_calls=[
|
||||
{"id": "call_1", "name": "search_tool", "args": {"query": "food safety"}}
|
||||
]),
|
||||
ToolMessage(content='{"results": ["Large food safety result data..." * 100]}', tool_call_id="call_1", name="search_tool"),
|
||||
|
||||
# Second tool round (most recent, should be preserved)
|
||||
AIMessage(content="Let me search for more information.", tool_calls=[
|
||||
{"id": "call_2", "name": "search_tool", "args": {"query": "cosmetics testing"}}
|
||||
]),
|
||||
ToolMessage(content='{"results": ["Large cosmetics testing data..." * 100]}', tool_call_id="call_2", name="search_tool"),
|
||||
]
|
||||
|
||||
trimmed = trimmer.trim_conversation_history(messages)
|
||||
|
||||
# Should apply optimization first, then regular trimming if needed
|
||||
assert len(trimmed) <= len(messages)
|
||||
|
||||
# System message should always be preserved
|
||||
assert any(isinstance(msg, SystemMessage) for msg in trimmed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
353
vw-agentic-rag/tests/unit/test_retrieval.py
Normal file
353
vw-agentic-rag/tests/unit/test_retrieval.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
Unit tests for agentic retrieval tools.
|
||||
Tests the HTTP wrapper tools that interface with the retrieval API.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, Mock
|
||||
import httpx
|
||||
|
||||
from service.retrieval.retrieval import AgenticRetrieval, RetrievalResponse
|
||||
from service.retrieval.clients import RetrievalAPIError, normalize_search_result
|
||||
|
||||
|
||||
class TestAgenticRetrieval:
|
||||
"""Test the agentic retrieval HTTP client."""
|
||||
|
||||
@pytest.fixture
|
||||
def retrieval_client(self):
|
||||
"""Create retrieval client for testing."""
|
||||
with patch('service.retrieval.retrieval.get_config') as mock_config:
|
||||
# Mock the new config structure
|
||||
mock_config.return_value.retrieval.endpoint = "https://test-search.search.azure.com"
|
||||
mock_config.return_value.retrieval.api_key = "test-key"
|
||||
mock_config.return_value.retrieval.api_version = "2024-11-01-preview"
|
||||
mock_config.return_value.retrieval.semantic_configuration = "default"
|
||||
|
||||
# Mock embedding config
|
||||
mock_config.return_value.retrieval.embedding.base_url = "http://test-embedding"
|
||||
mock_config.return_value.retrieval.embedding.api_key = "test-embedding-key"
|
||||
mock_config.return_value.retrieval.embedding.model = "test-embedding-model"
|
||||
mock_config.return_value.retrieval.embedding.dimension = 1536
|
||||
|
||||
# Mock index config
|
||||
mock_config.return_value.retrieval.index.standard_regulation_index = "index-catonline-standard-regulation-v2-prd"
|
||||
mock_config.return_value.retrieval.index.chunk_index = "index-catonline-chunk-v2-prd"
|
||||
mock_config.return_value.retrieval.index.chunk_user_manual_index = "index-cat-usermanual-chunk-prd"
|
||||
|
||||
return AgenticRetrieval()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_standard_regulation_success(self, retrieval_client):
|
||||
"""Test successful standards regulation retrieval."""
|
||||
mock_response_data = {
|
||||
"value": [
|
||||
{
|
||||
"id": "test_id_1",
|
||||
"title": "ISO 26262-1:2018",
|
||||
"document_code": "ISO26262-1",
|
||||
"document_category": "Standard",
|
||||
"@order_num": 1 # Should be preserved
|
||||
},
|
||||
{
|
||||
"id": "test_id_2",
|
||||
"title": "ISO 26262-3:2018",
|
||||
"document_code": "ISO26262-3",
|
||||
"document_category": "Standard",
|
||||
"@order_num": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value=mock_response_data) as mock_search:
|
||||
result = await retrieval_client.retrieve_standard_regulation("ISO 26262")
|
||||
|
||||
# Verify search call
|
||||
mock_search.assert_called_once()
|
||||
call_args = mock_search.call_args[1] # keyword arguments
|
||||
assert call_args['search_text'] == "ISO 26262"
|
||||
assert call_args['index_name'] == "index-catonline-standard-regulation-v2-prd"
|
||||
assert call_args['top_k'] == 10
|
||||
|
||||
# Verify response structure
|
||||
assert isinstance(result, RetrievalResponse)
|
||||
assert len(result.results) == 2
|
||||
assert result.total_count == 2
|
||||
assert result.took_ms is not None
|
||||
|
||||
# Verify normalization
|
||||
first_result = result.results[0]
|
||||
assert first_result['id'] == "test_id_1"
|
||||
assert first_result['title'] == "ISO 26262-1:2018"
|
||||
# Verify order number is preserved
|
||||
assert "@order_num" in first_result
|
||||
assert "content" not in first_result # Should not be included for standards
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_doc_chunk_success(self, retrieval_client):
|
||||
"""Test successful document chunk retrieval."""
|
||||
mock_response_data = {
|
||||
"value": [
|
||||
{
|
||||
"id": "chunk_id_1",
|
||||
"title": "Functional Safety Requirements",
|
||||
"content": "Detailed content about functional safety...",
|
||||
"document_code": "ISO26262-1",
|
||||
"@order_num": 1 # Should be preserved
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value=mock_response_data) as mock_search:
|
||||
result = await retrieval_client.retrieve_doc_chunk_standard_regulation(
|
||||
"functional safety",
|
||||
conversation_history="Previous question about ISO 26262"
|
||||
)
|
||||
|
||||
# Verify search call
|
||||
mock_search.assert_called_once()
|
||||
call_args = mock_search.call_args[1] # keyword arguments
|
||||
assert call_args['search_text'] == "functional safety"
|
||||
assert call_args['index_name'] == "index-catonline-chunk-v2-prd"
|
||||
assert "filter_query" in call_args # Should have document filter
|
||||
|
||||
# Verify response
|
||||
assert isinstance(result, RetrievalResponse)
|
||||
assert len(result.results) == 1
|
||||
|
||||
# Verify content is included for chunks
|
||||
first_result = result.results[0]
|
||||
assert "content" in first_result
|
||||
assert first_result['content'] == "Detailed content about functional safety..."
|
||||
# Verify order number is preserved
|
||||
assert "@order_num" in first_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_error_handling(self, retrieval_client):
|
||||
"""Test HTTP error handling."""
|
||||
with patch.object(retrieval_client.search_client, 'search_azure_ai') as mock_search:
|
||||
mock_search.side_effect = RetrievalAPIError("API request failed: 404")
|
||||
|
||||
with pytest.raises(RetrievalAPIError) as exc_info:
|
||||
await retrieval_client.retrieve_standard_regulation("nonexistent")
|
||||
|
||||
assert "404" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_handling(self, retrieval_client):
|
||||
"""Test timeout handling."""
|
||||
with patch.object(retrieval_client.search_client, 'search_azure_ai') as mock_search:
|
||||
mock_search.side_effect = RetrievalAPIError("Request timeout")
|
||||
|
||||
with pytest.raises(RetrievalAPIError) as exc_info:
|
||||
await retrieval_client.retrieve_standard_regulation("test query")
|
||||
|
||||
assert "timeout" in str(exc_info.value)
|
||||
|
||||
def test_normalize_result_removes_unwanted_fields(self, retrieval_client):
|
||||
"""Test result normalization removes unwanted fields."""
|
||||
from service.retrieval.clients import normalize_search_result
|
||||
|
||||
raw_result = {
|
||||
"id": "test_id",
|
||||
"title": "Test Title",
|
||||
"content": "Test content",
|
||||
"@search.score": 0.95, # Should be removed
|
||||
"@search.rerankerScore": 2.5, # Should be removed
|
||||
"@search.captions": [], # Should be removed
|
||||
"@subquery_id": 1, # Should be removed
|
||||
"empty_field": "", # Should be removed
|
||||
"null_field": None, # Should be removed
|
||||
"empty_list": [], # Should be removed
|
||||
"empty_dict": {}, # Should be removed
|
||||
"valid_field": "valid_value"
|
||||
}
|
||||
|
||||
# Test normalization
|
||||
normalized = normalize_search_result(raw_result)
|
||||
|
||||
# Verify unwanted fields are removed
|
||||
assert "@search.score" not in normalized
|
||||
assert "@search.rerankerScore" not in normalized
|
||||
assert "@search.captions" not in normalized
|
||||
assert "@subquery_id" not in normalized
|
||||
assert "empty_field" not in normalized
|
||||
assert "null_field" not in normalized
|
||||
assert "empty_list" not in normalized
|
||||
assert "empty_dict" not in normalized
|
||||
|
||||
# Verify valid fields are preserved (including content)
|
||||
assert normalized["id"] == "test_id"
|
||||
assert normalized["title"] == "Test Title"
|
||||
assert normalized["content"] == "Test content" # Content is now always preserved
|
||||
assert normalized["valid_field"] == "valid_value"
|
||||
|
||||
def test_normalize_result_includes_content_when_requested(self, retrieval_client):
|
||||
"""Test result normalization always includes content."""
|
||||
from service.retrieval.clients import normalize_search_result
|
||||
|
||||
raw_result = {
|
||||
"id": "test_id",
|
||||
"title": "Test Title",
|
||||
"content": "Test content"
|
||||
}
|
||||
|
||||
# Test normalization (content is always preserved now)
|
||||
normalized = normalize_search_result(raw_result)
|
||||
assert "content" in normalized
|
||||
assert normalized["content"] == "Test content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_results_handling(self, retrieval_client):
|
||||
"""Test handling of empty search results."""
|
||||
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"value": []}):
|
||||
result = await retrieval_client.retrieve_standard_regulation("no results query")
|
||||
|
||||
assert isinstance(result, RetrievalResponse)
|
||||
assert result.results == []
|
||||
assert result.total_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_response_handling(self, retrieval_client):
|
||||
"""Test handling of malformed API responses."""
|
||||
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"unexpected": "format"}):
|
||||
result = await retrieval_client.retrieve_standard_regulation("test query")
|
||||
|
||||
# Should handle gracefully and return empty results
|
||||
assert isinstance(result, RetrievalResponse)
|
||||
assert result.results == []
|
||||
assert result.total_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kwargs_override_payload(self, retrieval_client):
|
||||
"""Test that kwargs can override default values."""
|
||||
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"value": []}) as mock_search:
|
||||
await retrieval_client.retrieve_standard_regulation(
|
||||
"test query",
|
||||
top_k=5,
|
||||
score_threshold=2.0
|
||||
)
|
||||
|
||||
call_args = mock_search.call_args[1] # keyword arguments
|
||||
assert call_args['top_k'] == 5 # Override default 10
|
||||
assert call_args['score_threshold'] == 2.0 # Override default 1.5
|
||||
|
||||
|
||||
class TestRetrievalTools:
|
||||
"""Test the LangGraph tool decorators."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_standard_regulation_tool(self):
|
||||
"""Test the @tool decorated function for standards search."""
|
||||
from service.graph.tools import retrieve_standard_regulation
|
||||
|
||||
mock_response = RetrievalResponse(
|
||||
results=[{"title": "Test Standard", "id": "test_id"}],
|
||||
took_ms=100,
|
||||
total_count=1
|
||||
)
|
||||
|
||||
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.retrieve_standard_regulation.return_value = mock_response
|
||||
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
|
||||
mock_retrieval_class.return_value.__aexit__.return_value = None
|
||||
|
||||
# Note: retrieve_standard_regulation is a LangGraph @tool, so we call it directly
|
||||
result = await retrieve_standard_regulation.ainvoke({"query": "test query", "conversation_history": "test history"})
|
||||
|
||||
# Verify result format matches expected tool output
|
||||
assert isinstance(result, dict)
|
||||
assert "results" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_doc_chunk_tool(self):
|
||||
"""Test the @tool decorated function for document chunks search."""
|
||||
from service.graph.tools import retrieve_doc_chunk_standard_regulation
|
||||
|
||||
mock_response = RetrievalResponse(
|
||||
results=[{"title": "Test Chunk", "content": "Test content", "id": "chunk_id"}],
|
||||
took_ms=150,
|
||||
total_count=1
|
||||
)
|
||||
|
||||
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.retrieve_doc_chunk_standard_regulation.return_value = mock_response
|
||||
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
|
||||
mock_retrieval_class.return_value.__aexit__.return_value = None
|
||||
|
||||
# Call the tool with proper input format
|
||||
result = await retrieve_doc_chunk_standard_regulation.ainvoke({"query": "test query"})
|
||||
|
||||
# Verify result format
|
||||
assert isinstance(result, dict)
|
||||
assert "results" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_error_handling(self):
|
||||
"""Test tool error handling."""
|
||||
from service.graph.tools import retrieve_standard_regulation
|
||||
|
||||
with patch('service.retrieval.clients.AzureSearchClient.search_azure_ai') as mock_search:
|
||||
mock_search.side_effect = RetrievalAPIError("API Error")
|
||||
|
||||
# Tool should handle error gracefully and return error result
|
||||
result = await retrieve_standard_regulation.ainvoke({"query": "test query"})
|
||||
|
||||
# Should return error information in result
|
||||
assert isinstance(result, dict)
|
||||
assert "error" in result
|
||||
assert "API Error" in result["error"]
|
||||
|
||||
|
||||
class TestRetrievalIntegration:
|
||||
"""Integration-style tests for retrieval functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_retrieval_workflow(self):
|
||||
"""Test complete retrieval workflow with both tools."""
|
||||
# Mock both retrieval calls
|
||||
standards_response = RetrievalResponse(
|
||||
results=[
|
||||
{"id": "std_1", "title": "ISO 26262-1", "document_category": "Standard"}
|
||||
],
|
||||
took_ms=100,
|
||||
total_count=1
|
||||
)
|
||||
|
||||
chunks_response = RetrievalResponse(
|
||||
results=[
|
||||
{"id": "chunk_1", "title": "Safety Requirements", "content": "Detailed safety requirements..."}
|
||||
],
|
||||
took_ms=150,
|
||||
total_count=1
|
||||
)
|
||||
|
||||
with patch('service.retrieval.clients.AzureSearchClient.search_azure_ai') as mock_search:
|
||||
# Set up mock to return different responses for different calls
|
||||
def mock_response_side_effect(*args, **kwargs):
|
||||
if kwargs.get('index_name') == 'index-catonline-standard-regulation-v2-prd':
|
||||
return {
|
||||
"value": [{"id": "std_1", "title": "ISO 26262-1", "document_category": "Standard", "@order_num": 1}]
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"value": [{"id": "chunk_1", "title": "Safety Requirements", "content": "Detailed safety requirements...", "@order_num": 1}]
|
||||
}
|
||||
|
||||
mock_search.side_effect = mock_response_side_effect
|
||||
|
||||
# Import and test both tools
|
||||
from service.graph.tools import retrieve_standard_regulation, retrieve_doc_chunk_standard_regulation
|
||||
|
||||
# Test standards search
|
||||
std_result = await retrieve_standard_regulation.ainvoke({"query": "ISO 26262"})
|
||||
assert isinstance(std_result, dict)
|
||||
assert "results" in std_result
|
||||
assert len(std_result["results"]) == 1
|
||||
|
||||
# Test chunks search
|
||||
chunk_result = await retrieve_doc_chunk_standard_regulation.ainvoke({"query": "safety requirements"})
|
||||
assert isinstance(chunk_result, dict)
|
||||
assert "results" in chunk_result
|
||||
assert len(chunk_result["results"]) == 1
|
||||
73
vw-agentic-rag/tests/unit/test_sse.py
Normal file
73
vw-agentic-rag/tests/unit/test_sse.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import pytest
|
||||
from service.sse import (
|
||||
format_sse_event, create_token_event, create_tool_start_event,
|
||||
create_tool_result_event, create_tool_error_event,
|
||||
create_error_event
|
||||
)
|
||||
|
||||
|
||||
def test_format_sse_event():
|
||||
"""Test SSE event formatting"""
|
||||
event = format_sse_event("test", {"message": "hello"})
|
||||
expected = 'event: test\ndata: {"message": "hello"}\n\n'
|
||||
assert event == expected
|
||||
|
||||
|
||||
def test_create_token_event():
|
||||
"""Test token event creation"""
|
||||
event = create_token_event("hello", "tool_123")
|
||||
assert "event: tokens" in event
|
||||
assert '"delta": "hello"' in event
|
||||
assert '"tool_call_id": "tool_123"' in event
|
||||
|
||||
|
||||
def test_create_token_event_no_tool_id():
|
||||
"""Test token event without tool call ID"""
|
||||
event = create_token_event("hello")
|
||||
assert "event: tokens" in event
|
||||
assert '"delta": "hello"' in event
|
||||
assert '"tool_call_id": null' in event
|
||||
|
||||
|
||||
def test_create_tool_start_event():
|
||||
"""Test tool start event"""
|
||||
event = create_tool_start_event("tool_123", "retrieve_standard_regulation", {"query": "test"})
|
||||
assert "event: tool_start" in event
|
||||
assert '"id": "tool_123"' in event
|
||||
assert '"name": "retrieve_standard_regulation"' in event
|
||||
assert '"args": {"query": "test"}' in event
|
||||
|
||||
|
||||
def test_create_tool_result_event():
|
||||
"""Test tool result event"""
|
||||
results = [{"id": "1", "title": "Test Standard"}]
|
||||
event = create_tool_result_event("tool_123", "retrieve_standard_regulation", results, 500)
|
||||
assert "event: tool_result" in event
|
||||
assert '"id": "tool_123"' in event
|
||||
assert '"took_ms": 500' in event
|
||||
assert '"results"' in event
|
||||
|
||||
|
||||
def test_create_tool_error_event():
|
||||
"""Test tool error event"""
|
||||
event = create_tool_error_event("tool_123", "retrieve_standard_regulation", "API timeout")
|
||||
assert "event: tool_error" in event
|
||||
assert '"id": "tool_123"' in event
|
||||
assert '"error": "API timeout"' in event
|
||||
|
||||
|
||||
def test_create_error_event():
|
||||
"""Test error event"""
|
||||
event = create_error_event("Something went wrong")
|
||||
assert "event: error" in event
|
||||
assert '"error": "Something went wrong"' in event
|
||||
|
||||
|
||||
def test_create_error_event_with_details():
|
||||
"""Test error event with details"""
|
||||
details = {"code": 500, "source": "llm"}
|
||||
event = create_error_event("Something went wrong", details)
|
||||
assert "event: error" in event
|
||||
assert '"error": "Something went wrong"' in event
|
||||
assert '"details"' in event
|
||||
assert '"code": 500' in event
|
||||
200
vw-agentic-rag/tests/unit/test_trimming_fix.py
Normal file
200
vw-agentic-rag/tests/unit/test_trimming_fix.py
Normal file
@@ -0,0 +1,200 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试修改后的修剪逻辑:验证只在 tool_rounds=0 时修剪,其他时候跳过修剪
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from service.graph.state import AgentState
|
||||
from service.graph.graph import call_model
|
||||
from service.graph.user_manual_rag import user_manual_agent_node
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trimming_only_at_tool_rounds_zero():
|
||||
"""测试修剪只在 tool_rounds=0 时发生"""
|
||||
|
||||
# 创建一个大的消息列表来触发修剪
|
||||
large_messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="请搜索信息"),
|
||||
]
|
||||
|
||||
# 添加大量内容来触发修剪
|
||||
for i in range(10):
|
||||
large_messages.extend([
|
||||
AIMessage(content=f"搜索{i}", tool_calls=[
|
||||
{"id": f"call_{i}", "name": "search", "args": {"query": f"test{i}"}}
|
||||
]),
|
||||
ToolMessage(content="很长的搜索结果内容 " * 500, tool_call_id=f"call_{i}", name="search"),
|
||||
])
|
||||
|
||||
# Mock trimmer to track calls
|
||||
with patch('service.graph.graph.create_conversation_trimmer') as mock_trimmer_factory:
|
||||
mock_trimmer = MagicMock()
|
||||
mock_trimmer.should_trim.return_value = True # 总是返回需要修剪
|
||||
mock_trimmer.trim_conversation_history.return_value = large_messages[:3] # 返回修剪后的消息
|
||||
mock_trimmer_factory.return_value = mock_trimmer
|
||||
|
||||
# Mock LLM client
|
||||
with patch('service.graph.graph.LLMClient') as mock_llm_factory:
|
||||
mock_llm = MagicMock()
|
||||
# 设置异步方法的返回值
|
||||
async def mock_ainvoke_with_tools(*args, **kwargs):
|
||||
return AIMessage(content="最终答案")
|
||||
|
||||
async def mock_astream(*args, **kwargs):
|
||||
for token in ["最", "终", "答", "案"]:
|
||||
yield token
|
||||
|
||||
mock_llm.ainvoke_with_tools = mock_ainvoke_with_tools
|
||||
mock_llm.astream = mock_astream
|
||||
mock_llm.bind_tools = MagicMock()
|
||||
mock_llm_factory.return_value = mock_llm
|
||||
|
||||
# 测试 tool_rounds=0 的情况(应该修剪)
|
||||
from service.graph.state import AgentState
|
||||
state_round_0 = AgentState(
|
||||
messages=large_messages,
|
||||
session_id="test",
|
||||
intent=None,
|
||||
tool_results=[],
|
||||
final_answer="",
|
||||
tool_rounds=0, # 第一轮
|
||||
max_tool_rounds=3,
|
||||
max_tool_rounds_user_manual=2
|
||||
)
|
||||
|
||||
result = await call_model(state_round_0)
|
||||
|
||||
# 验证修剪器被调用
|
||||
mock_trimmer.should_trim.assert_called()
|
||||
mock_trimmer.trim_conversation_history.assert_called()
|
||||
|
||||
# 重置 mock
|
||||
mock_trimmer.reset_mock()
|
||||
|
||||
# 测试 tool_rounds>0 的情况(不应该修剪)
|
||||
from service.graph.state import AgentState
|
||||
state_round_1 = AgentState(
|
||||
messages=large_messages,
|
||||
session_id="test",
|
||||
intent=None,
|
||||
tool_results=[],
|
||||
final_answer="",
|
||||
tool_rounds=1, # 第二轮
|
||||
max_tool_rounds=3,
|
||||
max_tool_rounds_user_manual=2
|
||||
)
|
||||
|
||||
result = await call_model(state_round_1)
|
||||
|
||||
# 验证修剪器没有被调用
|
||||
mock_trimmer.should_trim.assert_not_called()
|
||||
mock_trimmer.trim_conversation_history.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_manual_agent_trimming_behavior():
|
||||
"""测试 user_manual_agent_node 的修剪行为"""
|
||||
|
||||
# 创建大的消息列表
|
||||
large_messages = [
|
||||
SystemMessage(content="You are a helpful user manual assistant."),
|
||||
HumanMessage(content="请帮我查找系统使用说明"),
|
||||
]
|
||||
|
||||
# 添加大量内容
|
||||
for i in range(8):
|
||||
large_messages.extend([
|
||||
AIMessage(content=f"查找{i}", tool_calls=[
|
||||
{"id": f"manual_call_{i}", "name": "retrieve_system_usermanual", "args": {"query": f"功能{i}"}}
|
||||
]),
|
||||
ToolMessage(content="很长的手册内容 " * 400, tool_call_id=f"manual_call_{i}", name="retrieve_system_usermanual"),
|
||||
])
|
||||
|
||||
# Mock dependencies
|
||||
with patch('service.graph.user_manual_rag.create_conversation_trimmer') as mock_trimmer_factory, \
|
||||
patch('service.graph.user_manual_rag.LLMClient') as mock_llm_factory, \
|
||||
patch('service.graph.user_manual_rag.get_user_manual_tool_schemas') as mock_tools, \
|
||||
patch('service.graph.user_manual_rag.get_cached_config') as mock_config:
|
||||
|
||||
# Setup mocks
|
||||
mock_trimmer = MagicMock()
|
||||
mock_trimmer.should_trim.return_value = True
|
||||
mock_trimmer.trim_conversation_history.return_value = large_messages[:3]
|
||||
mock_trimmer_factory.return_value = mock_trimmer
|
||||
|
||||
mock_llm = MagicMock()
|
||||
# 设置异步方法的返回值
|
||||
async def mock_ainvoke_with_tools(*args, **kwargs):
|
||||
return AIMessage(content="用户手册答案")
|
||||
|
||||
async def mock_astream(*args, **kwargs):
|
||||
for token in ["答", "案"]:
|
||||
yield token
|
||||
|
||||
mock_llm.ainvoke_with_tools = mock_ainvoke_with_tools
|
||||
mock_llm.astream = mock_astream
|
||||
mock_llm.bind_tools = MagicMock()
|
||||
mock_llm_factory.return_value = mock_llm
|
||||
|
||||
mock_tools.return_value = []
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.app.max_tool_rounds_user_manual = 2
|
||||
mock_app_config.get_rag_prompts.return_value = {
|
||||
"user_manual_prompt": "System prompt for user manual: {conversation_history} {context_content} {current_query}"
|
||||
}
|
||||
mock_config.return_value = mock_app_config
|
||||
|
||||
# 测试 tool_rounds=0(应该修剪)
|
||||
from service.graph.state import AgentState
|
||||
state_round_0 = AgentState(
|
||||
messages=large_messages,
|
||||
session_id="test",
|
||||
intent=None,
|
||||
tool_results=[],
|
||||
final_answer="",
|
||||
tool_rounds=0,
|
||||
max_tool_rounds=3,
|
||||
max_tool_rounds_user_manual=2
|
||||
)
|
||||
|
||||
result = await user_manual_agent_node(state_round_0)
|
||||
|
||||
# 验证修剪器被调用
|
||||
mock_trimmer.should_trim.assert_called()
|
||||
mock_trimmer.trim_conversation_history.assert_called()
|
||||
|
||||
# 重置 mock
|
||||
mock_trimmer.reset_mock()
|
||||
|
||||
# 测试 tool_rounds=1(不应该修剪)
|
||||
state_round_1 = AgentState(
|
||||
messages=large_messages,
|
||||
session_id="test",
|
||||
intent=None,
|
||||
tool_results=[],
|
||||
final_answer="",
|
||||
tool_rounds=1,
|
||||
max_tool_rounds=3,
|
||||
max_tool_rounds_user_manual=2
|
||||
)
|
||||
|
||||
result = await user_manual_agent_node(state_round_1)
|
||||
|
||||
# 验证修剪器没有被调用
|
||||
mock_trimmer.should_trim.assert_not_called()
|
||||
mock_trimmer.trim_conversation_history.assert_not_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
95
vw-agentic-rag/tests/unit/test_user_manual_tool.py
Normal file
95
vw-agentic-rag/tests/unit/test_user_manual_tool.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Unit test for the new retrieve_system_usermanual tool
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
import asyncio
|
||||
|
||||
from service.graph.user_manual_tools import retrieve_system_usermanual, user_manual_tools, get_user_manual_tool_schemas, get_user_manual_tools_by_name
|
||||
from service.graph.tools import get_tool_schemas, get_tools_by_name
|
||||
from service.retrieval.retrieval import RetrievalResponse
|
||||
|
||||
|
||||
class TestRetrieveSystemUsermanualTool:
|
||||
"""Test the new user manual retrieval tool"""
|
||||
|
||||
def test_tool_in_tools_list(self):
|
||||
"""Test that the new tool is in the tools list"""
|
||||
tool_names = [tool.name for tool in user_manual_tools]
|
||||
assert "retrieve_system_usermanual" in tool_names
|
||||
assert len(user_manual_tools) == 1 # Should have 1 user manual tool
|
||||
|
||||
def test_tool_schemas_generation(self):
|
||||
"""Test that tool schemas are generated correctly"""
|
||||
schemas = get_user_manual_tool_schemas()
|
||||
tool_names = [schema["function"]["name"] for schema in schemas]
|
||||
assert "retrieve_system_usermanual" in tool_names
|
||||
|
||||
# Find the user manual tool schema
|
||||
user_manual_schema = next(
|
||||
schema for schema in schemas
|
||||
if schema["function"]["name"] == "retrieve_system_usermanual"
|
||||
)
|
||||
|
||||
assert user_manual_schema["function"]["description"] == "Search for document content chunks of user manual of this system(CATOnline)"
|
||||
assert "query" in user_manual_schema["function"]["parameters"]["properties"]
|
||||
assert user_manual_schema["function"]["parameters"]["required"] == ["query"]
|
||||
|
||||
def test_tools_by_name_mapping(self):
|
||||
"""Test that tools_by_name mapping includes the new tool"""
|
||||
tools_mapping = get_user_manual_tools_by_name()
|
||||
assert "retrieve_system_usermanual" in tools_mapping
|
||||
assert tools_mapping["retrieve_system_usermanual"] == retrieve_system_usermanual
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_system_usermanual_success(self):
|
||||
"""Test successful user manual retrieval"""
|
||||
|
||||
mock_response = RetrievalResponse(
|
||||
results=[
|
||||
{"title": "User Manual Chapter 1", "content": "How to use the system", "id": "manual_1"}
|
||||
],
|
||||
took_ms=120,
|
||||
total_count=1
|
||||
)
|
||||
|
||||
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.retrieve_doc_chunk_user_manual.return_value = mock_response
|
||||
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
|
||||
mock_retrieval_class.return_value.__aexit__.return_value = None
|
||||
|
||||
# Test the tool
|
||||
result = await retrieve_system_usermanual.ainvoke({"query": "how to use system"})
|
||||
|
||||
# Verify result format
|
||||
assert isinstance(result, dict)
|
||||
assert result["tool_name"] == "retrieve_system_usermanual"
|
||||
assert result["results_count"] == 1
|
||||
assert len(result["results"]) == 1
|
||||
assert result["results"][0]["title"] == "User Manual Chapter 1"
|
||||
assert result["took_ms"] == 120
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_system_usermanual_error(self):
|
||||
"""Test error handling in user manual retrieval"""
|
||||
|
||||
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.retrieve_doc_chunk_user_manual.side_effect = Exception("Search API Error")
|
||||
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
|
||||
mock_retrieval_class.return_value.__aexit__.return_value = None
|
||||
|
||||
# Test error handling
|
||||
result = await retrieve_system_usermanual.ainvoke({"query": "test query"})
|
||||
|
||||
# Should return error information
|
||||
assert isinstance(result, dict)
|
||||
assert "error" in result
|
||||
assert "Search API Error" in result["error"]
|
||||
assert result["results_count"] == 0
|
||||
assert result["results"] == []
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user