Files
catonline_ai/vw-agentic-rag/docs/testing.md
2025-09-26 17:15:54 +08:00

28 KiB

🧪 Testing Guide

This guide covers the testing strategy, test structure, and best practices for the Agentic RAG system. It includes unit tests, integration tests, end-to-end tests, and performance testing approaches.

Testing Philosophy

Our testing strategy follows the testing pyramid:

        /\
       /  \
      / E2E \ (Few, Slow, High Confidence)
     /______\
    /        \
   /Integration\ (Some, Medium Speed)
  /____________\
 /              \
/   Unit Tests   \ (Many, Fast, Low Level)
/________________\

Test Categories

  • Unit Tests: Fast, isolated tests for individual functions and classes
  • Integration Tests: Test component interactions with real dependencies
  • End-to-End Tests: Full workflow tests simulating real user scenarios
  • Performance Tests: Load testing and performance benchmarks

Test Structure

tests/
├── conftest.py                    # Shared pytest fixtures
├── unit/                         # Unit tests (fast, isolated)
│   ├── test_config.py
│   ├── test_retrieval.py
│   ├── test_memory.py
│   ├── test_graph.py
│   ├── test_llm_client.py
│   └── test_sse.py
├── integration/                  # Integration tests  
│   ├── test_api.py
│   ├── test_streaming.py
│   ├── test_full_workflow.py
│   ├── test_mocked_streaming.py
│   └── test_e2e_tool_ui.py
└── performance/                  # Performance tests
    ├── test_load.py
    ├── test_memory_usage.py
    └── test_concurrent_users.py

Running Tests

Quick Test Commands

# Run all tests
make test

# Run specific test categories
make test-unit              # Unit tests only
make test-integration       # Integration tests only  
make test-e2e              # End-to-end tests

# Run with coverage
uv run pytest --cov=service --cov-report=html tests/

# Run specific test file
uv run pytest tests/unit/test_retrieval.py -v

# Run specific test method
uv run pytest tests/integration/test_api.py::test_chat_endpoint -v

# Run tests in parallel (faster)
uv run pytest -n auto tests/

# Run tests with detailed output
uv run pytest -s -vvv tests/

Test Configuration

The test configuration is defined in conftest.py:

# conftest.py
import pytest
import asyncio
import httpx
from unittest.mock import Mock, AsyncMock
from fastapi.testclient import TestClient

from service.main import create_app
from service.config import Config

@pytest.fixture(scope="session")
def event_loop():
    """Create an instance of the default event loop for the test session."""
    loop = asyncio.get_event_loop_policy().new_event_loop()
    yield loop
    loop.close()

@pytest.fixture
def test_config():
    """Test configuration with safe defaults."""
    return Config(
        provider="openai",
        openai_api_key="test-key",
        retrieval_endpoint="http://test-endpoint",
        retrieval_api_key="test-key",
        postgresql_host="localhost",
        postgresql_database="test_db",
        memory_ttl_days=1
    )

@pytest.fixture
def app(test_config):
    """Create test FastAPI app."""
    app = create_app()
    app.state.config = test_config
    return app

@pytest.fixture
def client(app):
    """Create test client."""
    return TestClient(app)

@pytest.fixture
def mock_llm():
    """Mock LLM client for testing."""
    mock = AsyncMock()
    mock.agenerate.return_value = Mock(
        generations=[[Mock(text="Mocked response")]]
    )
    return mock

Unit Tests

Unit tests focus on testing individual components in isolation.

Testing Retrieval Tools

# tests/unit/test_retrieval.py
import pytest
from unittest.mock import AsyncMock, patch
import httpx

from service.retrieval.agentic_retrieval import RetrievalTool

