init
This commit is contained in:
1
vw-agentic-rag/tests/__init__.py
Normal file
1
vw-agentic-rag/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Empty __init__.py files to make test packages
|
||||
317
vw-agentic-rag/tests/conftest.py
Normal file
317
vw-agentic-rag/tests/conftest.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
Shared pytest fixtures and configuration for the agentic-rag test suite.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
import httpx
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from service.main import create_app
|
||||
from service.config import Config
|
||||
from service.graph.state import TurnState, Message, ToolResult
|
||||
from service.memory.postgresql_memory import PostgreSQLMemoryManager
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for the test session."""
|
||||
policy = asyncio.get_event_loop_policy()
|
||||
loop = policy.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def config_mock():
|
||||
"""Mock configuration for all tests."""
|
||||
config = Mock()
|
||||
config.retrieval.endpoint = "http://test-endpoint"
|
||||
config.retrieval.api_key = "test-key"
|
||||
config.llm.provider = "openai"
|
||||
config.llm.model = "gpt-4"
|
||||
config.llm.api_key = "test-api-key"
|
||||
config.memory.enabled = True
|
||||
config.memory.type = "in_memory"
|
||||
config.memory.ttl_days = 7
|
||||
config.postgresql.enabled = False
|
||||
|
||||
with patch('service.config.get_config', return_value=config):
|
||||
with patch('service.retrieval.retrieval.get_config', return_value=config):
|
||||
with patch('service.graph.graph.get_config', return_value=config):
|
||||
yield config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_config():
|
||||
"""Test configuration with safe defaults."""
|
||||
return {
|
||||
"provider": "openai",
|
||||
"openai": {
|
||||
"api_key": "test-openai-key",
|
||||
"model": "gpt-4o",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"temperature": 0.2
|
||||
},
|
||||
"retrieval": {
|
||||
"endpoint": "http://test-retrieval-endpoint",
|
||||
"api_key": "test-retrieval-key"
|
||||
},
|
||||
"postgresql": {
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"database": "test_agent_memory",
|
||||
"username": "test",
|
||||
"password": "test",
|
||||
"ttl_days": 1
|
||||
},
|
||||
"app": {
|
||||
"name": "agentic-rag-test",
|
||||
"memory_ttl_days": 1,
|
||||
"max_tool_loops": 3,
|
||||
"cors_origins": ["*"]
|
||||
},
|
||||
"llm": {
|
||||
"rag": {
|
||||
"temperature": 0,
|
||||
"max_context_length": 32000,
|
||||
"agent_system_prompt": "You are a test assistant."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(test_config):
|
||||
"""Create test FastAPI app with mocked configuration."""
|
||||
with patch('service.config.load_config') as mock_load_config:
|
||||
mock_load_config.return_value = test_config
|
||||
|
||||
# Mock the memory manager to avoid PostgreSQL dependency in tests
|
||||
with patch('service.memory.postgresql_memory.get_memory_manager') as mock_memory:
|
||||
mock_memory_manager = Mock()
|
||||
mock_memory_manager.test_connection.return_value = True
|
||||
mock_memory.return_value = mock_memory_manager
|
||||
|
||||
# Mock the graph builder to avoid complex dependencies
|
||||
with patch('service.graph.graph.build_graph') as mock_build_graph:
|
||||
mock_graph = Mock()
|
||||
mock_build_graph.return_value = mock_graph
|
||||
|
||||
app = create_app()
|
||||
app.state.memory_manager = mock_memory_manager
|
||||
app.state.graph = mock_graph
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_client():
|
||||
"""Mock LLM client for testing."""
|
||||
mock = AsyncMock()
|
||||
mock.astream.return_value = iter(["Test", " response", " token"])
|
||||
mock.ainvoke_with_tools.return_value = Mock(
|
||||
content="Test response",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "test_tool_call_1",
|
||||
"function": {
|
||||
"name": "retrieve_standard_regulation",
|
||||
"arguments": '{"query": "test query"}'
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_retrieval_response():
|
||||
"""Mock response from retrieval API."""
|
||||
return {
|
||||
"results": [
|
||||
{
|
||||
"id": "test_result_1",
|
||||
"title": "ISO 26262-1:2018",
|
||||
"content": "Road vehicles — Functional safety — Part 1: Vocabulary",
|
||||
"score": 0.95,
|
||||
"url": "https://iso.org/26262-1",
|
||||
"metadata": {
|
||||
"@tool_call_id": "test_tool_call_1",
|
||||
"@order_num": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "test_result_2",
|
||||
"title": "ISO 26262-3:2018",
|
||||
"content": "Road vehicles — Functional safety — Part 3: Concept phase",
|
||||
"score": 0.88,
|
||||
"url": "https://iso.org/26262-3",
|
||||
"metadata": {
|
||||
"@tool_call_id": "test_tool_call_1",
|
||||
"@order_num": 1
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"total": 2,
|
||||
"took_ms": 150,
|
||||
"query": "test query"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chat_request():
|
||||
"""Sample chat request for testing."""
|
||||
return {
|
||||
"session_id": "test_session_123",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is ISO 26262?"}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_turn_state():
|
||||
"""Sample TurnState for testing."""
|
||||
return TurnState(
|
||||
session_id="test_session_123",
|
||||
messages=[
|
||||
Message(role="user", content="What is ISO 26262?")
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_httpx_client():
|
||||
"""Mock httpx client for API requests."""
|
||||
mock_client = AsyncMock()
|
||||
|
||||
# Default response for retrieval API
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"results": [
|
||||
{
|
||||
"id": "test_result",
|
||||
"title": "Test Standard",
|
||||
"content": "Test content",
|
||||
"score": 0.9
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
mock_client.post.return_value = mock_response
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_postgresql_memory():
|
||||
"""Mock PostgreSQL memory manager."""
|
||||
mock_manager = Mock(spec=PostgreSQLMemoryManager)
|
||||
mock_manager.test_connection.return_value = True
|
||||
|
||||
mock_checkpointer = Mock()
|
||||
mock_checkpointer.setup.return_value = None
|
||||
mock_manager.get_checkpointer.return_value = mock_checkpointer
|
||||
|
||||
return mock_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_streaming_response():
|
||||
"""Mock streaming response events."""
|
||||
return [
|
||||
'event: tool_start\ndata: {"id": "test_tool_1", "name": "retrieve_standard_regulation", "args": {"query": "test"}}\n\n',
|
||||
'event: tokens\ndata: {"delta": "Based on the retrieved standards", "tool_call_id": null}\n\n',
|
||||
'event: tool_result\ndata: {"id": "test_tool_1", "name": "retrieve_standard_regulation", "results": [], "took_ms": 100}\n\n',
|
||||
'event: tokens\ndata: {"delta": " this is a test response.", "tool_call_id": null}\n\n'
|
||||
]
|
||||
|
||||
|
||||
# Async test helpers
|
||||
@pytest.fixture
|
||||
def mock_agent_state():
|
||||
"""Mock agent state for graph testing."""
|
||||
return {
|
||||
"messages": [],
|
||||
"session_id": "test_session",
|
||||
"tool_results": [],
|
||||
"final_answer": ""
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_test_client():
|
||||
"""Async test client for integration tests."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
yield client
|
||||
|
||||
|
||||
# Database fixtures for integration tests
|
||||
@pytest.fixture
|
||||
def test_database_url():
|
||||
"""Test database URL (only for integration tests with real DB)."""
|
||||
return "postgresql://test:test@localhost:5432/test_agent_memory"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def integration_test_config(test_database_url):
|
||||
"""Configuration for integration tests with real database."""
|
||||
return {
|
||||
"provider": "openai",
|
||||
"openai": {
|
||||
"api_key": "test-key",
|
||||
"model": "gpt-4o"
|
||||
},
|
||||
"retrieval": {
|
||||
"endpoint": "http://localhost:8000/search", # Assume test retrieval server
|
||||
"api_key": "test-key"
|
||||
},
|
||||
"postgresql": {
|
||||
"connection_string": test_database_url
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Skip markers for different test types
|
||||
def pytest_configure(config):
|
||||
"""Configure pytest markers."""
|
||||
config.addinivalue_line("markers", "unit: mark test as unit test")
|
||||
config.addinivalue_line("markers", "integration: mark test as integration test")
|
||||
config.addinivalue_line("markers", "e2e: mark test as end-to-end test")
|
||||
config.addinivalue_line("markers", "slow: mark test as slow running")
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
"""Setup for test items."""
|
||||
# Skip integration tests if not explicitly requested
|
||||
if "integration" in item.keywords and not item.config.getoption("--run-integration"):
|
||||
pytest.skip("Integration tests not requested")
|
||||
|
||||
# Skip E2E tests if not explicitly requested
|
||||
if "e2e" in item.keywords and not item.config.getoption("--run-e2e"):
|
||||
pytest.skip("E2E tests not requested")
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
"""Add custom command line options."""
|
||||
parser.addoption(
|
||||
"--run-integration",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="run integration tests"
|
||||
)
|
||||
parser.addoption(
|
||||
"--run-e2e",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="run end-to-end tests"
|
||||
)
|
||||
33
vw-agentic-rag/tests/func_test.py
Normal file
33
vw-agentic-rag/tests/func_test.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import httpx
|
||||
|
||||
def get_embedding(text: str) -> list[float]:
|
||||
"""Get embedding vector for text using the configured embedding service"""
|
||||
|
||||
api_key = "h7ARU7tP7cblbpIQFpFXnhxVdFwH9rLXP654UfSJd8xKCJzeg4VOJQQJ99AKACi0881XJ3w3AAABACOGTlOf"
|
||||
model = "text-embedding-3-small"
|
||||
base_url = "https://aoai-lab-jpe-fl.openai.azure.com/openai/deployments/text-embedding-3-small/embeddings?api-version=2024-12-01-preview"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"input": text,
|
||||
"model": model
|
||||
}
|
||||
|
||||
try:
|
||||
response = httpx.post(f"{base_url}", json=payload, headers=headers )
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
print(result)
|
||||
return result["data"][0]["embedding"]
|
||||
except Exception as e:
|
||||
print(f"Failed to get embedding: {e}")
|
||||
raise Exception(f"Embedding generation failed: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Begin")
|
||||
text = "Sample text for embedding"
|
||||
result = get_embedding(text)
|
||||
print(result)
|
||||
1
vw-agentic-rag/tests/integration/__init__.py
Normal file
1
vw-agentic-rag/tests/integration/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Empty __init__.py files to make test packages
|
||||
170
vw-agentic-rag/tests/integration/test_2phase_retrieval.py
Normal file
170
vw-agentic-rag/tests/integration/test_2phase_retrieval.py
Normal file
@@ -0,0 +1,170 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test 2-phase retrieval strategy
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def test_2phase_retrieval():
|
||||
"""Test that agent uses 2-phase retrieval for content-focused queries"""
|
||||
|
||||
session_id = f"2phase-test-{random.randint(1000000000, 9999999999)}"
|
||||
base_url = "http://127.0.0.1:8000"
|
||||
|
||||
# Test query that should trigger 2-phase retrieval
|
||||
query = "如何测试电动汽车的充电性能?请详细说明测试方法和步骤。"
|
||||
|
||||
logger.info("🎯 2-PHASE RETRIEVAL TEST")
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"📝 Session: {session_id}")
|
||||
logger.info(f"📝 Query: {query}")
|
||||
logger.info("-" * 60)
|
||||
|
||||
# Create the request payload
|
||||
payload = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": query
|
||||
}
|
||||
],
|
||||
"session_id": session_id
|
||||
}
|
||||
|
||||
# Track tool usage
|
||||
metadata_tools = 0
|
||||
content_tools = 0
|
||||
total_tools = 0
|
||||
|
||||
timeout = httpx.Timeout(120.0) # 2 minute timeout
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
logger.info("✅ Streaming response started")
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{base_url}/api/chat",
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json"}
|
||||
) as response:
|
||||
|
||||
# Check if the response started successfully
|
||||
if response.status_code != 200:
|
||||
error_body = await response.aread()
|
||||
logger.error(f"❌ HTTP {response.status_code}: {error_body.decode()}")
|
||||
return
|
||||
|
||||
# Process the streaming response
|
||||
current_event_type = None
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
if line.startswith("event: "):
|
||||
current_event_type = line[7:] # Remove "event: " prefix
|
||||
continue
|
||||
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:] # Remove "data: " prefix
|
||||
|
||||
if data_str == "[DONE]":
|
||||
logger.info("✅ Stream completed with [DONE]")
|
||||
break
|
||||
|
||||
try:
|
||||
event_data = json.loads(data_str)
|
||||
event_type = current_event_type or "unknown"
|
||||
|
||||
if event_type == "tool_start":
|
||||
total_tools += 1
|
||||
tool_name = event_data.get("name", "unknown")
|
||||
args = event_data.get("args", {})
|
||||
query_arg = args.get("query", "")[:50] + "..." if len(args.get("query", "")) > 50 else args.get("query", "")
|
||||
|
||||
if tool_name == "retrieve_standard_regulation":
|
||||
metadata_tools += 1
|
||||
logger.info(f"📋 Phase 1 Tool {metadata_tools}: {tool_name}")
|
||||
logger.info(f" Query: {query_arg}")
|
||||
elif tool_name == "retrieve_doc_chunk_standard_regulation":
|
||||
content_tools += 1
|
||||
logger.info(f"📄 Phase 2 Tool {content_tools}: {tool_name}")
|
||||
logger.info(f" Query: {query_arg}")
|
||||
else:
|
||||
logger.info(f"🔧 Tool {total_tools}: {tool_name}")
|
||||
|
||||
elif event_type == "tool_result":
|
||||
tool_name = event_data.get("name", "unknown")
|
||||
results_count = len(event_data.get("results", []))
|
||||
took_ms = event_data.get("took_ms", 0)
|
||||
logger.info(f"✅ Tool completed: {tool_name} ({results_count} results, {took_ms}ms)")
|
||||
|
||||
elif event_type == "tokens":
|
||||
# Don't log every token, just count them
|
||||
pass
|
||||
|
||||
# Reset event type for next event
|
||||
current_event_type = None
|
||||
|
||||
# Break after many tools to avoid too much output
|
||||
if total_tools > 20:
|
||||
logger.info(" ⚠️ Breaking after 20 tools...")
|
||||
break
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"⚠️ Failed to parse event: {e}")
|
||||
current_event_type = None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Request failed: {e}")
|
||||
return
|
||||
|
||||
# Results
|
||||
logger.info("=" * 80)
|
||||
logger.info("📊 2-PHASE RETRIEVAL ANALYSIS")
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"Phase 1 (Metadata) tools: {metadata_tools}")
|
||||
logger.info(f"Phase 2 (Content) tools: {content_tools}")
|
||||
logger.info(f"Total tools executed: {total_tools}")
|
||||
logger.info("-" * 60)
|
||||
|
||||
# Success criteria
|
||||
success_criteria = [
|
||||
(metadata_tools > 0, f"Phase 1 metadata retrieval: {'✅' if metadata_tools > 0 else '❌'} ({metadata_tools} tools)"),
|
||||
(content_tools > 0, f"Phase 2 content retrieval: {'✅' if content_tools > 0 else '❌'} ({content_tools} tools)"),
|
||||
(total_tools >= 2, f"Multi-tool execution: {'✅' if total_tools >= 2 else '❌'} ({total_tools} tools)")
|
||||
]
|
||||
|
||||
logger.info("✅ SUCCESS CRITERIA:")
|
||||
all_passed = True
|
||||
for passed, message in success_criteria:
|
||||
logger.info(f" {message}")
|
||||
if not passed:
|
||||
all_passed = False
|
||||
|
||||
if all_passed:
|
||||
logger.info("🎉 2-PHASE RETRIEVAL TEST PASSED!")
|
||||
logger.info(" ✅ Agent correctly uses both metadata and content retrieval tools")
|
||||
else:
|
||||
logger.info("❌ 2-PHASE RETRIEVAL TEST FAILED!")
|
||||
if metadata_tools == 0:
|
||||
logger.info(" ❌ No metadata retrieval tools used")
|
||||
if content_tools == 0:
|
||||
logger.info(" ❌ No content retrieval tools used - this is the main issue!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_2phase_retrieval())
|
||||
372
vw-agentic-rag/tests/integration/test_api.py
Normal file
372
vw-agentic-rag/tests/integration/test_api.py
Normal file
@@ -0,0 +1,372 @@
|
||||
"""
|
||||
Remote Integration Tests for Agentic RAG API
|
||||
|
||||
These tests connect to a running service instance remotely to validate:
|
||||
- API endpoints and responses
|
||||
- Request/response schemas
|
||||
- Basic functionality without external dependencies
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
import httpx
|
||||
from typing import Optional, Dict, Any
|
||||
import time
|
||||
import os
|
||||
|
||||
|
||||
# Configuration for remote service connection
|
||||
DEFAULT_SERVICE_URL = "http://127.0.0.1:8000"
|
||||
SERVICE_URL = os.getenv("AGENTIC_RAG_SERVICE_URL", DEFAULT_SERVICE_URL)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def service_url() -> str:
|
||||
"""Get the service URL for testing"""
|
||||
return SERVICE_URL
|
||||
|
||||
|
||||
class TestBasicAPI:
|
||||
"""Test basic API endpoints and functionality"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_endpoint(self, service_url: str):
|
||||
"""Test service health endpoint"""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(f"{service_url}/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert data["service"] == "agentic-rag"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_root_endpoint(self, service_url: str):
|
||||
"""Test root API endpoint"""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(f"{service_url}/")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "message" in data
|
||||
assert "Agentic RAG API" in data["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openapi_docs(self, service_url: str):
|
||||
"""Test OpenAPI documentation endpoint"""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(f"{service_url}/openapi.json")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "openapi" in data
|
||||
assert "info" in data
|
||||
assert data["info"]["title"] == "Agentic RAG API"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_docs_endpoint(self, service_url: str):
|
||||
"""Test Swagger UI docs endpoint"""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(f"{service_url}/docs")
|
||||
assert response.status_code == 200
|
||||
assert "text/html" in response.headers["content-type"]
|
||||
|
||||
|
||||
class TestChatAPI:
|
||||
"""Test chat API endpoints with valid requests"""
|
||||
|
||||
def _create_chat_request(self, message: str, session_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Create a valid chat request"""
|
||||
return {
|
||||
"session_id": session_id or f"test_session_{int(time.time())}",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": message
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_endpoint_basic_request(self, service_url: str):
|
||||
"""Test basic chat endpoint request/response structure"""
|
||||
request_data = self._create_chat_request("Hello, can you help me?")
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
# Response should be streaming text/event-stream
|
||||
assert "text/event-stream" in response.headers.get("content-type", "") or \
|
||||
"text/plain" in response.headers.get("content-type", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_sdk_chat_endpoint_basic_request(self, service_url: str):
|
||||
"""Test AI SDK compatible chat endpoint"""
|
||||
request_data = self._create_chat_request("What is ISO 26262?")
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/ai-sdk/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
# AI SDK endpoint returns plain text stream
|
||||
assert "text/plain" in response.headers.get("content-type", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_endpoint_invalid_request(self, service_url: str):
|
||||
"""Test chat endpoint with invalid request data"""
|
||||
invalid_requests = [
|
||||
{}, # Empty request
|
||||
{"session_id": "test"}, # Missing messages
|
||||
{"messages": []}, # Missing session_id
|
||||
{"session_id": "test", "messages": [{"role": "invalid"}]}, # Invalid message format
|
||||
]
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
for invalid_request in invalid_requests:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=invalid_request,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
# Should return 422 for validation errors
|
||||
assert response.status_code == 422
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_persistence(self, service_url: str):
|
||||
"""Test that sessions persist across multiple requests"""
|
||||
session_id = f"persistent_session_{int(time.time())}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
# First message
|
||||
request1 = self._create_chat_request("My name is John", session_id)
|
||||
response1 = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request1,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
assert response1.status_code == 200
|
||||
|
||||
# Wait a moment for processing
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Second message referring to previous context
|
||||
request2 = self._create_chat_request("What did I just tell you my name was?", session_id)
|
||||
response2 = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request2,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
assert response2.status_code == 200
|
||||
|
||||
|
||||
class TestRequestValidation:
|
||||
"""Test request validation and error handling"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_json(self, service_url: str):
|
||||
"""Test endpoint with malformed JSON"""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
content="invalid json{",
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_content_type(self, service_url: str):
|
||||
"""Test endpoint without proper content type"""
|
||||
request_data = {
|
||||
"session_id": "test_session",
|
||||
"messages": [{"role": "user", "content": "test"}]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
content=json.dumps(request_data)
|
||||
# No Content-Type header
|
||||
)
|
||||
# FastAPI should handle this gracefully
|
||||
assert response.status_code in [415, 422]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oversized_request(self, service_url: str):
|
||||
"""Test endpoint with very large request"""
|
||||
large_content = "x" * 100000 # 100KB message
|
||||
request_data = {
|
||||
"session_id": "test_session",
|
||||
"messages": [{"role": "user", "content": large_content}]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
# Should either process or reject gracefully
|
||||
assert response.status_code in [200, 413, 422]
|
||||
|
||||
|
||||
class TestCORSAndHeaders:
|
||||
"""Test CORS and security headers"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cors_headers(self, service_url: str):
|
||||
"""Test CORS headers are properly set"""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.options(
|
||||
f"{service_url}/api/chat",
|
||||
headers={
|
||||
"Origin": "http://localhost:3000",
|
||||
"Access-Control-Request-Method": "POST",
|
||||
"Access-Control-Request-Headers": "Content-Type"
|
||||
}
|
||||
)
|
||||
|
||||
# CORS preflight should be handled
|
||||
assert response.status_code in [200, 204]
|
||||
|
||||
# Check for CORS headers in actual request
|
||||
request_data = {
|
||||
"session_id": "cors_test",
|
||||
"messages": [{"role": "user", "content": "test"}]
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request_data,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Origin": "http://localhost:3000"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
# Should have CORS headers
|
||||
assert "access-control-allow-origin" in response.headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_security_headers(self, service_url: str):
|
||||
"""Test basic security headers"""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(f"{service_url}/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Check for basic security practices
|
||||
# Note: Specific headers depend on deployment configuration
|
||||
headers = response.headers
|
||||
|
||||
# FastAPI should include some basic headers
|
||||
assert "content-length" in headers or "transfer-encoding" in headers
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test error handling and edge cases"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonexistent_endpoint(self, service_url: str):
|
||||
"""Test request to non-existent endpoint"""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(f"{service_url}/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_method_not_allowed(self, service_url: str):
|
||||
"""Test wrong HTTP method on endpoint"""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(f"{service_url}/api/chat") # GET instead of POST
|
||||
assert response.status_code == 405
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_handling(self, service_url: str):
|
||||
"""Test request timeout handling"""
|
||||
# Use a very short timeout to test timeout handling
|
||||
async with httpx.AsyncClient(timeout=0.001) as short_timeout_client:
|
||||
try:
|
||||
response = await short_timeout_client.get(f"{service_url}/health")
|
||||
# If it doesn't timeout, that's also fine
|
||||
assert response.status_code == 200
|
||||
except httpx.TimeoutException:
|
||||
# Expected timeout - this is fine
|
||||
pass
|
||||
|
||||
|
||||
class TestServiceIntegration:
|
||||
"""Test integration with actual service features"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manufacturing_standards_query(self, service_url: str):
|
||||
"""Test query related to manufacturing standards"""
|
||||
request_data = {
|
||||
"session_id": f"standards_test_{int(time.time())}",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What are the key safety requirements in ISO 26262?"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/ai-sdk/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Read some of the streaming response
|
||||
content = ""
|
||||
async for chunk in response.aiter_text():
|
||||
content += chunk
|
||||
if len(content) > 100: # Read enough to verify it's working
|
||||
break
|
||||
|
||||
# Should have some content indicating it's processing
|
||||
assert len(content) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_general_conversation(self, service_url: str):
|
||||
"""Test general conversation capability"""
|
||||
request_data = {
|
||||
"session_id": f"general_test_{int(time.time())}",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello! How can you help me today?"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify we get streaming response
|
||||
content = ""
|
||||
chunk_count = 0
|
||||
async for chunk in response.aiter_text():
|
||||
content += chunk
|
||||
chunk_count += 1
|
||||
if chunk_count > 10: # Read several chunks
|
||||
break
|
||||
|
||||
# Should receive streaming content
|
||||
assert len(content) > 0
|
||||
415
vw-agentic-rag/tests/integration/test_e2e_tool_ui.py
Normal file
415
vw-agentic-rag/tests/integration/test_e2e_tool_ui.py
Normal file
@@ -0,0 +1,415 @@
|
||||
"""
|
||||
End-to-End Integration Tests for Tool UI
|
||||
|
||||
These tests validate the complete user experience by connecting to a running service.
|
||||
They test tool calling, response formatting, and user interface integration.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
import httpx
|
||||
import time
|
||||
import os
|
||||
|
||||
|
||||
# Configuration for remote service connection
|
||||
DEFAULT_SERVICE_URL = "http://127.0.0.1:8000"
|
||||
SERVICE_URL = os.getenv("AGENTIC_RAG_SERVICE_URL", DEFAULT_SERVICE_URL)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def service_url() -> str:
|
||||
"""Get the service URL for testing"""
|
||||
return SERVICE_URL
|
||||
|
||||
|
||||
class TestEndToEndWorkflows:
|
||||
"""Test complete end-to-end user workflows"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_standards_research_with_tools(self, service_url: str):
|
||||
"""Test standards research workflow with tool calls"""
|
||||
session_id = f"e2e_standards_{int(time.time())}"
|
||||
|
||||
request_data = {
|
||||
"session_id": session_id,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What are the safety requirements for automotive braking systems according to ISO 26262?"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=90.0) as client:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Collect the full response to analyze tool usage
|
||||
full_content = ""
|
||||
async for chunk in response.aiter_text():
|
||||
full_content += chunk
|
||||
if len(full_content) > 1000: # Get substantial content
|
||||
break
|
||||
|
||||
# Verify we got meaningful content
|
||||
assert len(full_content) > 100
|
||||
print(f"Standards research response length: {len(full_content)} chars")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manufacturing_compliance_workflow(self, service_url: str):
|
||||
"""Test manufacturing compliance workflow"""
|
||||
session_id = f"e2e_compliance_{int(time.time())}"
|
||||
|
||||
request_data = {
|
||||
"session_id": session_id,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I need to understand compliance requirements for manufacturing equipment safety. What standards apply?"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=90.0) as client:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/ai-sdk/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Test AI SDK format response
|
||||
content = ""
|
||||
async for chunk in response.aiter_text():
|
||||
content += chunk
|
||||
if len(content) > 500:
|
||||
break
|
||||
|
||||
assert len(content) > 50
|
||||
print(f"Compliance workflow response length: {len(content)} chars")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_technical_documentation_workflow(self, service_url: str):
|
||||
"""Test technical documentation research workflow"""
|
||||
session_id = f"e2e_technical_{int(time.time())}"
|
||||
|
||||
request_data = {
|
||||
"session_id": session_id,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "How do I implement functional safety according to IEC 61508 for industrial control systems?"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=90.0) as client:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Collect response
|
||||
content = ""
|
||||
async for chunk in response.aiter_text():
|
||||
content += chunk
|
||||
if len(content) > 800:
|
||||
break
|
||||
|
||||
assert len(content) > 100
|
||||
print(f"Technical documentation response length: {len(content)} chars")
|
||||
|
||||
|
||||
class TestMultiTurnConversations:
|
||||
"""Test multi-turn conversation workflows"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_progressive_standards_exploration(self, service_url: str):
|
||||
"""Test progressive exploration of standards through multiple turns"""
|
||||
session_id = f"e2e_progressive_{int(time.time())}"
|
||||
|
||||
conversation_steps = [
|
||||
"What is ISO 26262?",
|
||||
"What are the ASIL levels?",
|
||||
"How do I determine ASIL D requirements?",
|
||||
"What testing is required for ASIL D systems?"
|
||||
]
|
||||
|
||||
async with httpx.AsyncClient(timeout=90.0) as client:
|
||||
for i, question in enumerate(conversation_steps):
|
||||
request_data = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": question}]
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Read response
|
||||
content = ""
|
||||
async for chunk in response.aiter_text():
|
||||
content += chunk
|
||||
if len(content) > 300:
|
||||
break
|
||||
|
||||
assert len(content) > 30
|
||||
print(f"Turn {i+1}: {len(content)} chars")
|
||||
|
||||
# Brief pause between turns
|
||||
await asyncio.sleep(1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_comparative_analysis_workflow(self, service_url: str):
|
||||
"""Test comparative analysis across multiple standards"""
|
||||
session_id = f"e2e_comparative_{int(time.time())}"
|
||||
|
||||
comparison_questions = [
|
||||
"What are the differences between ISO 26262 and IEC 61508?",
|
||||
"Which standard is more appropriate for automotive applications?",
|
||||
"How do the safety integrity levels compare between these standards?"
|
||||
]
|
||||
|
||||
async with httpx.AsyncClient(timeout=90.0) as client:
|
||||
for question in comparison_questions:
|
||||
request_data = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": question}]
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{service_url}/api/ai-sdk/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Collect comparison response
|
||||
content = ""
|
||||
async for chunk in response.aiter_text():
|
||||
content += chunk
|
||||
if len(content) > 400:
|
||||
break
|
||||
|
||||
assert len(content) > 50
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
|
||||
class TestSpecializedQueries:
|
||||
"""Test specialized query types and edge cases"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_specific_standard_section_query(self, service_url: str):
|
||||
"""Test queries about specific sections of standards"""
|
||||
session_id = f"e2e_specific_{int(time.time())}"
|
||||
|
||||
request_data = {
|
||||
"session_id": session_id,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What does section 4.3 of ISO 26262-3 say about software architectural design?"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=90.0) as client:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
content = ""
|
||||
async for chunk in response.aiter_text():
|
||||
content += chunk
|
||||
if len(content) > 600:
|
||||
break
|
||||
|
||||
assert len(content) > 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implementation_guidance_query(self, service_url: str):
|
||||
"""Test queries asking for implementation guidance"""
|
||||
session_id = f"e2e_implementation_{int(time.time())}"
|
||||
|
||||
request_data = {
|
||||
"session_id": session_id,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "How should I implement a safety management system according to ISO 45001?"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=90.0) as client:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/ai-sdk/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
content = ""
|
||||
async for chunk in response.aiter_text():
|
||||
content += chunk
|
||||
if len(content) > 500:
|
||||
break
|
||||
|
||||
assert len(content) > 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cross_domain_standards_query(self, service_url: str):
|
||||
"""Test queries spanning multiple domains"""
|
||||
session_id = f"e2e_cross_domain_{int(time.time())}"
|
||||
|
||||
request_data = {
|
||||
"session_id": session_id,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "How do cybersecurity standards like ISO 27001 relate to functional safety standards like ISO 26262?"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=90.0) as client:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
content = ""
|
||||
async for chunk in response.aiter_text():
|
||||
content += chunk
|
||||
if len(content) > 700:
|
||||
break
|
||||
|
||||
assert len(content) > 100
|
||||
|
||||
|
||||
class TestUserExperience:
|
||||
"""Test overall user experience aspects"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_quality_indicators(self, service_url: str):
|
||||
"""Test that responses have quality indicators (good structure, citations, etc.)"""
|
||||
session_id = f"e2e_quality_{int(time.time())}"
|
||||
|
||||
request_data = {
|
||||
"session_id": session_id,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What are the key principles of risk assessment in ISO 31000?"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=90.0) as client:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Collect full response to analyze quality
|
||||
full_content = ""
|
||||
async for chunk in response.aiter_text():
|
||||
full_content += chunk
|
||||
if len(full_content) > 1200:
|
||||
break
|
||||
|
||||
# Basic quality checks
|
||||
assert len(full_content) > 100
|
||||
|
||||
# Content should contain structured information
|
||||
# (These are basic heuristics for response quality)
|
||||
assert len(full_content.split()) > 20 # At least 20 words
|
||||
|
||||
print(f"Quality response length: {len(full_content)} chars")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_recovery_experience(self, service_url: str):
|
||||
"""Test user experience when recovering from errors"""
|
||||
session_id = f"e2e_error_recovery_{int(time.time())}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=90.0) as client:
|
||||
# Start with a good question
|
||||
good_request = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": "What is ISO 9001?"}]
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=good_request,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Try a potentially problematic request
|
||||
try:
|
||||
problematic_request = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": ""}] # Empty content
|
||||
}
|
||||
|
||||
await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=problematic_request,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
except Exception:
|
||||
pass # Expected to potentially fail
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Recovery with another good question
|
||||
recovery_request = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": "Can you help me understand quality management?"}]
|
||||
}
|
||||
|
||||
recovery_response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=recovery_request,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
# Should recover successfully
|
||||
assert recovery_response.status_code == 200
|
||||
|
||||
content = ""
|
||||
async for chunk in recovery_response.aiter_text():
|
||||
content += chunk
|
||||
if len(content) > 200:
|
||||
break
|
||||
|
||||
assert len(content) > 30
|
||||
print("📤 Sending to backend...")
|
||||
402
vw-agentic-rag/tests/integration/test_full_workflow.py
Normal file
402
vw-agentic-rag/tests/integration/test_full_workflow.py
Normal file
@@ -0,0 +1,402 @@
|
||||
"""
|
||||
Full Workflow Integration Tests
|
||||
|
||||
These tests validate complete end-to-end workflows by connecting to a running service.
|
||||
They test realistic user scenarios and complex interactions.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
import httpx
|
||||
import time
|
||||
import os
|
||||
from typing import List, Dict, Any
|
||||
|
||||
|
||||
# Configuration for remote service connection
|
||||
DEFAULT_SERVICE_URL = "http://127.0.0.1:8000"
|
||||
SERVICE_URL = os.getenv("AGENTIC_RAG_SERVICE_URL", DEFAULT_SERVICE_URL)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def service_url() -> str:
|
||||
"""Get the service URL for testing"""
|
||||
return SERVICE_URL
|
||||
|
||||
|
||||
class TestCompleteWorkflows:
|
||||
"""Test complete user workflows"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_standards_research_workflow(self, service_url: str):
|
||||
"""Test a complete standards research workflow"""
|
||||
session_id = f"standards_workflow_{int(time.time())}"
|
||||
|
||||
# Simulate a user researching ISO 26262
|
||||
conversation_flow = [
|
||||
"What is ISO 26262 and what does it cover?",
|
||||
"What are the ASIL levels in ISO 26262?",
|
||||
"Can you explain ASIL D requirements in detail?",
|
||||
"How does ISO 26262 relate to vehicle cybersecurity?"
|
||||
]
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
for i, question in enumerate(conversation_flow):
|
||||
request_data = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": question}]
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{service_url}/api/ai-sdk/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Read the streaming response
|
||||
content = ""
|
||||
async for chunk in response.aiter_text():
|
||||
content += chunk
|
||||
if len(content) > 200: # Get substantial response
|
||||
break
|
||||
|
||||
# Verify we get meaningful content
|
||||
assert len(content) > 50
|
||||
print(f"Question {i+1} response length: {len(content)} chars")
|
||||
|
||||
# Small delay between questions
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manufacturing_safety_workflow(self, service_url: str):
|
||||
"""Test manufacturing safety standards workflow"""
|
||||
session_id = f"manufacturing_workflow_{int(time.time())}"
|
||||
|
||||
conversation_flow = [
|
||||
"What are the key safety standards for manufacturing equipment?",
|
||||
"How do ISO 13849 and IEC 62061 compare?",
|
||||
"What is the process for safety risk assessment in manufacturing?"
|
||||
]
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
responses = []
|
||||
|
||||
for question in conversation_flow:
|
||||
request_data = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": question}]
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Collect response content
|
||||
content = ""
|
||||
async for chunk in response.aiter_text():
|
||||
content += chunk
|
||||
if len(content) > 300:
|
||||
break
|
||||
|
||||
responses.append(content)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Verify we got responses for all questions
|
||||
assert len(responses) == len(conversation_flow)
|
||||
for response_content in responses:
|
||||
assert len(response_content) > 30
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_context_continuity(self, service_url: str):
|
||||
"""Test that session context is maintained across requests"""
|
||||
session_id = f"context_test_{int(time.time())}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
# First message - establish context
|
||||
request1 = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": "I'm working on a safety system for automotive braking. What standard should I follow?"}]
|
||||
}
|
||||
|
||||
response1 = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request1,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
assert response1.status_code == 200
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Follow-up question that depends on context
|
||||
request2 = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": "What are the specific testing requirements for this standard?"}]
|
||||
}
|
||||
|
||||
response2 = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request2,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
assert response2.status_code == 200
|
||||
|
||||
# Verify both responses are meaningful
|
||||
content1 = ""
|
||||
async for chunk in response1.aiter_text():
|
||||
content1 += chunk
|
||||
if len(content1) > 100:
|
||||
break
|
||||
|
||||
content2 = ""
|
||||
async for chunk in response2.aiter_text():
|
||||
content2 += chunk
|
||||
if len(content2) > 100:
|
||||
break
|
||||
|
||||
assert len(content1) > 50
|
||||
assert len(content2) > 50
|
||||
|
||||
|
||||
class TestErrorRecoveryWorkflows:
|
||||
"""Test error recovery and edge case workflows"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_recovery_after_error(self, service_url: str):
|
||||
"""Test that sessions can recover after encountering errors"""
|
||||
session_id = f"error_recovery_{int(time.time())}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
# Valid request
|
||||
valid_request = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": "What is ISO 9001?"}]
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=valid_request,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Try an invalid request that might cause issues
|
||||
invalid_request = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": ""}] # Empty content
|
||||
}
|
||||
|
||||
try:
|
||||
await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=invalid_request,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
except Exception:
|
||||
pass # Expected to potentially fail
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Another valid request to test recovery
|
||||
recovery_request = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": "Can you summarize what we discussed?"}]
|
||||
}
|
||||
|
||||
recovery_response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=recovery_request,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
# Session should still work
|
||||
assert recovery_response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_sessions(self, service_url: str):
|
||||
"""Test multiple concurrent sessions"""
|
||||
base_time = int(time.time())
|
||||
sessions = [f"concurrent_{base_time}_{i}" for i in range(3)]
|
||||
|
||||
async def test_session(session_id: str, question: str):
|
||||
"""Test a single session"""
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
request = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": question}]
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
return session_id
|
||||
|
||||
# Run concurrent sessions
|
||||
questions = [
|
||||
"What is ISO 27001?",
|
||||
"What is NIST Cybersecurity Framework?",
|
||||
"What is GDPR compliance?"
|
||||
]
|
||||
|
||||
tasks = [
|
||||
test_session(session_id, question)
|
||||
for session_id, question in zip(sessions, questions)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# All sessions should complete successfully
|
||||
assert len(results) == 3
|
||||
for result in results:
|
||||
assert not isinstance(result, Exception)
|
||||
|
||||
|
||||
class TestPerformanceWorkflows:
|
||||
"""Test performance-related workflows"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_fire_requests(self, service_url: str):
|
||||
"""Test rapid consecutive requests in same session"""
|
||||
session_id = f"rapid_fire_{int(time.time())}"
|
||||
|
||||
questions = [
|
||||
"Hello",
|
||||
"What is ISO 14001?",
|
||||
"Thank you",
|
||||
"Goodbye"
|
||||
]
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
for i, question in enumerate(questions):
|
||||
request = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": question}]
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
print(f"Rapid request {i+1} completed")
|
||||
|
||||
# Very short delay
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_context_workflow(self, service_url: str):
|
||||
"""Test workflow with gradually increasing context"""
|
||||
session_id = f"large_context_{int(time.time())}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
# Build up context over multiple turns
|
||||
conversation = [
|
||||
"I need to understand automotive safety standards",
|
||||
"Specifically, tell me about ISO 26262 functional safety",
|
||||
"What are the different ASIL levels and their requirements?",
|
||||
"How do I implement ASIL D for a braking system?",
|
||||
"What testing and validation is required for ASIL D?",
|
||||
"Can you provide a summary of everything we've discussed?"
|
||||
]
|
||||
|
||||
for i, message in enumerate(conversation):
|
||||
request = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": message}]
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
print(f"Context turn {i+1} completed")
|
||||
|
||||
# Allow time for processing
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
class TestRealWorldScenarios:
|
||||
"""Test realistic user scenarios"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compliance_officer_scenario(self, service_url: str):
|
||||
"""Simulate a compliance officer's typical workflow"""
|
||||
session_id = f"compliance_officer_{int(time.time())}"
|
||||
|
||||
# Typical compliance questions
|
||||
scenario_questions = [
|
||||
"I need to ensure our new product meets regulatory requirements. What standards apply to automotive safety systems?",
|
||||
"Our system is classified as ASIL C. What does this mean for our development process?",
|
||||
"What documentation do we need to prepare for safety assessment?",
|
||||
"How often do we need to review and update our safety processes?"
|
||||
]
|
||||
|
||||
async with httpx.AsyncClient(timeout=90.0) as client:
|
||||
for i, question in enumerate(scenario_questions):
|
||||
request = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": question}]
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{service_url}/api/ai-sdk/chat",
|
||||
json=request,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Allow realistic time between questions
|
||||
await asyncio.sleep(2)
|
||||
print(f"Compliance scenario step {i+1} completed")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engineer_research_scenario(self, service_url: str):
|
||||
"""Simulate an engineer researching technical details"""
|
||||
session_id = f"engineer_research_{int(time.time())}"
|
||||
|
||||
research_flow = [
|
||||
"I'm designing a safety-critical system. What's the difference between ISO 26262 and IEC 61508?",
|
||||
"For automotive applications, which standard takes precedence?",
|
||||
"What are the specific requirements for software development under ISO 26262?",
|
||||
"Can you explain the V-model development process required by the standard?"
|
||||
]
|
||||
|
||||
async with httpx.AsyncClient(timeout=90.0) as client:
|
||||
for question in research_flow:
|
||||
request = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": question}]
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Read some response to verify it's working
|
||||
content = ""
|
||||
async for chunk in response.aiter_text():
|
||||
content += chunk
|
||||
if len(content) > 150:
|
||||
break
|
||||
|
||||
assert len(content) > 50
|
||||
await asyncio.sleep(1.5)
|
||||
406
vw-agentic-rag/tests/integration/test_streaming_integration.py
Normal file
406
vw-agentic-rag/tests/integration/test_streaming_integration.py
Normal file
@@ -0,0 +1,406 @@
|
||||
"""
|
||||
Streaming Integration Tests
|
||||
|
||||
These tests validate streaming behavior by connecting to a running service.
|
||||
They focus on real-time response patterns and streaming event handling.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
import httpx
|
||||
import time
|
||||
import os
|
||||
|
||||
|
||||
# Configuration for remote service connection
|
||||
DEFAULT_SERVICE_URL = "http://127.0.0.1:8000"
|
||||
SERVICE_URL = os.getenv("AGENTIC_RAG_SERVICE_URL", DEFAULT_SERVICE_URL)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def service_url() -> str:
|
||||
"""Get the service URL for testing"""
|
||||
return SERVICE_URL
|
||||
|
||||
|
||||
class TestStreamingBehavior:
|
||||
"""Test streaming response behavior"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_streaming_response(self, service_url: str):
|
||||
"""Test that responses are properly streamed"""
|
||||
session_id = f"streaming_test_{int(time.time())}"
|
||||
|
||||
request_data = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": "What is ISO 26262?"}]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Collect streaming chunks
|
||||
chunks = []
|
||||
async for chunk in response.aiter_text():
|
||||
chunks.append(chunk)
|
||||
if len(chunks) > 10: # Get enough chunks to verify streaming
|
||||
break
|
||||
|
||||
# Should receive multiple chunks (indicating streaming)
|
||||
assert len(chunks) > 1
|
||||
|
||||
# Chunks should have content
|
||||
total_content = "".join(chunks)
|
||||
assert len(total_content) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_sdk_streaming_format(self, service_url: str):
|
||||
"""Test AI SDK compatible streaming format"""
|
||||
session_id = f"ai_sdk_streaming_{int(time.time())}"
|
||||
|
||||
request_data = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": "Explain vehicle safety testing"}]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/ai-sdk/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "text/plain" in response.headers.get("content-type", "")
|
||||
|
||||
# Test streaming behavior
|
||||
chunk_count = 0
|
||||
total_length = 0
|
||||
|
||||
async for chunk in response.aiter_text():
|
||||
chunk_count += 1
|
||||
total_length += len(chunk)
|
||||
|
||||
if chunk_count > 15: # Collect enough chunks
|
||||
break
|
||||
|
||||
# Verify streaming characteristics
|
||||
assert chunk_count > 1 # Multiple chunks
|
||||
assert total_length > 50 # Meaningful content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_performance(self, service_url: str):
|
||||
"""Test streaming response timing and performance"""
|
||||
session_id = f"streaming_perf_{int(time.time())}"
|
||||
|
||||
request_data = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": "What are automotive safety standards?"}]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
start_time = time.time()
|
||||
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
first_chunk_time = None
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in response.aiter_text():
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time.time()
|
||||
|
||||
chunk_count += 1
|
||||
if chunk_count > 5: # Get a few chunks for timing
|
||||
break
|
||||
|
||||
# Time to first chunk should be reasonable (< 10 seconds)
|
||||
if first_chunk_time:
|
||||
time_to_first_chunk = first_chunk_time - start_time
|
||||
assert time_to_first_chunk < 10.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_interruption_handling(self, service_url: str):
|
||||
"""Test behavior when streaming is interrupted"""
|
||||
session_id = f"streaming_interrupt_{int(time.time())}"
|
||||
|
||||
request_data = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": "Tell me about ISO standards"}]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Read only a few chunks then stop
|
||||
chunk_count = 0
|
||||
async for chunk in response.aiter_text():
|
||||
chunk_count += 1
|
||||
if chunk_count >= 3:
|
||||
break # Interrupt streaming
|
||||
|
||||
# Should have received some chunks
|
||||
assert chunk_count > 0
|
||||
|
||||
|
||||
class TestConcurrentStreaming:
|
||||
"""Test concurrent streaming scenarios"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_concurrent_streams(self, service_url: str):
|
||||
"""Test multiple concurrent streaming requests"""
|
||||
base_time = int(time.time())
|
||||
|
||||
async def stream_request(session_suffix: str, question: str):
|
||||
"""Make a single streaming request"""
|
||||
session_id = f"concurrent_stream_{base_time}_{session_suffix}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json={
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": question}]
|
||||
},
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Read some chunks
|
||||
chunks = 0
|
||||
async for chunk in response.aiter_text():
|
||||
chunks += 1
|
||||
if chunks > 5:
|
||||
break
|
||||
|
||||
return chunks
|
||||
|
||||
# Run multiple concurrent streams
|
||||
questions = [
|
||||
"What is ISO 26262?",
|
||||
"Explain NIST framework",
|
||||
"What is GDPR?"
|
||||
]
|
||||
|
||||
tasks = [
|
||||
stream_request(f"session_{i}", question)
|
||||
for i, question in enumerate(questions)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# All streams should complete successfully
|
||||
assert len(results) == 3
|
||||
for result in results:
|
||||
assert not isinstance(result, Exception)
|
||||
assert result > 0 # Each stream should receive chunks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_session_rapid_requests(self, service_url: str):
|
||||
"""Test rapid requests in the same session"""
|
||||
session_id = f"rapid_session_{int(time.time())}"
|
||||
|
||||
questions = [
|
||||
"Hello",
|
||||
"What is ISO 9001?",
|
||||
"Thank you"
|
||||
]
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
for i, question in enumerate(questions):
|
||||
request_data = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": question}]
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Read some response
|
||||
chunk_count = 0
|
||||
async for chunk in response.aiter_text():
|
||||
chunk_count += 1
|
||||
if chunk_count > 3:
|
||||
break
|
||||
|
||||
print(f"Request {i+1} completed with {chunk_count} chunks")
|
||||
|
||||
# Very short delay
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
|
||||
class TestStreamingErrorHandling:
|
||||
"""Test error handling during streaming"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_with_invalid_session(self, service_url: str):
|
||||
"""Test streaming behavior with edge case session IDs"""
|
||||
test_cases = [
|
||||
"", # Empty session ID
|
||||
"a" * 1000, # Very long session ID
|
||||
"session with spaces", # Session ID with spaces
|
||||
"session/with/slashes" # Session ID with special chars
|
||||
]
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
for session_id in test_cases:
|
||||
request_data = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
# Should either work or return validation error
|
||||
assert response.status_code in [200, 422]
|
||||
|
||||
except Exception as e:
|
||||
# Some edge cases might cause exceptions, which is acceptable
|
||||
print(f"Session ID '{session_id}' caused exception: {e}")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_with_large_messages(self, service_url: str):
|
||||
"""Test streaming with large message content"""
|
||||
session_id = f"large_msg_stream_{int(time.time())}"
|
||||
|
||||
# Create a large message
|
||||
large_content = "Please explain safety standards. " * 100 # ~3KB message
|
||||
|
||||
request_data = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": large_content}]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=90.0) as client:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
# Should handle large messages appropriately
|
||||
assert response.status_code in [200, 413, 422]
|
||||
|
||||
if response.status_code == 200:
|
||||
# If accepted, should stream properly
|
||||
chunk_count = 0
|
||||
async for chunk in response.aiter_text():
|
||||
chunk_count += 1
|
||||
if chunk_count > 5:
|
||||
break
|
||||
|
||||
assert chunk_count > 0
|
||||
|
||||
|
||||
class TestStreamingContentValidation:
|
||||
"""Test streaming content quality and format"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_content_encoding(self, service_url: str):
|
||||
"""Test that streaming content is properly encoded"""
|
||||
session_id = f"encoding_test_{int(time.time())}"
|
||||
|
||||
# Test with special characters and unicode
|
||||
test_message = "What is ISO 26262? Please explain with émphasis on safety ñorms."
|
||||
|
||||
request_data = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": test_message}]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Collect content and verify encoding
|
||||
content = ""
|
||||
async for chunk in response.aiter_text():
|
||||
content += chunk
|
||||
if len(content) > 100:
|
||||
break
|
||||
|
||||
# Content should be valid UTF-8
|
||||
assert isinstance(content, str)
|
||||
assert len(content) > 0
|
||||
|
||||
# Should be able to encode/decode
|
||||
encoded = content.encode('utf-8')
|
||||
decoded = encoded.decode('utf-8')
|
||||
assert decoded == content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_response_consistency(self, service_url: str):
|
||||
"""Test that streaming responses are consistent for similar queries"""
|
||||
base_session = f"consistency_test_{int(time.time())}"
|
||||
|
||||
# Ask the same question multiple times
|
||||
test_question = "What is ISO 26262?"
|
||||
|
||||
responses = []
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
for i in range(3):
|
||||
session_id = f"{base_session}_{i}"
|
||||
|
||||
request_data = {
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": test_question}]
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{service_url}/api/chat",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Collect response
|
||||
content = ""
|
||||
async for chunk in response.aiter_text():
|
||||
content += chunk
|
||||
if len(content) > 200:
|
||||
break
|
||||
|
||||
responses.append(content)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# All responses should have content
|
||||
for response_content in responses:
|
||||
assert len(response_content) > 50
|
||||
|
||||
# Responses should have some consistency (all non-empty)
|
||||
assert len([r for r in responses if r.strip()]) == len(responses)
|
||||
1
vw-agentic-rag/tests/unit/__init__.py
Normal file
1
vw-agentic-rag/tests/unit/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Empty __init__.py files to make test packages
|
||||
114
vw-agentic-rag/tests/unit/test_aggressive_trimming.py
Normal file
114
vw-agentic-rag/tests/unit/test_aggressive_trimming.py
Normal file
@@ -0,0 +1,114 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试新的积极修剪策略:即使token数很少,也要修剪历史工具调用结果
|
||||
"""
|
||||
import pytest
|
||||
from service.graph.message_trimmer import ConversationTrimmer
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
|
||||
from langchain_core.messages.utils import count_tokens_approximately
|
||||
|
||||
|
||||
def test_aggressive_tool_history_trimming():
|
||||
"""测试积极的工具历史修剪策略"""
|
||||
|
||||
# 直接创建修剪器,避免配置依赖
|
||||
trimmer = ConversationTrimmer(max_context_length=100000)
|
||||
|
||||
# 创建包含多轮工具调用的对话(token数很少)
|
||||
messages = [
|
||||
SystemMessage(content='You are a helpful assistant.'),
|
||||
|
||||
# 历史对话轮次1
|
||||
HumanMessage(content='搜索汽车标准'),
|
||||
AIMessage(content='搜索中', tool_calls=[{'id': 'call_1', 'name': 'search', 'args': {'query': '汽车标准'}}]),
|
||||
ToolMessage(content='历史结果1', tool_call_id='call_1', name='search'),
|
||||
AIMessage(content='汽车标准信息'),
|
||||
|
||||
# 历史对话轮次2
|
||||
HumanMessage(content='搜索电池标准'),
|
||||
AIMessage(content='搜索中', tool_calls=[{'id': 'call_2', 'name': 'search', 'args': {'query': '电池标准'}}]),
|
||||
ToolMessage(content='历史结果2', tool_call_id='call_2', name='search'),
|
||||
AIMessage(content='电池标准信息'),
|
||||
|
||||
# 新的用户查询(触发修剪的时机)
|
||||
HumanMessage(content='搜索安全标准'),
|
||||
]
|
||||
|
||||
# 验证token数很少,远低于阈值
|
||||
token_count = count_tokens_approximately(messages)
|
||||
assert token_count < 1000, f"Token count should be low, got {token_count}"
|
||||
assert token_count < trimmer.history_token_limit, "Token count should be well below limit"
|
||||
|
||||
# 验证识别到多个工具轮次
|
||||
tool_rounds = trimmer._identify_tool_rounds(messages)
|
||||
assert len(tool_rounds) == 2, f"Should identify 2 tool rounds, got {len(tool_rounds)}"
|
||||
|
||||
# 验证触发修剪(因为有多个工具轮次)
|
||||
should_trim = trimmer.should_trim(messages)
|
||||
assert should_trim, "Should trigger trimming due to multiple tool rounds"
|
||||
|
||||
# 执行修剪
|
||||
trimmed = trimmer.trim_conversation_history(messages)
|
||||
|
||||
# 验证修剪效果
|
||||
assert len(trimmed) < len(messages), "Should have fewer messages after trimming"
|
||||
|
||||
# 验证保留了系统消息和初始查询
|
||||
assert isinstance(trimmed[0], SystemMessage), "Should preserve system message"
|
||||
assert isinstance(trimmed[1], HumanMessage), "Should preserve initial human message"
|
||||
|
||||
# 验证只保留了最新轮次的工具调用结果
|
||||
tool_messages = [msg for msg in trimmed if isinstance(msg, ToolMessage)]
|
||||
assert len(tool_messages) == 1, f"Should only keep 1 tool message, got {len(tool_messages)}"
|
||||
assert tool_messages[0].content == '历史结果2', "Should keep the most recent tool result"
|
||||
|
||||
|
||||
def test_single_tool_round_no_trimming():
|
||||
"""测试单轮工具调用不触发修剪"""
|
||||
|
||||
trimmer = ConversationTrimmer(max_context_length=100000)
|
||||
|
||||
# 只有一轮工具调用的对话
|
||||
messages = [
|
||||
SystemMessage(content='You are a helpful assistant.'),
|
||||
HumanMessage(content='搜索信息'),
|
||||
AIMessage(content='搜索中', tool_calls=[{'id': 'call_1', 'name': 'search', 'args': {'query': '信息'}}]),
|
||||
ToolMessage(content='搜索结果', tool_call_id='call_1', name='search'),
|
||||
AIMessage(content='这是搜索到的信息'),
|
||||
HumanMessage(content='新的问题'),
|
||||
]
|
||||
|
||||
# 验证只有一个工具轮次
|
||||
tool_rounds = trimmer._identify_tool_rounds(messages)
|
||||
assert len(tool_rounds) == 1, f"Should identify 1 tool round, got {len(tool_rounds)}"
|
||||
|
||||
# 验证不触发修剪(因为只有一个工具轮次且token数不高)
|
||||
should_trim = trimmer.should_trim(messages)
|
||||
assert not should_trim, "Should not trigger trimming for single tool round with low tokens"
|
||||
|
||||
|
||||
def test_no_tool_rounds_no_trimming():
|
||||
"""测试没有工具调用的对话不触发修剪"""
|
||||
|
||||
trimmer = ConversationTrimmer(max_context_length=100000)
|
||||
|
||||
# 没有工具调用的对话
|
||||
messages = [
|
||||
SystemMessage(content='You are a helpful assistant.'),
|
||||
HumanMessage(content='Hello'),
|
||||
AIMessage(content='Hi there!'),
|
||||
HumanMessage(content='How are you?'),
|
||||
AIMessage(content='I am doing well, thank you!'),
|
||||
]
|
||||
|
||||
# 验证没有工具轮次
|
||||
tool_rounds = trimmer._identify_tool_rounds(messages)
|
||||
assert len(tool_rounds) == 0, f"Should identify 0 tool rounds, got {len(tool_rounds)}"
|
||||
|
||||
# 验证不触发修剪
|
||||
should_trim = trimmer.should_trim(messages)
|
||||
assert not should_trim, "Should not trigger trimming without tool rounds"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
143
vw-agentic-rag/tests/unit/test_assistant_ui_best_practices.py
Normal file
143
vw-agentic-rag/tests/unit/test_assistant_ui_best_practices.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
Test assistant-ui best practices implementation
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
def test_package_json_dependencies():
|
||||
"""Test that package.json has the correct assistant-ui dependencies"""
|
||||
package_json_path = os.path.join(os.path.dirname(__file__), "../../web/package.json")
|
||||
|
||||
with open(package_json_path, 'r') as f:
|
||||
package_data = json.load(f)
|
||||
|
||||
deps = package_data.get("dependencies", {})
|
||||
|
||||
# Check for essential assistant-ui packages
|
||||
assert "@assistant-ui/react" in deps, "Missing @assistant-ui/react"
|
||||
assert "@assistant-ui/react-ui" in deps, "Missing @assistant-ui/react-ui"
|
||||
assert "@assistant-ui/react-markdown" in deps, "Missing @assistant-ui/react-markdown"
|
||||
assert "@assistant-ui/react-data-stream" in deps, "Missing @assistant-ui/react-data-stream"
|
||||
|
||||
# Check versions are reasonable (not too old)
|
||||
react_version = deps["@assistant-ui/react"]
|
||||
assert "0.10" in react_version or "0.9" in react_version, f"Version too old: {react_version}"
|
||||
|
||||
print("✅ Package dependencies test passed")
|
||||
|
||||
|
||||
def test_env_configuration():
|
||||
"""Test that environment configuration files exist"""
|
||||
env_local_path = os.path.join(os.path.dirname(__file__), "../../web/.env.local")
|
||||
assert os.path.exists(env_local_path), "Missing .env.local file"
|
||||
|
||||
with open(env_local_path, 'r') as f:
|
||||
env_content = f.read()
|
||||
|
||||
assert "NEXT_PUBLIC_LANGGRAPH_API_URL" in env_content, "Missing API URL config"
|
||||
assert "NEXT_PUBLIC_LANGGRAPH_ASSISTANT_ID" in env_content, "Missing Assistant ID config"
|
||||
|
||||
print("✅ Environment configuration test passed")
|
||||
|
||||
|
||||
def test_api_route_structure():
|
||||
"""Test that API routes are properly structured"""
|
||||
# Check main chat API route exists
|
||||
chat_route_path = os.path.join(os.path.dirname(__file__), "../../web/src/app/api/chat/route.ts")
|
||||
assert os.path.exists(chat_route_path), "Missing chat API route"
|
||||
|
||||
with open(chat_route_path, 'r') as f:
|
||||
route_content = f.read()
|
||||
|
||||
# Check for essential API patterns
|
||||
assert "export async function POST" in route_content, "Missing POST handler"
|
||||
assert "Response" in route_content, "Missing Response handling"
|
||||
assert "x-vercel-ai-data-stream" in route_content, "Missing AI SDK compatibility header"
|
||||
|
||||
print("✅ API route structure test passed")
|
||||
|
||||
|
||||
def test_component_structure():
|
||||
"""Test that main components follow best practices"""
|
||||
# Check main page component
|
||||
page_path = os.path.join(os.path.dirname(__file__), "../../web/src/app/page.tsx")
|
||||
assert os.path.exists(page_path), "Missing main page component"
|
||||
|
||||
with open(page_path, 'r') as f:
|
||||
page_content = f.read()
|
||||
|
||||
# Check for key React patterns and components
|
||||
assert '"use client"' in page_content, "Missing client-side directive"
|
||||
assert "Assistant" in page_content, "Missing Assistant component"
|
||||
assert "export default function" in page_content, "Missing default function export"
|
||||
|
||||
# Check for proper structure
|
||||
assert "className=" in page_content, "Missing CSS class usage"
|
||||
assert "h-screen" in page_content or "h-full" in page_content, "Missing full height layout"
|
||||
|
||||
print("✅ Component structure test passed")
|
||||
|
||||
|
||||
def test_markdown_component():
|
||||
"""Test that markdown component is properly configured"""
|
||||
markdown_path = os.path.join(os.path.dirname(__file__), "../../web/src/components/ui/markdown-text.tsx")
|
||||
assert os.path.exists(markdown_path), "Missing markdown component"
|
||||
|
||||
with open(markdown_path, 'r') as f:
|
||||
markdown_content = f.read()
|
||||
|
||||
assert "MarkdownTextPrimitive" in markdown_content, "Missing markdown primitive"
|
||||
assert "remarkGfm" in markdown_content, "Missing GFM support"
|
||||
|
||||
print("✅ Markdown component test passed")
|
||||
|
||||
|
||||
def test_best_practices_documentation():
|
||||
"""Test that best practices documentation exists and is comprehensive"""
|
||||
docs_path = os.path.join(os.path.dirname(__file__), "../../docs/topics/ASSISTANT_UI_BEST_PRACTICES.md")
|
||||
assert os.path.exists(docs_path), "Missing best practices documentation"
|
||||
|
||||
with open(docs_path, 'r') as f:
|
||||
docs_content = f.read()
|
||||
|
||||
# Check for key sections
|
||||
assert "Assistant-UI + LangGraph + FastAPI" in docs_content, "Missing main title"
|
||||
assert "Implementation Status" in docs_content, "Missing implementation status"
|
||||
assert "Package Dependencies Updated" in docs_content, "Missing dependencies section"
|
||||
assert "Server-Side API Routes" in docs_content, "Missing API routes explanation"
|
||||
|
||||
print("✅ Best practices documentation test passed")
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all tests"""
|
||||
print("🧪 Running assistant-ui best practices validation tests...")
|
||||
|
||||
try:
|
||||
test_package_json_dependencies()
|
||||
test_env_configuration()
|
||||
test_api_route_structure()
|
||||
test_component_structure()
|
||||
test_markdown_component()
|
||||
test_best_practices_documentation()
|
||||
|
||||
print("\n🎉 All assistant-ui best practices tests passed!")
|
||||
print("✅ Your implementation follows the recommended patterns for:")
|
||||
print(" - Package dependencies and versions")
|
||||
print(" - Environment configuration")
|
||||
print(" - API route structure")
|
||||
print(" - Component composition")
|
||||
print(" - Markdown rendering")
|
||||
print(" - Documentation completeness")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
exit(0 if success else 1)
|
||||
141
vw-agentic-rag/tests/unit/test_memory.py
Normal file
141
vw-agentic-rag/tests/unit/test_memory.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import pytest
|
||||
from service.memory.store import InMemoryStore
|
||||
from service.graph.state import TurnState, Message
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_store():
|
||||
"""Create memory store for testing"""
|
||||
return InMemoryStore(ttl_days=1) # Short TTL for testing
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_state():
|
||||
"""Create sample turn state"""
|
||||
state = TurnState(session_id="test_session")
|
||||
state.messages = [
|
||||
Message(role="user", content="Hello"),
|
||||
Message(role="assistant", content="Hi there!")
|
||||
]
|
||||
return state
|
||||
|
||||
|
||||
def test_create_new_session(memory_store):
|
||||
"""Test creating a new session"""
|
||||
state = memory_store.create_new_session("new_session")
|
||||
assert state.session_id == "new_session"
|
||||
assert len(state.messages) == 0
|
||||
|
||||
# Verify it's stored
|
||||
retrieved = memory_store.get("new_session")
|
||||
assert retrieved is not None
|
||||
assert retrieved.session_id == "new_session"
|
||||
|
||||
|
||||
def test_put_and_get_state(memory_store, sample_state):
|
||||
"""Test storing and retrieving state"""
|
||||
memory_store.put("test_session", sample_state)
|
||||
|
||||
retrieved = memory_store.get("test_session")
|
||||
assert retrieved is not None
|
||||
assert retrieved.session_id == "test_session"
|
||||
assert len(retrieved.messages) == 2
|
||||
assert retrieved.messages[0].content == "Hello"
|
||||
|
||||
|
||||
def test_get_nonexistent_session(memory_store):
|
||||
"""Test getting non-existent session returns None"""
|
||||
result = memory_store.get("nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_add_message(memory_store):
|
||||
"""Test adding messages to conversation"""
|
||||
memory_store.create_new_session("test_session")
|
||||
|
||||
message = Message(role="user", content="Test message")
|
||||
memory_store.add_message("test_session", message)
|
||||
|
||||
state = memory_store.get("test_session")
|
||||
assert len(state.messages) == 1
|
||||
assert state.messages[0].content == "Test message"
|
||||
|
||||
|
||||
def test_add_message_to_nonexistent_session(memory_store):
|
||||
"""Test adding message creates new session if it doesn't exist"""
|
||||
message = Message(role="user", content="Test message")
|
||||
memory_store.add_message("new_session", message)
|
||||
|
||||
state = memory_store.get("new_session")
|
||||
assert state is not None
|
||||
assert len(state.messages) == 1
|
||||
|
||||
|
||||
def test_get_conversation_history(memory_store):
|
||||
"""Test conversation history formatting"""
|
||||
memory_store.create_new_session("test_session")
|
||||
|
||||
messages = [
|
||||
Message(role="user", content="Hello"),
|
||||
Message(role="assistant", content="Hi there!"),
|
||||
Message(role="user", content="How are you?"),
|
||||
Message(role="assistant", content="I'm doing well!")
|
||||
]
|
||||
|
||||
for msg in messages:
|
||||
memory_store.add_message("test_session", msg)
|
||||
|
||||
history = memory_store.get_conversation_history("test_session")
|
||||
|
||||
assert "User: Hello" in history
|
||||
assert "Assistant: Hi there!" in history
|
||||
assert "User: How are you?" in history
|
||||
assert "Assistant: I'm doing well!" in history
|
||||
|
||||
|
||||
def test_get_conversation_history_empty(memory_store):
|
||||
"""Test conversation history for empty session"""
|
||||
history = memory_store.get_conversation_history("nonexistent")
|
||||
assert history == ""
|
||||
|
||||
|
||||
def test_trim_messages(memory_store):
|
||||
"""Test message trimming"""
|
||||
memory_store.create_new_session("test_session")
|
||||
|
||||
# Add many messages
|
||||
for i in range(25):
|
||||
memory_store.add_message("test_session", Message(role="user", content=f"Message {i}"))
|
||||
|
||||
# Trim to 10 messages
|
||||
memory_store.trim("test_session", max_messages=10)
|
||||
|
||||
state = memory_store.get("test_session")
|
||||
assert len(state.messages) <= 10
|
||||
|
||||
|
||||
def test_trim_nonexistent_session(memory_store):
|
||||
"""Test trimming non-existent session doesn't crash"""
|
||||
memory_store.trim("nonexistent", max_messages=10)
|
||||
# Should not raise an exception
|
||||
|
||||
|
||||
def test_ttl_cleanup(memory_store):
|
||||
"""Test TTL-based cleanup"""
|
||||
# Create store with very short TTL for testing
|
||||
short_ttl_store = InMemoryStore(ttl_days=0.001) # ~1.5 minutes
|
||||
|
||||
# Add a session
|
||||
state = TurnState(session_id="test_session")
|
||||
short_ttl_store.put("test_session", state)
|
||||
|
||||
# Verify it exists
|
||||
assert short_ttl_store.get("test_session") is not None
|
||||
|
||||
# Manually expire it by manipulating internal timestamp
|
||||
short_ttl_store.store["test_session"]["last_updated"] = datetime.now() - timedelta(days=1)
|
||||
|
||||
# Try to get it - should trigger cleanup
|
||||
assert short_ttl_store.get("test_session") is None
|
||||
assert "test_session" not in short_ttl_store.store
|
||||
194
vw-agentic-rag/tests/unit/test_message_trimmer.py
Normal file
194
vw-agentic-rag/tests/unit/test_message_trimmer.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Test conversation history trimming functionality.
|
||||
"""
|
||||
import pytest
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
|
||||
|
||||
from service.graph.message_trimmer import ConversationTrimmer, create_conversation_trimmer
|
||||
|
||||
|
||||
def test_conversation_trimmer_basic():
|
||||
"""Test basic trimming functionality."""
|
||||
trimmer = ConversationTrimmer(max_context_length=50) # Very low limit for testing (85% = 42 tokens)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="What is the capital of France?"),
|
||||
AIMessage(content="The capital of France is Paris."),
|
||||
HumanMessage(content="What about Germany?"),
|
||||
AIMessage(content="The capital of Germany is Berlin."),
|
||||
HumanMessage(content="And Italy?"),
|
||||
AIMessage(content="The capital of Italy is Rome."),
|
||||
]
|
||||
|
||||
# Should trigger trimming due to low token limit
|
||||
assert trimmer.should_trim(messages)
|
||||
|
||||
trimmed = trimmer.trim_conversation_history(messages)
|
||||
|
||||
# Should preserve system message and keep recent messages
|
||||
assert len(trimmed) < len(messages)
|
||||
assert isinstance(trimmed[0], SystemMessage)
|
||||
assert "helpful assistant" in trimmed[0].content
|
||||
|
||||
|
||||
def test_conversation_trimmer_no_trim_needed():
|
||||
"""Test when no trimming is needed."""
|
||||
trimmer = ConversationTrimmer(max_context_length=100000) # High limit
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hi there!"),
|
||||
]
|
||||
|
||||
# Should not need trimming
|
||||
assert not trimmer.should_trim(messages)
|
||||
|
||||
trimmed = trimmer.trim_conversation_history(messages)
|
||||
|
||||
# Should return same messages
|
||||
assert len(trimmed) == len(messages)
|
||||
|
||||
|
||||
def test_conversation_trimmer_fallback():
|
||||
"""Test fallback trimming logic."""
|
||||
trimmer = ConversationTrimmer(max_context_length=100)
|
||||
|
||||
# Create many messages to trigger fallback
|
||||
messages = [SystemMessage(content="System")] + [
|
||||
HumanMessage(content=f"Message {i}") for i in range(50)
|
||||
]
|
||||
|
||||
trimmed = trimmer._fallback_trim(messages, max_messages=5)
|
||||
|
||||
# Should keep system message + 4 recent messages
|
||||
assert len(trimmed) == 5
|
||||
assert isinstance(trimmed[0], SystemMessage)
|
||||
|
||||
|
||||
def test_create_conversation_trimmer():
|
||||
"""Test trimmer factory function."""
|
||||
# Test with explicit configuration
|
||||
custom_trimmer = create_conversation_trimmer(max_context_length=128000)
|
||||
assert custom_trimmer.max_context_length == 128000
|
||||
assert isinstance(custom_trimmer, ConversationTrimmer)
|
||||
assert custom_trimmer.preserve_system is True
|
||||
|
||||
|
||||
def test_multi_round_tool_optimization():
|
||||
"""Test multi-round tool call optimization."""
|
||||
trimmer = ConversationTrimmer(max_context_length=100000) # High limit to focus on optimization
|
||||
|
||||
# Create a multi-round tool calling scenario
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Search for food safety and cosmetics testing standards."),
|
||||
|
||||
# First tool round
|
||||
AIMessage(content="I'll search for information.", tool_calls=[
|
||||
{"id": "call_1", "name": "search_tool", "args": {"query": "food safety"}}
|
||||
]),
|
||||
ToolMessage(content='{"results": ["Large food safety result data..." * 100]}', tool_call_id="call_1", name="search_tool"),
|
||||
|
||||
# Second tool round
|
||||
AIMessage(content="Let me search for more specific information.", tool_calls=[
|
||||
{"id": "call_2", "name": "search_tool", "args": {"query": "cosmetics testing"}}
|
||||
]),
|
||||
ToolMessage(content='{"results": ["Large cosmetics testing data..." * 100]}', tool_call_id="call_2", name="search_tool"),
|
||||
|
||||
# Third tool round (most recent)
|
||||
AIMessage(content="Now let me get equipment information.", tool_calls=[
|
||||
{"id": "call_3", "name": "search_tool", "args": {"query": "testing equipment"}}
|
||||
]),
|
||||
ToolMessage(content='{"results": ["Testing equipment info..." * 50]}', tool_call_id="call_3", name="search_tool"),
|
||||
]
|
||||
|
||||
# Test tool round identification
|
||||
tool_rounds = trimmer._identify_tool_rounds(messages)
|
||||
assert len(tool_rounds) == 3 # Should identify 3 tool rounds
|
||||
|
||||
# Test optimization
|
||||
optimized = trimmer._optimize_multi_round_tool_calls(messages)
|
||||
|
||||
# Should preserve: SystemMessage + HumanMessage + latest tool round (AI + Tool messages)
|
||||
expected_length = 1 + 1 + 2 # system + human + (latest AI + latest Tool)
|
||||
assert len(optimized) == expected_length
|
||||
|
||||
# Check preserved messages
|
||||
assert isinstance(optimized[0], SystemMessage)
|
||||
assert isinstance(optimized[1], HumanMessage)
|
||||
assert isinstance(optimized[2], AIMessage)
|
||||
assert optimized[2].tool_calls[0]["id"] == "call_3" # Should be the latest tool call
|
||||
assert isinstance(optimized[3], ToolMessage)
|
||||
assert optimized[3].tool_call_id == "call_3" # Should be the latest tool result
|
||||
|
||||
|
||||
def test_multi_round_optimization_single_round():
|
||||
"""Test that single tool round is not optimized."""
|
||||
trimmer = ConversationTrimmer(max_context_length=100000)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Search for something."),
|
||||
AIMessage(content="I'll search.", tool_calls=[{"id": "call_1", "name": "search_tool", "args": {}}]),
|
||||
ToolMessage(content='{"results": ["data"]}', tool_call_id="call_1", name="search_tool"),
|
||||
]
|
||||
|
||||
optimized = trimmer._optimize_multi_round_tool_calls(messages)
|
||||
|
||||
# Should return all messages unchanged for single round
|
||||
assert len(optimized) == len(messages)
|
||||
|
||||
|
||||
def test_multi_round_optimization_no_tools():
|
||||
"""Test that conversations without tools are not optimized."""
|
||||
trimmer = ConversationTrimmer(max_context_length=100000)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hi there!"),
|
||||
HumanMessage(content="How are you?"),
|
||||
AIMessage(content="I'm doing well, thanks!"),
|
||||
]
|
||||
|
||||
optimized = trimmer._optimize_multi_round_tool_calls(messages)
|
||||
|
||||
# Should return all messages unchanged
|
||||
assert len(optimized) == len(messages)
|
||||
|
||||
|
||||
def test_full_trimming_with_optimization():
|
||||
"""Test complete trimming process with multi-round optimization."""
|
||||
trimmer = ConversationTrimmer(max_context_length=50) # Very low limit
|
||||
|
||||
# Create messages that will trigger both optimization and regular trimming
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Search for food safety and cosmetics testing standards."),
|
||||
|
||||
# First tool round (should be removed by optimization)
|
||||
AIMessage(content="I'll search for information.", tool_calls=[
|
||||
{"id": "call_1", "name": "search_tool", "args": {"query": "food safety"}}
|
||||
]),
|
||||
ToolMessage(content='{"results": ["Large food safety result data..." * 100]}', tool_call_id="call_1", name="search_tool"),
|
||||
|
||||
# Second tool round (most recent, should be preserved)
|
||||
AIMessage(content="Let me search for more information.", tool_calls=[
|
||||
{"id": "call_2", "name": "search_tool", "args": {"query": "cosmetics testing"}}
|
||||
]),
|
||||
ToolMessage(content='{"results": ["Large cosmetics testing data..." * 100]}', tool_call_id="call_2", name="search_tool"),
|
||||
]
|
||||
|
||||
trimmed = trimmer.trim_conversation_history(messages)
|
||||
|
||||
# Should apply optimization first, then regular trimming if needed
|
||||
assert len(trimmed) <= len(messages)
|
||||
|
||||
# System message should always be preserved
|
||||
assert any(isinstance(msg, SystemMessage) for msg in trimmed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
353
vw-agentic-rag/tests/unit/test_retrieval.py
Normal file
353
vw-agentic-rag/tests/unit/test_retrieval.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
Unit tests for agentic retrieval tools.
|
||||
Tests the HTTP wrapper tools that interface with the retrieval API.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, Mock
|
||||
import httpx
|
||||
|
||||
from service.retrieval.retrieval import AgenticRetrieval, RetrievalResponse
|
||||
from service.retrieval.clients import RetrievalAPIError, normalize_search_result
|
||||
|
||||
|
||||
class TestAgenticRetrieval:
|
||||
"""Test the agentic retrieval HTTP client."""
|
||||
|
||||
@pytest.fixture
|
||||
def retrieval_client(self):
|
||||
"""Create retrieval client for testing."""
|
||||
with patch('service.retrieval.retrieval.get_config') as mock_config:
|
||||
# Mock the new config structure
|
||||
mock_config.return_value.retrieval.endpoint = "https://test-search.search.azure.com"
|
||||
mock_config.return_value.retrieval.api_key = "test-key"
|
||||
mock_config.return_value.retrieval.api_version = "2024-11-01-preview"
|
||||
mock_config.return_value.retrieval.semantic_configuration = "default"
|
||||
|
||||
# Mock embedding config
|
||||
mock_config.return_value.retrieval.embedding.base_url = "http://test-embedding"
|
||||
mock_config.return_value.retrieval.embedding.api_key = "test-embedding-key"
|
||||
mock_config.return_value.retrieval.embedding.model = "test-embedding-model"
|
||||
mock_config.return_value.retrieval.embedding.dimension = 1536
|
||||
|
||||
# Mock index config
|
||||
mock_config.return_value.retrieval.index.standard_regulation_index = "index-catonline-standard-regulation-v2-prd"
|
||||
mock_config.return_value.retrieval.index.chunk_index = "index-catonline-chunk-v2-prd"
|
||||
mock_config.return_value.retrieval.index.chunk_user_manual_index = "index-cat-usermanual-chunk-prd"
|
||||
|
||||
return AgenticRetrieval()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_standard_regulation_success(self, retrieval_client):
|
||||
"""Test successful standards regulation retrieval."""
|
||||
mock_response_data = {
|
||||
"value": [
|
||||
{
|
||||
"id": "test_id_1",
|
||||
"title": "ISO 26262-1:2018",
|
||||
"document_code": "ISO26262-1",
|
||||
"document_category": "Standard",
|
||||
"@order_num": 1 # Should be preserved
|
||||
},
|
||||
{
|
||||
"id": "test_id_2",
|
||||
"title": "ISO 26262-3:2018",
|
||||
"document_code": "ISO26262-3",
|
||||
"document_category": "Standard",
|
||||
"@order_num": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value=mock_response_data) as mock_search:
|
||||
result = await retrieval_client.retrieve_standard_regulation("ISO 26262")
|
||||
|
||||
# Verify search call
|
||||
mock_search.assert_called_once()
|
||||
call_args = mock_search.call_args[1] # keyword arguments
|
||||
assert call_args['search_text'] == "ISO 26262"
|
||||
assert call_args['index_name'] == "index-catonline-standard-regulation-v2-prd"
|
||||
assert call_args['top_k'] == 10
|
||||
|
||||
# Verify response structure
|
||||
assert isinstance(result, RetrievalResponse)
|
||||
assert len(result.results) == 2
|
||||
assert result.total_count == 2
|
||||
assert result.took_ms is not None
|
||||
|
||||
# Verify normalization
|
||||
first_result = result.results[0]
|
||||
assert first_result['id'] == "test_id_1"
|
||||
assert first_result['title'] == "ISO 26262-1:2018"
|
||||
# Verify order number is preserved
|
||||
assert "@order_num" in first_result
|
||||
assert "content" not in first_result # Should not be included for standards
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_doc_chunk_success(self, retrieval_client):
|
||||
"""Test successful document chunk retrieval."""
|
||||
mock_response_data = {
|
||||
"value": [
|
||||
{
|
||||
"id": "chunk_id_1",
|
||||
"title": "Functional Safety Requirements",
|
||||
"content": "Detailed content about functional safety...",
|
||||
"document_code": "ISO26262-1",
|
||||
"@order_num": 1 # Should be preserved
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value=mock_response_data) as mock_search:
|
||||
result = await retrieval_client.retrieve_doc_chunk_standard_regulation(
|
||||
"functional safety",
|
||||
conversation_history="Previous question about ISO 26262"
|
||||
)
|
||||
|
||||
# Verify search call
|
||||
mock_search.assert_called_once()
|
||||
call_args = mock_search.call_args[1] # keyword arguments
|
||||
assert call_args['search_text'] == "functional safety"
|
||||
assert call_args['index_name'] == "index-catonline-chunk-v2-prd"
|
||||
assert "filter_query" in call_args # Should have document filter
|
||||
|
||||
# Verify response
|
||||
assert isinstance(result, RetrievalResponse)
|
||||
assert len(result.results) == 1
|
||||
|
||||
# Verify content is included for chunks
|
||||
first_result = result.results[0]
|
||||
assert "content" in first_result
|
||||
assert first_result['content'] == "Detailed content about functional safety..."
|
||||
# Verify order number is preserved
|
||||
assert "@order_num" in first_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_error_handling(self, retrieval_client):
|
||||
"""Test HTTP error handling."""
|
||||
with patch.object(retrieval_client.search_client, 'search_azure_ai') as mock_search:
|
||||
mock_search.side_effect = RetrievalAPIError("API request failed: 404")
|
||||
|
||||
with pytest.raises(RetrievalAPIError) as exc_info:
|
||||
await retrieval_client.retrieve_standard_regulation("nonexistent")
|
||||
|
||||
assert "404" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_handling(self, retrieval_client):
|
||||
"""Test timeout handling."""
|
||||
with patch.object(retrieval_client.search_client, 'search_azure_ai') as mock_search:
|
||||
mock_search.side_effect = RetrievalAPIError("Request timeout")
|
||||
|
||||
with pytest.raises(RetrievalAPIError) as exc_info:
|
||||
await retrieval_client.retrieve_standard_regulation("test query")
|
||||
|
||||
assert "timeout" in str(exc_info.value)
|
||||
|
||||
def test_normalize_result_removes_unwanted_fields(self, retrieval_client):
|
||||
"""Test result normalization removes unwanted fields."""
|
||||
from service.retrieval.clients import normalize_search_result
|
||||
|
||||
raw_result = {
|
||||
"id": "test_id",
|
||||
"title": "Test Title",
|
||||
"content": "Test content",
|
||||
"@search.score": 0.95, # Should be removed
|
||||
"@search.rerankerScore": 2.5, # Should be removed
|
||||
"@search.captions": [], # Should be removed
|
||||
"@subquery_id": 1, # Should be removed
|
||||
"empty_field": "", # Should be removed
|
||||
"null_field": None, # Should be removed
|
||||
"empty_list": [], # Should be removed
|
||||
"empty_dict": {}, # Should be removed
|
||||
"valid_field": "valid_value"
|
||||
}
|
||||
|
||||
# Test normalization
|
||||
normalized = normalize_search_result(raw_result)
|
||||
|
||||
# Verify unwanted fields are removed
|
||||
assert "@search.score" not in normalized
|
||||
assert "@search.rerankerScore" not in normalized
|
||||
assert "@search.captions" not in normalized
|
||||
assert "@subquery_id" not in normalized
|
||||
assert "empty_field" not in normalized
|
||||
assert "null_field" not in normalized
|
||||
assert "empty_list" not in normalized
|
||||
assert "empty_dict" not in normalized
|
||||
|
||||
# Verify valid fields are preserved (including content)
|
||||
assert normalized["id"] == "test_id"
|
||||
assert normalized["title"] == "Test Title"
|
||||
assert normalized["content"] == "Test content" # Content is now always preserved
|
||||
assert normalized["valid_field"] == "valid_value"
|
||||
|
||||
def test_normalize_result_includes_content_when_requested(self, retrieval_client):
|
||||
"""Test result normalization always includes content."""
|
||||
from service.retrieval.clients import normalize_search_result
|
||||
|
||||
raw_result = {
|
||||
"id": "test_id",
|
||||
"title": "Test Title",
|
||||
"content": "Test content"
|
||||
}
|
||||
|
||||
# Test normalization (content is always preserved now)
|
||||
normalized = normalize_search_result(raw_result)
|
||||
assert "content" in normalized
|
||||
assert normalized["content"] == "Test content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_results_handling(self, retrieval_client):
|
||||
"""Test handling of empty search results."""
|
||||
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"value": []}):
|
||||
result = await retrieval_client.retrieve_standard_regulation("no results query")
|
||||
|
||||
assert isinstance(result, RetrievalResponse)
|
||||
assert result.results == []
|
||||
assert result.total_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_response_handling(self, retrieval_client):
|
||||
"""Test handling of malformed API responses."""
|
||||
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"unexpected": "format"}):
|
||||
result = await retrieval_client.retrieve_standard_regulation("test query")
|
||||
|
||||
# Should handle gracefully and return empty results
|
||||
assert isinstance(result, RetrievalResponse)
|
||||
assert result.results == []
|
||||
assert result.total_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kwargs_override_payload(self, retrieval_client):
|
||||
"""Test that kwargs can override default values."""
|
||||
with patch.object(retrieval_client.search_client, 'search_azure_ai', return_value={"value": []}) as mock_search:
|
||||
await retrieval_client.retrieve_standard_regulation(
|
||||
"test query",
|
||||
top_k=5,
|
||||
score_threshold=2.0
|
||||
)
|
||||
|
||||
call_args = mock_search.call_args[1] # keyword arguments
|
||||
assert call_args['top_k'] == 5 # Override default 10
|
||||
assert call_args['score_threshold'] == 2.0 # Override default 1.5
|
||||
|
||||
|
||||
class TestRetrievalTools:
|
||||
"""Test the LangGraph tool decorators."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_standard_regulation_tool(self):
|
||||
"""Test the @tool decorated function for standards search."""
|
||||
from service.graph.tools import retrieve_standard_regulation
|
||||
|
||||
mock_response = RetrievalResponse(
|
||||
results=[{"title": "Test Standard", "id": "test_id"}],
|
||||
took_ms=100,
|
||||
total_count=1
|
||||
)
|
||||
|
||||
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.retrieve_standard_regulation.return_value = mock_response
|
||||
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
|
||||
mock_retrieval_class.return_value.__aexit__.return_value = None
|
||||
|
||||
# Note: retrieve_standard_regulation is a LangGraph @tool, so we call it directly
|
||||
result = await retrieve_standard_regulation.ainvoke({"query": "test query", "conversation_history": "test history"})
|
||||
|
||||
# Verify result format matches expected tool output
|
||||
assert isinstance(result, dict)
|
||||
assert "results" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_doc_chunk_tool(self):
|
||||
"""Test the @tool decorated function for document chunks search."""
|
||||
from service.graph.tools import retrieve_doc_chunk_standard_regulation
|
||||
|
||||
mock_response = RetrievalResponse(
|
||||
results=[{"title": "Test Chunk", "content": "Test content", "id": "chunk_id"}],
|
||||
took_ms=150,
|
||||
total_count=1
|
||||
)
|
||||
|
||||
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.retrieve_doc_chunk_standard_regulation.return_value = mock_response
|
||||
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
|
||||
mock_retrieval_class.return_value.__aexit__.return_value = None
|
||||
|
||||
# Call the tool with proper input format
|
||||
result = await retrieve_doc_chunk_standard_regulation.ainvoke({"query": "test query"})
|
||||
|
||||
# Verify result format
|
||||
assert isinstance(result, dict)
|
||||
assert "results" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_error_handling(self):
|
||||
"""Test tool error handling."""
|
||||
from service.graph.tools import retrieve_standard_regulation
|
||||
|
||||
with patch('service.retrieval.clients.AzureSearchClient.search_azure_ai') as mock_search:
|
||||
mock_search.side_effect = RetrievalAPIError("API Error")
|
||||
|
||||
# Tool should handle error gracefully and return error result
|
||||
result = await retrieve_standard_regulation.ainvoke({"query": "test query"})
|
||||
|
||||
# Should return error information in result
|
||||
assert isinstance(result, dict)
|
||||
assert "error" in result
|
||||
assert "API Error" in result["error"]
|
||||
|
||||
|
||||
class TestRetrievalIntegration:
|
||||
"""Integration-style tests for retrieval functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_retrieval_workflow(self):
|
||||
"""Test complete retrieval workflow with both tools."""
|
||||
# Mock both retrieval calls
|
||||
standards_response = RetrievalResponse(
|
||||
results=[
|
||||
{"id": "std_1", "title": "ISO 26262-1", "document_category": "Standard"}
|
||||
],
|
||||
took_ms=100,
|
||||
total_count=1
|
||||
)
|
||||
|
||||
chunks_response = RetrievalResponse(
|
||||
results=[
|
||||
{"id": "chunk_1", "title": "Safety Requirements", "content": "Detailed safety requirements..."}
|
||||
],
|
||||
took_ms=150,
|
||||
total_count=1
|
||||
)
|
||||
|
||||
with patch('service.retrieval.clients.AzureSearchClient.search_azure_ai') as mock_search:
|
||||
# Set up mock to return different responses for different calls
|
||||
def mock_response_side_effect(*args, **kwargs):
|
||||
if kwargs.get('index_name') == 'index-catonline-standard-regulation-v2-prd':
|
||||
return {
|
||||
"value": [{"id": "std_1", "title": "ISO 26262-1", "document_category": "Standard", "@order_num": 1}]
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"value": [{"id": "chunk_1", "title": "Safety Requirements", "content": "Detailed safety requirements...", "@order_num": 1}]
|
||||
}
|
||||
|
||||
mock_search.side_effect = mock_response_side_effect
|
||||
|
||||
# Import and test both tools
|
||||
from service.graph.tools import retrieve_standard_regulation, retrieve_doc_chunk_standard_regulation
|
||||
|
||||
# Test standards search
|
||||
std_result = await retrieve_standard_regulation.ainvoke({"query": "ISO 26262"})
|
||||
assert isinstance(std_result, dict)
|
||||
assert "results" in std_result
|
||||
assert len(std_result["results"]) == 1
|
||||
|
||||
# Test chunks search
|
||||
chunk_result = await retrieve_doc_chunk_standard_regulation.ainvoke({"query": "safety requirements"})
|
||||
assert isinstance(chunk_result, dict)
|
||||
assert "results" in chunk_result
|
||||
assert len(chunk_result["results"]) == 1
|
||||
73
vw-agentic-rag/tests/unit/test_sse.py
Normal file
73
vw-agentic-rag/tests/unit/test_sse.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import pytest
|
||||
from service.sse import (
|
||||
format_sse_event, create_token_event, create_tool_start_event,
|
||||
create_tool_result_event, create_tool_error_event,
|
||||
create_error_event
|
||||
)
|
||||
|
||||
|
||||
def test_format_sse_event():
|
||||
"""Test SSE event formatting"""
|
||||
event = format_sse_event("test", {"message": "hello"})
|
||||
expected = 'event: test\ndata: {"message": "hello"}\n\n'
|
||||
assert event == expected
|
||||
|
||||
|
||||
def test_create_token_event():
|
||||
"""Test token event creation"""
|
||||
event = create_token_event("hello", "tool_123")
|
||||
assert "event: tokens" in event
|
||||
assert '"delta": "hello"' in event
|
||||
assert '"tool_call_id": "tool_123"' in event
|
||||
|
||||
|
||||
def test_create_token_event_no_tool_id():
|
||||
"""Test token event without tool call ID"""
|
||||
event = create_token_event("hello")
|
||||
assert "event: tokens" in event
|
||||
assert '"delta": "hello"' in event
|
||||
assert '"tool_call_id": null' in event
|
||||
|
||||
|
||||
def test_create_tool_start_event():
|
||||
"""Test tool start event"""
|
||||
event = create_tool_start_event("tool_123", "retrieve_standard_regulation", {"query": "test"})
|
||||
assert "event: tool_start" in event
|
||||
assert '"id": "tool_123"' in event
|
||||
assert '"name": "retrieve_standard_regulation"' in event
|
||||
assert '"args": {"query": "test"}' in event
|
||||
|
||||
|
||||
def test_create_tool_result_event():
|
||||
"""Test tool result event"""
|
||||
results = [{"id": "1", "title": "Test Standard"}]
|
||||
event = create_tool_result_event("tool_123", "retrieve_standard_regulation", results, 500)
|
||||
assert "event: tool_result" in event
|
||||
assert '"id": "tool_123"' in event
|
||||
assert '"took_ms": 500' in event
|
||||
assert '"results"' in event
|
||||
|
||||
|
||||
def test_create_tool_error_event():
|
||||
"""Test tool error event"""
|
||||
event = create_tool_error_event("tool_123", "retrieve_standard_regulation", "API timeout")
|
||||
assert "event: tool_error" in event
|
||||
assert '"id": "tool_123"' in event
|
||||
assert '"error": "API timeout"' in event
|
||||
|
||||
|
||||
def test_create_error_event():
|
||||
"""Test error event"""
|
||||
event = create_error_event("Something went wrong")
|
||||
assert "event: error" in event
|
||||
assert '"error": "Something went wrong"' in event
|
||||
|
||||
|
||||
def test_create_error_event_with_details():
|
||||
"""Test error event with details"""
|
||||
details = {"code": 500, "source": "llm"}
|
||||
event = create_error_event("Something went wrong", details)
|
||||
assert "event: error" in event
|
||||
assert '"error": "Something went wrong"' in event
|
||||
assert '"details"' in event
|
||||
assert '"code": 500' in event
|
||||
200
vw-agentic-rag/tests/unit/test_trimming_fix.py
Normal file
200
vw-agentic-rag/tests/unit/test_trimming_fix.py
Normal file
@@ -0,0 +1,200 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试修改后的修剪逻辑:验证只在 tool_rounds=0 时修剪,其他时候跳过修剪
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from service.graph.state import AgentState
|
||||
from service.graph.graph import call_model
|
||||
from service.graph.user_manual_rag import user_manual_agent_node
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trimming_only_at_tool_rounds_zero():
|
||||
"""测试修剪只在 tool_rounds=0 时发生"""
|
||||
|
||||
# 创建一个大的消息列表来触发修剪
|
||||
large_messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="请搜索信息"),
|
||||
]
|
||||
|
||||
# 添加大量内容来触发修剪
|
||||
for i in range(10):
|
||||
large_messages.extend([
|
||||
AIMessage(content=f"搜索{i}", tool_calls=[
|
||||
{"id": f"call_{i}", "name": "search", "args": {"query": f"test{i}"}}
|
||||
]),
|
||||
ToolMessage(content="很长的搜索结果内容 " * 500, tool_call_id=f"call_{i}", name="search"),
|
||||
])
|
||||
|
||||
# Mock trimmer to track calls
|
||||
with patch('service.graph.graph.create_conversation_trimmer') as mock_trimmer_factory:
|
||||
mock_trimmer = MagicMock()
|
||||
mock_trimmer.should_trim.return_value = True # 总是返回需要修剪
|
||||
mock_trimmer.trim_conversation_history.return_value = large_messages[:3] # 返回修剪后的消息
|
||||
mock_trimmer_factory.return_value = mock_trimmer
|
||||
|
||||
# Mock LLM client
|
||||
with patch('service.graph.graph.LLMClient') as mock_llm_factory:
|
||||
mock_llm = MagicMock()
|
||||
# 设置异步方法的返回值
|
||||
async def mock_ainvoke_with_tools(*args, **kwargs):
|
||||
return AIMessage(content="最终答案")
|
||||
|
||||
async def mock_astream(*args, **kwargs):
|
||||
for token in ["最", "终", "答", "案"]:
|
||||
yield token
|
||||
|
||||
mock_llm.ainvoke_with_tools = mock_ainvoke_with_tools
|
||||
mock_llm.astream = mock_astream
|
||||
mock_llm.bind_tools = MagicMock()
|
||||
mock_llm_factory.return_value = mock_llm
|
||||
|
||||
# 测试 tool_rounds=0 的情况(应该修剪)
|
||||
from service.graph.state import AgentState
|
||||
state_round_0 = AgentState(
|
||||
messages=large_messages,
|
||||
session_id="test",
|
||||
intent=None,
|
||||
tool_results=[],
|
||||
final_answer="",
|
||||
tool_rounds=0, # 第一轮
|
||||
max_tool_rounds=3,
|
||||
max_tool_rounds_user_manual=2
|
||||
)
|
||||
|
||||
result = await call_model(state_round_0)
|
||||
|
||||
# 验证修剪器被调用
|
||||
mock_trimmer.should_trim.assert_called()
|
||||
mock_trimmer.trim_conversation_history.assert_called()
|
||||
|
||||
# 重置 mock
|
||||
mock_trimmer.reset_mock()
|
||||
|
||||
# 测试 tool_rounds>0 的情况(不应该修剪)
|
||||
from service.graph.state import AgentState
|
||||
state_round_1 = AgentState(
|
||||
messages=large_messages,
|
||||
session_id="test",
|
||||
intent=None,
|
||||
tool_results=[],
|
||||
final_answer="",
|
||||
tool_rounds=1, # 第二轮
|
||||
max_tool_rounds=3,
|
||||
max_tool_rounds_user_manual=2
|
||||
)
|
||||
|
||||
result = await call_model(state_round_1)
|
||||
|
||||
# 验证修剪器没有被调用
|
||||
mock_trimmer.should_trim.assert_not_called()
|
||||
mock_trimmer.trim_conversation_history.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_manual_agent_trimming_behavior():
|
||||
"""测试 user_manual_agent_node 的修剪行为"""
|
||||
|
||||
# 创建大的消息列表
|
||||
large_messages = [
|
||||
SystemMessage(content="You are a helpful user manual assistant."),
|
||||
HumanMessage(content="请帮我查找系统使用说明"),
|
||||
]
|
||||
|
||||
# 添加大量内容
|
||||
for i in range(8):
|
||||
large_messages.extend([
|
||||
AIMessage(content=f"查找{i}", tool_calls=[
|
||||
{"id": f"manual_call_{i}", "name": "retrieve_system_usermanual", "args": {"query": f"功能{i}"}}
|
||||
]),
|
||||
ToolMessage(content="很长的手册内容 " * 400, tool_call_id=f"manual_call_{i}", name="retrieve_system_usermanual"),
|
||||
])
|
||||
|
||||
# Mock dependencies
|
||||
with patch('service.graph.user_manual_rag.create_conversation_trimmer') as mock_trimmer_factory, \
|
||||
patch('service.graph.user_manual_rag.LLMClient') as mock_llm_factory, \
|
||||
patch('service.graph.user_manual_rag.get_user_manual_tool_schemas') as mock_tools, \
|
||||
patch('service.graph.user_manual_rag.get_cached_config') as mock_config:
|
||||
|
||||
# Setup mocks
|
||||
mock_trimmer = MagicMock()
|
||||
mock_trimmer.should_trim.return_value = True
|
||||
mock_trimmer.trim_conversation_history.return_value = large_messages[:3]
|
||||
mock_trimmer_factory.return_value = mock_trimmer
|
||||
|
||||
mock_llm = MagicMock()
|
||||
# 设置异步方法的返回值
|
||||
async def mock_ainvoke_with_tools(*args, **kwargs):
|
||||
return AIMessage(content="用户手册答案")
|
||||
|
||||
async def mock_astream(*args, **kwargs):
|
||||
for token in ["答", "案"]:
|
||||
yield token
|
||||
|
||||
mock_llm.ainvoke_with_tools = mock_ainvoke_with_tools
|
||||
mock_llm.astream = mock_astream
|
||||
mock_llm.bind_tools = MagicMock()
|
||||
mock_llm_factory.return_value = mock_llm
|
||||
|
||||
mock_tools.return_value = []
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.app.max_tool_rounds_user_manual = 2
|
||||
mock_app_config.get_rag_prompts.return_value = {
|
||||
"user_manual_prompt": "System prompt for user manual: {conversation_history} {context_content} {current_query}"
|
||||
}
|
||||
mock_config.return_value = mock_app_config
|
||||
|
||||
# 测试 tool_rounds=0(应该修剪)
|
||||
from service.graph.state import AgentState
|
||||
state_round_0 = AgentState(
|
||||
messages=large_messages,
|
||||
session_id="test",
|
||||
intent=None,
|
||||
tool_results=[],
|
||||
final_answer="",
|
||||
tool_rounds=0,
|
||||
max_tool_rounds=3,
|
||||
max_tool_rounds_user_manual=2
|
||||
)
|
||||
|
||||
result = await user_manual_agent_node(state_round_0)
|
||||
|
||||
# 验证修剪器被调用
|
||||
mock_trimmer.should_trim.assert_called()
|
||||
mock_trimmer.trim_conversation_history.assert_called()
|
||||
|
||||
# 重置 mock
|
||||
mock_trimmer.reset_mock()
|
||||
|
||||
# 测试 tool_rounds=1(不应该修剪)
|
||||
state_round_1 = AgentState(
|
||||
messages=large_messages,
|
||||
session_id="test",
|
||||
intent=None,
|
||||
tool_results=[],
|
||||
final_answer="",
|
||||
tool_rounds=1,
|
||||
max_tool_rounds=3,
|
||||
max_tool_rounds_user_manual=2
|
||||
)
|
||||
|
||||
result = await user_manual_agent_node(state_round_1)
|
||||
|
||||
# 验证修剪器没有被调用
|
||||
mock_trimmer.should_trim.assert_not_called()
|
||||
mock_trimmer.trim_conversation_history.assert_not_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
95
vw-agentic-rag/tests/unit/test_user_manual_tool.py
Normal file
95
vw-agentic-rag/tests/unit/test_user_manual_tool.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Unit test for the new retrieve_system_usermanual tool
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
import asyncio
|
||||
|
||||
from service.graph.user_manual_tools import retrieve_system_usermanual, user_manual_tools, get_user_manual_tool_schemas, get_user_manual_tools_by_name
|
||||
from service.graph.tools import get_tool_schemas, get_tools_by_name
|
||||
from service.retrieval.retrieval import RetrievalResponse
|
||||
|
||||
|
||||
class TestRetrieveSystemUsermanualTool:
|
||||
"""Test the new user manual retrieval tool"""
|
||||
|
||||
def test_tool_in_tools_list(self):
|
||||
"""Test that the new tool is in the tools list"""
|
||||
tool_names = [tool.name for tool in user_manual_tools]
|
||||
assert "retrieve_system_usermanual" in tool_names
|
||||
assert len(user_manual_tools) == 1 # Should have 1 user manual tool
|
||||
|
||||
def test_tool_schemas_generation(self):
|
||||
"""Test that tool schemas are generated correctly"""
|
||||
schemas = get_user_manual_tool_schemas()
|
||||
tool_names = [schema["function"]["name"] for schema in schemas]
|
||||
assert "retrieve_system_usermanual" in tool_names
|
||||
|
||||
# Find the user manual tool schema
|
||||
user_manual_schema = next(
|
||||
schema for schema in schemas
|
||||
if schema["function"]["name"] == "retrieve_system_usermanual"
|
||||
)
|
||||
|
||||
assert user_manual_schema["function"]["description"] == "Search for document content chunks of user manual of this system(CATOnline)"
|
||||
assert "query" in user_manual_schema["function"]["parameters"]["properties"]
|
||||
assert user_manual_schema["function"]["parameters"]["required"] == ["query"]
|
||||
|
||||
def test_tools_by_name_mapping(self):
|
||||
"""Test that tools_by_name mapping includes the new tool"""
|
||||
tools_mapping = get_user_manual_tools_by_name()
|
||||
assert "retrieve_system_usermanual" in tools_mapping
|
||||
assert tools_mapping["retrieve_system_usermanual"] == retrieve_system_usermanual
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_system_usermanual_success(self):
|
||||
"""Test successful user manual retrieval"""
|
||||
|
||||
mock_response = RetrievalResponse(
|
||||
results=[
|
||||
{"title": "User Manual Chapter 1", "content": "How to use the system", "id": "manual_1"}
|
||||
],
|
||||
took_ms=120,
|
||||
total_count=1
|
||||
)
|
||||
|
||||
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.retrieve_doc_chunk_user_manual.return_value = mock_response
|
||||
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
|
||||
mock_retrieval_class.return_value.__aexit__.return_value = None
|
||||
|
||||
# Test the tool
|
||||
result = await retrieve_system_usermanual.ainvoke({"query": "how to use system"})
|
||||
|
||||
# Verify result format
|
||||
assert isinstance(result, dict)
|
||||
assert result["tool_name"] == "retrieve_system_usermanual"
|
||||
assert result["results_count"] == 1
|
||||
assert len(result["results"]) == 1
|
||||
assert result["results"][0]["title"] == "User Manual Chapter 1"
|
||||
assert result["took_ms"] == 120
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_system_usermanual_error(self):
|
||||
"""Test error handling in user manual retrieval"""
|
||||
|
||||
with patch('service.retrieval.retrieval.AgenticRetrieval') as mock_retrieval_class:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.retrieve_doc_chunk_user_manual.side_effect = Exception("Search API Error")
|
||||
mock_retrieval_class.return_value.__aenter__.return_value = mock_instance
|
||||
mock_retrieval_class.return_value.__aexit__.return_value = None
|
||||
|
||||
# Test error handling
|
||||
result = await retrieve_system_usermanual.ainvoke({"query": "test query"})
|
||||
|
||||
# Should return error information
|
||||
assert isinstance(result, dict)
|
||||
assert "error" in result
|
||||
assert "Search API Error" in result["error"]
|
||||
assert result["results_count"] == 0
|
||||
assert result["results"] == []
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user