This commit is contained in:
2025-09-26 17:15:54 +08:00
commit db0e5965ec
211 changed files with 40437 additions and 0 deletions

View File

@@ -0,0 +1 @@
# Empty __init__.py files to make test packages

View File

@@ -0,0 +1,114 @@
#!/usr/bin/env python3
"""
测试新的积极修剪策略即使token数很少也要修剪历史工具调用结果
"""
import pytest
from service.graph.message_trimmer import ConversationTrimmer
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
from langchain_core.messages.utils import count_tokens_approximately
def test_aggressive_tool_history_trimming():
"""测试积极的工具历史修剪策略"""
# 直接创建修剪器,避免配置依赖
trimmer = ConversationTrimmer(max_context_length=100000)
# 创建包含多轮工具调用的对话token数很少
messages = [
SystemMessage(content='You are a helpful assistant.'),
# 历史对话轮次1
HumanMessage(content='搜索汽车标准'),
AIMessage(content='搜索中', tool_calls=[{'id': 'call_1', 'name': 'search', 'args': {'query': '汽车标准'}}]),
ToolMessage(content='历史结果1', tool_call_id='call_1', name='search'),
AIMessage(content='汽车标准信息'),
# 历史对话轮次2
HumanMessage(content='搜索电池标准'),
AIMessage(content='搜索中', tool_calls=[{'id': 'call_2', 'name': 'search', 'args': {'query': '电池标准'}}]),
ToolMessage(content='历史结果2', tool_call_id='call_2', name='search'),
AIMessage(content='电池标准信息'),
# 新的用户查询(触发修剪的时机)
HumanMessage(content='搜索安全标准'),
]
# 验证token数很少远低于阈值
token_count = count_tokens_approximately(messages)
assert token_count < 1000, f"Token count should be low, got {token_count}"
assert token_count < trimmer.history_token_limit, "Token count should be well below limit"
# 验证识别到多个工具轮次
tool_rounds = trimmer._identify_tool_rounds(messages)
assert len(tool_rounds) == 2, f"Should identify 2 tool rounds, got {len(tool_rounds)}"
# 验证触发修剪(因为有多个工具轮次)
should_trim = trimmer.should_trim(messages)
assert should_trim, "Should trigger trimming due to multiple tool rounds"
# 执行修剪
trimmed = trimmer.trim_conversation_history(messages)
# 验证修剪效果
assert len(trimmed) < len(messages), "Should have fewer messages after trimming"
# 验证保留了系统消息和初始查询
assert isinstance(trimmed[0], SystemMessage), "Should preserve system message"
assert isinstance(trimmed[1], HumanMessage), "Should preserve initial human message"
# 验证只保留了最新轮次的工具调用结果
tool_messages = [msg for msg in trimmed if isinstance(msg, ToolMessage)]
assert len(tool_messages) == 1, f"Should only keep 1 tool message, got {len(tool_messages)}"
assert tool_messages[0].content == '历史结果2', "Should keep the most recent tool result"
def test_single_tool_round_no_trimming():
"""测试单轮工具调用不触发修剪"""
trimmer = ConversationTrimmer(max_context_length=100000)
# 只有一轮工具调用的对话
messages = [
SystemMessage(content='You are a helpful assistant.'),
HumanMessage(content='搜索信息'),
AIMessage(content='搜索中', tool_calls=[{'id': 'call_1', 'name': 'search', 'args': {'query': '信息'}}]),
ToolMessage(content='搜索结果', tool_call_id='call_1', name='search'),
AIMessage(content='这是搜索到的信息'),
HumanMessage(content='新的问题'),
]
# 验证只有一个工具轮次
tool_rounds = trimmer._identify_tool_rounds(messages)
assert len(tool_rounds) == 1, f"Should identify 1 tool round, got {len(tool_rounds)}"
# 验证不触发修剪因为只有一个工具轮次且token数不高
should_trim = trimmer.should_trim(messages)
assert not should_trim, "Should not trigger trimming for single tool round with low tokens"
def test_no_tool_rounds_no_trimming():
"""测试没有工具调用的对话不触发修剪"""
trimmer = ConversationTrimmer(max_context_length=100000)
# 没有工具调用的对话
messages = [
SystemMessage(content='You are a helpful assistant.'),
HumanMessage(content='Hello'),
AIMessage(content='Hi there!'),
HumanMessage(content='How are you?'),
AIMessage(content='I am doing well, thank you!'),
]
# 验证没有工具轮次
tool_rounds = trimmer._identify_tool_rounds(messages)
assert len(tool_rounds) == 0, f"Should identify 0 tool rounds, got {len(tool_rounds)}"
# 验证不触发修剪
should_trim = trimmer.should_trim(messages)
assert not should_trim, "Should not trigger trimming without tool rounds"
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,143 @@
"""
Test assistant-ui best practices implementation
"""
import json
import os
def test_package_json_dependencies():
"""Test that package.json has the correct assistant-ui dependencies"""
package_json_path = os.path.join(os.path.dirname(__file__), "../../web/package.json")
with open(package_json_path, 'r') as f:
package_data = json.load(f)
deps = package_data.get("dependencies", {})
# Check for essential assistant-ui packages
assert "@assistant-ui/react" in deps, "Missing @assistant-ui/react"
assert "@assistant-ui/react-ui" in deps, "Missing @assistant-ui/react-ui"
assert "@assistant-ui/react-markdown" in deps, "Missing @assistant-ui/react-markdown"
assert "@assistant-ui/react-data-stream" in deps, "Missing @assistant-ui/react-data-stream"
# Check versions are reasonable (not too old)
react_version = deps["@assistant-ui/react"]
assert "0.10" in react_version or "0.9" in react_version, f"Version too old: {react_version}"
print("✅ Package dependencies test passed")
def test_env_configuration():
"""Test that environment configuration files exist"""
env_local_path = os.path.join(os.path.dirname(__file__), "../../web/.env.local")
assert os.path.exists(env_local_path), "Missing .env.local file"
with open(env_local_path, 'r') as f:
env_content = f.read()
assert "NEXT_PUBLIC_LANGGRAPH_API_URL" in env_content, "Missing API URL config"
assert "NEXT_PUBLIC_LANGGRAPH_ASSISTANT_ID" in env_content, "Missing Assistant ID config"
print("✅ Environment configuration test passed")
def test_api_route_structure():
"""Test that API routes are properly structured"""
# Check main chat API route exists
chat_route_path = os.path.join(os.path.dirname(__file__), "../../web/src/app/api/chat/route.ts")
assert os.path.exists(chat_route_path), "Missing chat API route"
with open(chat_route_path, 'r') as f:
route_content = f.read()
# Check for essential API patterns
assert "export async function POST" in route_content, "Missing POST handler"
assert "Response" in route_content, "Missing Response handling"
assert "x-vercel-ai-data-stream" in route_content, "Missing AI SDK compatibility header"
print("✅ API route structure test passed")
def test_component_structure():
"""Test that main components follow best practices"""
# Check main page component
page_path = os.path.join(os.path.dirname(__file__), "../../web/src/app/page.tsx")
assert os.path.exists(page_path), "Missing main page component"
with open(page_path, 'r') as f:
page_content = f.read()
# Check for key React patterns and components
assert '"use client"' in page_content, "Missing client-side directive"
assert "Assistant" in page_content, "Missing Assistant component"
assert "export default function" in page_content, "Missing default function export"
# Check for proper structure
assert "className=" in page_content, "Missing CSS class usage"
assert "h-screen" in page_content or "h-full" in page_content, "Missing full height layout"
print("✅ Component structure test passed")
def test_markdown_component():
"""Test that markdown component is properly configured"""
markdown_path = os.path.join(os.path.dirname(__file__), "../../web/src/components/ui/markdown-text.tsx")
assert os.path.exists(markdown_path), "Missing markdown component"
with open(markdown_path, 'r') as f:
markdown_content = f.read()
assert "MarkdownTextPrimitive" in markdown_content, "Missing markdown primitive"
assert "remarkGfm" in markdown_content, "Missing GFM support"
print("✅ Markdown component test passed")
def test_best_practices_documentation():
"""Test that best practices documentation exists and is comprehensive"""
docs_path = os.path.join(os.path.dirname(__file__), "../../docs/topics/ASSISTANT_UI_BEST_PRACTICES.md")
assert os.path.exists(docs_path), "Missing best practices documentation"
with open(docs_path, 'r') as f:
docs_content = f.read()
# Check for key sections
assert "Assistant-UI + LangGraph + FastAPI" in docs_content, "Missing main title"
assert "Implementation Status" in docs_content, "Missing implementation status"
assert "Package Dependencies Updated" in docs_content, "Missing dependencies section"
assert "Server-Side API Routes" in docs_content, "Missing API routes explanation"
print("✅ Best practices documentation test passed")
def run_all_tests():
"""Run all tests"""
print("🧪 Running assistant-ui best practices validation tests...")
try:
test_package_json_dependencies()
test_env_configuration()
test_api_route_structure()
test_component_structure()
test_markdown_component()
test_best_practices_documentation()
print("\n🎉 All assistant-ui best practices tests passed!")
print("✅ Your implementation follows the recommended patterns for:")
print(" - Package dependencies and versions")
print(" - Environment configuration")
print(" - API route structure")
print(" - Component composition")
print(" - Markdown rendering")
print(" - Documentation completeness")
return True
except Exception as e:
print(f"\n❌ Test failed: {e}")
return False
if __name__ == "__main__":
success = run_all_tests()
exit(0 if success else 1)

