318 lines
9.4 KiB
Python
318 lines
9.4 KiB
Python
"""
|
|
Shared pytest fixtures and configuration for the agentic-rag test suite.
|
|
"""
|
|
import pytest
|
|
import asyncio
|
|
import httpx
|
|
from unittest.mock import Mock, AsyncMock, patch
|
|
from fastapi.testclient import TestClient
|
|
|
|
from service.main import create_app
|
|
from service.config import Config
|
|
from service.graph.state import TurnState, Message, ToolResult
|
|
from service.memory.postgresql_memory import PostgreSQLMemoryManager
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def event_loop():
|
|
"""Create an instance of the default event loop for the test session."""
|
|
policy = asyncio.get_event_loop_policy()
|
|
loop = policy.new_event_loop()
|
|
yield loop
|
|
loop.close()
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def config_mock():
|
|
"""Mock configuration for all tests."""
|
|
config = Mock()
|
|
config.retrieval.endpoint = "http://test-endpoint"
|
|
config.retrieval.api_key = "test-key"
|
|
config.llm.provider = "openai"
|
|
config.llm.model = "gpt-4"
|
|
config.llm.api_key = "test-api-key"
|
|
config.memory.enabled = True
|
|
config.memory.type = "in_memory"
|
|
config.memory.ttl_days = 7
|
|
config.postgresql.enabled = False
|
|
|
|
with patch('service.config.get_config', return_value=config):
|
|
with patch('service.retrieval.retrieval.get_config', return_value=config):
|
|
with patch('service.graph.graph.get_config', return_value=config):
|
|
yield config
|
|
|
|
|
|
@pytest.fixture
|
|
def test_config():
|
|
"""Test configuration with safe defaults."""
|
|
return {
|
|
"provider": "openai",
|
|
"openai": {
|
|
"api_key": "test-openai-key",
|
|
"model": "gpt-4o",
|
|
"base_url": "https://api.openai.com/v1",
|
|
"temperature": 0.2
|
|
},
|
|
"retrieval": {
|
|
"endpoint": "http://test-retrieval-endpoint",
|
|
"api_key": "test-retrieval-key"
|
|
},
|
|
"postgresql": {
|
|
"host": "localhost",
|
|
"port": 5432,
|
|
"database": "test_agent_memory",
|
|
"username": "test",
|
|
"password": "test",
|
|
"ttl_days": 1
|
|
},
|
|
"app": {
|
|
"name": "agentic-rag-test",
|
|
"memory_ttl_days": 1,
|
|
"max_tool_loops": 3,
|
|
"cors_origins": ["*"]
|
|
},
|
|
"llm": {
|
|
"rag": {
|
|
"temperature": 0,
|
|
"max_context_length": 32000,
|
|
"agent_system_prompt": "You are a test assistant."
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def app(test_config):
|
|
"""Create test FastAPI app with mocked configuration."""
|
|
with patch('service.config.load_config') as mock_load_config:
|
|
mock_load_config.return_value = test_config
|
|
|
|
# Mock the memory manager to avoid PostgreSQL dependency in tests
|
|
with patch('service.memory.postgresql_memory.get_memory_manager') as mock_memory:
|
|
mock_memory_manager = Mock()
|
|
mock_memory_manager.test_connection.return_value = True
|
|
mock_memory.return_value = mock_memory_manager
|
|
|
|
# Mock the graph builder to avoid complex dependencies
|
|
with patch('service.graph.graph.build_graph') as mock_build_graph:
|
|
mock_graph = Mock()
|
|
mock_build_graph.return_value = mock_graph
|
|
|
|
app = create_app()
|
|
app.state.memory_manager = mock_memory_manager
|
|
app.state.graph = mock_graph
|
|
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app):
|
|
"""Create test client."""
|
|
return TestClient(app)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_llm_client():
|
|
"""Mock LLM client for testing."""
|
|
mock = AsyncMock()
|
|
mock.astream.return_value = iter(["Test", " response", " token"])
|
|
mock.ainvoke_with_tools.return_value = Mock(
|
|
content="Test response",
|
|
tool_calls=[
|
|
{
|
|
"id": "test_tool_call_1",
|
|
"function": {
|
|
"name": "retrieve_standard_regulation",
|
|
"arguments": '{"query": "test query"}'
|
|
}
|
|
}
|
|
]
|
|
)
|
|
return mock
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_retrieval_response():
|
|
"""Mock response from retrieval API."""
|
|
return {
|
|
"results": [
|
|
{
|
|
"id": "test_result_1",
|
|
"title": "ISO 26262-1:2018",
|
|
"content": "Road vehicles — Functional safety — Part 1: Vocabulary",
|
|
"score": 0.95,
|
|
"url": "https://iso.org/26262-1",
|
|
"metadata": {
|
|
"@tool_call_id": "test_tool_call_1",
|
|
"@order_num": 0
|
|
}
|
|
},
|
|
{
|
|
"id": "test_result_2",
|
|
"title": "ISO 26262-3:2018",
|
|
"content": "Road vehicles — Functional safety — Part 3: Concept phase",
|
|
"score": 0.88,
|
|
"url": "https://iso.org/26262-3",
|
|
"metadata": {
|
|
"@tool_call_id": "test_tool_call_1",
|
|
"@order_num": 1
|
|
}
|
|
}
|
|
],
|
|
"metadata": {
|
|
"total": 2,
|
|
"took_ms": 150,
|
|
"query": "test query"
|
|
}
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_chat_request():
|
|
"""Sample chat request for testing."""
|
|
return {
|
|
"session_id": "test_session_123",
|
|
"messages": [
|
|
{"role": "user", "content": "What is ISO 26262?"}
|
|
]
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_turn_state():
|
|
"""Sample TurnState for testing."""
|
|
return TurnState(
|
|
session_id="test_session_123",
|
|
messages=[
|
|
Message(role="user", content="What is ISO 26262?")
|
|
]
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_httpx_client():
|
|
"""Mock httpx client for API requests."""
|
|
mock_client = AsyncMock()
|
|
|
|
# Default response for retrieval API
|
|
mock_response = Mock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"results": [
|
|
{
|
|
"id": "test_result",
|
|
"title": "Test Standard",
|
|
"content": "Test content",
|
|
"score": 0.9
|
|
}
|
|
]
|
|
}
|
|
|
|
mock_client.post.return_value = mock_response
|
|
return mock_client
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_postgresql_memory():
|
|
"""Mock PostgreSQL memory manager."""
|
|
mock_manager = Mock(spec=PostgreSQLMemoryManager)
|
|
mock_manager.test_connection.return_value = True
|
|
|
|
mock_checkpointer = Mock()
|
|
mock_checkpointer.setup.return_value = None
|
|
mock_manager.get_checkpointer.return_value = mock_checkpointer
|
|
|
|
return mock_manager
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_streaming_response():
|
|
"""Mock streaming response events."""
|
|
return [
|
|
'event: tool_start\ndata: {"id": "test_tool_1", "name": "retrieve_standard_regulation", "args": {"query": "test"}}\n\n',
|
|
'event: tokens\ndata: {"delta": "Based on the retrieved standards", "tool_call_id": null}\n\n',
|
|
'event: tool_result\ndata: {"id": "test_tool_1", "name": "retrieve_standard_regulation", "results": [], "took_ms": 100}\n\n',
|
|
'event: tokens\ndata: {"delta": " this is a test response.", "tool_call_id": null}\n\n'
|
|
]
|
|
|
|
|
|
# Async test helpers
|
|
@pytest.fixture
|
|
def mock_agent_state():
|
|
"""Mock agent state for graph testing."""
|
|
return {
|
|
"messages": [],
|
|
"session_id": "test_session",
|
|
"tool_results": [],
|
|
"final_answer": ""
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
async def async_test_client():
|
|
"""Async test client for integration tests."""
|
|
async with httpx.AsyncClient() as client:
|
|
yield client
|
|
|
|
|
|
# Database fixtures for integration tests
|
|
@pytest.fixture
|
|
def test_database_url():
|
|
"""Test database URL (only for integration tests with real DB)."""
|
|
return "postgresql://test:test@localhost:5432/test_agent_memory"
|
|
|
|
|
|
@pytest.fixture
|
|
def integration_test_config(test_database_url):
|
|
"""Configuration for integration tests with real database."""
|
|
return {
|
|
"provider": "openai",
|
|
"openai": {
|
|
"api_key": "test-key",
|
|
"model": "gpt-4o"
|
|
},
|
|
"retrieval": {
|
|
"endpoint": "http://localhost:8000/search", # Assume test retrieval server
|
|
"api_key": "test-key"
|
|
},
|
|
"postgresql": {
|
|
"connection_string": test_database_url
|
|
}
|
|
}
|
|
|
|
|
|
# Skip markers for different test types
|
|
def pytest_configure(config):
|
|
"""Configure pytest markers."""
|
|
config.addinivalue_line("markers", "unit: mark test as unit test")
|
|
config.addinivalue_line("markers", "integration: mark test as integration test")
|
|
config.addinivalue_line("markers", "e2e: mark test as end-to-end test")
|
|
config.addinivalue_line("markers", "slow: mark test as slow running")
|
|
|
|
|
|
def pytest_runtest_setup(item):
|
|
"""Setup for test items."""
|
|
# Skip integration tests if not explicitly requested
|
|
if "integration" in item.keywords and not item.config.getoption("--run-integration"):
|
|
pytest.skip("Integration tests not requested")
|
|
|
|
# Skip E2E tests if not explicitly requested
|
|
if "e2e" in item.keywords and not item.config.getoption("--run-e2e"):
|
|
pytest.skip("E2E tests not requested")
|
|
|
|
|
|
def pytest_addoption(parser):
|
|
"""Add custom command line options."""
|
|
parser.addoption(
|
|
"--run-integration",
|
|
action="store_true",
|
|
default=False,
|
|
help="run integration tests"
|
|
)
|
|
parser.addoption(
|
|
"--run-e2e",
|
|
action="store_true",
|
|
default=False,
|
|
help="run end-to-end tests"
|
|
)
|