init
This commit is contained in:
317
vw-agentic-rag/tests/conftest.py
Normal file
317
vw-agentic-rag/tests/conftest.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
Shared pytest fixtures and configuration for the agentic-rag test suite.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
import httpx
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from service.main import create_app
|
||||
from service.config import Config
|
||||
from service.graph.state import TurnState, Message, ToolResult
|
||||
from service.memory.postgresql_memory import PostgreSQLMemoryManager
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for the test session."""
|
||||
policy = asyncio.get_event_loop_policy()
|
||||
loop = policy.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def config_mock():
|
||||
"""Mock configuration for all tests."""
|
||||
config = Mock()
|
||||
config.retrieval.endpoint = "http://test-endpoint"
|
||||
config.retrieval.api_key = "test-key"
|
||||
config.llm.provider = "openai"
|
||||
config.llm.model = "gpt-4"
|
||||
config.llm.api_key = "test-api-key"
|
||||
config.memory.enabled = True
|
||||
config.memory.type = "in_memory"
|
||||
config.memory.ttl_days = 7
|
||||
config.postgresql.enabled = False
|
||||
|
||||
with patch('service.config.get_config', return_value=config):
|
||||
with patch('service.retrieval.retrieval.get_config', return_value=config):
|
||||
with patch('service.graph.graph.get_config', return_value=config):
|
||||
yield config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_config():
|
||||
"""Test configuration with safe defaults."""
|
||||
return {
|
||||
"provider": "openai",
|
||||
"openai": {
|
||||
"api_key": "test-openai-key",
|
||||
"model": "gpt-4o",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"temperature": 0.2
|
||||
},
|
||||
"retrieval": {
|
||||
"endpoint": "http://test-retrieval-endpoint",
|
||||
"api_key": "test-retrieval-key"
|
||||
},
|
||||
"postgresql": {
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"database": "test_agent_memory",
|
||||
"username": "test",
|
||||
"password": "test",
|
||||
"ttl_days": 1
|
||||
},
|
||||
"app": {
|
||||
"name": "agentic-rag-test",
|
||||
"memory_ttl_days": 1,
|
||||
"max_tool_loops": 3,
|
||||
"cors_origins": ["*"]
|
||||
},
|
||||
"llm": {
|
||||
"rag": {
|
||||
"temperature": 0,
|
||||
"max_context_length": 32000,
|
||||
"agent_system_prompt": "You are a test assistant."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(test_config):
|
||||
"""Create test FastAPI app with mocked configuration."""
|
||||
with patch('service.config.load_config') as mock_load_config:
|
||||
mock_load_config.return_value = test_config
|
||||
|
||||
# Mock the memory manager to avoid PostgreSQL dependency in tests
|
||||
with patch('service.memory.postgresql_memory.get_memory_manager') as mock_memory:
|
||||
mock_memory_manager = Mock()
|
||||
mock_memory_manager.test_connection.return_value = True
|
||||
mock_memory.return_value = mock_memory_manager
|
||||
|
||||
# Mock the graph builder to avoid complex dependencies
|
||||
with patch('service.graph.graph.build_graph') as mock_build_graph:
|
||||
mock_graph = Mock()
|
||||
mock_build_graph.return_value = mock_graph
|
||||
|
||||
app = create_app()
|
||||
app.state.memory_manager = mock_memory_manager
|
||||
app.state.graph = mock_graph
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_client():
|
||||
"""Mock LLM client for testing."""
|
||||
mock = AsyncMock()
|
||||
mock.astream.return_value = iter(["Test", " response", " token"])
|
||||
mock.ainvoke_with_tools.return_value = Mock(
|
||||
content="Test response",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "test_tool_call_1",
|
||||
"function": {
|
||||
"name": "retrieve_standard_regulation",
|
||||
"arguments": '{"query": "test query"}'
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_retrieval_response():
|
||||
"""Mock response from retrieval API."""
|
||||
return {
|
||||
"results": [
|
||||
{
|
||||
"id": "test_result_1",
|
||||
"title": "ISO 26262-1:2018",
|
||||
"content": "Road vehicles — Functional safety — Part 1: Vocabulary",
|
||||
"score": 0.95,
|
||||
"url": "https://iso.org/26262-1",
|
||||
"metadata": {
|
||||
"@tool_call_id": "test_tool_call_1",
|
||||
"@order_num": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "test_result_2",
|
||||
"title": "ISO 26262-3:2018",
|
||||
"content": "Road vehicles — Functional safety — Part 3: Concept phase",
|
||||
"score": 0.88,
|
||||
"url": "https://iso.org/26262-3",
|
||||
"metadata": {
|
||||
"@tool_call_id": "test_tool_call_1",
|
||||
"@order_num": 1
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"total": 2,
|
||||
"took_ms": 150,
|
||||
"query": "test query"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chat_request():
|
||||
"""Sample chat request for testing."""
|
||||
return {
|
||||
"session_id": "test_session_123",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is ISO 26262?"}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_turn_state():
|
||||
"""Sample TurnState for testing."""
|
||||
return TurnState(
|
||||
session_id="test_session_123",
|
||||
messages=[
|
||||
Message(role="user", content="What is ISO 26262?")
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_httpx_client():
|
||||
"""Mock httpx client for API requests."""
|
||||
mock_client = AsyncMock()
|
||||
|
||||
# Default response for retrieval API
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"results": [
|
||||
{
|
||||
"id": "test_result",
|
||||
"title": "Test Standard",
|
||||
"content": "Test content",
|
||||
"score": 0.9
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
mock_client.post.return_value = mock_response
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_postgresql_memory():
|
||||
"""Mock PostgreSQL memory manager."""
|
||||
mock_manager = Mock(spec=PostgreSQLMemoryManager)
|
||||
mock_manager.test_connection.return_value = True
|
||||
|
||||
mock_checkpointer = Mock()
|
||||
mock_checkpointer.setup.return_value = None
|
||||
mock_manager.get_checkpointer.return_value = mock_checkpointer
|
||||
|
||||
return mock_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_streaming_response():
|
||||
"""Mock streaming response events."""
|
||||
return [
|
||||
'event: tool_start\ndata: {"id": "test_tool_1", "name": "retrieve_standard_regulation", "args": {"query": "test"}}\n\n',
|
||||
'event: tokens\ndata: {"delta": "Based on the retrieved standards", "tool_call_id": null}\n\n',
|
||||
'event: tool_result\ndata: {"id": "test_tool_1", "name": "retrieve_standard_regulation", "results": [], "took_ms": 100}\n\n',
|
||||
'event: tokens\ndata: {"delta": " this is a test response.", "tool_call_id": null}\n\n'
|
||||
]
|
||||
|
||||
|
||||
# Async test helpers
|
||||
@pytest.fixture
|
||||
def mock_agent_state():
|
||||
"""Mock agent state for graph testing."""
|
||||
return {
|
||||
"messages": [],
|
||||
"session_id": "test_session",
|
||||
"tool_results": [],
|
||||
"final_answer": ""
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_test_client():
|
||||
"""Async test client for integration tests."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
yield client
|
||||
|
||||
|
||||
# Database fixtures for integration tests
|
||||
@pytest.fixture
|
||||
def test_database_url():
|
||||
"""Test database URL (only for integration tests with real DB)."""
|
||||
return "postgresql://test:test@localhost:5432/test_agent_memory"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def integration_test_config(test_database_url):
|
||||
"""Configuration for integration tests with real database."""
|
||||
return {
|
||||
"provider": "openai",
|
||||
"openai": {
|
||||
"api_key": "test-key",
|
||||
"model": "gpt-4o"
|
||||
},
|
||||
"retrieval": {
|
||||
"endpoint": "http://localhost:8000/search", # Assume test retrieval server
|
||||
"api_key": "test-key"
|
||||
},
|
||||
"postgresql": {
|
||||
"connection_string": test_database_url
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Skip markers for different test types
|
||||
def pytest_configure(config):
|
||||
"""Configure pytest markers."""
|
||||
config.addinivalue_line("markers", "unit: mark test as unit test")
|
||||
config.addinivalue_line("markers", "integration: mark test as integration test")
|
||||
config.addinivalue_line("markers", "e2e: mark test as end-to-end test")
|
||||
config.addinivalue_line("markers", "slow: mark test as slow running")
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
"""Setup for test items."""
|
||||
# Skip integration tests if not explicitly requested
|
||||
if "integration" in item.keywords and not item.config.getoption("--run-integration"):
|
||||
pytest.skip("Integration tests not requested")
|
||||
|
||||
# Skip E2E tests if not explicitly requested
|
||||
if "e2e" in item.keywords and not item.config.getoption("--run-e2e"):
|
||||
pytest.skip("E2E tests not requested")
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
"""Add custom command line options."""
|
||||
parser.addoption(
|
||||
"--run-integration",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="run integration tests"
|
||||
)
|
||||
parser.addoption(
|
||||
"--run-e2e",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="run end-to-end tests"
|
||||
)
|
||||
Reference in New Issue
Block a user