class TestRetrievalTool:
    
    @pytest.fixture
    def tool(self):
        return RetrievalTool(
            endpoint="http://test-endpoint",
            api_key="test-key"
        )
    
    @pytest.mark.asyncio
    async def test_search_standards_success(self, tool):
        mock_response = {
            "results": [
                {"title": "ISO 26262", "content": "Functional safety"},
                {"title": "UN 38.3", "content": "Battery safety"}
            ],
            "metadata": {"total": 2, "took_ms": 150}
        }
        
        with patch('httpx.AsyncClient.post') as mock_post:
            mock_post.return_value.json.return_value = mock_response
            mock_post.return_value.status_code = 200
            
            result = await tool.search_standards("battery safety")
            
            assert len(result["results"]) == 2
            assert result["results"][0]["title"] == "ISO 26262"
            assert result["metadata"]["took_ms"] == 150
    
    @pytest.mark.asyncio
    async def test_search_standards_http_error(self, tool):
        with patch('httpx.AsyncClient.post') as mock_post:
            mock_post.side_effect = httpx.HTTPStatusError(
                message="Not Found",
                request=Mock(),
                response=Mock(status_code=404)
            )
            
            with pytest.raises(Exception) as exc_info:
                await tool.search_standards("nonexistent")
            
            assert "HTTP error" in str(exc_info.value)
    
    def test_format_query(self, tool):
        query = tool._format_query("test query", {"history": "previous"})
        assert "test query" in query
        assert "previous" in query

Testing Configuration

# tests/unit/test_config.py
import os
import pytest
from pydantic import ValidationError

from service.config import Config, load_config

class TestConfig:
    
    def test_config_validation_success(self):
        config = Config(
            provider="openai",
            openai_api_key="test-key",
            retrieval_endpoint="http://test.com",
            retrieval_api_key="test-key"
        )
        assert config.provider == "openai"
        assert config.openai_api_key == "test-key"
    
    def test_config_validation_missing_required(self):
        with pytest.raises(ValidationError):
            Config(provider="openai")  # Missing required fields
    
    def test_load_config_from_env(self, monkeypatch):
        monkeypatch.setenv("OPENAI_API_KEY", "env-key")
        monkeypatch.setenv("RETRIEVAL_API_KEY", "env-retrieval-key")
        
        # Mock config file loading
        with patch('service.config.yaml.safe_load') as mock_yaml:
            mock_yaml.return_value = {
                "provider": "openai",
                "retrieval": {"endpoint": "http://test.com"}
            }
            
            config = load_config()
            assert config.openai_api_key == "env-key"

Testing LLM Client

# tests/unit/test_llm_client.py
import pytest
from unittest.mock import Mock, AsyncMock, patch

from service.llm_client import get_llm_client, OpenAIClient

class TestLLMClient:
    
    @pytest.mark.asyncio
    async def test_openai_client_generate(self):
        with patch('openai.AsyncOpenAI') as mock_openai:
            mock_client = AsyncMock()
            mock_openai.return_value = mock_client
            
            mock_response = Mock()
            mock_response.choices = [
                Mock(message=Mock(content="Generated response"))
            ]
            mock_client.chat.completions.create.return_value = mock_response
            
            client = OpenAIClient(api_key="test", model="gpt-4")
            result = await client.generate([{"role": "user", "content": "test"}])
            
            assert result == "Generated response"
    
    def test_get_llm_client_openai(self, test_config):
        test_config.provider = "openai"
        test_config.openai_api_key = "test-key"
        
        client = get_llm_client(test_config)
        assert isinstance(client, OpenAIClient)
    
    def test_get_llm_client_unsupported(self, test_config):
        test_config.provider = "unsupported"
        
        with pytest.raises(ValueError, match="Unsupported provider"):
            get_llm_client(test_config)

Integration Tests

Integration tests verify that components work together correctly.

Testing API Endpoints

# tests/integration/test_api.py
import pytest
import json
from fastapi.testclient import TestClient

def test_health_endpoint(client):
    """Test health check endpoint."""
    response = client.get("/health")
    assert response.status_code == 200
    assert response.json() == {"status": "healthy", "service": "agentic-rag"}

def test_root_endpoint(client):
    """Test root endpoint."""
    response = client.get("/")
    assert response.status_code == 200
    data = response.json()
    assert "Agentic RAG API" in data["message"]

@pytest.mark.asyncio
async def test_chat_endpoint_integration():
    """Integration test for chat endpoint using httpx client."""
    async with httpx.AsyncClient() as client:
        request_data = {
            "messages": [{"role": "user", "content": "test question"}],
            "session_id": "test_session_123"
        }
        
        response = await client.post(
            "http://localhost:8000/api/chat",
            json=request_data,
            timeout=30.0
        )
        
        assert response.status_code == 200
        assert response.headers["content-type"] == "text/event-stream"