View File

@@ -0,0 +1,141 @@
import pytest
from service.memory.store import InMemoryStore
from service.graph.state import TurnState, Message
from datetime import datetime, timedelta
@pytest.fixture
def memory_store():
"""Create memory store for testing"""
return InMemoryStore(ttl_days=1) # Short TTL for testing
@pytest.fixture
def sample_state():
"""Create sample turn state"""
state = TurnState(session_id="test_session")
state.messages = [
Message(role="user", content="Hello"),
Message(role="assistant", content="Hi there!")
]
return state
def test_create_new_session(memory_store):
"""Test creating a new session"""
state = memory_store.create_new_session("new_session")
assert state.session_id == "new_session"
assert len(state.messages) == 0
# Verify it's stored
retrieved = memory_store.get("new_session")
assert retrieved is not None
assert retrieved.session_id == "new_session"
def test_put_and_get_state(memory_store, sample_state):
"""Test storing and retrieving state"""
memory_store.put("test_session", sample_state)
retrieved = memory_store.get("test_session")
assert retrieved is not None
assert retrieved.session_id == "test_session"
assert len(retrieved.messages) == 2
assert retrieved.messages[0].content == "Hello"
def test_get_nonexistent_session(memory_store):
"""Test getting non-existent session returns None"""
result = memory_store.get("nonexistent")
assert result is None
def test_add_message(memory_store):
"""Test adding messages to conversation"""
memory_store.create_new_session("test_session")
message = Message(role="user", content="Test message")
memory_store.add_message("test_session", message)
state = memory_store.get("test_session")
assert len(state.messages) == 1
assert state.messages[0].content == "Test message"
def test_add_message_to_nonexistent_session(memory_store):
"""Test adding message creates new session if it doesn't exist"""
message = Message(role="user", content="Test message")
memory_store.add_message("new_session", message)
state = memory_store.get("new_session")
assert state is not None
assert len(state.messages) == 1
def test_get_conversation_history(memory_store):
"""Test conversation history formatting"""
memory_store.create_new_session("test_session")
messages = [
Message(role="user", content="Hello"),
Message(role="assistant", content="Hi there!"),
Message(role="user", content="How are you?"),
Message(role="assistant", content="I'm doing well!")
]
for msg in messages:
memory_store.add_message("test_session", msg)
history = memory_store.get_conversation_history("test_session")
assert "User: Hello" in history
assert "Assistant: Hi there!" in history
assert "User: How are you?" in history
assert "Assistant: I'm doing well!" in history
def test_get_conversation_history_empty(memory_store):
"""Test conversation history for empty session"""
history = memory_store.get_conversation_history("nonexistent")
assert history == ""
def test_trim_messages(memory_store):
"""Test message trimming"""
memory_store.create_new_session("test_session")
# Add many messages
for i in range(25):
memory_store.add_message("test_session", Message(role="user", content=f"Message {i}"))
# Trim to 10 messages
memory_store.trim("test_session", max_messages=10)
state = memory_store.get("test_session")
assert len(state.messages) <= 10
def test_trim_nonexistent_session(memory_store):
"""Test trimming non-existent session doesn't crash"""
memory_store.trim("nonexistent", max_messages=10)
# Should not raise an exception
def test_ttl_cleanup(memory_store):
"""Test TTL-based cleanup"""
# Create store with very short TTL for testing
short_ttl_store = InMemoryStore(ttl_days=0.001) # ~1.5 minutes
# Add a session
state = TurnState(session_id="test_session")
short_ttl_store.put("test_session", state)
# Verify it exists
assert short_ttl_store.get("test_session") is not None
# Manually expire it by manipulating internal timestamp
short_ttl_store.store["test_session"]["last_updated"] = datetime.now() - timedelta(days=1)
# Try to get it - should trigger cleanup
assert short_ttl_store.get("test_session") is None
assert "test_session" not in short_ttl_store.store

