This commit is contained in:
2025-09-26 17:15:54 +08:00
commit db0e5965ec
211 changed files with 40437 additions and 0 deletions

View File

@@ -0,0 +1 @@
# Empty __init__.py files to make test packages

View File

@@ -0,0 +1,317 @@
"""
Shared pytest fixtures and configuration for the agentic-rag test suite.
"""
import pytest
import asyncio
import httpx
from unittest.mock import Mock, AsyncMock, patch
from fastapi.testclient import TestClient
from service.main import create_app
from service.config import Config
from service.graph.state import TurnState, Message, ToolResult
from service.memory.postgresql_memory import PostgreSQLMemoryManager
@pytest.fixture(scope="session")
def event_loop():
"""Create an instance of the default event loop for the test session."""
policy = asyncio.get_event_loop_policy()
loop = policy.new_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True)
def config_mock():
"""Mock configuration for all tests."""
config = Mock()
config.retrieval.endpoint = "http://test-endpoint"
config.retrieval.api_key = "test-key"
config.llm.provider = "openai"
config.llm.model = "gpt-4"
config.llm.api_key = "test-api-key"
config.memory.enabled = True
config.memory.type = "in_memory"
config.memory.ttl_days = 7
config.postgresql.enabled = False
with patch('service.config.get_config', return_value=config):
with patch('service.retrieval.retrieval.get_config', return_value=config):
with patch('service.graph.graph.get_config', return_value=config):
yield config
@pytest.fixture
def test_config():
"""Test configuration with safe defaults."""
return {
"provider": "openai",
"openai": {
"api_key": "test-openai-key",
"model": "gpt-4o",
"base_url": "https://api.openai.com/v1",
"temperature": 0.2
},
"retrieval": {
"endpoint": "http://test-retrieval-endpoint",
"api_key": "test-retrieval-key"
},
"postgresql": {
"host": "localhost",
"port": 5432,
"database": "test_agent_memory",
"username": "test",
"password": "test",
"ttl_days": 1
},
"app": {
"name": "agentic-rag-test",
"memory_ttl_days": 1,
"max_tool_loops": 3,
"cors_origins": ["*"]
},
"llm": {
"rag": {
"temperature": 0,
"max_context_length": 32000,
"agent_system_prompt": "You are a test assistant."
}
}
}
@pytest.fixture
def app(test_config):
"""Create test FastAPI app with mocked configuration."""
with patch('service.config.load_config') as mock_load_config:
mock_load_config.return_value = test_config
# Mock the memory manager to avoid PostgreSQL dependency in tests
with patch('service.memory.postgresql_memory.get_memory_manager') as mock_memory:
mock_memory_manager = Mock()
mock_memory_manager.test_connection.return_value = True
mock_memory.return_value = mock_memory_manager
# Mock the graph builder to avoid complex dependencies
with patch('service.graph.graph.build_graph') as mock_build_graph:
mock_graph = Mock()
mock_build_graph.return_value = mock_graph
app = create_app()
app.state.memory_manager = mock_memory_manager
app.state.graph = mock_graph
return app
@pytest.fixture
def client(app):
"""Create test client."""
return TestClient(app)
@pytest.fixture
def mock_llm_client():
"""Mock LLM client for testing."""
mock = AsyncMock()
mock.astream.return_value = iter(["Test", " response", " token"])
mock.ainvoke_with_tools.return_value = Mock(
content="Test response",
tool_calls=[
{
"id": "test_tool_call_1",
"function": {
"name": "retrieve_standard_regulation",
"arguments": '{"query": "test query"}'
}
}
]
)
return mock
@pytest.fixture
def mock_retrieval_response():
"""Mock response from retrieval API."""
return {
"results": [
{
"id": "test_result_1",
"title": "ISO 26262-1:2018",
"content": "Road vehicles — Functional safety — Part 1: Vocabulary",
"score": 0.95,
"url": "https://iso.org/26262-1",
"metadata": {
"@tool_call_id": "test_tool_call_1",
"@order_num": 0
}
},
{
"id": "test_result_2",
"title": "ISO 26262-3:2018",
"content": "Road vehicles — Functional safety — Part 3: Concept phase",
"score": 0.88,
"url": "https://iso.org/26262-3",
"metadata": {
"@tool_call_id": "test_tool_call_1",
"@order_num": 1
}
}
],
"metadata": {
"total": 2,
"took_ms": 150,
"query": "test query"
}
}
@pytest.fixture
def sample_chat_request():
"""Sample chat request for testing."""
return {
"session_id": "test_session_123",
"messages": [
{"role": "user", "content": "What is ISO 26262?"}
]
}
@pytest.fixture
def sample_turn_state():
"""Sample TurnState for testing."""
return TurnState(
session_id="test_session_123",
messages=[
Message(role="user", content="What is ISO 26262?")
]
)
@pytest.fixture
def mock_httpx_client():
"""Mock httpx client for API requests."""
mock_client = AsyncMock()
# Default response for retrieval API
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"results": [
{
"id": "test_result",
"title": "Test Standard",
"content": "Test content",
"score": 0.9
}
]
}
mock_client.post.return_value = mock_response
return mock_client
@pytest.fixture
def mock_postgresql_memory():
"""Mock PostgreSQL memory manager."""
mock_manager = Mock(spec=PostgreSQLMemoryManager)
mock_manager.test_connection.return_value = True
mock_checkpointer = Mock()
mock_checkpointer.setup.return_value = None
mock_manager.get_checkpointer.return_value = mock_checkpointer
return mock_manager
@pytest.fixture
def mock_streaming_response():
"""Mock streaming response events."""
return [
'event: tool_start\ndata: {"id": "test_tool_1", "name": "retrieve_standard_regulation", "args": {"query": "test"}}\n\n',
'event: tokens\ndata: {"delta": "Based on the retrieved standards", "tool_call_id": null}\n\n',
'event: tool_result\ndata: {"id": "test_tool_1", "name": "retrieve_standard_regulation", "results": [], "took_ms": 100}\n\n',
'event: tokens\ndata: {"delta": " this is a test response.", "tool_call_id": null}\n\n'
]
# Async test helpers
@pytest.fixture
def mock_agent_state():
"""Mock agent state for graph testing."""
return {
"messages": [],
"session_id": "test_session",
"tool_results": [],
"final_answer": ""
}
@pytest.fixture
async def async_test_client():
"""Async test client for integration tests."""
async with httpx.AsyncClient() as client:
yield client
# Database fixtures for integration tests
@pytest.fixture
def test_database_url():
"""Test database URL (only for integration tests with real DB)."""
return "postgresql://test:test@localhost:5432/test_agent_memory"
@pytest.fixture
def integration_test_config(test_database_url):
"""Configuration for integration tests with real database."""
return {
"provider": "openai",
"openai": {
"api_key": "test-key",
"model": "gpt-4o"
},
"retrieval": {
"endpoint": "http://localhost:8000/search", # Assume test retrieval server
"api_key": "test-key"
},
"postgresql": {
"connection_string": test_database_url
}
}
# Skip markers for different test types
def pytest_configure(config):
"""Configure pytest markers."""
config.addinivalue_line("markers", "unit: mark test as unit test")
config.addinivalue_line("markers", "integration: mark test as integration test")
config.addinivalue_line("markers", "e2e: mark test as end-to-end test")
config.addinivalue_line("markers", "slow: mark test as slow running")
def pytest_runtest_setup(item):
"""Setup for test items."""
# Skip integration tests if not explicitly requested
if "integration" in item.keywords and not item.config.getoption("--run-integration"):
pytest.skip("Integration tests not requested")
# Skip E2E tests if not explicitly requested
if "e2e" in item.keywords and not item.config.getoption("--run-e2e"):
pytest.skip("E2E tests not requested")
def pytest_addoption(parser):
"""Add custom command line options."""
parser.addoption(
"--run-integration",
action="store_true",
default=False,
help="run integration tests"
)
parser.addoption(
"--run-e2e",
action="store_true",
default=False,
help="run end-to-end tests"
)

View File

@@ -0,0 +1,33 @@
import httpx
def get_embedding(text: str) -> list[float]:
"""Get embedding vector for text using the configured embedding service"""
api_key = "h7ARU7tP7cblbpIQFpFXnhxVdFwH9rLXP654UfSJd8xKCJzeg4VOJQQJ99AKACi0881XJ3w3AAABACOGTlOf"
model = "text-embedding-3-small"
base_url = "https://aoai-lab-jpe-fl.openai.azure.com/openai/deployments/text-embedding-3-small/embeddings?api-version=2024-12-01-preview"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
payload = {
"input": text,
"model": model
}
try:
response = httpx.post(f"{base_url}", json=payload, headers=headers )
response.raise_for_status()
result = response.json()
print(result)
return result["data"][0]["embedding"]
except Exception as e:
print(f"Failed to get embedding: {e}")
raise Exception(f"Embedding generation failed: {str(e)}")
if __name__ == "__main__":
print("Begin")
text = "Sample text for embedding"
result = get_embedding(text)
print(result)

View File

@@ -0,0 +1 @@
# Empty __init__.py files to make test packages

View File