def test_chat_request_validation(client):
    """Test chat request validation."""
    # Missing messages
    response = client.post("/api/chat", json={})
    assert response.status_code == 422
    
    # Invalid message format
    response = client.post("/api/chat", json={
        "messages": [{"role": "invalid", "content": "test"}]
    })
    assert response.status_code == 422
    
    # Valid request
    response = client.post("/api/chat", json={
        "messages": [{"role": "user", "content": "test"}],
        "session_id": "test_session"
    })
    assert response.status_code == 200

Testing Streaming

# tests/integration/test_streaming.py
import pytest
import json
import asyncio
from httpx import AsyncClient

@pytest.mark.asyncio
async def test_streaming_event_format():
    """Test streaming response format."""
    async with AsyncClient() as client:
        request_data = {
            "messages": [{"role": "user", "content": "What is ISO 26262?"}],
            "session_id": "stream_test_session"
        }
        
        async with client.stream(
            "POST",
            "http://localhost:8000/api/chat",
            json=request_data,
            timeout=60.0
        ) as response:
            assert response.status_code == 200
            
            events = []
            async for line in response.aiter_lines():
                if line.startswith("data: "):
                    try:
                        data = json.loads(line[6:])  # Remove "data: " prefix
                        events.append(data)
                    except json.JSONDecodeError:
                        continue
            
            # Verify we got expected event types
            event_types = [event.get("type") for event in events if "type" in event]
            assert "tool_start" in event_types
            assert "tokens" in event_types
            assert "tool_result" in event_types

@pytest.mark.asyncio
async def test_concurrent_streaming():
    """Test concurrent streaming requests."""
    async def single_request(session_id: str):
        async with AsyncClient() as client:
            request_data = {
                "messages": [{"role": "user", "content": f"Test {session_id}"}],
                "session_id": session_id
            }
            
            response = await client.post(
                "http://localhost:8000/api/chat",
                json=request_data,
                timeout=30.0
            )
            return response.status_code
    
    # Run 5 concurrent requests
    tasks = [
        single_request(f"concurrent_test_{i}")
        for i in range(5)
    ]
    
    results = await asyncio.gather(*tasks)
    assert all(status == 200 for status in results)

Testing Memory Persistence

# tests/integration/test_memory.py
import pytest
from service.memory.postgresql_memory import PostgreSQLMemoryManager

@pytest.mark.asyncio
async def test_session_persistence():
    """Test that conversations persist across requests."""
    memory_manager = PostgreSQLMemoryManager("postgresql://test:test@localhost/test")
    
    if not memory_manager.test_connection():
        pytest.skip("PostgreSQL not available for testing")
    
    checkpointer = memory_manager.get_checkpointer()
    
    # Simulate first conversation turn
    session_id = "memory_test_session"
    initial_state = {
        "messages": [
            {"role": "user", "content": "Hello"},
            {"role": "assistant", "content": "Hi there!"}
        ]
    }
    
    # Save state
    await checkpointer.aput(
        config={"configurable": {"session_id": session_id}},
        checkpoint={
            "id": "checkpoint_1",
            "ts": "2024-01-01T00:00:00Z"
        },
        metadata={},
        new_versions={}
    )
    
    # Retrieve state
    retrieved = await checkpointer.aget_tuple(
        config={"configurable": {"session_id": session_id}}
    )
    
    assert retrieved is not None
    assert retrieved.checkpoint["id"] == "checkpoint_1"

End-to-End Tests

E2E tests simulate complete user workflows.

Full Workflow Test

# tests/integration/test_full_workflow.py
import pytest
import asyncio
import json
from httpx import AsyncClient