View File

@@ -0,0 +1,194 @@
"""
Test conversation history trimming functionality.
"""
import pytest
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
from service.graph.message_trimmer import ConversationTrimmer, create_conversation_trimmer
def test_conversation_trimmer_basic():
"""Test basic trimming functionality."""
trimmer = ConversationTrimmer(max_context_length=50) # Very low limit for testing (85% = 42 tokens)
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="What is the capital of France?"),
AIMessage(content="The capital of France is Paris."),
HumanMessage(content="What about Germany?"),
AIMessage(content="The capital of Germany is Berlin."),
HumanMessage(content="And Italy?"),
AIMessage(content="The capital of Italy is Rome."),
]
# Should trigger trimming due to low token limit
assert trimmer.should_trim(messages)
trimmed = trimmer.trim_conversation_history(messages)
# Should preserve system message and keep recent messages
assert len(trimmed) < len(messages)
assert isinstance(trimmed[0], SystemMessage)
assert "helpful assistant" in trimmed[0].content
def test_conversation_trimmer_no_trim_needed():
"""Test when no trimming is needed."""
trimmer = ConversationTrimmer(max_context_length=100000) # High limit
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hello"),
AIMessage(content="Hi there!"),
]
# Should not need trimming
assert not trimmer.should_trim(messages)
trimmed = trimmer.trim_conversation_history(messages)
# Should return same messages
assert len(trimmed) == len(messages)
def test_conversation_trimmer_fallback():
"""Test fallback trimming logic."""
trimmer = ConversationTrimmer(max_context_length=100)
# Create many messages to trigger fallback
messages = [SystemMessage(content="System")] + [
HumanMessage(content=f"Message {i}") for i in range(50)
]
trimmed = trimmer._fallback_trim(messages, max_messages=5)
# Should keep system message + 4 recent messages
assert len(trimmed) == 5
assert isinstance(trimmed[0], SystemMessage)
def test_create_conversation_trimmer():
"""Test trimmer factory function."""
# Test with explicit configuration
custom_trimmer = create_conversation_trimmer(max_context_length=128000)
assert custom_trimmer.max_context_length == 128000
assert isinstance(custom_trimmer, ConversationTrimmer)
assert custom_trimmer.preserve_system is True
def test_multi_round_tool_optimization():
"""Test multi-round tool call optimization."""
trimmer = ConversationTrimmer(max_context_length=100000) # High limit to focus on optimization
# Create a multi-round tool calling scenario
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Search for food safety and cosmetics testing standards."),
# First tool round
AIMessage(content="I'll search for information.", tool_calls=[
{"id": "call_1", "name": "search_tool", "args": {"query": "food safety"}}
]),
ToolMessage(content='{"results": ["Large food safety result data..." * 100]}', tool_call_id="call_1", name="search_tool"),
# Second tool round
AIMessage(content="Let me search for more specific information.", tool_calls=[
{"id": "call_2", "name": "search_tool", "args": {"query": "cosmetics testing"}}
]),
ToolMessage(content='{"results": ["Large cosmetics testing data..." * 100]}', tool_call_id="call_2", name="search_tool"),
# Third tool round (most recent)
AIMessage(content="Now let me get equipment information.", tool_calls=[
{"id": "call_3", "name": "search_tool", "args": {"query": "testing equipment"}}
]),
ToolMessage(content='{"results": ["Testing equipment info..." * 50]}', tool_call_id="call_3", name="search_tool"),
]
# Test tool round identification
tool_rounds = trimmer._identify_tool_rounds(messages)
assert len(tool_rounds) == 3 # Should identify 3 tool rounds
# Test optimization
optimized = trimmer._optimize_multi_round_tool_calls(messages)
# Should preserve: SystemMessage + HumanMessage + latest tool round (AI + Tool messages)
expected_length = 1 + 1 + 2 # system + human + (latest AI + latest Tool)
assert len(optimized) == expected_length
# Check preserved messages
assert isinstance(optimized[0], SystemMessage)
assert isinstance(optimized[1], HumanMessage)
assert isinstance(optimized[2], AIMessage)
assert optimized[2].tool_calls[0]["id"] == "call_3" # Should be the latest tool call
assert isinstance(optimized[3], ToolMessage)
assert optimized[3].tool_call_id == "call_3" # Should be the latest tool result
def test_multi_round_optimization_single_round():
"""Test that single tool round is not optimized."""
trimmer = ConversationTrimmer(max_context_length=100000)
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Search for something."),
AIMessage(content="I'll search.", tool_calls=[{"id": "call_1", "name": "search_tool", "args": {}}]),
ToolMessage(content='{"results": ["data"]}', tool_call_id="call_1", name="search_tool"),
]
optimized = trimmer._optimize_multi_round_tool_calls(messages)
# Should return all messages unchanged for single round
assert len(optimized) == len(messages)
def test_multi_round_optimization_no_tools():
"""Test that conversations without tools are not optimized."""
trimmer = ConversationTrimmer(max_context_length=100000)
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hello"),
AIMessage(content="Hi there!"),
HumanMessage(content="How are you?"),
AIMessage(content="I'm doing well, thanks!"),
]
optimized = trimmer._optimize_multi_round_tool_calls(messages)
# Should return all messages unchanged
assert len(optimized) == len(messages)
def test_full_trimming_with_optimization():
"""Test complete trimming process with multi-round optimization."""
trimmer = ConversationTrimmer(max_context_length=50) # Very low limit
# Create messages that will trigger both optimization and regular trimming
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Search for food safety and cosmetics testing standards."),
# First tool round (should be removed by optimization)
AIMessage(content="I'll search for information.", tool_calls=[
{"id": "call_1", "name": "search_tool", "args": {"query": "food safety"}}
]),
ToolMessage(content='{"results": ["Large food safety result data..." * 100]}', tool_call_id="call_1", name="search_tool"),
# Second tool round (most recent, should be preserved)
AIMessage(content="Let me search for more information.", tool_calls=[
{"id": "call_2", "name": "search_tool", "args": {"query": "cosmetics testing"}}
]),
ToolMessage(content='{"results": ["Large cosmetics testing data..." * 100]}', tool_call_id="call_2", name="search_tool"),
]
trimmed = trimmer.trim_conversation_history(messages)
# Should apply optimization first, then regular trimming if needed
assert len(trimmed) <= len(messages)
# System message should always be preserved
assert any(isinstance(msg, SystemMessage) for msg in trimmed)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,353 @@
"""
Unit tests for agentic retrieval tools.
Tests the HTTP wrapper tools that interface with the retrieval API.
"""
import pytest
from unittest.mock import AsyncMock, patch, Mock
import httpx
from service.retrieval.retrieval import AgenticRetrieval, RetrievalResponse
from service.retrieval.clients import RetrievalAPIError, normalize_search_result
class TestAgenticRetrieval:
"""Test the agentic retrieval HTTP client."""
@pytest.fixture
def retrieval_client(self):
"""Create retrieval client for testing."""
with patch('service.retrieval.retrieval.get_config') as mock_config:
# Mock the new config structure
mock_config.return_value.retrieval.endpoint = "https://test-search.search.azure.com"
mock_config.return_value.retrieval.api_key = "test-key"
mock_config.return_value.retrieval.api_version = "2024-11-01-preview"
mock_config.return_value.retrieval.semantic_configuration = "default"
# Mock embedding config
mock_config.return_value.retrieval.embedding.base_url = "http://test-embedding"
mock_config.return_value.retrieval.embedding.api_key = "test-embedding-key"
mock_config.return_value.retrieval.embedding.model = "test-embedding-model"
mock_config.return_value.retrieval.embedding.dimension = 1536
# Mock index config
mock_config.return_value.retrieval.index.standard_regulation_index = "index-catonline-standard-regulation-v2-prd"
mock_config.return_value.retrieval.index.chunk_index = "index-catonline-chunk-v2-prd"
mock_config.return_value.retrieval.index.chunk_user_manual_index = "index-cat-usermanual-chunk-prd"
return AgenticRetrieval()
@pytest.mark.asyncio
async def test_retrieve_standard_regulation_success(self, retrieval_client):
"""Test successful standards regulation retrieval."""
mock_response_data = {
"value": [
{
"id": "test_id_1",
"title": "ISO 26262-1:2018",
"document_code": "ISO26262-1",
"document_category": "Standard",
"@order_num": 1 # Should be preserved
},
{
"id": "test_id_2",
"title": "ISO 26262-3:2018",
"document_code": "ISO26262-3",
"document_category": "Standard",
"@order_num": 2
}
]
}
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value=mock_response_data) as mock_search:
result = await retrieval_client.retrieve_standard_regulation("ISO 26262")
# Verify search call
mock_search.assert_called_once()
call_args = mock_search.call_args[1] # keyword arguments
assert call_args['search_text'] == "ISO 26262"
assert call_args['index_name'] == "index-catonline-standard-regulation-v2-prd"
assert call_args['top_k'] == 10
# Verify response structure
assert isinstance(result, RetrievalResponse)
assert len(result.results) == 2
assert result.total_count == 2
assert result.took_ms is not None
# Verify normalization
first_result = result.results[0]
assert first_result['id'] == "test_id_1"
assert first_result['title'] == "ISO 26262-1:2018"
# Verify order number is preserved
assert "@order_num" in first_result
assert "content" not in first_result # Should not be included for standards
@pytest.mark.asyncio
async def test_retrieve_doc_chunk_success(self, retrieval_client):
"""Test successful document chunk retrieval."""
mock_response_data = {
"value": [
{
"id": "chunk_id_1",
"title": "Functional Safety Requirements",
"content": "Detailed content about functional safety...",
"document_code": "ISO26262-1",
"@order_num": 1 # Should be preserved
}
]
}
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value=mock_response_data) as mock_search:
result = await retrieval_client.retrieve_doc_chunk_standard_regulation(
"functional safety",
conversation_history="Previous question about ISO 26262"
)
# Verify search call
mock_search.assert_called_once()
call_args = mock_search.call_args[1] # keyword arguments
assert call_args['search_text'] == "functional safety"
assert call_args['index_name'] == "index-catonline-chunk-v2-prd"
assert "filter_query" in call_args # Should have document filter
# Verify response
assert isinstance(result, RetrievalResponse)
assert len(result.results) == 1
# Verify content is included for chunks
first_result = result.results[0]
assert "content" in first_result
assert first_result['content'] == "Detailed content about functional safety..."
# Verify order number is preserved
assert "@order_num" in first_result
@pytest.mark.asyncio
async def test_http_error_handling(self, retrieval_client):
"""Test HTTP error handling."""
with patch.object(retrieval_client.search_client, 'search_azure_ai') as mock_search:
mock_search.side_effect = RetrievalAPIError("API request failed: 404")
with pytest.raises(RetrievalAPIError) as exc_info:
await retrieval_client.retrieve_standard_regulation("nonexistent")
assert "404" in str(exc_info.value)
@pytest.mark.asyncio
async def test_timeout_handling(self, retrieval_client):
"""Test timeout handling."""
with patch.object(retrieval_client.search_client, 'search_azure_ai') as mock_search:
mock_search.side_effect = RetrievalAPIError("Request timeout")
with pytest.raises(RetrievalAPIError) as exc_info:
await retrieval_client.retrieve_standard_regulation("test query")
assert "timeout" in str(exc_info.value)
def test_normalize_result_removes_unwanted_fields(self, retrieval_client):
"""Test result normalization removes unwanted fields."""
from service.retrieval.clients import normalize_search_result
raw_result = {
"id": "test_id",
"title": "Test Title",
"content": "Test content",
"@search.score": 0.95, # Should be removed
"@search.rerankerScore": 2.5, # Should be removed
"@search.captions": [], # Should be removed
"@subquery_id": 1, # Should be removed
"empty_field": "", # Should be removed
"null_field": None, # Should be removed
"empty_list": [], # Should be removed
"empty_dict": {}, # Should be removed
"valid_field": "valid_value"
}
# Test normalization
normalized = normalize_search_result(raw_result)
# Verify unwanted fields are removed
assert "@search.score" not in normalized
assert "@search.rerankerScore" not in normalized
assert "@search.captions" not in normalized
assert "@subquery_id" not in normalized
assert "empty_field" not in normalized
assert "null_field" not in normalized
assert "empty_list" not in normalized
assert "empty_dict" not in normalized
# Verify valid fields are preserved (including content)
assert normalized["id"] == "test_id"
assert normalized["title"] == "Test Title"
assert normalized["content"] == "Test content" # Content is now always preserved
assert normalized["valid_field"] == "valid_value"
def test_normalize_result_includes_content_when_requested(self, retrieval_client):
"""Test result normalization always includes content."""
from service.retrieval.clients import normalize_search_result
raw_result = {
"id": "test_id",
"title": "Test Title",
"content": "Test content"
}
# Test normalization (content is always preserved now)
normalized = normalize_search_result(raw_result)
assert "content" in normalized
assert normalized["content"] == "Test content"
@pytest.mark.asyncio
async def test_empty_results_handling(self, retrieval_client):
"""Test handling of empty search results."""
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"value": []}):
result = await retrieval_client.retrieve_standard_regulation("no results query")
assert isinstance(result, RetrievalResponse)
assert result.results == []
assert result.total_count == 0
@pytest.mark.asyncio
async def test_malformed_response_handling(self, retrieval_client):
"""Test handling of malformed API responses."""
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"unexpected": "format"}):
result = await retrieval_client.retrieve_standard_regulation("test query")
# Should handle gracefully and return empty results
assert isinstance(result, RetrievalResponse)
assert result.results == []
assert result.total_count == 0
@pytest.mark.asyncio
async def test_kwargs_override_payload(self, retrieval_client):
"""Test that kwargs can override default values."""
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"value": []}) as mock_search:
await retrieval_client.retrieve_standard_regulation(
"test query",
top_k=5,
score_threshold=2.0
)
call_args = mock_search.call_args[1] # keyword arguments
assert call_args['top_k'] == 5 # Override default 10
assert call_args['score_threshold'] == 2.0 # Override default 1.5
class TestRetrievalTools:
"""Test the LangGraph tool decorators."""
@pytest.mark.asyncio
async def test_retrieve_standard_regulation_tool(self):
"""Test the @tool decorated function for standards search."""
from service.graph.tools import retrieve_standard_regulation
mock_response = RetrievalResponse(
results=[{"title": "Test Standard", "id": "test_id"}],
took_ms=100,
total_count=1
)
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
mock_instance = AsyncMock()
mock_instance.retrieve_standard_regulation.return_value = mock_response
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
mock_retrieval_class.return_value.__aexit__.return_value = None
# Note: retrieve_standard_regulation is a LangGraph @tool, so we call it directly
result = await retrieve_standard_regulation.ainvoke({"query": "test query", "conversation_history": "test history"})
# Verify result format matches expected tool output
assert isinstance(result, dict)
assert "results" in result
@pytest.mark.asyncio
async def test_retrieve_doc_chunk_tool(self):
"""Test the @tool decorated function for document chunks search."""
from service.graph.tools import retrieve_doc_chunk_standard_regulation
mock_response = RetrievalResponse(
results=[{"title": "Test Chunk", "content": "Test content", "id": "chunk_id"}],
took_ms=150,
total_count=1
)
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
mock_instance = AsyncMock()
mock_instance.retrieve_doc_chunk_standard_regulation.return_value = mock_response
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
mock_retrieval_class.return_value.__aexit__.return_value = None
# Call the tool with proper input format
result = await retrieve_doc_chunk_standard_regulation.ainvoke({"query": "test query"})
# Verify result format
assert isinstance(result, dict)
assert "results" in result
@pytest.mark.asyncio
async def test_tool_error_handling(self):
"""Test tool error handling."""
from service.graph.tools import retrieve_standard_regulation
with patch('service.retrieval.clients.AzureSearchClient.search_azure_ai') as mock_search:
mock_search.side_effect = RetrievalAPIError("API Error")
# Tool should handle error gracefully and return error result
result = await retrieve_standard_regulation.ainvoke({"query": "test query"})
# Should return error information in result
assert isinstance(result, dict)
assert "error" in result
assert "API Error" in result["error"]
class TestRetrievalIntegration:
"""Integration-style tests for retrieval functionality."""
@pytest.mark.asyncio
async def test_full_retrieval_workflow(self):
"""Test complete retrieval workflow with both tools."""
# Mock both retrieval calls
standards_response = RetrievalResponse(
results=[
{"id": "std_1", "title": "ISO 26262-1", "document_category": "Standard"}
],
took_ms=100,
total_count=1
)
chunks_response = RetrievalResponse(
results=[
{"id": "chunk_1", "title": "Safety Requirements", "content": "Detailed safety requirements..."}
],
took_ms=150,
total_count=1
)
with patch('service.retrieval.clients.AzureSearchClient.search_azure_ai') as mock_search:
# Set up mock to return different responses for different calls
def mock_response_side_effect(*args, **kwargs):
if kwargs.get('index_name') == 'index-catonline-standard-regulation-v2-prd':
return {
"value": [{"id": "std_1", "title": "ISO 26262-1", "document_category": "Standard", "@order_num": 1}]
}
else:
return {
"value": [{"id": "chunk_1", "title": "Safety Requirements", "content": "Detailed safety requirements...", "@order_num": 1}]
}
mock_search.side_effect = mock_response_side_effect
# Import and test both tools
from service.graph.tools import retrieve_standard_regulation, retrieve_doc_chunk_standard_regulation
# Test standards search
std_result = await retrieve_standard_regulation.ainvoke({"query": "ISO 26262"})
assert isinstance(std_result, dict)
assert "results" in std_result
assert len(std_result["results"]) == 1
# Test chunks search
chunk_result = await retrieve_doc_chunk_standard_regulation.ainvoke({"query": "safety requirements"})
assert isinstance(chunk_result, dict)
assert "results" in chunk_result
assert len(chunk_result["results"]) == 1

