Files
catonline_ai/vw-agentic-rag/tests/conftest.py
2025-09-26 17:15:54 +08:00

318 lines
9.4 KiB
Python

"""
Shared pytest fixtures and configuration for the agentic-rag test suite.
"""
import pytest
import asyncio
import httpx
from unittest.mock import Mock, AsyncMock, patch
from fastapi.testclient import TestClient
from service.main import create_app
from service.config import Config
from service.graph.state import TurnState, Message, ToolResult
from service.memory.postgresql_memory import PostgreSQLMemoryManager
@pytest.fixture(scope="session")
def event_loop():
"""Create an instance of the default event loop for the test session."""
policy = asyncio.get_event_loop_policy()
loop = policy.new_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True)
def config_mock():
"""Mock configuration for all tests."""
config = Mock()
config.retrieval.endpoint = "http://test-endpoint"
config.retrieval.api_key = "test-key"
config.llm.provider = "openai"
config.llm.model = "gpt-4"
config.llm.api_key = "test-api-key"
config.memory.enabled = True
config.memory.type = "in_memory"
config.memory.ttl_days = 7
config.postgresql.enabled = False
with patch('service.config.get_config', return_value=config):
with patch('service.retrieval.retrieval.get_config', return_value=config):
with patch('service.graph.graph.get_config', return_value=config):
yield config
@pytest.fixture
def test_config():
"""Test configuration with safe defaults."""
return {
"provider": "openai",
"openai": {
"api_key": "test-openai-key",
"model": "gpt-4o",
"base_url": "https://api.openai.com/v1",
"temperature": 0.2
},
"retrieval": {
"endpoint": "http://test-retrieval-endpoint",
"api_key": "test-retrieval-key"
},
"postgresql": {
"host": "localhost",
"port": 5432,
"database": "test_agent_memory",
"username": "test",
"password": "test",
"ttl_days": 1
},
"app": {
"name": "agentic-rag-test",
"memory_ttl_days": 1,
"max_tool_loops": 3,
"cors_origins": ["*"]
},
"llm": {
"rag": {
"temperature": 0,
"max_context_length": 32000,
"agent_system_prompt": "You are a test assistant."
}
}
}
@pytest.fixture
def app(test_config):
"""Create test FastAPI app with mocked configuration."""
with patch('service.config.load_config') as mock_load_config:
mock_load_config.return_value = test_config
# Mock the memory manager to avoid PostgreSQL dependency in tests
with patch('service.memory.postgresql_memory.get_memory_manager') as mock_memory:
mock_memory_manager = Mock()
mock_memory_manager.test_connection.return_value = True
mock_memory.return_value = mock_memory_manager
# Mock the graph builder to avoid complex dependencies
with patch('service.graph.graph.build_graph') as mock_build_graph:
mock_graph = Mock()
mock_build_graph.return_value = mock_graph
app = create_app()
app.state.memory_manager = mock_memory_manager
app.state.graph = mock_graph
return app
@pytest.fixture
def client(app):
"""Create test client."""
return TestClient(app)
@pytest.fixture
def mock_llm_client():
"""Mock LLM client for testing."""
mock = AsyncMock()
mock.astream.return_value = iter(["Test", " response", " token"])
mock.ainvoke_with_tools.return_value = Mock(
content="Test response",
tool_calls=[
{
"id": "test_tool_call_1",
"function": {
"name": "retrieve_standard_regulation",
"arguments": '{"query": "test query"}'
}
}
]
)
return mock
@pytest.fixture
def mock_retrieval_response():
"""Mock response from retrieval API."""
return {
"results": [
{
"id": "test_result_1",
"title": "ISO 26262-1:2018",
"content": "Road vehicles — Functional safety — Part 1: Vocabulary",
"score": 0.95,
"url": "https://iso.org/26262-1",
"metadata": {
"@tool_call_id": "test_tool_call_1",
"@order_num": 0
}
},
{
"id": "test_result_2",
"title": "ISO 26262-3:2018",
"content": "Road vehicles — Functional safety — Part 3: Concept phase",
"score": 0.88,
"url": "https://iso.org/26262-3",
"metadata": {
"@tool_call_id": "test_tool_call_1",
"@order_num": 1
}
}
],
"metadata": {
"total": 2,
"took_ms": 150,
"query": "test query"
}
}
@pytest.fixture
def sample_chat_request():
"""Sample chat request for testing."""
return {
"session_id": "test_session_123",
"messages": [
{"role": "user", "content": "What is ISO 26262?"}
]
}
@pytest.fixture
def sample_turn_state():
"""Sample TurnState for testing."""
return TurnState(
session_id="test_session_123",
messages=[
Message(role="user", content="What is ISO 26262?")
]
)
@pytest.fixture
def mock_httpx_client():
"""Mock httpx client for API requests."""
mock_client = AsyncMock()
# Default response for retrieval API
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"results": [
{
"id": "test_result",
"title": "Test Standard",
"content": "Test content",
"score": 0.9
}
]
}
mock_client.post.return_value = mock_response
return mock_client
@pytest.fixture
def mock_postgresql_memory():
"""Mock PostgreSQL memory manager."""
mock_manager = Mock(spec=PostgreSQLMemoryManager)
mock_manager.test_connection.return_value = True
mock_checkpointer = Mock()
mock_checkpointer.setup.return_value = None
mock_manager.get_checkpointer.return_value = mock_checkpointer
return mock_manager
@pytest.fixture
def mock_streaming_response():
"""Mock streaming response events."""
return [
'event: tool_start\ndata: {"id": "test_tool_1", "name": "retrieve_standard_regulation", "args": {"query": "test"}}\n\n',
'event: tokens\ndata: {"delta": "Based on the retrieved standards", "tool_call_id": null}\n\n',
'event: tool_result\ndata: {"id": "test_tool_1", "name": "retrieve_standard_regulation", "results": [], "took_ms": 100}\n\n',
'event: tokens\ndata: {"delta": " this is a test response.", "tool_call_id": null}\n\n'
]
# Async test helpers
@pytest.fixture
def mock_agent_state():
"""Mock agent state for graph testing."""
return {
"messages": [],
"session_id": "test_session",
"tool_results": [],
"final_answer": ""
}
@pytest.fixture
async def async_test_client():
"""Async test client for integration tests."""
async with httpx.AsyncClient() as client:
yield client
# Database fixtures for integration tests
@pytest.fixture
def test_database_url():
"""Test database URL (only for integration tests with real DB)."""
return "postgresql://test:test@localhost:5432/test_agent_memory"
@pytest.fixture
def integration_test_config(test_database_url):
"""Configuration for integration tests with real database."""
return {
"provider": "openai",
"openai": {
"api_key": "test-key",
"model": "gpt-4o"
},
"retrieval": {
"endpoint": "http://localhost:8000/search", # Assume test retrieval server
"api_key": "test-key"
},
"postgresql": {
"connection_string": test_database_url
}
}
# Skip markers for different test types
def pytest_configure(config):
"""Configure pytest markers."""
config.addinivalue_line("markers", "unit: mark test as unit test")
config.addinivalue_line("markers", "integration: mark test as integration test")
config.addinivalue_line("markers", "e2e: mark test as end-to-end test")
config.addinivalue_line("markers", "slow: mark test as slow running")
def pytest_runtest_setup(item):
"""Setup for test items."""
# Skip integration tests if not explicitly requested
if "integration" in item.keywords and not item.config.getoption("--run-integration"):
pytest.skip("Integration tests not requested")
# Skip E2E tests if not explicitly requested
if "e2e" in item.keywords and not item.config.getoption("--run-e2e"):
pytest.skip("E2E tests not requested")
def pytest_addoption(parser):
"""Add custom command line options."""
parser.addoption(
"--run-integration",
action="store_true",
default=False,
help="run integration tests"
)
parser.addoption(
"--run-e2e",
action="store_true",
default=False,
help="run end-to-end tests"
)