#!/usr/bin/env python3 """ Test 2-phase retrieval strategy """ import asyncio import httpx import json import logging import random import time # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger(__name__) async def test_2phase_retrieval(): """Test that agent uses 2-phase retrieval for content-focused queries""" session_id = f"2phase-test-{random.randint(1000000000, 9999999999)}" base_url = "http://127.0.0.1:8000" # Test query that should trigger 2-phase retrieval query = "如何测试电动汽车的充电性能?请详细说明测试方法和步骤。" logger.info("🎯 2-PHASE RETRIEVAL TEST") logger.info("=" * 80) logger.info(f"📝 Session: {session_id}") logger.info(f"📝 Query: {query}") logger.info("-" * 60) # Create the request payload payload = { "messages": [ { "role": "user", "content": query } ], "session_id": session_id } # Track tool usage metadata_tools = 0 content_tools = 0 total_tools = 0 timeout = httpx.Timeout(120.0) # 2 minute timeout try: async with httpx.AsyncClient(timeout=timeout) as client: logger.info("✅ Streaming response started") async with client.stream( "POST", f"{base_url}/api/chat", json=payload, headers={"Content-Type": "application/json"} ) as response: # Check if the response started successfully if response.status_code != 200: error_body = await response.aread() logger.error(f"❌ HTTP {response.status_code}: {error_body.decode()}") return # Process the streaming response current_event_type = None async for line in response.aiter_lines(): if not line.strip(): continue if line.startswith("event: "): current_event_type = line[7:] # Remove "event: " prefix continue if line.startswith("data: "): data_str = line[6:] # Remove "data: " prefix if data_str == "[DONE]": logger.info("✅ Stream completed with [DONE]") break try: event_data = json.loads(data_str) event_type = current_event_type or "unknown" if event_type == "tool_start": total_tools += 1 tool_name = event_data.get("name", "unknown") args = event_data.get("args", {}) query_arg = args.get("query", "")[:50] + "..." if len(args.get("query", "")) > 50 else args.get("query", "") if tool_name == "retrieve_standard_regulation": metadata_tools += 1 logger.info(f"📋 Phase 1 Tool {metadata_tools}: {tool_name}") logger.info(f" Query: {query_arg}") elif tool_name == "retrieve_doc_chunk_standard_regulation": content_tools += 1 logger.info(f"📄 Phase 2 Tool {content_tools}: {tool_name}") logger.info(f" Query: {query_arg}") else: logger.info(f"🔧 Tool {total_tools}: {tool_name}") elif event_type == "tool_result": tool_name = event_data.get("name", "unknown") results_count = len(event_data.get("results", [])) took_ms = event_data.get("took_ms", 0) logger.info(f"✅ Tool completed: {tool_name} ({results_count} results, {took_ms}ms)") elif event_type == "tokens": # Don't log every token, just count them pass # Reset event type for next event current_event_type = None # Break after many tools to avoid too much output if total_tools > 20: logger.info(" ⚠️ Breaking after 20 tools...") break except json.JSONDecodeError as e: logger.warning(f"⚠️ Failed to parse event: {e}") current_event_type = None except Exception as e: logger.error(f"❌ Request failed: {e}") return # Results logger.info("=" * 80) logger.info("📊 2-PHASE RETRIEVAL ANALYSIS") logger.info("=" * 80) logger.info(f"Phase 1 (Metadata) tools: {metadata_tools}") logger.info(f"Phase 2 (Content) tools: {content_tools}") logger.info(f"Total tools executed: {total_tools}") logger.info("-" * 60) # Success criteria success_criteria = [ (metadata_tools > 0, f"Phase 1 metadata retrieval: {'✅' if metadata_tools > 0 else '❌'} ({metadata_tools} tools)"), (content_tools > 0, f"Phase 2 content retrieval: {'✅' if content_tools > 0 else '❌'} ({content_tools} tools)"), (total_tools >= 2, f"Multi-tool execution: {'✅' if total_tools >= 2 else '❌'} ({total_tools} tools)") ] logger.info("✅ SUCCESS CRITERIA:") all_passed = True for passed, message in success_criteria: logger.info(f" {message}") if not passed: all_passed = False if all_passed: logger.info("🎉 2-PHASE RETRIEVAL TEST PASSED!") logger.info(" ✅ Agent correctly uses both metadata and content retrieval tools") else: logger.info("❌ 2-PHASE RETRIEVAL TEST FAILED!") if metadata_tools == 0: logger.info(" ❌ No metadata retrieval tools used") if content_tools == 0: logger.info(" ❌ No content retrieval tools used - this is the main issue!") if __name__ == "__main__": asyncio.run(test_2phase_retrieval())