init
This commit is contained in:
170
vw-agentic-rag/tests/integration/test_2phase_retrieval.py
Normal file
170
vw-agentic-rag/tests/integration/test_2phase_retrieval.py
Normal file
@@ -0,0 +1,170 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test 2-phase retrieval strategy
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def test_2phase_retrieval():
|
||||
"""Test that agent uses 2-phase retrieval for content-focused queries"""
|
||||
|
||||
session_id = f"2phase-test-{random.randint(1000000000, 9999999999)}"
|
||||
base_url = "http://127.0.0.1:8000"
|
||||
|
||||
# Test query that should trigger 2-phase retrieval
|
||||
query = "如何测试电动汽车的充电性能?请详细说明测试方法和步骤。"
|
||||
|
||||
logger.info("🎯 2-PHASE RETRIEVAL TEST")
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"📝 Session: {session_id}")
|
||||
logger.info(f"📝 Query: {query}")
|
||||
logger.info("-" * 60)
|
||||
|
||||
# Create the request payload
|
||||
payload = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": query
|
||||
}
|
||||
],
|
||||
"session_id": session_id
|
||||
}
|
||||
|
||||
# Track tool usage
|
||||
metadata_tools = 0
|
||||
content_tools = 0
|
||||
total_tools = 0
|
||||
|
||||
timeout = httpx.Timeout(120.0) # 2 minute timeout
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
logger.info("✅ Streaming response started")
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{base_url}/api/chat",
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json"}
|
||||
) as response:
|
||||
|
||||
# Check if the response started successfully
|
||||
if response.status_code != 200:
|
||||
error_body = await response.aread()
|
||||
logger.error(f"❌ HTTP {response.status_code}: {error_body.decode()}")
|
||||
return
|
||||
|
||||
# Process the streaming response
|
||||
current_event_type = None
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
if line.startswith("event: "):
|
||||
current_event_type = line[7:] # Remove "event: " prefix
|
||||
continue
|
||||
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:] # Remove "data: " prefix
|
||||
|
||||
if data_str == "[DONE]":
|
||||
logger.info("✅ Stream completed with [DONE]")
|
||||
break
|
||||
|
||||
try:
|
||||
event_data = json.loads(data_str)
|
||||
event_type = current_event_type or "unknown"
|
||||
|
||||
if event_type == "tool_start":
|
||||
total_tools += 1
|
||||
tool_name = event_data.get("name", "unknown")
|
||||
args = event_data.get("args", {})
|
||||
query_arg = args.get("query", "")[:50] + "..." if len(args.get("query", "")) > 50 else args.get("query", "")
|
||||
|
||||
if tool_name == "retrieve_standard_regulation":
|
||||
metadata_tools += 1
|
||||
logger.info(f"📋 Phase 1 Tool {metadata_tools}: {tool_name}")
|
||||
logger.info(f" Query: {query_arg}")
|
||||
elif tool_name == "retrieve_doc_chunk_standard_regulation":
|
||||
content_tools += 1
|
||||
logger.info(f"📄 Phase 2 Tool {content_tools}: {tool_name}")
|
||||
logger.info(f" Query: {query_arg}")
|
||||
else:
|
||||
logger.info(f"🔧 Tool {total_tools}: {tool_name}")
|
||||
|
||||
elif event_type == "tool_result":
|
||||
tool_name = event_data.get("name", "unknown")
|
||||
results_count = len(event_data.get("results", []))
|
||||
took_ms = event_data.get("took_ms", 0)
|
||||
logger.info(f"✅ Tool completed: {tool_name} ({results_count} results, {took_ms}ms)")
|
||||
|
||||
elif event_type == "tokens":
|
||||
# Don't log every token, just count them
|
||||
pass
|
||||
|
||||
# Reset event type for next event
|
||||
current_event_type = None
|
||||
|
||||
# Break after many tools to avoid too much output
|
||||
if total_tools > 20:
|
||||
logger.info(" ⚠️ Breaking after 20 tools...")
|
||||
break
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"⚠️ Failed to parse event: {e}")
|
||||
current_event_type = None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Request failed: {e}")
|
||||
return
|
||||
|
||||
# Results
|
||||
logger.info("=" * 80)
|
||||
logger.info("📊 2-PHASE RETRIEVAL ANALYSIS")
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"Phase 1 (Metadata) tools: {metadata_tools}")
|
||||
logger.info(f"Phase 2 (Content) tools: {content_tools}")
|
||||
logger.info(f"Total tools executed: {total_tools}")
|
||||
logger.info("-" * 60)
|
||||
|
||||
# Success criteria
|
||||
success_criteria = [
|
||||
(metadata_tools > 0, f"Phase 1 metadata retrieval: {'✅' if metadata_tools > 0 else '❌'} ({metadata_tools} tools)"),
|
||||
(content_tools > 0, f"Phase 2 content retrieval: {'✅' if content_tools > 0 else '❌'} ({content_tools} tools)"),
|
||||
(total_tools >= 2, f"Multi-tool execution: {'✅' if total_tools >= 2 else '❌'} ({total_tools} tools)")
|
||||
]
|
||||
|
||||
logger.info("✅ SUCCESS CRITERIA:")
|
||||
all_passed = True
|
||||
for passed, message in success_criteria:
|
||||
logger.info(f" {message}")
|
||||
if not passed:
|
||||
all_passed = False
|
||||
|
||||
if all_passed:
|
||||
logger.info("🎉 2-PHASE RETRIEVAL TEST PASSED!")
|
||||
logger.info(" ✅ Agent correctly uses both metadata and content retrieval tools")
|
||||
else:
|
||||
logger.info("❌ 2-PHASE RETRIEVAL TEST FAILED!")
|
||||
if metadata_tools == 0:
|
||||
logger.info(" ❌ No metadata retrieval tools used")
|
||||
if content_tools == 0:
|
||||
logger.info(" ❌ No content retrieval tools used - this is the main issue!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_2phase_retrieval())
|
||||
Reference in New Issue
Block a user