View File

@@ -0,0 +1,73 @@
import pytest
from service.sse import (
format_sse_event, create_token_event, create_tool_start_event,
create_tool_result_event, create_tool_error_event,
create_error_event
)
def test_format_sse_event():
"""Test SSE event formatting"""
event = format_sse_event("test", {"message": "hello"})
expected = 'event: test\ndata: {"message": "hello"}\n\n'
assert event == expected
def test_create_token_event():
"""Test token event creation"""
event = create_token_event("hello", "tool_123")
assert "event: tokens" in event
assert '"delta": "hello"' in event
assert '"tool_call_id": "tool_123"' in event
def test_create_token_event_no_tool_id():
"""Test token event without tool call ID"""
event = create_token_event("hello")
assert "event: tokens" in event
assert '"delta": "hello"' in event
assert '"tool_call_id": null' in event
def test_create_tool_start_event():
"""Test tool start event"""
event = create_tool_start_event("tool_123", "retrieve_standard_regulation", {"query": "test"})
assert "event: tool_start" in event
assert '"id": "tool_123"' in event
assert '"name": "retrieve_standard_regulation"' in event
assert '"args": {"query": "test"}' in event
def test_create_tool_result_event():
"""Test tool result event"""
results = [{"id": "1", "title": "Test Standard"}]
event = create_tool_result_event("tool_123", "retrieve_standard_regulation", results, 500)
assert "event: tool_result" in event
assert '"id": "tool_123"' in event
assert '"took_ms": 500' in event
assert '"results"' in event
def test_create_tool_error_event():
"""Test tool error event"""
event = create_tool_error_event("tool_123", "retrieve_standard_regulation", "API timeout")
assert "event: tool_error" in event
assert '"id": "tool_123"' in event
assert '"error": "API timeout"' in event
def test_create_error_event():
"""Test error event"""
event = create_error_event("Something went wrong")
assert "event: error" in event
assert '"error": "Something went wrong"' in event
def test_create_error_event_with_details():
"""Test error event with details"""
details = {"code": 500, "source": "llm"}
event = create_error_event("Something went wrong", details)
assert "event: error" in event
assert '"error": "Something went wrong"' in event
assert '"details"' in event
assert '"code": 500' in event