@pytest.mark.asyncio
async def test_complete_rag_workflow():
    """Test complete RAG workflow from query to citation."""
    
    async with AsyncClient() as client:
        # Step 1: Send initial query
        request_data = {
            "messages": [
                {"role": "user", "content": "What are the safety standards for lithium-ion batteries?"}
            ],
            "session_id": "e2e_workflow_test"
        }
        
        response = await client.post(
            "http://localhost:8000/api/chat",
            json=request_data,
            timeout=120.0
        )
        
        assert response.status_code == 200
        
        # Step 2: Parse streaming response
        events = []
        tool_calls = []
        final_answer = None
        citations = None
        
        async for line in response.aiter_lines():
            if line.startswith("data: "):
                try:
                    data = json.loads(line[6:])
                    events.append(data)
                    
                    if data.get("type") == "tool_start":
                        tool_calls.append(data["name"])
                    elif data.get("type") == "post_append_1":
                        final_answer = data.get("answer")
                        citations = data.get("citations_mapping_csv")
                        
                except json.JSONDecodeError:
                    continue
        
        # Step 3: Verify workflow execution
        assert len(tool_calls) > 0, "No tools were called"
        assert "retrieve_standard_regulation" in tool_calls or \
               "retrieve_doc_chunk_standard_regulation" in tool_calls
        
        assert final_answer is not None, "No final answer received"
        assert "safety" in final_answer.lower() or "standard" in final_answer.lower()
        
        if citations:
            assert len(citations.split('\n')) > 0, "No citations provided"
        
        # Step 4: Follow-up question to test memory
        followup_request = {
            "messages": [
                {"role": "user", "content": "What are the safety standards for lithium-ion batteries?"},
                {"role": "assistant", "content": final_answer},
                {"role": "user", "content": "What about testing procedures?"}
            ],
            "session_id": "e2e_workflow_test"  # Same session
        }
        
        followup_response = await client.post(
            "http://localhost:8000/api/chat",
            json=followup_request,
            timeout=120.0
        )
        
        assert followup_response.status_code == 200

@pytest.mark.asyncio  
async def test_error_handling():
    """Test error handling in workflow."""
    
    async with AsyncClient() as client:
        # Test with invalid session format
        request_data = {
            "messages": [{"role": "user", "content": "test"}],
            "session_id": ""  # Invalid session ID
        }
        
        response = await client.post(
            "http://localhost:8000/api/chat",
            json=request_data,
            timeout=30.0
        )
        
        # Should handle gracefully (generate new session ID)
        assert response.status_code == 200

Frontend Integration Test

# tests/integration/test_e2e_tool_ui.py
import pytest
from playwright.sync_api import sync_playwright

@pytest.mark.skipif(
    not os.getenv("RUN_E2E_TESTS"),
    reason="E2E tests require RUN_E2E_TESTS=1"
)
def test_chat_interface():
    """Test the frontend chat interface."""
    
    with sync_playwright() as p:
        browser = p.chromium.launch(headless=True)
        page = browser.new_page()
        
        # Navigate to chat interface
        page.goto("http://localhost:3000")
        
        # Wait for chat interface to load
        page.wait_for_selector('[data-testid="chat-input"]')
        
        # Send a message
        chat_input = page.locator('[data-testid="chat-input"]')
        chat_input.fill("What is ISO 26262?")
        
        send_button = page.locator('[data-testid="send-button"]')
        send_button.click()
        
        # Wait for response
        page.wait_for_selector('[data-testid="assistant-message"]', timeout=30000)
        
        # Verify response appeared
        response = page.locator('[data-testid="assistant-message"]').first
        assert response.is_visible()
        
        # Check for tool UI elements
        tool_ui = page.locator('[data-testid="tool-call"]')
        if tool_ui.count() > 0:
            assert tool_ui.first.is_visible()
        
        browser.close()

Performance Tests

Load Testing

# tests/performance/test_load.py
import pytest
import asyncio
import time
import statistics
from httpx import AsyncClient