@@ -0,0 +1,170 @@
#!/usr/bin/env python3
"""
Test 2-phase retrieval strategy
"""
import asyncio
import httpx
import json
import logging
import random
import time
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
async def test_2phase_retrieval():
"""Test that agent uses 2-phase retrieval for content-focused queries"""
session_id = f"2phase-test-{random.randint(1000000000, 9999999999)}"
base_url = "http://127.0.0.1:8000"
# Test query that should trigger 2-phase retrieval
query = "如何测试电动汽车的充电性能?请详细说明测试方法和步骤。"
logger.info("🎯 2-PHASE RETRIEVAL TEST")
logger.info("=" * 80)
logger.info(f"📝 Session: {session_id}")
logger.info(f"📝 Query: {query}")
logger.info("-" * 60)
# Create the request payload
payload = {
"messages": [
{
"role": "user",
"content": query
}
],
"session_id": session_id
}
# Track tool usage
metadata_tools = 0
content_tools = 0
total_tools = 0
timeout = httpx.Timeout(120.0) # 2 minute timeout
try:
async with httpx.AsyncClient(timeout=timeout) as client:
logger.info("✅ Streaming response started")
async with client.stream(
"POST",
f"{base_url}/api/chat",
json=payload,
headers={"Content-Type": "application/json"}
) as response:
# Check if the response started successfully
if response.status_code != 200:
error_body = await response.aread()
logger.error(f"❌ HTTP {response.status_code}: {error_body.decode()}")
return
# Process the streaming response
current_event_type = None
async for line in response.aiter_lines():
if not line.strip():
continue
if line.startswith("event: "):
current_event_type = line[7:] # Remove "event: " prefix
continue
if line.startswith("data: "):
data_str = line[6:] # Remove "data: " prefix
if data_str == "[DONE]":
logger.info("✅ Stream completed with [DONE]")
break
try:
event_data = json.loads(data_str)
event_type = current_event_type or "unknown"
if event_type == "tool_start":
total_tools += 1
tool_name = event_data.get("name", "unknown")
args = event_data.get("args", {})
query_arg = args.get("query", "")[:50] + "..." if len(args.get("query", "")) > 50 else args.get("query", "")
if tool_name == "retrieve_standard_regulation":
metadata_tools += 1
logger.info(f"📋 Phase 1 Tool {metadata_tools}: {tool_name}")
logger.info(f" Query: {query_arg}")
elif tool_name == "retrieve_doc_chunk_standard_regulation":
content_tools += 1
logger.info(f"📄 Phase 2 Tool {content_tools}: {tool_name}")
logger.info(f" Query: {query_arg}")
else:
logger.info(f"🔧 Tool {total_tools}: {tool_name}")
elif event_type == "tool_result":
tool_name = event_data.get("name", "unknown")
results_count = len(event_data.get("results", []))
took_ms = event_data.get("took_ms", 0)
logger.info(f"✅ Tool completed: {tool_name} ({results_count} results, {took_ms}ms)")
elif event_type == "tokens":
# Don't log every token, just count them
pass
# Reset event type for next event
current_event_type = None
# Break after many tools to avoid too much output
if total_tools > 20:
logger.info(" ⚠️ Breaking after 20 tools...")
break
except json.JSONDecodeError as e:
logger.warning(f"⚠️ Failed to parse event: {e}")
current_event_type = None
except Exception as e:
logger.error(f"❌ Request failed: {e}")
return
# Results
logger.info("=" * 80)
logger.info("📊 2-PHASE RETRIEVAL ANALYSIS")
logger.info("=" * 80)
logger.info(f"Phase 1 (Metadata) tools: {metadata_tools}")
logger.info(f"Phase 2 (Content) tools: {content_tools}")
logger.info(f"Total tools executed: {total_tools}")
logger.info("-" * 60)
# Success criteria
success_criteria = [
(metadata_tools > 0, f"Phase 1 metadata retrieval: {'' if metadata_tools > 0 else ''} ({metadata_tools} tools)"),
(content_tools > 0, f"Phase 2 content retrieval: {'' if content_tools > 0 else ''} ({content_tools} tools)"),
(total_tools >= 2, f"Multi-tool execution: {'' if total_tools >= 2 else ''} ({total_tools} tools)")
]
logger.info("✅ SUCCESS CRITERIA:")
all_passed = True
for passed, message in success_criteria:
logger.info(f" {message}")
if not passed:
all_passed = False
if all_passed:
logger.info("🎉 2-PHASE RETRIEVAL TEST PASSED!")
logger.info(" ✅ Agent correctly uses both metadata and content retrieval tools")
else:
logger.info("❌ 2-PHASE RETRIEVAL TEST FAILED!")
if metadata_tools == 0:
logger.info(" ❌ No metadata retrieval tools used")
if content_tools == 0:
logger.info(" ❌ No content retrieval tools used - this is the main issue!")
if __name__ == "__main__":
asyncio.run(test_2phase_retrieval())

View File

@@ -0,0 +1,372 @@
"""
Remote Integration Tests for Agentic RAG API
These tests connect to a running service instance remotely to validate:
- API endpoints and responses
- Request/response schemas
- Basic functionality without external dependencies
"""
import pytest
import asyncio
import json
import httpx
from typing import Optional, Dict, Any
import time
import os
# Configuration for remote service connection
DEFAULT_SERVICE_URL = "http://127.0.0.1:8000"
SERVICE_URL = os.getenv("AGENTIC_RAG_SERVICE_URL", DEFAULT_SERVICE_URL)
@pytest.fixture(scope="session")
def service_url() -> str:
"""Get the service URL for testing"""
return SERVICE_URL
class TestBasicAPI:
"""Test basic API endpoints and functionality"""
@pytest.mark.asyncio
async def test_health_endpoint(self, service_url: str):
"""Test service health endpoint"""
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(f"{service_url}/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert data["service"] == "agentic-rag"
@pytest.mark.asyncio
async def test_root_endpoint(self, service_url: str):
"""Test root API endpoint"""
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(f"{service_url}/")
assert response.status_code == 200
data = response.json()
assert "message" in data
assert "Agentic RAG API" in data["message"]
@pytest.mark.asyncio
async def test_openapi_docs(self, service_url: str):
"""Test OpenAPI documentation endpoint"""
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(f"{service_url}/openapi.json")
assert response.status_code == 200
data = response.json()
assert "openapi" in data
assert "info" in data
assert data["info"]["title"] == "Agentic RAG API"
@pytest.mark.asyncio
async def test_docs_endpoint(self, service_url: str):
"""Test Swagger UI docs endpoint"""
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(f"{service_url}/docs")
assert response.status_code == 200
assert "text/html" in response.headers["content-type"]
class TestChatAPI:
"""Test chat API endpoints with valid requests"""
def _create_chat_request(self, message: str, session_id: Optional[str] = None) -> Dict[str, Any]:
"""Create a valid chat request"""
return {
"session_id": session_id or f"test_session_{int(time.time())}",
"messages": [
{
"role": "user",
"content": message
}
]
}
@pytest.mark.asyncio
async def test_chat_endpoint_basic_request(self, service_url: str):
"""Test basic chat endpoint request/response structure"""
request_data = self._create_chat_request("Hello, can you help me?")
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{service_url}/api/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# Response should be streaming text/event-stream
assert "text/event-stream" in response.headers.get("content-type", "") or \
"text/plain" in response.headers.get("content-type", "")
@pytest.mark.asyncio
async def test_ai_sdk_chat_endpoint_basic_request(self, service_url: str):
"""Test AI SDK compatible chat endpoint"""
request_data = self._create_chat_request("What is ISO 26262?")
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{service_url}/api/ai-sdk/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# AI SDK endpoint returns plain text stream
assert "text/plain" in response.headers.get("content-type", "")
@pytest.mark.asyncio
async def test_chat_endpoint_invalid_request(self, service_url: str):
"""Test chat endpoint with invalid request data"""
invalid_requests = [
{}, # Empty request
{"session_id": "test"}, # Missing messages
{"messages": []}, # Missing session_id
{"session_id": "test", "messages": [{"role": "invalid"}]}, # Invalid message format
]
async with httpx.AsyncClient(timeout=30.0) as client:
for invalid_request in invalid_requests:
response = await client.post(
f"{service_url}/api/chat",
json=invalid_request,
headers={"Content-Type": "application/json"}
)
# Should return 422 for validation errors
assert response.status_code == 422
@pytest.mark.asyncio
async def test_session_persistence(self, service_url: str):
"""Test that sessions persist across multiple requests"""
session_id = f"persistent_session_{int(time.time())}"
async with httpx.AsyncClient(timeout=30.0) as client:
# First message
request1 = self._create_chat_request("My name is John", session_id)
response1 = await client.post(
f"{service_url}/api/chat",
json=request1,
headers={"Content-Type": "application/json"}
)
assert response1.status_code == 200
# Wait a moment for processing
await asyncio.sleep(1)
# Second message referring to previous context
request2 = self._create_chat_request("What did I just tell you my name was?", session_id)
response2 = await client.post(
f"{service_url}/api/chat",
json=request2,
headers={"Content-Type": "application/json"}
)
assert response2.status_code == 200
class TestRequestValidation:
"""Test request validation and error handling"""
@pytest.mark.asyncio
async def test_malformed_json(self, service_url: str):
"""Test endpoint with malformed JSON"""
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{service_url}/api/chat",
content="invalid json{",
headers={"Content-Type": "application/json"}
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_missing_content_type(self, service_url: str):
"""Test endpoint without proper content type"""
request_data = {
"session_id": "test_session",
"messages": [{"role": "user", "content": "test"}]
}
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{service_url}/api/chat",
content=json.dumps(request_data)
# No Content-Type header
)
# FastAPI should handle this gracefully
assert response.status_code in [415, 422]
@pytest.mark.asyncio
async def test_oversized_request(self, service_url: str):
"""Test endpoint with very large request"""
large_content = "x" * 100000 # 100KB message
request_data = {
"session_id": "test_session",
"messages": [{"role": "user", "content": large_content}]
}
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{service_url}/api/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
# Should either process or reject gracefully
assert response.status_code in [200, 413, 422]
class TestCORSAndHeaders:
"""Test CORS and security headers"""
@pytest.mark.asyncio
async def test_cors_headers(self, service_url: str):
"""Test CORS headers are properly set"""
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.options(
f"{service_url}/api/chat",
headers={
"Origin": "http://localhost:3000",
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "Content-Type"
}
)
# CORS preflight should be handled
assert response.status_code in [200, 204]
# Check for CORS headers in actual request
request_data = {
"session_id": "cors_test",
"messages": [{"role": "user", "content": "test"}]
}
response = await client.post(
f"{service_url}/api/chat",
json=request_data,
headers={
"Content-Type": "application/json",
"Origin": "http://localhost:3000"
}
)
assert response.status_code == 200
# Should have CORS headers
assert "access-control-allow-origin" in response.headers
@pytest.mark.asyncio
async def test_security_headers(self, service_url: str):
"""Test basic security headers"""
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(f"{service_url}/health")
assert response.status_code == 200
# Check for basic security practices
# Note: Specific headers depend on deployment configuration
headers = response.headers
# FastAPI should include some basic headers
assert "content-length" in headers or "transfer-encoding" in headers
class TestErrorHandling:
"""Test error handling and edge cases"""
@pytest.mark.asyncio
async def test_nonexistent_endpoint(self, service_url: str):
"""Test request to non-existent endpoint"""
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(f"{service_url}/nonexistent")
assert response.status_code == 404
@pytest.mark.asyncio
async def test_method_not_allowed(self, service_url: str):
"""Test wrong HTTP method on endpoint"""
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(f"{service_url}/api/chat") # GET instead of POST
assert response.status_code == 405
@pytest.mark.asyncio
async def test_timeout_handling(self, service_url: str):
"""Test request timeout handling"""
# Use a very short timeout to test timeout handling
async with httpx.AsyncClient(timeout=0.001) as short_timeout_client:
try:
response = await short_timeout_client.get(f"{service_url}/health")
# If it doesn't timeout, that's also fine
assert response.status_code == 200
except httpx.TimeoutException:
# Expected timeout - this is fine
pass
class TestServiceIntegration:
"""Test integration with actual service features"""
@pytest.mark.asyncio
async def test_manufacturing_standards_query(self, service_url: str):
"""Test query related to manufacturing standards"""
request_data = {
"session_id": f"standards_test_{int(time.time())}",
"messages": [
{
"role": "user",
"content": "What are the key safety requirements in ISO 26262?"
}
]
}
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
f"{service_url}/api/ai-sdk/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# Read some of the streaming response
content = ""
async for chunk in response.aiter_text():
content += chunk
if len(content) > 100: # Read enough to verify it's working
break
# Should have some content indicating it's processing
assert len(content) > 0
@pytest.mark.asyncio
async def test_general_conversation(self, service_url: str):
"""Test general conversation capability"""
request_data = {
"session_id": f"general_test_{int(time.time())}",
"messages": [
{
"role": "user",
"content": "Hello! How can you help me today?"
}
]
}
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
f"{service_url}/api/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# Verify we get streaming response
content = ""
chunk_count = 0
async for chunk in response.aiter_text():
content += chunk
chunk_count += 1
if chunk_count > 10: # Read several chunks
break
# Should receive streaming content
assert len(content) > 0

View File