View File

@@ -0,0 +1,200 @@
#!/usr/bin/env python3
"""
测试修改后的修剪逻辑:验证只在 tool_rounds=0 时修剪,其他时候跳过修剪
"""
import pytest
from unittest.mock import patch, MagicMock
import asyncio
from datetime import datetime
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from service.graph.state import AgentState
from service.graph.graph import call_model
from service.graph.user_manual_rag import user_manual_agent_node
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
@pytest.mark.asyncio
async def test_trimming_only_at_tool_rounds_zero():
"""测试修剪只在 tool_rounds=0 时发生"""
# 创建一个大的消息列表来触发修剪
large_messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="请搜索信息"),
]
# 添加大量内容来触发修剪
for i in range(10):
large_messages.extend([
AIMessage(content=f"搜索{i}", tool_calls=[
{"id": f"call_{i}", "name": "search", "args": {"query": f"test{i}"}}
]),
ToolMessage(content="很长的搜索结果内容 " * 500, tool_call_id=f"call_{i}", name="search"),
])
# Mock trimmer to track calls
with patch('service.graph.graph.create_conversation_trimmer') as mock_trimmer_factory:
mock_trimmer = MagicMock()
mock_trimmer.should_trim.return_value = True # 总是返回需要修剪
mock_trimmer.trim_conversation_history.return_value = large_messages[:3] # 返回修剪后的消息
mock_trimmer_factory.return_value = mock_trimmer
# Mock LLM client
with patch('service.graph.graph.LLMClient') as mock_llm_factory:
mock_llm = MagicMock()
# 设置异步方法的返回值
async def mock_ainvoke_with_tools(*args, **kwargs):
return AIMessage(content="最终答案")
async def mock_astream(*args, **kwargs):
for token in ["", "", "", ""]:
yield token
mock_llm.ainvoke_with_tools = mock_ainvoke_with_tools
mock_llm.astream = mock_astream
mock_llm.bind_tools = MagicMock()
mock_llm_factory.return_value = mock_llm
# 测试 tool_rounds=0 的情况(应该修剪)
from service.graph.state import AgentState
state_round_0 = AgentState(
messages=large_messages,
session_id="test",
intent=None,
tool_results=[],
final_answer="",
tool_rounds=0, # 第一轮
max_tool_rounds=3,
max_tool_rounds_user_manual=2
)
result = await call_model(state_round_0)
# 验证修剪器被调用
mock_trimmer.should_trim.assert_called()
mock_trimmer.trim_conversation_history.assert_called()
# 重置 mock
mock_trimmer.reset_mock()
# 测试 tool_rounds>0 的情况(不应该修剪)
from service.graph.state import AgentState
state_round_1 = AgentState(
messages=large_messages,
session_id="test",
intent=None,
tool_results=[],
final_answer="",
tool_rounds=1, # 第二轮
max_tool_rounds=3,
max_tool_rounds_user_manual=2
)
result = await call_model(state_round_1)
# 验证修剪器没有被调用
mock_trimmer.should_trim.assert_not_called()
mock_trimmer.trim_conversation_history.assert_not_called()
@pytest.mark.asyncio
async def test_user_manual_agent_trimming_behavior():
"""测试 user_manual_agent_node 的修剪行为"""
# 创建大的消息列表
large_messages = [
SystemMessage(content="You are a helpful user manual assistant."),
HumanMessage(content="请帮我查找系统使用说明"),
]
# 添加大量内容
for i in range(8):
large_messages.extend([
AIMessage(content=f"查找{i}", tool_calls=[
{"id": f"manual_call_{i}", "name": "retrieve_system_usermanual", "args": {"query": f"功能{i}"}}
]),
ToolMessage(content="很长的手册内容 " * 400, tool_call_id=f"manual_call_{i}", name="retrieve_system_usermanual"),
])
# Mock dependencies
with patch('service.graph.user_manual_rag.create_conversation_trimmer') as mock_trimmer_factory, \
patch('service.graph.user_manual_rag.LLMClient') as mock_llm_factory, \
patch('service.graph.user_manual_rag.get_user_manual_tool_schemas') as mock_tools, \
patch('service.graph.user_manual_rag.get_cached_config') as mock_config:
# Setup mocks
mock_trimmer = MagicMock()
mock_trimmer.should_trim.return_value = True
mock_trimmer.trim_conversation_history.return_value = large_messages[:3]
mock_trimmer_factory.return_value = mock_trimmer
mock_llm = MagicMock()
# 设置异步方法的返回值
async def mock_ainvoke_with_tools(*args, **kwargs):
return AIMessage(content="用户手册答案")
async def mock_astream(*args, **kwargs):
for token in ["", ""]:
yield token
mock_llm.ainvoke_with_tools = mock_ainvoke_with_tools
mock_llm.astream = mock_astream
mock_llm.bind_tools = MagicMock()
mock_llm_factory.return_value = mock_llm
mock_tools.return_value = []
mock_app_config = MagicMock()
mock_app_config.app.max_tool_rounds_user_manual = 2
mock_app_config.get_rag_prompts.return_value = {
"user_manual_prompt": "System prompt for user manual: {conversation_history} {context_content} {current_query}"
}
mock_config.return_value = mock_app_config
# 测试 tool_rounds=0应该修剪
from service.graph.state import AgentState
state_round_0 = AgentState(
messages=large_messages,
session_id="test",
intent=None,
tool_results=[],
final_answer="",
tool_rounds=0,
max_tool_rounds=3,
max_tool_rounds_user_manual=2
)
result = await user_manual_agent_node(state_round_0)
# 验证修剪器被调用
mock_trimmer.should_trim.assert_called()
mock_trimmer.trim_conversation_history.assert_called()
# 重置 mock
mock_trimmer.reset_mock()
# 测试 tool_rounds=1不应该修剪
state_round_1 = AgentState(
messages=large_messages,
session_id="test",
intent=None,
tool_results=[],
final_answer="",
tool_rounds=1,
max_tool_rounds=3,
max_tool_rounds_user_manual=2
)
result = await user_manual_agent_node(state_round_1)
# 验证修剪器没有被调用
mock_trimmer.should_trim.assert_not_called()
mock_trimmer.trim_conversation_history.assert_not_called()
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,95 @@
"""
Unit test for the new retrieve_system_usermanual tool
"""
import pytest
from unittest.mock import AsyncMock, patch
import asyncio
from service.graph.user_manual_tools import retrieve_system_usermanual, user_manual_tools, get_user_manual_tool_schemas, get_user_manual_tools_by_name
from service.graph.tools import get_tool_schemas, get_tools_by_name
from service.retrieval.retrieval import RetrievalResponse
class TestRetrieveSystemUsermanualTool:
"""Test the new user manual retrieval tool"""
def test_tool_in_tools_list(self):
"""Test that the new tool is in the tools list"""
tool_names = [tool.name for tool in user_manual_tools]
assert "retrieve_system_usermanual" in tool_names
assert len(user_manual_tools) == 1 # Should have 1 user manual tool
def test_tool_schemas_generation(self):
"""Test that tool schemas are generated correctly"""
schemas = get_user_manual_tool_schemas()
tool_names = [schema["function"]["name"] for schema in schemas]
assert "retrieve_system_usermanual" in tool_names
# Find the user manual tool schema
user_manual_schema = next(
schema for schema in schemas
if schema["function"]["name"] == "retrieve_system_usermanual"
)
assert user_manual_schema["function"]["description"] == "Search for document content chunks of user manual of this system(CATOnline)"
assert "query" in user_manual_schema["function"]["parameters"]["properties"]
assert user_manual_schema["function"]["parameters"]["required"] == ["query"]
def test_tools_by_name_mapping(self):
"""Test that tools_by_name mapping includes the new tool"""
tools_mapping = get_user_manual_tools_by_name()
assert "retrieve_system_usermanual" in tools_mapping
assert tools_mapping["retrieve_system_usermanual"] == retrieve_system_usermanual
@pytest.mark.asyncio
async def test_retrieve_system_usermanual_success(self):
"""Test successful user manual retrieval"""
mock_response = RetrievalResponse(
results=[
{"title": "User Manual Chapter 1", "content": "How to use the system", "id": "manual_1"}
],
took_ms=120,
total_count=1
)
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
mock_instance = AsyncMock()
mock_instance.retrieve_doc_chunk_user_manual.return_value = mock_response
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
mock_retrieval_class.return_value.__aexit__.return_value = None
# Test the tool
result = await retrieve_system_usermanual.ainvoke({"query": "how to use system"})
# Verify result format
assert isinstance(result, dict)
assert result["tool_name"] == "retrieve_system_usermanual"
assert result["results_count"] == 1
assert len(result["results"]) == 1
assert result["results"][0]["title"] == "User Manual Chapter 1"
assert result["took_ms"] == 120
@pytest.mark.asyncio
async def test_retrieve_system_usermanual_error(self):
"""Test error handling in user manual retrieval"""
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
mock_instance = AsyncMock()
mock_instance.retrieve_doc_chunk_user_manual.side_effect = Exception("Search API Error")
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
mock_retrieval_class.return_value.__aexit__.return_value = None
# Test error handling
result = await retrieve_system_usermanual.ainvoke({"query": "test query"})
# Should return error information
assert isinstance(result, dict)
assert "error" in result
assert "Search API Error" in result["error"]
assert result["results_count"] == 0
assert result["results"] == []
if __name__ == "__main__":
pytest.main([__file__, "-v"])