@pytest.mark.asyncio
async def test_concurrent_requests():
    """Test system performance under concurrent load."""
    
    async def single_request(client: AsyncClient, request_id: int):
        start_time = time.time()
        
        request_data = {
            "messages": [{"role": "user", "content": f"Test query {request_id}"}],
            "session_id": f"load_test_{request_id}"
        }
        
        try:
            response = await client.post(
                "http://localhost:8000/api/chat",
                json=request_data,
                timeout=30.0
            )
            
            end_time = time.time()
            return {
                "status_code": response.status_code,
                "response_time": end_time - start_time,
                "success": response.status_code == 200
            }
        except Exception as e:
            end_time = time.time()
            return {
                "status_code": 0,
                "response_time": end_time - start_time,
                "success": False,
                "error": str(e)
            }
    
    # Test with 20 concurrent requests
    async with AsyncClient() as client:
        tasks = [single_request(client, i) for i in range(20)]
        results = await asyncio.gather(*tasks, return_exceptions=True)
    
    # Analyze results
    successful_requests = [r for r in results if isinstance(r, dict) and r["success"]]
    response_times = [r["response_time"] for r in successful_requests]
    
    success_rate = len(successful_requests) / len(results)
    avg_response_time = statistics.mean(response_times) if response_times else 0
    p95_response_time = statistics.quantiles(response_times, n=20)[18] if len(response_times) > 5 else 0
    
    print(f"Success rate: {success_rate:.2%}")
    print(f"Average response time: {avg_response_time:.2f}s")
    print(f"95th percentile: {p95_response_time:.2f}s")
    
    # Performance assertions
    assert success_rate >= 0.95, f"Success rate too low: {success_rate:.2%}"
    assert avg_response_time < 10.0, f"Average response time too high: {avg_response_time:.2f}s"
    assert p95_response_time < 20.0, f"95th percentile too high: {p95_response_time:.2f}s"

@pytest.mark.asyncio
async def test_memory_usage():
    """Test memory usage under load."""
    import psutil
    import gc
    
    process = psutil.Process()
    initial_memory = process.memory_info().rss / 1024 / 1024  # MB
    
    # Run multiple requests
    async with AsyncClient() as client:
        for i in range(50):
            request_data = {
                "messages": [{"role": "user", "content": f"Memory test {i}"}],
                "session_id": f"memory_test_{i}"
            }
            
            await client.post(
                "http://localhost:8000/api/chat",
                json=request_data,
                timeout=30.0
            )
            
            if i % 10 == 0:
                gc.collect()  # Force garbage collection
    
    final_memory = process.memory_info().rss / 1024 / 1024  # MB
    memory_increase = final_memory - initial_memory
    
    print(f"Initial memory: {initial_memory:.1f} MB")
    print(f"Final memory: {final_memory:.1f} MB")
    print(f"Memory increase: {memory_increase:.1f} MB")
    
    # Memory assertions (adjust based on expected usage)
    assert memory_increase < 100, f"Memory increase too high: {memory_increase:.1f} MB"

Test Data Management

Test Fixtures

# tests/fixtures.py
import pytest
from typing import List, Dict

@pytest.fixture
def sample_messages() -> List[Dict]:
    """Sample message history for testing."""
    return [
        {"role": "user", "content": "What is ISO 26262?"},
        {"role": "assistant", "content": "ISO 26262 is a functional safety standard..."},
        {"role": "user", "content": "What about testing procedures?"}
    ]

@pytest.fixture
def mock_retrieval_response() -> Dict:
    """Mock response from retrieval API."""
    return {
        "results": [
            {
                "title": "ISO 26262-1:2018",
                "content": "Road vehicles — Functional safety — Part 1: Vocabulary",
                "source": "ISO",
                "url": "https://iso.org/26262-1",
                "score": 0.95
            },
            {
                "title": "ISO 26262-3:2018", 
                "content": "Road vehicles — Functional safety — Part 3: Concept phase",
                "source": "ISO",
                "url": "https://iso.org/26262-3",
                "score": 0.88
            }
        ],
        "metadata": {
            "total": 2,
            "took_ms": 150,
            "query": "ISO 26262"
        }
    }

@pytest.fixture
def mock_llm_response() -> str:
    """Mock LLM response with citations."""
    return """ISO 26262 is an international standard for functional safety of electrical and electronic systems in road vehicles <sup>1</sup>. 

The standard consists of multiple parts:
- Part 1: Vocabulary <sup>1</sup>
- Part 3: Concept phase <sup>2</sup>

These standards ensure that safety-critical automotive systems operate reliably even in the presence of faults."""

Database Test Setup

# tests/database_setup.py
import asyncio
import pytest
from sqlalchemy import create_engine, text
from service.memory.postgresql_memory import PostgreSQLMemoryManager