@@ -0,0 +1,415 @@
"""
End-to-End Integration Tests for Tool UI
These tests validate the complete user experience by connecting to a running service.
They test tool calling, response formatting, and user interface integration.
"""
import pytest
import asyncio
import httpx
import time
import os
# Configuration for remote service connection
DEFAULT_SERVICE_URL = "http://127.0.0.1:8000"
SERVICE_URL = os.getenv("AGENTIC_RAG_SERVICE_URL", DEFAULT_SERVICE_URL)
@pytest.fixture(scope="session")
def service_url() -> str:
"""Get the service URL for testing"""
return SERVICE_URL
class TestEndToEndWorkflows:
"""Test complete end-to-end user workflows"""
@pytest.mark.asyncio
async def test_standards_research_with_tools(self, service_url: str):
"""Test standards research workflow with tool calls"""
session_id = f"e2e_standards_{int(time.time())}"
request_data = {
"session_id": session_id,
"messages": [
{
"role": "user",
"content": "What are the safety requirements for automotive braking systems according to ISO 26262?"
}
]
}
async with httpx.AsyncClient(timeout=90.0) as client:
response = await client.post(
f"{service_url}/api/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# Collect the full response to analyze tool usage
full_content = ""
async for chunk in response.aiter_text():
full_content += chunk
if len(full_content) > 1000: # Get substantial content
break
# Verify we got meaningful content
assert len(full_content) > 100
print(f"Standards research response length: {len(full_content)} chars")
@pytest.mark.asyncio
async def test_manufacturing_compliance_workflow(self, service_url: str):
"""Test manufacturing compliance workflow"""
session_id = f"e2e_compliance_{int(time.time())}"
request_data = {
"session_id": session_id,
"messages": [
{
"role": "user",
"content": "I need to understand compliance requirements for manufacturing equipment safety. What standards apply?"
}
]
}
async with httpx.AsyncClient(timeout=90.0) as client:
response = await client.post(
f"{service_url}/api/ai-sdk/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# Test AI SDK format response
content = ""
async for chunk in response.aiter_text():
content += chunk
if len(content) > 500:
break
assert len(content) > 50
print(f"Compliance workflow response length: {len(content)} chars")
@pytest.mark.asyncio
async def test_technical_documentation_workflow(self, service_url: str):
"""Test technical documentation research workflow"""
session_id = f"e2e_technical_{int(time.time())}"
request_data = {
"session_id": session_id,
"messages": [
{
"role": "user",
"content": "How do I implement functional safety according to IEC 61508 for industrial control systems?"
}
]
}
async with httpx.AsyncClient(timeout=90.0) as client:
response = await client.post(
f"{service_url}/api/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# Collect response
content = ""
async for chunk in response.aiter_text():
content += chunk
if len(content) > 800:
break
assert len(content) > 100
print(f"Technical documentation response length: {len(content)} chars")
class TestMultiTurnConversations:
"""Test multi-turn conversation workflows"""
@pytest.mark.asyncio
async def test_progressive_standards_exploration(self, service_url: str):
"""Test progressive exploration of standards through multiple turns"""
session_id = f"e2e_progressive_{int(time.time())}"
conversation_steps = [
"What is ISO 26262?",
"What are the ASIL levels?",
"How do I determine ASIL D requirements?",
"What testing is required for ASIL D systems?"
]
async with httpx.AsyncClient(timeout=90.0) as client:
for i, question in enumerate(conversation_steps):
request_data = {
"session_id": session_id,
"messages": [{"role": "user", "content": question}]
}
response = await client.post(
f"{service_url}/api/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# Read response
content = ""
async for chunk in response.aiter_text():
content += chunk
if len(content) > 300:
break
assert len(content) > 30
print(f"Turn {i+1}: {len(content)} chars")
# Brief pause between turns
await asyncio.sleep(1)
@pytest.mark.asyncio
async def test_comparative_analysis_workflow(self, service_url: str):
"""Test comparative analysis across multiple standards"""
session_id = f"e2e_comparative_{int(time.time())}"
comparison_questions = [
"What are the differences between ISO 26262 and IEC 61508?",
"Which standard is more appropriate for automotive applications?",
"How do the safety integrity levels compare between these standards?"
]
async with httpx.AsyncClient(timeout=90.0) as client:
for question in comparison_questions:
request_data = {
"session_id": session_id,
"messages": [{"role": "user", "content": question}]
}
response = await client.post(
f"{service_url}/api/ai-sdk/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# Collect comparison response
content = ""
async for chunk in response.aiter_text():
content += chunk
if len(content) > 400:
break
assert len(content) > 50
await asyncio.sleep(1.5)
class TestSpecializedQueries:
"""Test specialized query types and edge cases"""
@pytest.mark.asyncio
async def test_specific_standard_section_query(self, service_url: str):
"""Test queries about specific sections of standards"""
session_id = f"e2e_specific_{int(time.time())}"
request_data = {
"session_id": session_id,
"messages": [
{
"role": "user",
"content": "What does section 4.3 of ISO 26262-3 say about software architectural design?"
}
]
}
async with httpx.AsyncClient(timeout=90.0) as client:
response = await client.post(
f"{service_url}/api/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
content = ""
async for chunk in response.aiter_text():
content += chunk
if len(content) > 600:
break
assert len(content) > 50
@pytest.mark.asyncio
async def test_implementation_guidance_query(self, service_url: str):
"""Test queries asking for implementation guidance"""
session_id = f"e2e_implementation_{int(time.time())}"
request_data = {
"session_id": session_id,
"messages": [
{
"role": "user",
"content": "How should I implement a safety management system according to ISO 45001?"
}
]
}
async with httpx.AsyncClient(timeout=90.0) as client:
response = await client.post(
f"{service_url}/api/ai-sdk/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
content = ""
async for chunk in response.aiter_text():
content += chunk
if len(content) > 500:
break
assert len(content) > 100
@pytest.mark.asyncio
async def test_cross_domain_standards_query(self, service_url: str):
"""Test queries spanning multiple domains"""
session_id = f"e2e_cross_domain_{int(time.time())}"
request_data = {
"session_id": session_id,
"messages": [
{
"role": "user",
"content": "How do cybersecurity standards like ISO 27001 relate to functional safety standards like ISO 26262?"
}
]
}
async with httpx.AsyncClient(timeout=90.0) as client:
response = await client.post(
f"{service_url}/api/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
content = ""
async for chunk in response.aiter_text():
content += chunk
if len(content) > 700:
break
assert len(content) > 100
class TestUserExperience:
"""Test overall user experience aspects"""
@pytest.mark.asyncio
async def test_response_quality_indicators(self, service_url: str):
"""Test that responses have quality indicators (good structure, citations, etc.)"""
session_id = f"e2e_quality_{int(time.time())}"
request_data = {
"session_id": session_id,
"messages": [
{
"role": "user",
"content": "What are the key principles of risk assessment in ISO 31000?"
}
]
}
async with httpx.AsyncClient(timeout=90.0) as client:
response = await client.post(
f"{service_url}/api/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# Collect full response to analyze quality
full_content = ""
async for chunk in response.aiter_text():
full_content += chunk
if len(full_content) > 1200:
break
# Basic quality checks
assert len(full_content) > 100
# Content should contain structured information
# (These are basic heuristics for response quality)
assert len(full_content.split()) > 20 # At least 20 words
print(f"Quality response length: {len(full_content)} chars")
@pytest.mark.asyncio
async def test_error_recovery_experience(self, service_url: str):
"""Test user experience when recovering from errors"""
session_id = f"e2e_error_recovery_{int(time.time())}"
async with httpx.AsyncClient(timeout=90.0) as client:
# Start with a good question
good_request = {
"session_id": session_id,
"messages": [{"role": "user", "content": "What is ISO 9001?"}]
}
response = await client.post(
f"{service_url}/api/chat",
json=good_request,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
await asyncio.sleep(1)
# Try a potentially problematic request
try:
problematic_request = {
"session_id": session_id,
"messages": [{"role": "user", "content": ""}] # Empty content
}
await client.post(
f"{service_url}/api/chat",
json=problematic_request,
headers={"Content-Type": "application/json"}
)
except Exception:
pass # Expected to potentially fail
await asyncio.sleep(1)
# Recovery with another good question
recovery_request = {
"session_id": session_id,
"messages": [{"role": "user", "content": "Can you help me understand quality management?"}]
}
recovery_response = await client.post(
f"{service_url}/api/chat",
json=recovery_request,
headers={"Content-Type": "application/json"}
)
# Should recover successfully
assert recovery_response.status_code == 200
content = ""
async for chunk in recovery_response.aiter_text():
content += chunk
if len(content) > 200:
break
assert len(content) > 30
print("📤 Sending to backend...")

View File

@@ -0,0 +1,402 @@
"""
Full Workflow Integration Tests
These tests validate complete end-to-end workflows by connecting to a running service.
They test realistic user scenarios and complex interactions.
"""
import pytest
import asyncio
import httpx
import time
import os
from typing import List, Dict, Any
# Configuration for remote service connection
DEFAULT_SERVICE_URL = "http://127.0.0.1:8000"
SERVICE_URL = os.getenv("AGENTIC_RAG_SERVICE_URL", DEFAULT_SERVICE_URL)
@pytest.fixture(scope="session")
def service_url() -> str:
"""Get the service URL for testing"""
return SERVICE_URL
class TestCompleteWorkflows:
"""Test complete user workflows"""
@pytest.mark.asyncio
async def test_standards_research_workflow(self, service_url: str):
"""Test a complete standards research workflow"""
session_id = f"standards_workflow_{int(time.time())}"
# Simulate a user researching ISO 26262
conversation_flow = [
"What is ISO 26262 and what does it cover?",
"What are the ASIL levels in ISO 26262?",
"Can you explain ASIL D requirements in detail?",
"How does ISO 26262 relate to vehicle cybersecurity?"
]
async with httpx.AsyncClient(timeout=60.0) as client:
for i, question in enumerate(conversation_flow):
request_data = {
"session_id": session_id,
"messages": [{"role": "user", "content": question}]
}
response = await client.post(
f"{service_url}/api/ai-sdk/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# Read the streaming response
content = ""
async for chunk in response.aiter_text():
content += chunk
if len(content) > 200: # Get substantial response
break
# Verify we get meaningful content
assert len(content) > 50
print(f"Question {i+1} response length: {len(content)} chars")
# Small delay between questions
await asyncio.sleep(0.5)
@pytest.mark.asyncio
async def test_manufacturing_safety_workflow(self, service_url: str):
"""Test manufacturing safety standards workflow"""
session_id = f"manufacturing_workflow_{int(time.time())}"
conversation_flow = [
"What are the key safety standards for manufacturing equipment?",
"How do ISO 13849 and IEC 62061 compare?",
"What is the process for safety risk assessment in manufacturing?"
]
async with httpx.AsyncClient(timeout=60.0) as client:
responses = []
for question in conversation_flow:
request_data = {
"session_id": session_id,
"messages": [{"role": "user", "content": question}]
}
response = await client.post(
f"{service_url}/api/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# Collect response content
content = ""
async for chunk in response.aiter_text():
content += chunk
if len(content) > 300:
break
responses.append(content)
await asyncio.sleep(0.5)
# Verify we got responses for all questions
assert len(responses) == len(conversation_flow)
for response_content in responses:
assert len(response_content) > 30
@pytest.mark.asyncio
async def test_session_context_continuity(self, service_url: str):
"""Test that session context is maintained across requests"""
session_id = f"context_test_{int(time.time())}"
async with httpx.AsyncClient(timeout=60.0) as client:
# First message - establish context
request1 = {
"session_id": session_id,
"messages": [{"role": "user", "content": "I'm working on a safety system for automotive braking. What standard should I follow?"}]
}
response1 = await client.post(
f"{service_url}/api/chat",
json=request1,
headers={"Content-Type": "application/json"}
)
assert response1.status_code == 200
# Wait for processing
await asyncio.sleep(2)
# Follow-up question that depends on context
request2 = {
"session_id": session_id,
"messages": [{"role": "user", "content": "What are the specific testing requirements for this standard?"}]
}
response2 = await client.post(
f"{service_url}/api/chat",
json=request2,
headers={"Content-Type": "application/json"}
)
assert response2.status_code == 200
# Verify both responses are meaningful
content1 = ""
async for chunk in response1.aiter_text():
content1 += chunk
if len(content1) > 100:
break
content2 = ""
async for chunk in response2.aiter_text():
content2 += chunk
if len(content2) > 100:
break
assert len(content1) > 50
assert len(content2) > 50
class TestErrorRecoveryWorkflows:
"""Test error recovery and edge case workflows"""
@pytest.mark.asyncio
async def test_session_recovery_after_error(self, service_url: str):
"""Test that sessions can recover after encountering errors"""
session_id = f"error_recovery_{int(time.time())}"
async with httpx.AsyncClient(timeout=60.0) as client:
# Valid request
valid_request = {
"session_id": session_id,
"messages": [{"role": "user", "content": "What is ISO 9001?"}]
}
response = await client.post(
f"{service_url}/api/chat",
json=valid_request,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# Try an invalid request that might cause issues
invalid_request = {
"session_id": session_id,
"messages": [{"role": "user", "content": ""}] # Empty content
}
try:
await client.post(
f"{service_url}/api/chat",
json=invalid_request,
headers={"Content-Type": "application/json"}
)
except Exception:
pass # Expected to potentially fail
await asyncio.sleep(1)
# Another valid request to test recovery
recovery_request = {
"session_id": session_id,
"messages": [{"role": "user", "content": "Can you summarize what we discussed?"}]
}
recovery_response = await client.post(
f"{service_url}/api/chat",
json=recovery_request,
headers={"Content-Type": "application/json"}
)
# Session should still work
assert recovery_response.status_code == 200
@pytest.mark.asyncio
async def test_concurrent_sessions(self, service_url: str):
"""Test multiple concurrent sessions"""
base_time = int(time.time())
sessions = [f"concurrent_{base_time}_{i}" for i in range(3)]
async def test_session(session_id: str, question: str):
"""Test a single session"""
async with httpx.AsyncClient(timeout=60.0) as client:
request = {
"session_id": session_id,
"messages": [{"role": "user", "content": question}]
}
response = await client.post(
f"{service_url}/api/chat",
json=request,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
return session_id
# Run concurrent sessions
questions = [
"What is ISO 27001?",
"What is NIST Cybersecurity Framework?",
"What is GDPR compliance?"
]
tasks = [
test_session(session_id, question)
for session_id, question in zip(sessions, questions)
]
results = await asyncio.gather(*tasks, return_exceptions=True)
# All sessions should complete successfully
assert len(results) == 3
for result in results:
assert not isinstance(result, Exception)
class TestPerformanceWorkflows:
"""Test performance-related workflows"""
@pytest.mark.asyncio
async def test_rapid_fire_requests(self, service_url: str):
"""Test rapid consecutive requests in same session"""
session_id = f"rapid_fire_{int(time.time())}"
questions = [
"Hello",
"What is ISO 14001?",
"Thank you",
"Goodbye"
]
async with httpx.AsyncClient(timeout=60.0) as client:
for i, question in enumerate(questions):
request = {
"session_id": session_id,
"messages": [{"role": "user", "content": question}]
}
response = await client.post(
f"{service_url}/api/chat",
json=request,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
print(f"Rapid request {i+1} completed")
# Very short delay
await asyncio.sleep(0.1)
@pytest.mark.asyncio
async def test_large_context_workflow(self, service_url: str):
"""Test workflow with gradually increasing context"""
session_id = f"large_context_{int(time.time())}"
async with httpx.AsyncClient(timeout=60.0) as client:
# Build up context over multiple turns
conversation = [
"I need to understand automotive safety standards",
"Specifically, tell me about ISO 26262 functional safety",
"What are the different ASIL levels and their requirements?",
"How do I implement ASIL D for a braking system?",
"What testing and validation is required for ASIL D?",
"Can you provide a summary of everything we've discussed?"
]
for i, message in enumerate(conversation):
request = {
"session_id": session_id,
"messages": [{"role": "user", "content": message}]
}
response = await client.post(
f"{service_url}/api/chat",
json=request,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
print(f"Context turn {i+1} completed")
# Allow time for processing
await asyncio.sleep(1)
class TestRealWorldScenarios:
"""Test realistic user scenarios"""
@pytest.mark.asyncio
async def test_compliance_officer_scenario(self, service_url: str):
"""Simulate a compliance officer's typical workflow"""
session_id = f"compliance_officer_{int(time.time())}"
# Typical compliance questions
scenario_questions = [
"I need to ensure our new product meets regulatory requirements. What standards apply to automotive safety systems?",
"Our system is classified as ASIL C. What does this mean for our development process?",
"What documentation do we need to prepare for safety assessment?",
"How often do we need to review and update our safety processes?"
]
async with httpx.AsyncClient(timeout=90.0) as client:
for i, question in enumerate(scenario_questions):
request = {
"session_id": session_id,
"messages": [{"role": "user", "content": question}]
}
response = await client.post(
f"{service_url}/api/ai-sdk/chat",
json=request,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# Allow realistic time between questions
await asyncio.sleep(2)
print(f"Compliance scenario step {i+1} completed")
@pytest.mark.asyncio
async def test_engineer_research_scenario(self, service_url: str):
"""Simulate an engineer researching technical details"""
session_id = f"engineer_research_{int(time.time())}"
research_flow = [
"I'm designing a safety-critical system. What's the difference between ISO 26262 and IEC 61508?",
"For automotive applications, which standard takes precedence?",
"What are the specific requirements for software development under ISO 26262?",
"Can you explain the V-model development process required by the standard?"
]
async with httpx.AsyncClient(timeout=90.0) as client:
for question in research_flow:
request = {
"session_id": session_id,
"messages": [{"role": "user", "content": question}]
}
response = await client.post(
f"{service_url}/api/chat",
json=request,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# Read some response to verify it's working
content = ""
async for chunk in response.aiter_text():
content += chunk
if len(content) > 150:
break
assert len(content) > 50
await asyncio.sleep(1.5)

View File

@@ -0,0 +1,406 @@
"""
Streaming Integration Tests
These tests validate streaming behavior by connecting to a running service.
They focus on real-time response patterns and streaming event handling.
"""
import pytest
import asyncio
import httpx
import time
import os
# Configuration for remote service connection
DEFAULT_SERVICE_URL = "http://127.0.0.1:8000"
SERVICE_URL = os.getenv("AGENTIC_RAG_SERVICE_URL", DEFAULT_SERVICE_URL)
@pytest.fixture(scope="session")
def service_url() -> str:
"""Get the service URL for testing"""
return SERVICE_URL
class TestStreamingBehavior:
"""Test streaming response behavior"""
@pytest.mark.asyncio
async def test_basic_streaming_response(self, service_url: str):
"""Test that responses are properly streamed"""
session_id = f"streaming_test_{int(time.time())}"
request_data = {
"session_id": session_id,
"messages": [{"role": "user", "content": "What is ISO 26262?"}]
}
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
f"{service_url}/api/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# Collect streaming chunks
chunks = []
async for chunk in response.aiter_text():
chunks.append(chunk)
if len(chunks) > 10: # Get enough chunks to verify streaming
break
# Should receive multiple chunks (indicating streaming)
assert len(chunks) > 1
# Chunks should have content
total_content = "".join(chunks)
assert len(total_content) > 0
@pytest.mark.asyncio
async def test_ai_sdk_streaming_format(self, service_url: str):
"""Test AI SDK compatible streaming format"""
session_id = f"ai_sdk_streaming_{int(time.time())}"
request_data = {
"session_id": session_id,
"messages": [{"role": "user", "content": "Explain vehicle safety testing"}]
}
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
f"{service_url}/api/ai-sdk/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
assert "text/plain" in response.headers.get("content-type", "")
# Test streaming behavior
chunk_count = 0
total_length = 0
async for chunk in response.aiter_text():
chunk_count += 1
total_length += len(chunk)
if chunk_count > 15: # Collect enough chunks
break
# Verify streaming characteristics
assert chunk_count > 1 # Multiple chunks
assert total_length > 50 # Meaningful content
@pytest.mark.asyncio
async def test_streaming_performance(self, service_url: str):
"""Test streaming response timing and performance"""
session_id = f"streaming_perf_{int(time.time())}"
request_data = {
"session_id": session_id,
"messages": [{"role": "user", "content": "What are automotive safety standards?"}]
}
async with httpx.AsyncClient(timeout=60.0) as client:
start_time = time.time()
response = await client.post(
f"{service_url}/api/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
first_chunk_time = None
chunk_count = 0
async for chunk in response.aiter_text():
if first_chunk_time is None:
first_chunk_time = time.time()
chunk_count += 1
if chunk_count > 5: # Get a few chunks for timing
break
# Time to first chunk should be reasonable (< 10 seconds)
if first_chunk_time:
time_to_first_chunk = first_chunk_time - start_time
assert time_to_first_chunk < 10.0
@pytest.mark.asyncio
async def test_streaming_interruption_handling(self, service_url: str):
"""Test behavior when streaming is interrupted"""
session_id = f"streaming_interrupt_{int(time.time())}"
request_data = {
"session_id": session_id,
"messages": [{"role": "user", "content": "Tell me about ISO standards"}]
}
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
f"{service_url}/api/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# Read only a few chunks then stop
chunk_count = 0
async for chunk in response.aiter_text():
chunk_count += 1
if chunk_count >= 3:
break # Interrupt streaming
# Should have received some chunks
assert chunk_count > 0
class TestConcurrentStreaming:
"""Test concurrent streaming scenarios"""
@pytest.mark.asyncio
async def test_multiple_concurrent_streams(self, service_url: str):
"""Test multiple concurrent streaming requests"""
base_time = int(time.time())
async def stream_request(session_suffix: str, question: str):
"""Make a single streaming request"""
session_id = f"concurrent_stream_{base_time}_{session_suffix}"
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
f"{service_url}/api/chat",
json={
"session_id": session_id,
"messages": [{"role": "user", "content": question}]
},
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# Read some chunks
chunks = 0
async for chunk in response.aiter_text():
chunks += 1
if chunks > 5:
break
return chunks
# Run multiple concurrent streams
questions = [
"What is ISO 26262?",
"Explain NIST framework",
"What is GDPR?"
]
tasks = [
stream_request(f"session_{i}", question)
for i, question in enumerate(questions)
]
results = await asyncio.gather(*tasks, return_exceptions=True)
# All streams should complete successfully
assert len(results) == 3
for result in results:
assert not isinstance(result, Exception)
assert result > 0 # Each stream should receive chunks
@pytest.mark.asyncio
async def test_same_session_rapid_requests(self, service_url: str):
"""Test rapid requests in the same session"""
session_id = f"rapid_session_{int(time.time())}"
questions = [
"Hello",
"What is ISO 9001?",
"Thank you"
]
async with httpx.AsyncClient(timeout=60.0) as client:
for i, question in enumerate(questions):
request_data = {
"session_id": session_id,
"messages": [{"role": "user", "content": question}]
}
response = await client.post(
f"{service_url}/api/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# Read some response
chunk_count = 0
async for chunk in response.aiter_text():
chunk_count += 1
if chunk_count > 3:
break
print(f"Request {i+1} completed with {chunk_count} chunks")
# Very short delay
await asyncio.sleep(0.2)
class TestStreamingErrorHandling:
"""Test error handling during streaming"""
@pytest.mark.asyncio
async def test_streaming_with_invalid_session(self, service_url: str):
"""Test streaming behavior with edge case session IDs"""
test_cases = [
"", # Empty session ID
"a" * 1000, # Very long session ID
"session with spaces", # Session ID with spaces
"session/with/slashes" # Session ID with special chars
]
async with httpx.AsyncClient(timeout=60.0) as client:
for session_id in test_cases:
request_data = {
"session_id": session_id,
"messages": [{"role": "user", "content": "Hello"}]
}
try:
response = await client.post(
f"{service_url}/api/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
# Should either work or return validation error
assert response.status_code in [200, 422]
except Exception as e:
# Some edge cases might cause exceptions, which is acceptable
print(f"Session ID '{session_id}' caused exception: {e}")
@pytest.mark.asyncio
async def test_streaming_with_large_messages(self, service_url: str):
"""Test streaming with large message content"""
session_id = f"large_msg_stream_{int(time.time())}"
# Create a large message
large_content = "Please explain safety standards. " * 100 # ~3KB message
request_data = {
"session_id": session_id,
"messages": [{"role": "user", "content": large_content}]
}
async with httpx.AsyncClient(timeout=90.0) as client:
response = await client.post(
f"{service_url}/api/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
# Should handle large messages appropriately
assert response.status_code in [200, 413, 422]
if response.status_code == 200:
# If accepted, should stream properly
chunk_count = 0
async for chunk in response.aiter_text():
chunk_count += 1
if chunk_count > 5:
break
assert chunk_count > 0
class TestStreamingContentValidation:
"""Test streaming content quality and format"""
@pytest.mark.asyncio
async def test_streaming_content_encoding(self, service_url: str):
"""Test that streaming content is properly encoded"""
session_id = f"encoding_test_{int(time.time())}"
# Test with special characters and unicode
test_message = "What is ISO 26262? Please explain with émphasis on safety ñorms."
request_data = {
"session_id": session_id,
"messages": [{"role": "user", "content": test_message}]
}
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
f"{service_url}/api/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# Collect content and verify encoding
content = ""
async for chunk in response.aiter_text():
content += chunk
if len(content) > 100:
break
# Content should be valid UTF-8
assert isinstance(content, str)
assert len(content) > 0
# Should be able to encode/decode
encoded = content.encode('utf-8')
decoded = encoded.decode('utf-8')
assert decoded == content
@pytest.mark.asyncio
async def test_streaming_response_consistency(self, service_url: str):
"""Test that streaming responses are consistent for similar queries"""
base_session = f"consistency_test_{int(time.time())}"
# Ask the same question multiple times
test_question = "What is ISO 26262?"
responses = []
async with httpx.AsyncClient(timeout=60.0) as client:
for i in range(3):
session_id = f"{base_session}_{i}"
request_data = {
"session_id": session_id,
"messages": [{"role": "user", "content": test_question}]
}
response = await client.post(
f"{service_url}/api/chat",
json=request_data,
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
# Collect response
content = ""
async for chunk in response.aiter_text():
content += chunk
if len(content) > 200:
break
responses.append(content)
await asyncio.sleep(0.5)
# All responses should have content
for response_content in responses:
assert len(response_content) > 50
# Responses should have some consistency (all non-empty)
assert len([r for r in responses if r.strip()]) == len(responses)

View File

@@ -0,0 +1 @@
# Empty __init__.py files to make test packages

View File

@@ -0,0 +1,114 @@
#!/usr/bin/env python3
"""
测试新的积极修剪策略即使token数很少也要修剪历史工具调用结果
"""
import pytest
from service.graph.message_trimmer import ConversationTrimmer
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
from langchain_core.messages.utils import count_tokens_approximately
def test_aggressive_tool_history_trimming():
"""测试积极的工具历史修剪策略"""
# 直接创建修剪器,避免配置依赖
trimmer = ConversationTrimmer(max_context_length=100000)
# 创建包含多轮工具调用的对话token数很少
messages = [
SystemMessage(content='You are a helpful assistant.'),
# 历史对话轮次1
HumanMessage(content='搜索汽车标准'),
AIMessage(content='搜索中', tool_calls=[{'id': 'call_1', 'name': 'search', 'args': {'query': '汽车标准'}}]),
ToolMessage(content='历史结果1', tool_call_id='call_1', name='search'),
AIMessage(content='汽车标准信息'),
# 历史对话轮次2
HumanMessage(content='搜索电池标准'),
AIMessage(content='搜索中', tool_calls=[{'id': 'call_2', 'name': 'search', 'args': {'query': '电池标准'}}]),
ToolMessage(content='历史结果2', tool_call_id='call_2', name='search'),
AIMessage(content='电池标准信息'),
# 新的用户查询(触发修剪的时机)
HumanMessage(content='搜索安全标准'),
]
# 验证token数很少远低于阈值
token_count = count_tokens_approximately(messages)
assert token_count < 1000, f"Token count should be low, got {token_count}"
assert token_count < trimmer.history_token_limit, "Token count should be well below limit"
# 验证识别到多个工具轮次
tool_rounds = trimmer._identify_tool_rounds(messages)
assert len(tool_rounds) == 2, f"Should identify 2 tool rounds, got {len(tool_rounds)}"
# 验证触发修剪(因为有多个工具轮次)
should_trim = trimmer.should_trim(messages)
assert should_trim, "Should trigger trimming due to multiple tool rounds"
# 执行修剪
trimmed = trimmer.trim_conversation_history(messages)
# 验证修剪效果
assert len(trimmed) < len(messages), "Should have fewer messages after trimming"
# 验证保留了系统消息和初始查询
assert isinstance(trimmed[0], SystemMessage), "Should preserve system message"
assert isinstance(trimmed[1], HumanMessage), "Should preserve initial human message"
# 验证只保留了最新轮次的工具调用结果
tool_messages = [msg for msg in trimmed if isinstance(msg, ToolMessage)]
assert len(tool_messages) == 1, f"Should only keep 1 tool message, got {len(tool_messages)}"
assert tool_messages[0].content == '历史结果2', "Should keep the most recent tool result"
def test_single_tool_round_no_trimming():
"""测试单轮工具调用不触发修剪"""
trimmer = ConversationTrimmer(max_context_length=100000)
# 只有一轮工具调用的对话
messages = [
SystemMessage(content='You are a helpful assistant.'),
HumanMessage(content='搜索信息'),
AIMessage(content='搜索中', tool_calls=[{'id': 'call_1', 'name': 'search', 'args': {'query': '信息'}}]),
ToolMessage(content='搜索结果', tool_call_id='call_1', name='search'),
AIMessage(content='这是搜索到的信息'),
HumanMessage(content='新的问题'),
]
# 验证只有一个工具轮次
tool_rounds = trimmer._identify_tool_rounds(messages)
assert len(tool_rounds) == 1, f"Should identify 1 tool round, got {len(tool_rounds)}"
# 验证不触发修剪因为只有一个工具轮次且token数不高
should_trim = trimmer.should_trim(messages)
assert not should_trim, "Should not trigger trimming for single tool round with low tokens"
def test_no_tool_rounds_no_trimming():
"""测试没有工具调用的对话不触发修剪"""
trimmer = ConversationTrimmer(max_context_length=100000)
# 没有工具调用的对话
messages = [
SystemMessage(content='You are a helpful assistant.'),
HumanMessage(content='Hello'),
AIMessage(content='Hi there!'),
HumanMessage(content='How are you?'),
AIMessage(content='I am doing well, thank you!'),
]
# 验证没有工具轮次
tool_rounds = trimmer._identify_tool_rounds(messages)
assert len(tool_rounds) == 0, f"Should identify 0 tool rounds, got {len(tool_rounds)}"
# 验证不触发修剪
should_trim = trimmer.should_trim(messages)
assert not should_trim, "Should not trigger trimming without tool rounds"
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,143 @@
"""
Test assistant-ui best practices implementation
"""
import json
import os
def test_package_json_dependencies():
"""Test that package.json has the correct assistant-ui dependencies"""
package_json_path = os.path.join(os.path.dirname(__file__), "../../web/package.json")
with open(package_json_path, 'r') as f:
package_data = json.load(f)
deps = package_data.get("dependencies", {})
# Check for essential assistant-ui packages
assert "@assistant-ui/react" in deps, "Missing @assistant-ui/react"
assert "@assistant-ui/react-ui" in deps, "Missing @assistant-ui/react-ui"
assert "@assistant-ui/react-markdown" in deps, "Missing @assistant-ui/react-markdown"
assert "@assistant-ui/react-data-stream" in deps, "Missing @assistant-ui/react-data-stream"
# Check versions are reasonable (not too old)
react_version = deps["@assistant-ui/react"]
assert "0.10" in react_version or "0.9" in react_version, f"Version too old: {react_version}"
print("✅ Package dependencies test passed")
def test_env_configuration():
"""Test that environment configuration files exist"""
env_local_path = os.path.join(os.path.dirname(__file__), "../../web/.env.local")
assert os.path.exists(env_local_path), "Missing .env.local file"
with open(env_local_path, 'r') as f:
env_content = f.read()
assert "NEXT_PUBLIC_LANGGRAPH_API_URL" in env_content, "Missing API URL config"
assert "NEXT_PUBLIC_LANGGRAPH_ASSISTANT_ID" in env_content, "Missing Assistant ID config"
print("✅ Environment configuration test passed")
def test_api_route_structure():
"""Test that API routes are properly structured"""
# Check main chat API route exists
chat_route_path = os.path.join(os.path.dirname(__file__), "../../web/src/app/api/chat/route.ts")
assert os.path.exists(chat_route_path), "Missing chat API route"
with open(chat_route_path, 'r') as f:
route_content = f.read()
# Check for essential API patterns
assert "export async function POST" in route_content, "Missing POST handler"
assert "Response" in route_content, "Missing Response handling"
assert "x-vercel-ai-data-stream" in route_content, "Missing AI SDK compatibility header"
print("✅ API route structure test passed")
def test_component_structure():
"""Test that main components follow best practices"""
# Check main page component
page_path = os.path.join(os.path.dirname(__file__), "../../web/src/app/page.tsx")
assert os.path.exists(page_path), "Missing main page component"
with open(page_path, 'r') as f:
page_content = f.read()
# Check for key React patterns and components
assert '"use client"' in page_content, "Missing client-side directive"
assert "Assistant" in page_content, "Missing Assistant component"
assert "export default function" in page_content, "Missing default function export"
# Check for proper structure
assert "className=" in page_content, "Missing CSS class usage"
assert "h-screen" in page_content or "h-full" in page_content, "Missing full height layout"
print("✅ Component structure test passed")
def test_markdown_component():
"""Test that markdown component is properly configured"""
markdown_path = os.path.join(os.path.dirname(__file__), "../../web/src/components/ui/markdown-text.tsx")
assert os.path.exists(markdown_path), "Missing markdown component"
with open(markdown_path, 'r') as f:
markdown_content = f.read()
assert "MarkdownTextPrimitive" in markdown_content, "Missing markdown primitive"
assert "remarkGfm" in markdown_content, "Missing GFM support"
print("✅ Markdown component test passed")
def test_best_practices_documentation():
"""Test that best practices documentation exists and is comprehensive"""
docs_path = os.path.join(os.path.dirname(__file__), "../../docs/topics/ASSISTANT_UI_BEST_PRACTICES.md")
assert os.path.exists(docs_path), "Missing best practices documentation"
with open(docs_path, 'r') as f:
docs_content = f.read()
# Check for key sections
assert "Assistant-UI + LangGraph + FastAPI" in docs_content, "Missing main title"
assert "Implementation Status" in docs_content, "Missing implementation status"
assert "Package Dependencies Updated" in docs_content, "Missing dependencies section"
assert "Server-Side API Routes" in docs_content, "Missing API routes explanation"
print("✅ Best practices documentation test passed")
def run_all_tests():
"""Run all tests"""
print("🧪 Running assistant-ui best practices validation tests...")
try:
test_package_json_dependencies()
test_env_configuration()
test_api_route_structure()
test_component_structure()
test_markdown_component()
test_best_practices_documentation()
print("\n🎉 All assistant-ui best practices tests passed!")
print("✅ Your implementation follows the recommended patterns for:")
print(" - Package dependencies and versions")
print(" - Environment configuration")
print(" - API route structure")
print(" - Component composition")
print(" - Markdown rendering")
print(" - Documentation completeness")
return True
except Exception as e:
print(f"\n❌ Test failed: {e}")
return False
if __name__ == "__main__":
success = run_all_tests()
exit(0 if success else 1)

View File

@@ -0,0 +1,141 @@
import pytest
from service.memory.store import InMemoryStore
from service.graph.state import TurnState, Message
from datetime import datetime, timedelta
@pytest.fixture
def memory_store():
"""Create memory store for testing"""
return InMemoryStore(ttl_days=1) # Short TTL for testing
@pytest.fixture
def sample_state():
"""Create sample turn state"""
state = TurnState(session_id="test_session")
state.messages = [
Message(role="user", content="Hello"),
Message(role="assistant", content="Hi there!")
]
return state
def test_create_new_session(memory_store):
"""Test creating a new session"""
state = memory_store.create_new_session("new_session")
assert state.session_id == "new_session"
assert len(state.messages) == 0
# Verify it's stored
retrieved = memory_store.get("new_session")
assert retrieved is not None
assert retrieved.session_id == "new_session"
def test_put_and_get_state(memory_store, sample_state):
"""Test storing and retrieving state"""
memory_store.put("test_session", sample_state)
retrieved = memory_store.get("test_session")
assert retrieved is not None
assert retrieved.session_id == "test_session"
assert len(retrieved.messages) == 2
assert retrieved.messages[0].content == "Hello"
def test_get_nonexistent_session(memory_store):
"""Test getting non-existent session returns None"""
result = memory_store.get("nonexistent")
assert result is None
def test_add_message(memory_store):
"""Test adding messages to conversation"""
memory_store.create_new_session("test_session")
message = Message(role="user", content="Test message")
memory_store.add_message("test_session", message)
state = memory_store.get("test_session")
assert len(state.messages) == 1
assert state.messages[0].content == "Test message"
def test_add_message_to_nonexistent_session(memory_store):
"""Test adding message creates new session if it doesn't exist"""
message = Message(role="user", content="Test message")
memory_store.add_message("new_session", message)
state = memory_store.get("new_session")
assert state is not None
assert len(state.messages) == 1
def test_get_conversation_history(memory_store):
"""Test conversation history formatting"""
memory_store.create_new_session("test_session")
messages = [
Message(role="user", content="Hello"),
Message(role="assistant", content="Hi there!"),
Message(role="user", content="How are you?"),
Message(role="assistant", content="I'm doing well!")
]
for msg in messages:
memory_store.add_message("test_session", msg)
history = memory_store.get_conversation_history("test_session")
assert "User: Hello" in history
assert "Assistant: Hi there!" in history
assert "User: How are you?" in history
assert "Assistant: I'm doing well!" in history
def test_get_conversation_history_empty(memory_store):
"""Test conversation history for empty session"""
history = memory_store.get_conversation_history("nonexistent")
assert history == ""
def test_trim_messages(memory_store):
"""Test message trimming"""
memory_store.create_new_session("test_session")
# Add many messages
for i in range(25):
memory_store.add_message("test_session", Message(role="user", content=f"Message {i}"))
# Trim to 10 messages
memory_store.trim("test_session", max_messages=10)
state = memory_store.get("test_session")
assert len(state.messages) <= 10
def test_trim_nonexistent_session(memory_store):
"""Test trimming non-existent session doesn't crash"""
memory_store.trim("nonexistent", max_messages=10)
# Should not raise an exception
def test_ttl_cleanup(memory_store):
"""Test TTL-based cleanup"""
# Create store with very short TTL for testing
short_ttl_store = InMemoryStore(ttl_days=0.001) # ~1.5 minutes
# Add a session
state = TurnState(session_id="test_session")
short_ttl_store.put("test_session", state)
# Verify it exists
assert short_ttl_store.get("test_session") is not None
# Manually expire it by manipulating internal timestamp
short_ttl_store.store["test_session"]["last_updated"] = datetime.now() - timedelta(days=1)
# Try to get it - should trigger cleanup
assert short_ttl_store.get("test_session") is None
assert "test_session" not in short_ttl_store.store

View File

@@ -0,0 +1,194 @@
"""
Test conversation history trimming functionality.
"""
import pytest
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
from service.graph.message_trimmer import ConversationTrimmer, create_conversation_trimmer
def test_conversation_trimmer_basic():
"""Test basic trimming functionality."""
trimmer = ConversationTrimmer(max_context_length=50) # Very low limit for testing (85% = 42 tokens)
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="What is the capital of France?"),
AIMessage(content="The capital of France is Paris."),
HumanMessage(content="What about Germany?"),
AIMessage(content="The capital of Germany is Berlin."),
HumanMessage(content="And Italy?"),
AIMessage(content="The capital of Italy is Rome."),
]
# Should trigger trimming due to low token limit
assert trimmer.should_trim(messages)
trimmed = trimmer.trim_conversation_history(messages)
# Should preserve system message and keep recent messages
assert len(trimmed) < len(messages)
assert isinstance(trimmed[0], SystemMessage)
assert "helpful assistant" in trimmed[0].content
def test_conversation_trimmer_no_trim_needed():
"""Test when no trimming is needed."""
trimmer = ConversationTrimmer(max_context_length=100000) # High limit
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hello"),
AIMessage(content="Hi there!"),
]
# Should not need trimming
assert not trimmer.should_trim(messages)
trimmed = trimmer.trim_conversation_history(messages)
# Should return same messages
assert len(trimmed) == len(messages)
def test_conversation_trimmer_fallback():
"""Test fallback trimming logic."""
trimmer = ConversationTrimmer(max_context_length=100)
# Create many messages to trigger fallback
messages = [SystemMessage(content="System")] + [
HumanMessage(content=f"Message {i}") for i in range(50)
]
trimmed = trimmer._fallback_trim(messages, max_messages=5)
# Should keep system message + 4 recent messages
assert len(trimmed) == 5
assert isinstance(trimmed[0], SystemMessage)
def test_create_conversation_trimmer():
"""Test trimmer factory function."""
# Test with explicit configuration
custom_trimmer = create_conversation_trimmer(max_context_length=128000)
assert custom_trimmer.max_context_length == 128000
assert isinstance(custom_trimmer, ConversationTrimmer)
assert custom_trimmer.preserve_system is True
def test_multi_round_tool_optimization():
"""Test multi-round tool call optimization."""
trimmer = ConversationTrimmer(max_context_length=100000) # High limit to focus on optimization
# Create a multi-round tool calling scenario
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Search for food safety and cosmetics testing standards."),
# First tool round
AIMessage(content="I'll search for information.", tool_calls=[
{"id": "call_1", "name": "search_tool", "args": {"query": "food safety"}}
]),
ToolMessage(content='{"results": ["Large food safety result data..." * 100]}', tool_call_id="call_1", name="search_tool"),
# Second tool round
AIMessage(content="Let me search for more specific information.", tool_calls=[
{"id": "call_2", "name": "search_tool", "args": {"query": "cosmetics testing"}}
]),
ToolMessage(content='{"results": ["Large cosmetics testing data..." * 100]}', tool_call_id="call_2", name="search_tool"),
# Third tool round (most recent)
AIMessage(content="Now let me get equipment information.", tool_calls=[
{"id": "call_3", "name": "search_tool", "args": {"query": "testing equipment"}}
]),
ToolMessage(content='{"results": ["Testing equipment info..." * 50]}', tool_call_id="call_3", name="search_tool"),
]
# Test tool round identification
tool_rounds = trimmer._identify_tool_rounds(messages)
assert len(tool_rounds) == 3 # Should identify 3 tool rounds
# Test optimization
optimized = trimmer._optimize_multi_round_tool_calls(messages)
# Should preserve: SystemMessage + HumanMessage + latest tool round (AI + Tool messages)
expected_length = 1 + 1 + 2 # system + human + (latest AI + latest Tool)
assert len(optimized) == expected_length
# Check preserved messages
assert isinstance(optimized[0], SystemMessage)
assert isinstance(optimized[1], HumanMessage)
assert isinstance(optimized[2], AIMessage)
assert optimized[2].tool_calls[0]["id"] == "call_3" # Should be the latest tool call
assert isinstance(optimized[3], ToolMessage)
assert optimized[3].tool_call_id == "call_3" # Should be the latest tool result
def test_multi_round_optimization_single_round():
"""Test that single tool round is not optimized."""
trimmer = ConversationTrimmer(max_context_length=100000)
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Search for something."),
AIMessage(content="I'll search.", tool_calls=[{"id": "call_1", "name": "search_tool", "args": {}}]),
ToolMessage(content='{"results": ["data"]}', tool_call_id="call_1", name="search_tool"),
]
optimized = trimmer._optimize_multi_round_tool_calls(messages)
# Should return all messages unchanged for single round
assert len(optimized) == len(messages)
def test_multi_round_optimization_no_tools():
"""Test that conversations without tools are not optimized."""
trimmer = ConversationTrimmer(max_context_length=100000)
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hello"),
AIMessage(content="Hi there!"),
HumanMessage(content="How are you?"),
AIMessage(content="I'm doing well, thanks!"),
]
optimized = trimmer._optimize_multi_round_tool_calls(messages)
# Should return all messages unchanged
assert len(optimized) == len(messages)
def test_full_trimming_with_optimization():
"""Test complete trimming process with multi-round optimization."""
trimmer = ConversationTrimmer(max_context_length=50) # Very low limit
# Create messages that will trigger both optimization and regular trimming
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Search for food safety and cosmetics testing standards."),
# First tool round (should be removed by optimization)
AIMessage(content="I'll search for information.", tool_calls=[
{"id": "call_1", "name": "search_tool", "args": {"query": "food safety"}}
]),
ToolMessage(content='{"results": ["Large food safety result data..." * 100]}', tool_call_id="call_1", name="search_tool"),
# Second tool round (most recent, should be preserved)
AIMessage(content="Let me search for more information.", tool_calls=[
{"id": "call_2", "name": "search_tool", "args": {"query": "cosmetics testing"}}
]),
ToolMessage(content='{"results": ["Large cosmetics testing data..." * 100]}', tool_call_id="call_2", name="search_tool"),
]
trimmed = trimmer.trim_conversation_history(messages)
# Should apply optimization first, then regular trimming if needed
assert len(trimmed) <= len(messages)
# System message should always be preserved
assert any(isinstance(msg, SystemMessage) for msg in trimmed)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,353 @@
"""
Unit tests for agentic retrieval tools.
Tests the HTTP wrapper tools that interface with the retrieval API.
"""
import pytest
from unittest.mock import AsyncMock, patch, Mock
import httpx
from service.retrieval.retrieval import AgenticRetrieval, RetrievalResponse
from service.retrieval.clients import RetrievalAPIError, normalize_search_result
class TestAgenticRetrieval:
"""Test the agentic retrieval HTTP client."""
@pytest.fixture
def retrieval_client(self):
"""Create retrieval client for testing."""
with patch('service.retrieval.retrieval.get_config') as mock_config:
# Mock the new config structure
mock_config.return_value.retrieval.endpoint = "https://test-search.search.azure.com"
mock_config.return_value.retrieval.api_key = "test-key"
mock_config.return_value.retrieval.api_version = "2024-11-01-preview"
mock_config.return_value.retrieval.semantic_configuration = "default"
# Mock embedding config
mock_config.return_value.retrieval.embedding.base_url = "http://test-embedding"
mock_config.return_value.retrieval.embedding.api_key = "test-embedding-key"
mock_config.return_value.retrieval.embedding.model = "test-embedding-model"
mock_config.return_value.retrieval.embedding.dimension = 1536
# Mock index config
mock_config.return_value.retrieval.index.standard_regulation_index = "index-catonline-standard-regulation-v2-prd"
mock_config.return_value.retrieval.index.chunk_index = "index-catonline-chunk-v2-prd"
mock_config.return_value.retrieval.index.chunk_user_manual_index = "index-cat-usermanual-chunk-prd"
return AgenticRetrieval()
@pytest.mark.asyncio
async def test_retrieve_standard_regulation_success(self, retrieval_client):
"""Test successful standards regulation retrieval."""
mock_response_data = {
"value": [
{
"id": "test_id_1",
"title": "ISO 26262-1:2018",
"document_code": "ISO26262-1",
"document_category": "Standard",
"@order_num": 1 # Should be preserved
},
{
"id": "test_id_2",
"title": "ISO 26262-3:2018",
"document_code": "ISO26262-3",
"document_category": "Standard",
"@order_num": 2
}
]
}
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value=mock_response_data) as mock_search:
result = await retrieval_client.retrieve_standard_regulation("ISO 26262")
# Verify search call
mock_search.assert_called_once()
call_args = mock_search.call_args[1] # keyword arguments
assert call_args['search_text'] == "ISO 26262"
assert call_args['index_name'] == "index-catonline-standard-regulation-v2-prd"
assert call_args['top_k'] == 10
# Verify response structure
assert isinstance(result, RetrievalResponse)
assert len(result.results) == 2
assert result.total_count == 2
assert result.took_ms is not None
# Verify normalization
first_result = result.results[0]
assert first_result['id'] == "test_id_1"
assert first_result['title'] == "ISO 26262-1:2018"
# Verify order number is preserved
assert "@order_num" in first_result
assert "content" not in first_result # Should not be included for standards
@pytest.mark.asyncio
async def test_retrieve_doc_chunk_success(self, retrieval_client):
"""Test successful document chunk retrieval."""
mock_response_data = {
"value": [
{
"id": "chunk_id_1",
"title": "Functional Safety Requirements",
"content": "Detailed content about functional safety...",
"document_code": "ISO26262-1",
"@order_num": 1 # Should be preserved
}
]
}
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value=mock_response_data) as mock_search:
result = await retrieval_client.retrieve_doc_chunk_standard_regulation(
"functional safety",
conversation_history="Previous question about ISO 26262"
)
# Verify search call
mock_search.assert_called_once()
call_args = mock_search.call_args[1] # keyword arguments
assert call_args['search_text'] == "functional safety"
assert call_args['index_name'] == "index-catonline-chunk-v2-prd"
assert "filter_query" in call_args # Should have document filter
# Verify response
assert isinstance(result, RetrievalResponse)
assert len(result.results) == 1
# Verify content is included for chunks
first_result = result.results[0]
assert "content" in first_result
assert first_result['content'] == "Detailed content about functional safety..."
# Verify order number is preserved
assert "@order_num" in first_result
@pytest.mark.asyncio
async def test_http_error_handling(self, retrieval_client):
"""Test HTTP error handling."""
with patch.object(retrieval_client.search_client, 'search_azure_ai') as mock_search:
mock_search.side_effect = RetrievalAPIError("API request failed: 404")
with pytest.raises(RetrievalAPIError) as exc_info:
await retrieval_client.retrieve_standard_regulation("nonexistent")
assert "404" in str(exc_info.value)
@pytest.mark.asyncio
async def test_timeout_handling(self, retrieval_client):
"""Test timeout handling."""
with patch.object(retrieval_client.search_client, 'search_azure_ai') as mock_search:
mock_search.side_effect = RetrievalAPIError("Request timeout")
with pytest.raises(RetrievalAPIError) as exc_info:
await retrieval_client.retrieve_standard_regulation("test query")
assert "timeout" in str(exc_info.value)
def test_normalize_result_removes_unwanted_fields(self, retrieval_client):
"""Test result normalization removes unwanted fields."""
from service.retrieval.clients import normalize_search_result
raw_result = {
"id": "test_id",
"title": "Test Title",
"content": "Test content",
"@search.score": 0.95, # Should be removed
"@search.rerankerScore": 2.5, # Should be removed
"@search.captions": [], # Should be removed
"@subquery_id": 1, # Should be removed
"empty_field": "", # Should be removed
"null_field": None, # Should be removed
"empty_list": [], # Should be removed
"empty_dict": {}, # Should be removed
"valid_field": "valid_value"
}
# Test normalization
normalized = normalize_search_result(raw_result)
# Verify unwanted fields are removed
assert "@search.score" not in normalized
assert "@search.rerankerScore" not in normalized
assert "@search.captions" not in normalized
assert "@subquery_id" not in normalized
assert "empty_field" not in normalized
assert "null_field" not in normalized
assert "empty_list" not in normalized
assert "empty_dict" not in normalized
# Verify valid fields are preserved (including content)
assert normalized["id"] == "test_id"
assert normalized["title"] == "Test Title"
assert normalized["content"] == "Test content" # Content is now always preserved
assert normalized["valid_field"] == "valid_value"
def test_normalize_result_includes_content_when_requested(self, retrieval_client):
"""Test result normalization always includes content."""
from service.retrieval.clients import normalize_search_result
raw_result = {
"id": "test_id",
"title": "Test Title",
"content": "Test content"
}
# Test normalization (content is always preserved now)
normalized = normalize_search_result(raw_result)
assert "content" in normalized
assert normalized["content"] == "Test content"
@pytest.mark.asyncio
async def test_empty_results_handling(self, retrieval_client):
"""Test handling of empty search results."""
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"value": []}):
result = await retrieval_client.retrieve_standard_regulation("no results query")
assert isinstance(result, RetrievalResponse)
assert result.results == []
assert result.total_count == 0
@pytest.mark.asyncio
async def test_malformed_response_handling(self, retrieval_client):
"""Test handling of malformed API responses."""
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"unexpected": "format"}):
result = await retrieval_client.retrieve_standard_regulation("test query")
# Should handle gracefully and return empty results
assert isinstance(result, RetrievalResponse)
assert result.results == []
assert result.total_count == 0
@pytest.mark.asyncio
async def test_kwargs_override_payload(self, retrieval_client):
"""Test that kwargs can override default values."""
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"value": []}) as mock_search:
await retrieval_client.retrieve_standard_regulation(
"test query",
top_k=5,
score_threshold=2.0
)
call_args = mock_search.call_args[1] # keyword arguments
assert call_args['top_k'] == 5 # Override default 10
assert call_args['score_threshold'] == 2.0 # Override default 1.5
class TestRetrievalTools:
"""Test the LangGraph tool decorators."""
@pytest.mark.asyncio
async def test_retrieve_standard_regulation_tool(self):
"""Test the @tool decorated function for standards search."""
from service.graph.tools import retrieve_standard_regulation
mock_response = RetrievalResponse(
results=[{"title": "Test Standard", "id": "test_id"}],
took_ms=100,
total_count=1
)
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
mock_instance = AsyncMock()
mock_instance.retrieve_standard_regulation.return_value = mock_response
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
mock_retrieval_class.return_value.__aexit__.return_value = None
# Note: retrieve_standard_regulation is a LangGraph @tool, so we call it directly
result = await retrieve_standard_regulation.ainvoke({"query": "test query", "conversation_history": "test history"})
# Verify result format matches expected tool output
assert isinstance(result, dict)
assert "results" in result
@pytest.mark.asyncio
async def test_retrieve_doc_chunk_tool(self):
"""Test the @tool decorated function for document chunks search."""
from service.graph.tools import retrieve_doc_chunk_standard_regulation
mock_response = RetrievalResponse(
results=[{"title": "Test Chunk", "content": "Test content", "id": "chunk_id"}],
took_ms=150,
total_count=1
)
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
mock_instance = AsyncMock()
mock_instance.retrieve_doc_chunk_standard_regulation.return_value = mock_response
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
mock_retrieval_class.return_value.__aexit__.return_value = None
# Call the tool with proper input format
result = await retrieve_doc_chunk_standard_regulation.ainvoke({"query": "test query"})
# Verify result format
assert isinstance(result, dict)
assert "results" in result
@pytest.mark.asyncio
async def test_tool_error_handling(self):
"""Test tool error handling."""
from service.graph.tools import retrieve_standard_regulation
with patch('service.retrieval.clients.AzureSearchClient.search_azure_ai') as mock_search:
mock_search.side_effect = RetrievalAPIError("API Error")
# Tool should handle error gracefully and return error result
result = await retrieve_standard_regulation.ainvoke({"query": "test query"})
# Should return error information in result
assert isinstance(result, dict)
assert "error" in result
assert "API Error" in result["error"]
class TestRetrievalIntegration:
"""Integration-style tests for retrieval functionality."""
@pytest.mark.asyncio
async def test_full_retrieval_workflow(self):
"""Test complete retrieval workflow with both tools."""
# Mock both retrieval calls
standards_response = RetrievalResponse(
results=[
{"id": "std_1", "title": "ISO 26262-1", "document_category": "Standard"}
],
took_ms=100,
total_count=1
)
chunks_response = RetrievalResponse(
results=[
{"id": "chunk_1", "title": "Safety Requirements", "content": "Detailed safety requirements..."}
],
took_ms=150,
total_count=1
)
with patch('service.retrieval.clients.AzureSearchClient.search_azure_ai') as mock_search:
# Set up mock to return different responses for different calls
def mock_response_side_effect(*args, **kwargs):
if kwargs.get('index_name') == 'index-catonline-standard-regulation-v2-prd':
return {
"value": [{"id": "std_1", "title": "ISO 26262-1", "document_category": "Standard", "@order_num": 1}]
}
else:
return {
"value": [{"id": "chunk_1", "title": "Safety Requirements", "content": "Detailed safety requirements...", "@order_num": 1}]
}
mock_search.side_effect = mock_response_side_effect
# Import and test both tools
from service.graph.tools import retrieve_standard_regulation, retrieve_doc_chunk_standard_regulation
# Test standards search
std_result = await retrieve_standard_regulation.ainvoke({"query": "ISO 26262"})
assert isinstance(std_result, dict)
assert "results" in std_result
assert len(std_result["results"]) == 1
# Test chunks search
chunk_result = await retrieve_doc_chunk_standard_regulation.ainvoke({"query": "safety requirements"})
assert isinstance(chunk_result, dict)
assert "results" in chunk_result
assert len(chunk_result["results"]) == 1

View File

@@ -0,0 +1,73 @@
import pytest
from service.sse import (
format_sse_event, create_token_event, create_tool_start_event,
create_tool_result_event, create_tool_error_event,
create_error_event
)
def test_format_sse_event():
"""Test SSE event formatting"""
event = format_sse_event("test", {"message": "hello"})
expected = 'event: test\ndata: {"message": "hello"}\n\n'
assert event == expected
def test_create_token_event():
"""Test token event creation"""
event = create_token_event("hello", "tool_123")
assert "event: tokens" in event
assert '"delta": "hello"' in event
assert '"tool_call_id": "tool_123"' in event
def test_create_token_event_no_tool_id():
"""Test token event without tool call ID"""
event = create_token_event("hello")
assert "event: tokens" in event
assert '"delta": "hello"' in event
assert '"tool_call_id": null' in event
def test_create_tool_start_event():
"""Test tool start event"""
event = create_tool_start_event("tool_123", "retrieve_standard_regulation", {"query": "test"})
assert "event: tool_start" in event
assert '"id": "tool_123"' in event
assert '"name": "retrieve_standard_regulation"' in event
assert '"args": {"query": "test"}' in event
def test_create_tool_result_event():
"""Test tool result event"""
results = [{"id": "1", "title": "Test Standard"}]
event = create_tool_result_event("tool_123", "retrieve_standard_regulation", results, 500)
assert "event: tool_result" in event
assert '"id": "tool_123"' in event
assert '"took_ms": 500' in event
assert '"results"' in event
def test_create_tool_error_event():
"""Test tool error event"""
event = create_tool_error_event("tool_123", "retrieve_standard_regulation", "API timeout")
assert "event: tool_error" in event
assert '"id": "tool_123"' in event
assert '"error": "API timeout"' in event
def test_create_error_event():
"""Test error event"""
event = create_error_event("Something went wrong")
assert "event: error" in event
assert '"error": "Something went wrong"' in event
def test_create_error_event_with_details():
"""Test error event with details"""
details = {"code": 500, "source": "llm"}
event = create_error_event("Something went wrong", details)
assert "event: error" in event
assert '"error": "Something went wrong"' in event
assert '"details"' in event
assert '"code": 500' in event

View File

@@ -0,0 +1,200 @@
#!/usr/bin/env python3
"""
测试修改后的修剪逻辑:验证只在 tool_rounds=0 时修剪,其他时候跳过修剪
"""
import pytest
from unittest.mock import patch, MagicMock
import asyncio
from datetime import datetime
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from service.graph.state import AgentState
from service.graph.graph import call_model
from service.graph.user_manual_rag import user_manual_agent_node
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
@pytest.mark.asyncio
async def test_trimming_only_at_tool_rounds_zero():
"""测试修剪只在 tool_rounds=0 时发生"""
# 创建一个大的消息列表来触发修剪
large_messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="请搜索信息"),
]
# 添加大量内容来触发修剪
for i in range(10):
large_messages.extend([
AIMessage(content=f"搜索{i}", tool_calls=[
{"id": f"call_{i}", "name": "search", "args": {"query": f"test{i}"}}
]),
ToolMessage(content="很长的搜索结果内容 " * 500, tool_call_id=f"call_{i}", name="search"),
])
# Mock trimmer to track calls
with patch('service.graph.graph.create_conversation_trimmer') as mock_trimmer_factory:
mock_trimmer = MagicMock()
mock_trimmer.should_trim.return_value = True # 总是返回需要修剪
mock_trimmer.trim_conversation_history.return_value = large_messages[:3] # 返回修剪后的消息
mock_trimmer_factory.return_value = mock_trimmer
# Mock LLM client
with patch('service.graph.graph.LLMClient') as mock_llm_factory:
mock_llm = MagicMock()
# 设置异步方法的返回值
async def mock_ainvoke_with_tools(*args, **kwargs):
return AIMessage(content="最终答案")
async def mock_astream(*args, **kwargs):
for token in ["", "", "", ""]:
yield token
mock_llm.ainvoke_with_tools = mock_ainvoke_with_tools
mock_llm.astream = mock_astream
mock_llm.bind_tools = MagicMock()
mock_llm_factory.return_value = mock_llm
# 测试 tool_rounds=0 的情况(应该修剪)
from service.graph.state import AgentState
state_round_0 = AgentState(
messages=large_messages,
session_id="test",
intent=None,
tool_results=[],
final_answer="",
tool_rounds=0, # 第一轮
max_tool_rounds=3,
max_tool_rounds_user_manual=2
)
result = await call_model(state_round_0)
# 验证修剪器被调用
mock_trimmer.should_trim.assert_called()
mock_trimmer.trim_conversation_history.assert_called()
# 重置 mock
mock_trimmer.reset_mock()
# 测试 tool_rounds>0 的情况(不应该修剪)
from service.graph.state import AgentState
state_round_1 = AgentState(
messages=large_messages,
session_id="test",
intent=None,
tool_results=[],
final_answer="",
tool_rounds=1, # 第二轮
max_tool_rounds=3,
max_tool_rounds_user_manual=2
)
result = await call_model(state_round_1)
# 验证修剪器没有被调用
mock_trimmer.should_trim.assert_not_called()
mock_trimmer.trim_conversation_history.assert_not_called()
@pytest.mark.asyncio
async def test_user_manual_agent_trimming_behavior():
"""测试 user_manual_agent_node 的修剪行为"""
# 创建大的消息列表
large_messages = [
SystemMessage(content="You are a helpful user manual assistant."),
HumanMessage(content="请帮我查找系统使用说明"),
]
# 添加大量内容
for i in range(8):
large_messages.extend([
AIMessage(content=f"查找{i}", tool_calls=[
{"id": f"manual_call_{i}", "name": "retrieve_system_usermanual", "args": {"query": f"功能{i}"}}
]),
ToolMessage(content="很长的手册内容 " * 400, tool_call_id=f"manual_call_{i}", name="retrieve_system_usermanual"),
])
# Mock dependencies
with patch('service.graph.user_manual_rag.create_conversation_trimmer') as mock_trimmer_factory, \
patch('service.graph.user_manual_rag.LLMClient') as mock_llm_factory, \
patch('service.graph.user_manual_rag.get_user_manual_tool_schemas') as mock_tools, \
patch('service.graph.user_manual_rag.get_cached_config') as mock_config:
# Setup mocks
mock_trimmer = MagicMock()
mock_trimmer.should_trim.return_value = True
mock_trimmer.trim_conversation_history.return_value = large_messages[:3]
mock_trimmer_factory.return_value = mock_trimmer
mock_llm = MagicMock()
# 设置异步方法的返回值
async def mock_ainvoke_with_tools(*args, **kwargs):
return AIMessage(content="用户手册答案")
async def mock_astream(*args, **kwargs):
for token in ["", ""]:
yield token
mock_llm.ainvoke_with_tools = mock_ainvoke_with_tools
mock_llm.astream = mock_astream
mock_llm.bind_tools = MagicMock()
mock_llm_factory.return_value = mock_llm
mock_tools.return_value = []
mock_app_config = MagicMock()
mock_app_config.app.max_tool_rounds_user_manual = 2
mock_app_config.get_rag_prompts.return_value = {
"user_manual_prompt": "System prompt for user manual: {conversation_history} {context_content} {current_query}"
}
mock_config.return_value = mock_app_config
# 测试 tool_rounds=0应该修剪
from service.graph.state import AgentState
state_round_0 = AgentState(
messages=large_messages,
session_id="test",
intent=None,
tool_results=[],
final_answer="",
tool_rounds=0,
max_tool_rounds=3,
max_tool_rounds_user_manual=2
)
result = await user_manual_agent_node(state_round_0)
# 验证修剪器被调用
mock_trimmer.should_trim.assert_called()
mock_trimmer.trim_conversation_history.assert_called()
# 重置 mock
mock_trimmer.reset_mock()
# 测试 tool_rounds=1不应该修剪
state_round_1 = AgentState(
messages=large_messages,
session_id="test",
intent=None,
tool_results=[],
final_answer="",
tool_rounds=1,
max_tool_rounds=3,
max_tool_rounds_user_manual=2
)
result = await user_manual_agent_node(state_round_1)
# 验证修剪器没有被调用
mock_trimmer.should_trim.assert_not_called()
mock_trimmer.trim_conversation_history.assert_not_called()
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,95 @@
"""
Unit test for the new retrieve_system_usermanual tool
"""
import pytest
from unittest.mock import AsyncMock, patch
import asyncio
from service.graph.user_manual_tools import retrieve_system_usermanual, user_manual_tools, get_user_manual_tool_schemas, get_user_manual_tools_by_name
from service.graph.tools import get_tool_schemas, get_tools_by_name
from service.retrieval.retrieval import RetrievalResponse
class TestRetrieveSystemUsermanualTool:
"""Test the new user manual retrieval tool"""
def test_tool_in_tools_list(self):
"""Test that the new tool is in the tools list"""
tool_names = [tool.name for tool in user_manual_tools]
assert "retrieve_system_usermanual" in tool_names
assert len(user_manual_tools) == 1 # Should have 1 user manual tool
def test_tool_schemas_generation(self):
"""Test that tool schemas are generated correctly"""
schemas = get_user_manual_tool_schemas()
tool_names = [schema["function"]["name"] for schema in schemas]
assert "retrieve_system_usermanual" in tool_names
# Find the user manual tool schema
user_manual_schema = next(
schema for schema in schemas
if schema["function"]["name"] == "retrieve_system_usermanual"
)
assert user_manual_schema["function"]["description"] == "Search for document content chunks of user manual of this system(CATOnline)"
assert "query" in user_manual_schema["function"]["parameters"]["properties"]
assert user_manual_schema["function"]["parameters"]["required"] == ["query"]
def test_tools_by_name_mapping(self):
"""Test that tools_by_name mapping includes the new tool"""
tools_mapping = get_user_manual_tools_by_name()
assert "retrieve_system_usermanual" in tools_mapping
assert tools_mapping["retrieve_system_usermanual"] == retrieve_system_usermanual
@pytest.mark.asyncio
async def test_retrieve_system_usermanual_success(self):
"""Test successful user manual retrieval"""
mock_response = RetrievalResponse(
results=[
{"title": "User Manual Chapter 1", "content": "How to use the system", "id": "manual_1"}
],
took_ms=120,
total_count=1
)
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
mock_instance = AsyncMock()
mock_instance.retrieve_doc_chunk_user_manual.return_value = mock_response
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
mock_retrieval_class.return_value.__aexit__.return_value = None
# Test the tool
result = await retrieve_system_usermanual.ainvoke({"query": "how to use system"})
# Verify result format
assert isinstance(result, dict)
assert result["tool_name"] == "retrieve_system_usermanual"
assert result["results_count"] == 1
assert len(result["results"]) == 1
assert result["results"][0]["title"] == "User Manual Chapter 1"
assert result["took_ms"] == 120
@pytest.mark.asyncio
async def test_retrieve_system_usermanual_error(self):
"""Test error handling in user manual retrieval"""
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
mock_instance = AsyncMock()
mock_instance.retrieve_doc_chunk_user_manual.side_effect = Exception("Search API Error")
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
mock_retrieval_class.return_value.__aexit__.return_value = None
# Test error handling
result = await retrieve_system_usermanual.ainvoke({"query": "test query"})
# Should return error information
assert isinstance(result, dict)
assert "error" in result
assert "Search API Error" in result["error"]
assert result["results_count"] == 0
assert result["results"] == []
if __name__ == "__main__":
pytest.main([__file__, "-v"])