@pytest.fixture(scope="session")
async def test_database():
    """Set up test database."""
    
    # Create test database
    engine = create_engine("postgresql://test:test@localhost/postgres")
    with engine.connect() as conn:
        conn.execute(text("DROP DATABASE IF EXISTS test_agentic_rag"))
        conn.execute(text("CREATE DATABASE test_agentic_rag"))
        conn.commit()
    
    # Initialize schema
    test_connection_string = "postgresql://test:test@localhost/test_agentic_rag"
    memory_manager = PostgreSQLMemoryManager(test_connection_string)
    checkpointer = memory_manager.get_checkpointer()
    checkpointer.setup()
    
    yield test_connection_string
    
    # Cleanup
    with engine.connect() as conn:
        conn.execute(text("DROP DATABASE test_agentic_rag"))
        conn.commit()

Continuous Integration

GitHub Actions Workflow

# .github/workflows/test.yml
name: Tests

on:
  push:
    branches: [ main, develop ]
  pull_request:
    branches: [ main ]

jobs:
  test:
    runs-on: ubuntu-latest
    
    services:
      postgres:
        image: postgres:15
        env:
          POSTGRES_PASSWORD: test
          POSTGRES_USER: test
          POSTGRES_DB: test
        options: >-
          --health-cmd pg_isready
          --health-interval 10s
          --health-timeout 5s
          --health-retries 5
        ports:
          - 5432:5432
    
    steps:
    - uses: actions/checkout@v4
    
    - name: Set up Python
      uses: actions/setup-python@v4
      with:
        python-version: '3.12'
    
    - name: Install uv
      uses: astral-sh/setup-uv@v1
    
    - name: Install dependencies
      run: uv sync --dev
    
    - name: Run unit tests
      run: uv run pytest tests/unit/ -v --cov=service --cov-report=xml
      env:
        DATABASE_URL: postgresql://test:test@localhost:5432/test
        OPENAI_API_KEY: test-key
        RETRIEVAL_API_KEY: test-key
    
    - name: Start test server
      run: |
        uv run uvicorn service.main:app --host 0.0.0.0 --port 8000 &
        sleep 10
      env:
        DATABASE_URL: postgresql://test:test@localhost:5432/test
        OPENAI_API_KEY: test-key
        RETRIEVAL_API_KEY: test-key
    
    - name: Run integration tests
      run: uv run pytest tests/integration/ -v
      env:
        DATABASE_URL: postgresql://test:test@localhost:5432/test
        OPENAI_API_KEY: test-key
        RETRIEVAL_API_KEY: test-key
    
    - name: Upload coverage to Codecov
      uses: codecov/codecov-action@v3
      with:
        file: ./coverage.xml

Testing Best Practices

1. Test Organization

  • Keep tests close to code: Mirror the source structure in test directories
  • Use descriptive names: Test names should clearly describe what they test
  • Group related tests: Use test classes to group related functionality

2. Test Data

  • Use fixtures: Create reusable test data with pytest fixtures
  • Avoid hardcoded values: Use factories or builders for test data generation
  • Clean up after tests: Ensure tests don't affect each other

3. Mocking Strategy

# Good: Mock external dependencies
@patch('service.retrieval.httpx.AsyncClient')
async def test_retrieval_with_mock(mock_client):
    # Test implementation
    pass

# Good: Mock at the right level
@patch('service.llm_client.OpenAIClient.generate')
async def test_agent_workflow(mock_generate):
    # Test workflow logic without hitting LLM API
    pass

# Avoid: Over-mocking (mocking everything)
# Avoid: Under-mocking (hitting real APIs in unit tests)

4. Async Testing

# Proper async test setup
@pytest.mark.asyncio
async def test_async_function():
    result = await async_function()
    assert result is not None

# Use async context managers
@pytest.mark.asyncio
async def test_with_async_client():
    async with AsyncClient() as client:
        response = await client.get("/")
        assert response.status_code == 200

5. Performance Testing

  • Set realistic timeouts: Don't make tests too strict or too loose
  • Test under load: Verify system behavior with concurrent requests
  • Monitor resource usage: Check memory leaks and CPU usage

6. Error Testing

def test_error_handling():
    """Test that errors are handled gracefully."""
    
    # Test invalid input
    with pytest.raises(ValueError):
        function_with_validation("")
    
    # Test network errors
    with patch('httpx.post', side_effect=httpx.ConnectError("Connection failed")):
        result = robust_function()
        assert result["error"] is not None

This testing guide provides a comprehensive framework for ensuring the quality and reliability of the Agentic RAG system. Regular testing at all levels helps maintain code quality and prevents regressions as the system evolves.