2025-05-06 12:56:55 +08:00
|
|
|
#!/usr/bin/env python
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
Doris MCP SSE Server Implementation
|
|
|
|
|
|
|
|
|
|
Implements a standard MCP SSE server based on MCP's SseServerTransport,
|
|
|
|
|
supports bidirectional communication with clients, and integrates with the existing Doris-MCP-Server.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
import json
|
|
|
|
|
import uuid
|
|
|
|
|
import logging
|
|
|
|
|
import time
|
|
|
|
|
from typing import Any, Optional
|
|
|
|
|
from fastapi import FastAPI, Request
|
|
|
|
|
from fastapi.responses import JSONResponse
|
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
from sse_starlette.sse import EventSourceResponse
|
|
|
|
|
|
|
|
|
|
# Get logger
|
|
|
|
|
logger = logging.getLogger("doris-mcp-sse")
|
|
|
|
|
|
|
|
|
|
class DorisMCPSseServer:
|
|
|
|
|
"""Doris MCP SSE Server Implementation"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, mcp_server, app: FastAPI):
|
|
|
|
|
"""
|
|
|
|
|
Initialize the Doris MCP SSE server
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
mcp_server: FastMCP server instance
|
|
|
|
|
app: FastAPI application instance
|
|
|
|
|
"""
|
|
|
|
|
self.mcp_server = mcp_server
|
|
|
|
|
|
|
|
|
|
# Ensure app is a FastAPI instance
|
|
|
|
|
if not isinstance(app, FastAPI):
|
|
|
|
|
logger.warning("Passed application is not a FastAPI instance, will use the existing FastAPI instance")
|
|
|
|
|
|
|
|
|
|
self.app = app
|
|
|
|
|
|
|
|
|
|
# Add CORS middleware
|
|
|
|
|
self.app.add_middleware(
|
|
|
|
|
CORSMiddleware,
|
|
|
|
|
allow_origins=["http://localhost:3100"], # Specify frontend domain
|
|
|
|
|
allow_credentials=True,
|
|
|
|
|
allow_methods=["*"],
|
|
|
|
|
allow_headers=["*"],
|
|
|
|
|
expose_headers=["*"]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Client session management
|
|
|
|
|
self.client_sessions = {}
|
|
|
|
|
|
|
|
|
|
# Set up SSE routes
|
|
|
|
|
self.setup_sse_routes()
|
|
|
|
|
|
|
|
|
|
# Register startup event
|
|
|
|
|
@self.app.on_event("startup")
|
|
|
|
|
async def startup_event():
|
|
|
|
|
# Start session cleanup task
|
|
|
|
|
asyncio.create_task(self.cleanup_idle_sessions())
|
|
|
|
|
# Start task to send periodic status updates
|
|
|
|
|
asyncio.create_task(self.send_periodic_updates())
|
|
|
|
|
|
|
|
|
|
def setup_sse_routes(self):
|
|
|
|
|
"""Set up SSE related routes"""
|
|
|
|
|
|
|
|
|
|
@self.app.get("/health")
|
|
|
|
|
async def health_check():
|
|
|
|
|
"""Health check endpoint"""
|
|
|
|
|
try:
|
|
|
|
|
# Use direct health check logic
|
|
|
|
|
return {
|
|
|
|
|
"status": "healthy",
|
|
|
|
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
|
|
|
"server": "Doris MCP Server"
|
|
|
|
|
}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
return {
|
|
|
|
|
"status": "error",
|
|
|
|
|
"error": str(e),
|
|
|
|
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@self.app.get("/status")
|
|
|
|
|
async def status():
|
|
|
|
|
"""Get server status"""
|
|
|
|
|
try:
|
|
|
|
|
# Get tool list
|
|
|
|
|
tools = await self.mcp_server.list_tools()
|
|
|
|
|
tool_names = [tool.name if hasattr(tool, 'name') else str(tool) for tool in tools]
|
|
|
|
|
logger.info(f"Getting tool list, currently registered tools: {tool_names}")
|
|
|
|
|
|
|
|
|
|
# Get resource list
|
|
|
|
|
resources = await self.mcp_server.list_resources()
|
|
|
|
|
resource_names = [res.name if hasattr(res, 'name') else str(res) for res in resources]
|
|
|
|
|
|
|
|
|
|
# Get prompt template list
|
|
|
|
|
prompts = await self.mcp_server.list_prompts()
|
|
|
|
|
prompt_names = [prompt.name if hasattr(prompt, 'name') else str(prompt) for prompt in prompts]
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"status": "running",
|
|
|
|
|
"name": self.mcp_server.name,
|
|
|
|
|
"mode": "mcp_sse",
|
|
|
|
|
"clients": len(self.client_sessions),
|
|
|
|
|
"tools": tool_names,
|
|
|
|
|
"resources": resource_names,
|
|
|
|
|
"prompts": prompt_names
|
|
|
|
|
}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error getting status: {str(e)}")
|
|
|
|
|
return {
|
|
|
|
|
"status": "error",
|
|
|
|
|
"error": str(e)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@self.app.get("/sse")
|
|
|
|
|
async def mcp_sse_init(request: Request):
|
|
|
|
|
"""SSE service entry point, establishes client connection (New Endpoint)"""
|
|
|
|
|
# Generate session ID
|
|
|
|
|
session_id = str(uuid.uuid4())
|
|
|
|
|
logger.info(f"New SSE connection [Session ID: {session_id}] at /sse")
|
|
|
|
|
|
|
|
|
|
# Create client session
|
|
|
|
|
self.client_sessions[session_id] = {
|
|
|
|
|
"client_id": request.headers.get("X-Client-ID", f"client_{str(uuid.uuid4())[:8]}"),
|
|
|
|
|
"created_at": time.time(),
|
|
|
|
|
"last_active": time.time(),
|
|
|
|
|
"queue": asyncio.Queue()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Immediately put endpoint information into the queue
|
|
|
|
|
endpoint_data = f"/mcp/messages?session_id={session_id}"
|
|
|
|
|
await self.client_sessions[session_id]["queue"].put({
|
|
|
|
|
"event": "endpoint",
|
|
|
|
|
"data": endpoint_data
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
# Create event generator
|
|
|
|
|
async def event_generator():
|
|
|
|
|
try:
|
|
|
|
|
while True:
|
|
|
|
|
# Use timeout to get new messages, to detect client disconnect
|
|
|
|
|
try:
|
|
|
|
|
message = await asyncio.wait_for(
|
|
|
|
|
self.client_sessions[session_id]["queue"].get(),
|
|
|
|
|
timeout=30
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Check if it's a close command
|
|
|
|
|
if isinstance(message, dict) and message.get("event") == "close":
|
|
|
|
|
logger.info(f"Received close command [Session ID: {session_id}]")
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# Return message
|
|
|
|
|
if isinstance(message, dict):
|
|
|
|
|
if "event" in message:
|
|
|
|
|
# If event field exists, it's a system event
|
|
|
|
|
event_type = message["event"]
|
|
|
|
|
event_data = message["data"]
|
|
|
|
|
yield {
|
|
|
|
|
"event": event_type,
|
|
|
|
|
"data": event_data
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
# Otherwise it's a normal message, use message event
|
|
|
|
|
yield {
|
|
|
|
|
"event": "message",
|
|
|
|
|
"data": json.dumps(message)
|
|
|
|
|
}
|
|
|
|
|
elif isinstance(message, str):
|
|
|
|
|
# If it's a string, send directly
|
|
|
|
|
yield {
|
|
|
|
|
"event": "message",
|
|
|
|
|
"data": message
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
# Other types, convert to JSON
|
|
|
|
|
yield {
|
|
|
|
|
"event": "message",
|
|
|
|
|
"data": json.dumps(message)
|
|
|
|
|
}
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
# Send ping to keep connection alive
|
|
|
|
|
yield {
|
|
|
|
|
"event": "ping",
|
|
|
|
|
"data": "keepalive"
|
|
|
|
|
}
|
|
|
|
|
continue
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
# Connection cancelled
|
|
|
|
|
logger.info(f"SSE connection cancelled [Session ID: {session_id}]")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
# Other error occurred
|
|
|
|
|
logger.error(f"SSE event generator error [Session ID: {session_id}]: {str(e)}")
|
|
|
|
|
finally:
|
|
|
|
|
# Clean up session
|
|
|
|
|
if session_id in self.client_sessions:
|
|
|
|
|
logger.info(f"Cleaning up session [Session ID: {session_id}]")
|
|
|
|
|
del self.client_sessions[session_id]
|
|
|
|
|
|
|
|
|
|
# Return standard SSE response
|
|
|
|
|
return EventSourceResponse(
|
|
|
|
|
event_generator(),
|
|
|
|
|
headers={
|
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
|
"Access-Control-Allow-Credentials": "true",
|
|
|
|
|
"Access-Control-Allow-Methods": "*",
|
|
|
|
|
"Access-Control-Allow-Headers": "*",
|
|
|
|
|
"Access-Control-Expose-Headers": "*",
|
|
|
|
|
"Cache-Control": "no-cache",
|
|
|
|
|
"Connection": "keep-alive"
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@self.app.options("/mcp/messages")
|
|
|
|
|
async def mcp_messages_options(request: Request):
|
|
|
|
|
"""Handle preflight requests"""
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
{},
|
|
|
|
|
headers={
|
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
|
"Access-Control-Allow-Credentials": "true",
|
|
|
|
|
"Access-Control-Allow-Methods": "*",
|
|
|
|
|
"Access-Control-Allow-Headers": "*",
|
|
|
|
|
"Access-Control-Expose-Headers": "*"
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@self.app.post("/mcp/messages")
|
|
|
|
|
async def mcp_messages_handler(request: Request):
|
|
|
|
|
"""Handle client message requests, using class method"""
|
|
|
|
|
return await self.mcp_message(request)
|
|
|
|
|
|
|
|
|
|
async def cleanup_idle_sessions(self):
|
|
|
|
|
"""Clean up idle client sessions"""
|
|
|
|
|
while True:
|
|
|
|
|
await asyncio.sleep(60) # Check every minute
|
|
|
|
|
current_time = time.time()
|
|
|
|
|
|
|
|
|
|
# Find sessions idle for over 5 minutes
|
|
|
|
|
idle_sessions = []
|
|
|
|
|
for session_id, session in self.client_sessions.items():
|
|
|
|
|
if current_time - session["last_active"] > 300: # 5 minutes
|
|
|
|
|
idle_sessions.append(session_id)
|
|
|
|
|
|
|
|
|
|
# Close and remove idle sessions
|
|
|
|
|
for session_id in idle_sessions:
|
|
|
|
|
try:
|
|
|
|
|
# Send close message
|
|
|
|
|
await self.client_sessions[session_id]["queue"].put({"event": "close"})
|
|
|
|
|
# Clean up session
|
|
|
|
|
logger.info(f"Cleaned up idle session: {session_id}")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error cleaning up session: {str(e)}")
|
|
|
|
|
finally:
|
|
|
|
|
# Ensure session is removed
|
|
|
|
|
if session_id in self.client_sessions:
|
|
|
|
|
del self.client_sessions[session_id]
|
|
|
|
|
|
|
|
|
|
async def send_periodic_updates(self):
|
|
|
|
|
"""Periodically send status updates to all clients"""
|
|
|
|
|
while True:
|
|
|
|
|
try:
|
|
|
|
|
# Send status update every 5 seconds
|
|
|
|
|
await asyncio.sleep(5)
|
|
|
|
|
|
|
|
|
|
# If no clients are connected, skip this update
|
|
|
|
|
if not self.client_sessions:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Get current status
|
|
|
|
|
status_data = {
|
|
|
|
|
"timestamp": time.time(),
|
|
|
|
|
"clients_count": len(self.client_sessions),
|
|
|
|
|
"server_status": "running"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Send status update to all clients
|
|
|
|
|
await self.broadcast_status_update(status_data)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error sending periodic updates: {str(e)}")
|
|
|
|
|
# Wait a bit after error before continuing
|
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
|
|
|
|
|
|
async def broadcast_status_update(self, status_data):
|
|
|
|
|
"""Broadcast status update to all clients
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
status_data: Status data
|
|
|
|
|
"""
|
|
|
|
|
logger.debug(f"Broadcasting status update: {status_data}")
|
|
|
|
|
message = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"method": "notifications/status",
|
|
|
|
|
"params": {
|
|
|
|
|
"type": "status_update",
|
|
|
|
|
"data": status_data
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
await self.broadcast_message(message)
|
|
|
|
|
|
|
|
|
|
async def broadcast_visualization_data(self, visualization_data):
|
|
|
|
|
"""Broadcast visualization data to all clients
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
visualization_data: Visualization data, should include type field
|
|
|
|
|
"""
|
|
|
|
|
if not visualization_data or not isinstance(visualization_data, dict) or "type" not in visualization_data:
|
|
|
|
|
logger.warning(f"Invalid visualization data: {visualization_data}")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
logger.info(f"Broadcasting visualization data: {visualization_data['type']}")
|
|
|
|
|
message = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"method": "notifications/visualization",
|
|
|
|
|
"params": {
|
|
|
|
|
"type": "visualization",
|
|
|
|
|
"data": visualization_data
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
await self.broadcast_message(message)
|
|
|
|
|
|
|
|
|
|
async def send_visualization_data(self, session_id, visualization_data):
|
|
|
|
|
"""Send visualization data to a specific client
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
session_id: Session ID
|
|
|
|
|
visualization_data: Visualization data, should include type field
|
|
|
|
|
"""
|
|
|
|
|
if not visualization_data or not isinstance(visualization_data, dict) or "type" not in visualization_data:
|
|
|
|
|
logger.warning(f"Invalid visualization data: {visualization_data}")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if session_id not in self.client_sessions:
|
|
|
|
|
logger.warning(f"Session does not exist: {session_id}")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
logger.info(f"Sending visualization data to session {session_id}: {visualization_data['type']}")
|
|
|
|
|
message = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"method": "notifications/visualization",
|
|
|
|
|
"params": {
|
|
|
|
|
"type": "visualization",
|
|
|
|
|
"data": visualization_data
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
await self.client_sessions[session_id]["queue"].put(message)
|
|
|
|
|
|
|
|
|
|
async def send_tool_result(self, session_id, tool_name, result_data, is_final=True):
|
|
|
|
|
"""Send tool execution result to the client
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
session_id: Session ID
|
|
|
|
|
tool_name: Tool name
|
|
|
|
|
result_data: Result data
|
|
|
|
|
is_final: Whether it is the final result
|
|
|
|
|
"""
|
|
|
|
|
if session_id not in self.client_sessions:
|
|
|
|
|
logger.warning(f"Session does not exist: {session_id}")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
logger.info(f"Sending tool result to session {session_id}: {tool_name}")
|
|
|
|
|
message = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"method": "notifications/tool_result",
|
|
|
|
|
"params": {
|
|
|
|
|
"type": "tool_result",
|
|
|
|
|
"tool": tool_name,
|
|
|
|
|
"result": result_data,
|
|
|
|
|
"is_final": is_final
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
await self.client_sessions[session_id]["queue"].put(message)
|
|
|
|
|
|
|
|
|
|
async def broadcast_message(self, message):
|
|
|
|
|
"""Broadcast a message to all active sessions
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
message: Message to broadcast
|
|
|
|
|
"""
|
|
|
|
|
# If no clients are connected, return immediately
|
|
|
|
|
if not self.client_sessions:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Create a copy of the session ID list so the original dictionary can be safely modified during iteration
|
|
|
|
|
session_ids = list(self.client_sessions.keys())
|
|
|
|
|
|
|
|
|
|
# Send message to all sessions
|
|
|
|
|
for session_id in session_ids:
|
|
|
|
|
try:
|
|
|
|
|
if session_id in self.client_sessions: # Check again, as session might have been removed during iteration
|
|
|
|
|
await self.client_sessions[session_id]["queue"].put(message)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error sending message to session {session_id}: {str(e)}")
|
|
|
|
|
|
|
|
|
|
async def get_status(self):
|
|
|
|
|
"""Get server status"""
|
|
|
|
|
return {
|
|
|
|
|
"status": "running",
|
|
|
|
|
"name": self.mcp_server.name,
|
|
|
|
|
"mode": "mcp_sse",
|
|
|
|
|
"clients": len(self.client_sessions)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async def mcp_message(self, request: Request):
|
|
|
|
|
"""Endpoint to receive client messages"""
|
|
|
|
|
try:
|
|
|
|
|
# Parse request parameters
|
|
|
|
|
session_id = self._get_session_id(request)
|
|
|
|
|
|
|
|
|
|
# Check if session exists
|
|
|
|
|
if not session_id or session_id not in self.client_sessions:
|
|
|
|
|
logger.warning(f"Session does not exist: {session_id}")
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
{"jsonrpc": "2.0", "error": {"code": -32000, "message": "Session does not exist or has expired"}},
|
|
|
|
|
status_code=401
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Update session last active time
|
|
|
|
|
self.client_sessions[session_id]["last_active"] = time.time()
|
|
|
|
|
|
|
|
|
|
# Get request body
|
|
|
|
|
try:
|
|
|
|
|
body = await request.json()
|
|
|
|
|
logger.info(f"Received message [Session ID: {session_id}]: {json.dumps(body)}")
|
|
|
|
|
|
|
|
|
|
# Process message
|
|
|
|
|
message_id = body.get("id", str(uuid.uuid4()))
|
|
|
|
|
|
|
|
|
|
# Handle JSON-RPC 2.0 formatted commands
|
|
|
|
|
if "jsonrpc" not in body or body.get("jsonrpc") != "2.0" or "method" not in body:
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
{"jsonrpc": "2.0", "id": message_id, "error": {"code": -32600, "message": "Invalid request, must be JSON-RPC 2.0 format"}},
|
|
|
|
|
status_code=400
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Get method and parameters
|
|
|
|
|
method = body.get("method")
|
|
|
|
|
params = body.get("params", {})
|
|
|
|
|
|
|
|
|
|
# Special handling for JSON-RPC formatted commands
|
|
|
|
|
if method == "initialize":
|
|
|
|
|
# Initialization request
|
|
|
|
|
logger.info(f"Processing initialize command [Session ID: {session_id}]")
|
|
|
|
|
response = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"id": message_id,
|
|
|
|
|
"result": {
|
|
|
|
|
"protocolVersion": "2024-11-05",
|
|
|
|
|
"name": self.mcp_server.name,
|
|
|
|
|
"instructions": "This is an MCP server for Apache Doris database",
|
|
|
|
|
"serverInfo": {
|
|
|
|
|
"version": "0.1.0",
|
|
|
|
|
"name": "Doris MCP Server"
|
|
|
|
|
},
|
|
|
|
|
"capabilities": {
|
|
|
|
|
"tools": {
|
|
|
|
|
"supportsStreaming": True,
|
|
|
|
|
"supportsProgress": True
|
|
|
|
|
},
|
|
|
|
|
"resources": {
|
|
|
|
|
"supportsStreaming": False
|
|
|
|
|
},
|
|
|
|
|
"prompts": {
|
|
|
|
|
"supported": True
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
await self.client_sessions[session_id]["queue"].put(response)
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
{"status": "success"},
|
|
|
|
|
headers={
|
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
|
"Access-Control-Allow-Credentials": "true",
|
|
|
|
|
"Access-Control-Allow-Methods": "*",
|
|
|
|
|
"Access-Control-Allow-Headers": "*",
|
|
|
|
|
"Access-Control-Expose-Headers": "*"
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif method == "mcp/listOfferings":
|
|
|
|
|
# List all available features
|
|
|
|
|
logger.info(f"Processing listOfferings command [Session ID: {session_id}]")
|
|
|
|
|
|
|
|
|
|
# Get tool list
|
|
|
|
|
tools = await self.mcp_server.list_tools()
|
|
|
|
|
tools_json = [
|
|
|
|
|
{
|
|
|
|
|
"name": tool.name if hasattr(tool, "name") else str(tool),
|
|
|
|
|
"description": tool.description if hasattr(tool, "description") else "",
|
|
|
|
|
"inputSchema": tool.parameters if hasattr(tool, "parameters") else {
|
|
|
|
|
"type": "object",
|
|
|
|
|
"properties": {},
|
|
|
|
|
"required": []
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for tool in tools
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Get resource list
|
|
|
|
|
resources = await self.mcp_server.list_resources()
|
|
|
|
|
resources_json = [res.model_dump() if hasattr(res, "model_dump") else res for res in resources]
|
|
|
|
|
|
|
|
|
|
# Get prompt template list
|
|
|
|
|
prompts = await self.mcp_server.list_prompts()
|
|
|
|
|
prompts_json = [prompt.model_dump() if hasattr(prompt, "model_dump") else prompt for prompt in prompts]
|
|
|
|
|
|
|
|
|
|
# Build response
|
|
|
|
|
response = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"id": message_id,
|
|
|
|
|
"result": {
|
|
|
|
|
"tools": tools_json,
|
|
|
|
|
"resources": resources_json,
|
|
|
|
|
"prompts": prompts_json
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
await self.client_sessions[session_id]["queue"].put(response)
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
{"status": "success"},
|
|
|
|
|
headers={
|
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
|
"Access-Control-Allow-Credentials": "true",
|
|
|
|
|
"Access-Control-Allow-Methods": "*",
|
|
|
|
|
"Access-Control-Allow-Headers": "*",
|
|
|
|
|
"Access-Control-Expose-Headers": "*"
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif method == "mcp/listTools" or method == "tools/list":
|
|
|
|
|
# List all tools
|
|
|
|
|
logger.info(f"Processing listTools command [Session ID: {session_id}]")
|
|
|
|
|
tools = await self.mcp_server.list_tools()
|
|
|
|
|
tools_json = [
|
|
|
|
|
{
|
|
|
|
|
"name": tool.name if hasattr(tool, "name") else str(tool),
|
|
|
|
|
"description": tool.description if hasattr(tool, "description") else "",
|
|
|
|
|
"inputSchema": tool.parameters if hasattr(tool, "parameters") else {
|
|
|
|
|
"type": "object",
|
|
|
|
|
"properties": {},
|
|
|
|
|
"required": []
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for tool in tools
|
|
|
|
|
]
|
|
|
|
|
response = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"id": message_id,
|
|
|
|
|
"result": {
|
|
|
|
|
"tools": tools_json
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
await self.client_sessions[session_id]["queue"].put(response)
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
{"status": "success"},
|
|
|
|
|
headers={
|
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
|
"Access-Control-Allow-Credentials": "true",
|
|
|
|
|
"Access-Control-Allow-Methods": "*",
|
|
|
|
|
"Access-Control-Allow-Headers": "*",
|
|
|
|
|
"Access-Control-Expose-Headers": "*"
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif method == "mcp/listResources":
|
|
|
|
|
# List all resources
|
|
|
|
|
logger.info(f"Processing listResources command [Session ID: {session_id}]")
|
|
|
|
|
resources = await self.mcp_server.list_resources()
|
|
|
|
|
resources_json = [res.model_dump() if hasattr(res, "model_dump") else res for res in resources]
|
|
|
|
|
response = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"id": message_id,
|
|
|
|
|
"result": {
|
|
|
|
|
"resources": resources_json
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
await self.client_sessions[session_id]["queue"].put(response)
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
{"status": "success"},
|
|
|
|
|
headers={
|
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
|
"Access-Control-Allow-Credentials": "true",
|
|
|
|
|
"Access-Control-Allow-Methods": "*",
|
|
|
|
|
"Access-Control-Allow-Headers": "*",
|
|
|
|
|
"Access-Control-Expose-Headers": "*"
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif method == "mcp/callTool" or method == "tools/call":
|
|
|
|
|
# Call tool - special handling
|
|
|
|
|
tool_name = params.get("name")
|
|
|
|
|
arguments = params.get("arguments", {})
|
|
|
|
|
|
|
|
|
|
if not tool_name:
|
|
|
|
|
error_response = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"id": message_id,
|
|
|
|
|
"error": {
|
|
|
|
|
"code": -32602,
|
|
|
|
|
"message": "Invalid params: tool name is required"
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
await self.client_sessions[session_id]["queue"].put(error_response)
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
{"status": "error", "message": "Tool name is required"},
|
|
|
|
|
status_code=400,
|
|
|
|
|
headers={
|
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
|
"Access-Control-Allow-Credentials": "true",
|
|
|
|
|
"Access-Control-Allow-Methods": "*",
|
|
|
|
|
"Access-Control-Allow-Headers": "*",
|
|
|
|
|
"Access-Control-Expose-Headers": "*"
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Get MCP instance
|
|
|
|
|
mcp = self.mcp_server
|
|
|
|
|
|
|
|
|
|
# Check if it's a streaming tool call
|
|
|
|
|
stream_mode = "stream" in request.query_params or params.get("stream", False)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Calling tool [Session ID: {session_id}, Tool: {tool_name}, Args: {arguments}, Stream mode: {stream_mode}]")
|
|
|
|
|
|
|
|
|
|
if stream_mode:
|
|
|
|
|
# Streaming tool call
|
|
|
|
|
logger.info(f"Using streaming response to handle tool call [Session ID: {session_id}, Tool: {tool_name}]")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Find tool
|
|
|
|
|
tool_instance = None
|
|
|
|
|
for tool in await mcp.list_tools():
|
|
|
|
|
if getattr(tool, 'name', '') == tool_name:
|
|
|
|
|
tool_instance = tool
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if not tool_instance:
|
|
|
|
|
raise Exception(f"Tool {tool_name} does not exist")
|
|
|
|
|
|
|
|
|
|
# Define callback function
|
|
|
|
|
async def callback(content, metadata):
|
|
|
|
|
# Send partial result
|
|
|
|
|
partial_message = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"id": message_id,
|
|
|
|
|
"partial": True,
|
|
|
|
|
"result": {
|
|
|
|
|
"content": content,
|
|
|
|
|
"metadata": metadata
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
# Put message into queue
|
|
|
|
|
await self.client_sessions[session_id]["queue"].put(partial_message)
|
|
|
|
|
|
|
|
|
|
# If visualization data is included, broadcast to all clients
|
|
|
|
|
if metadata and "visualization" in metadata:
|
|
|
|
|
await self.broadcast_visualization_data(metadata["visualization"])
|
|
|
|
|
|
|
|
|
|
# Build argument dictionary
|
|
|
|
|
kwargs = dict(arguments)
|
|
|
|
|
kwargs['callback'] = callback
|
|
|
|
|
|
|
|
|
|
# Execute tool call
|
|
|
|
|
# Fix: Do not call tool object directly, instead create async task to call call_tool method
|
|
|
|
|
# func = tool_instance.func if hasattr(tool_instance, 'func') else tool_instance
|
|
|
|
|
|
|
|
|
|
# Start async task to execute streaming tool
|
|
|
|
|
# asyncio.create_task(self._execute_stream_tool(func, kwargs, message_id, session_id))
|
|
|
|
|
# Modified to use call_tool method
|
|
|
|
|
asyncio.create_task(self._execute_stream_tool_wrapper(tool_name, kwargs, message_id, session_id, request))
|
|
|
|
|
|
|
|
|
|
# Return received confirmation
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
{"status": "processing"},
|
|
|
|
|
headers={
|
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
|
"Access-Control-Allow-Credentials": "true",
|
|
|
|
|
"Access-Control-Allow-Methods": "*",
|
|
|
|
|
"Access-Control-Allow-Headers": "*",
|
|
|
|
|
"Access-Control-Expose-Headers": "*"
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Streaming tool processing error: {str(e)}")
|
|
|
|
|
error_message = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"id": message_id,
|
|
|
|
|
"success": False,
|
|
|
|
|
"error": str(e)
|
|
|
|
|
}
|
|
|
|
|
# Put error message into queue
|
|
|
|
|
await self.client_sessions[session_id]["queue"].put(error_message)
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
{"status": "error", "message": str(e)},
|
|
|
|
|
status_code=500,
|
|
|
|
|
headers={
|
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
|
"Access-Control-Allow-Credentials": "true",
|
|
|
|
|
"Access-Control-Allow-Methods": "*",
|
|
|
|
|
"Access-Control-Allow-Headers": "*",
|
|
|
|
|
"Access-Control-Expose-Headers": "*"
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
# Non-streaming tool call
|
|
|
|
|
logger.info(f"Using standard response to handle tool call [Session ID: {session_id}, Tool: {tool_name}]")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Find tool
|
|
|
|
|
tool_instance = None
|
|
|
|
|
for tool in await mcp.list_tools():
|
|
|
|
|
if getattr(tool, 'name', '') == tool_name:
|
|
|
|
|
tool_instance = tool
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if not tool_instance:
|
|
|
|
|
error_response = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"id": message_id,
|
|
|
|
|
"error": {
|
|
|
|
|
"code": -32601,
|
|
|
|
|
"message": f"Tool '{tool_name}' not found"
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
await self.client_sessions[session_id]["queue"].put(error_response)
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
{"status": "error", "message": f"Tool '{tool_name}' not found"},
|
|
|
|
|
status_code=404,
|
|
|
|
|
headers={
|
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
|
"Access-Control-Allow-Credentials": "true",
|
|
|
|
|
"Access-Control-Allow-Methods": "*",
|
|
|
|
|
"Access-Control-Allow-Headers": "*",
|
|
|
|
|
"Access-Control-Expose-Headers": "*"
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Execute tool call
|
|
|
|
|
# Fix: Do not call tool object directly, instead call custom call_tool method
|
|
|
|
|
result = await self.call_tool(tool_name, arguments, request)
|
|
|
|
|
|
|
|
|
|
# Special formatting for result
|
|
|
|
|
# If result is already in the correct format, use directly
|
|
|
|
|
if isinstance(result, dict) and "content" in result and isinstance(result["content"], list):
|
|
|
|
|
formatted_result = result
|
|
|
|
|
else:
|
|
|
|
|
# Otherwise, format into standard format
|
|
|
|
|
formatted_result = {
|
|
|
|
|
"content": [
|
|
|
|
|
{
|
|
|
|
|
"type": "json",
|
|
|
|
|
"text": result if isinstance(result, str) else json.dumps(result, ensure_ascii=False)
|
|
|
|
|
}
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Build response
|
|
|
|
|
response = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"id": message_id,
|
|
|
|
|
"result": formatted_result
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Put response into queue
|
|
|
|
|
await self.client_sessions[session_id]["queue"].put(response)
|
|
|
|
|
|
|
|
|
|
# Return received confirmation
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
{"status": "success"},
|
|
|
|
|
headers={
|
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
|
"Access-Control-Allow-Credentials": "true",
|
|
|
|
|
"Access-Control-Allow-Methods": "*",
|
|
|
|
|
"Access-Control-Allow-Headers": "*",
|
|
|
|
|
"Access-Control-Expose-Headers": "*"
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Tool call error: {str(e)}", exc_info=True)
|
|
|
|
|
|
|
|
|
|
# Build error response
|
|
|
|
|
if str(e).startswith('{"code":'):
|
|
|
|
|
# If it's a JSON formatted error, use directly
|
|
|
|
|
try:
|
|
|
|
|
error_obj = json.loads(str(e))
|
|
|
|
|
error_response = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"id": message_id,
|
|
|
|
|
"error": error_obj
|
|
|
|
|
}
|
|
|
|
|
except:
|
|
|
|
|
# If parsing fails, use standard format
|
|
|
|
|
error_response = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"id": message_id,
|
|
|
|
|
"error": {
|
|
|
|
|
"code": -32000,
|
|
|
|
|
"message": str(e)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
# Normal error string
|
|
|
|
|
error_response = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"id": message_id,
|
|
|
|
|
"error": {
|
|
|
|
|
"code": -32000,
|
|
|
|
|
"message": str(e)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Put error response into queue
|
|
|
|
|
await self.client_sessions[session_id]["queue"].put(error_response)
|
|
|
|
|
|
|
|
|
|
# Return error status
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
{"status": "error", "message": str(e)},
|
|
|
|
|
status_code=500,
|
|
|
|
|
headers={
|
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
|
"Access-Control-Allow-Credentials": "true",
|
|
|
|
|
"Access-Control-Allow-Methods": "*",
|
|
|
|
|
"Access-Control-Allow-Headers": "*",
|
|
|
|
|
"Access-Control-Expose-Headers": "*"
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
# Other message types, forward directly to MCP for handling
|
|
|
|
|
logger.info(f"Processing general message [Session ID: {session_id}]")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Process message
|
|
|
|
|
# FastMCP object doesn't have process_message method, build response directly instead
|
|
|
|
|
# result = await mcp.process_message(body)
|
|
|
|
|
|
|
|
|
|
# Build response
|
|
|
|
|
response = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"id": message_id,
|
|
|
|
|
"result": {
|
|
|
|
|
"status": "ok",
|
|
|
|
|
"message": "Message received, but unable to process unrecognized message type"
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error processing message [Session ID: {session_id}]: {str(e)}")
|
|
|
|
|
response = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"id": message_id,
|
|
|
|
|
"error": {
|
|
|
|
|
"code": -32000,
|
|
|
|
|
"message": f"Error processing message: {str(e)}"
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Put response into queue
|
|
|
|
|
if response:
|
|
|
|
|
await self.client_sessions[session_id]["queue"].put(response)
|
|
|
|
|
|
|
|
|
|
# Return received confirmation to HTTP request
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
{"status": "received"},
|
|
|
|
|
headers={
|
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
|
"Access-Control-Allow-Credentials": "true",
|
|
|
|
|
"Access-Control-Allow-Methods": "*",
|
|
|
|
|
"Access-Control-Allow-Headers": "*",
|
|
|
|
|
"Access-Control-Expose-Headers": "*"
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error processing message: {str(e)}")
|
|
|
|
|
# Send error response
|
|
|
|
|
error_response = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"id": message_id if 'message_id' in locals() else str(uuid.uuid4()),
|
|
|
|
|
"error": {
|
|
|
|
|
"code": -32000,
|
|
|
|
|
"message": str(e)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
await self.client_sessions[session_id]["queue"].put(error_response)
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
{"status": "error", "message": str(e)},
|
|
|
|
|
status_code=500,
|
|
|
|
|
headers={
|
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
|
"Access-Control-Allow-Credentials": "true",
|
|
|
|
|
"Access-Control-Allow-Methods": "*",
|
|
|
|
|
"Access-Control-Allow-Headers": "*",
|
|
|
|
|
"Access-Control-Expose-Headers": "*"
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error processing request: {str(e)}", exc_info=True)
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
{
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"id": "unknown",
|
|
|
|
|
"error": {
|
|
|
|
|
"code": -32000,
|
|
|
|
|
"message": f"Error processing request: {str(e)}"
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
status_code=500,
|
|
|
|
|
headers={
|
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
|
"Access-Control-Allow-Credentials": "true",
|
|
|
|
|
"Access-Control-Allow-Methods": "*",
|
|
|
|
|
"Access-Control-Allow-Headers": "*",
|
|
|
|
|
"Access-Control-Expose-Headers": "*"
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def _execute_stream_tool_wrapper(self, tool_name, kwargs, message_id, session_id, request):
|
|
|
|
|
"""Wrapper for streaming tool calls
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
tool_name: Tool name
|
|
|
|
|
kwargs: Function parameters
|
|
|
|
|
message_id: Message ID
|
|
|
|
|
session_id: Session ID
|
|
|
|
|
request: Request object
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
# Execute by calling the standard tool method
|
|
|
|
|
result = await self.call_tool(tool_name, kwargs, request)
|
|
|
|
|
|
|
|
|
|
# Send completion message
|
|
|
|
|
final_message = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"id": message_id,
|
|
|
|
|
"success": True,
|
|
|
|
|
"result": result
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Check if session still exists
|
|
|
|
|
if session_id in self.client_sessions:
|
|
|
|
|
await self.client_sessions[session_id]["queue"].put(final_message)
|
|
|
|
|
logger.info(f"Streaming tool execution completed [Session ID: {session_id}, Message ID: {message_id}]")
|
|
|
|
|
else:
|
|
|
|
|
logger.warning(f"Streaming tool execution completed but session closed [Session ID: {session_id}]")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Streaming tool execution failed [Session ID: {session_id}]: {str(e)}")
|
|
|
|
|
|
|
|
|
|
# Send error message
|
|
|
|
|
error_message = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"id": message_id,
|
|
|
|
|
"success": False,
|
|
|
|
|
"error": str(e)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Check if session still exists
|
|
|
|
|
if session_id in self.client_sessions:
|
|
|
|
|
await self.client_sessions[session_id]["queue"].put(error_message)
|
|
|
|
|
else:
|
|
|
|
|
logger.warning(f"Streaming tool execution failed but session closed [Session ID: {session_id}]")
|
|
|
|
|
|
|
|
|
|
async def broadcast_tool_result(self, tool_name, result_data):
|
|
|
|
|
"""Broadcast tool call result to all clients
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
tool_name: Tool name
|
|
|
|
|
result_data: Result data
|
|
|
|
|
"""
|
|
|
|
|
logger.info(f"Broadcasting tool result: {tool_name}")
|
|
|
|
|
message = {
|
|
|
|
|
"jsonrpc": "2.0",
|
|
|
|
|
"method": "notifications/tool_result",
|
|
|
|
|
"params": {
|
|
|
|
|
"type": "tool_result",
|
|
|
|
|
"tool": tool_name,
|
|
|
|
|
"result": result_data
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
await self.broadcast_message(message)
|
|
|
|
|
|
|
|
|
|
async def call_tool(self, tool_name, arguments, request):
|
|
|
|
|
"""
|
|
|
|
|
Calls the specified tool and returns the result
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
tool_name: Tool name
|
|
|
|
|
arguments: Tool parameters
|
|
|
|
|
request: Original request object
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tool call result
|
|
|
|
|
"""
|
|
|
|
|
logger.info(f"Calling tool: {tool_name}, Parameters: {json.dumps(arguments, ensure_ascii=False)}")
|
|
|
|
|
|
|
|
|
|
# Get recent query content, used to handle random_string parameter
|
|
|
|
|
recent_query = self._extract_recent_query(request)
|
|
|
|
|
|
|
|
|
|
# Handle tool name mapping - Add support for standard name tools
|
|
|
|
|
tool_mapping = {
|
|
|
|
|
# --- Retained/New Tools ---
|
|
|
|
|
"exec_query": "mcp_doris_exec_query",
|
|
|
|
|
"get_table_schema": "mcp_doris_get_table_schema",
|
|
|
|
|
"get_db_table_list": "mcp_doris_get_db_table_list",
|
|
|
|
|
"get_db_list": "mcp_doris_get_db_list",
|
|
|
|
|
"get_table_comment": "mcp_doris_get_table_comment",
|
|
|
|
|
"get_table_column_comments": "mcp_doris_get_table_column_comments",
|
|
|
|
|
"get_table_indexes": "mcp_doris_get_table_indexes",
|
2025-06-06 14:35:53 +08:00
|
|
|
"get_recent_audit_logs": "mcp_doris_get_recent_audit_logs",
|
|
|
|
|
"get_catalog_list": "mcp_doris_get_catalog_list"
|
2025-05-06 12:56:55 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# If it's a standard name, convert to MCP name
|
|
|
|
|
mapped_tool_name = tool_mapping.get(tool_name, tool_name)
|
|
|
|
|
|
|
|
|
|
# Import tool functions from mcp_doris_tools
|
|
|
|
|
try:
|
|
|
|
|
# Import tool module
|
|
|
|
|
import doris_mcp_server.tools.mcp_doris_tools as mcp_tools
|
|
|
|
|
|
|
|
|
|
# Get the corresponding tool function
|
|
|
|
|
tool_function = getattr(mcp_tools, mapped_tool_name, None)
|
|
|
|
|
|
|
|
|
|
if not tool_function:
|
|
|
|
|
# If it doesn't exist in mcp_tools, try getting the tool using the MCP server instance
|
|
|
|
|
mcp = self.mcp_server
|
|
|
|
|
# Find the corresponding tool
|
|
|
|
|
for tool in await mcp.list_tools():
|
|
|
|
|
if getattr(tool, 'name', '') == mapped_tool_name:
|
|
|
|
|
tool_function = tool
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if not tool_function:
|
|
|
|
|
raise ValueError(f"Tool not found: {tool_name} / {mapped_tool_name}")
|
|
|
|
|
|
|
|
|
|
# Process common input parameter conversions
|
|
|
|
|
processed_args = self._process_tool_arguments(mapped_tool_name, arguments, recent_query)
|
|
|
|
|
|
|
|
|
|
# Call the tool function
|
|
|
|
|
try:
|
|
|
|
|
# Log tool type and attribute information to help debugging
|
|
|
|
|
logger.debug(f"Tool function type: {type(tool_function)}")
|
|
|
|
|
logger.debug(f"Tool function attributes: {dir(tool_function)}")
|
|
|
|
|
|
|
|
|
|
if callable(tool_function):
|
|
|
|
|
logger.debug("Tool function is callable, calling directly")
|
|
|
|
|
result = await tool_function(**processed_args)
|
|
|
|
|
elif hasattr(tool_function, 'run'):
|
|
|
|
|
logger.debug("Tool function has run method, calling run method")
|
|
|
|
|
result = await tool_function.run(**processed_args)
|
|
|
|
|
elif hasattr(tool_function, 'execute'):
|
|
|
|
|
logger.debug("Tool function has execute method, calling execute method")
|
|
|
|
|
result = await tool_function.execute(**processed_args)
|
|
|
|
|
elif hasattr(tool_function, 'call'):
|
|
|
|
|
logger.debug("Tool function has call method, calling call method")
|
|
|
|
|
result = await tool_function.call(**processed_args)
|
|
|
|
|
elif hasattr(tool_function, '__call__'):
|
|
|
|
|
logger.debug("Tool function has __call__ method, calling __call__ method")
|
|
|
|
|
result = await tool_function.__call__(**processed_args)
|
|
|
|
|
else:
|
|
|
|
|
# If it's a dict type, try getting the function from it
|
|
|
|
|
if isinstance(tool_function, dict) and 'function' in tool_function:
|
|
|
|
|
logger.debug("Tool is a dictionary type, trying to get 'function' key")
|
|
|
|
|
actual_func = tool_function['function']
|
|
|
|
|
if callable(actual_func):
|
|
|
|
|
result = await actual_func(**processed_args)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Function in dictionary is not callable: {type(actual_func)}")
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unsupported tool type: {type(tool_function)}, Attributes: {dir(tool_function)}")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to call tool function: {str(e)}", exc_info=True)
|
|
|
|
|
raise ValueError(f"Error calling tool: {str(e)}")
|
|
|
|
|
|
|
|
|
|
# Return tool execution result
|
|
|
|
|
return result
|
|
|
|
|
except AttributeError as e:
|
|
|
|
|
logger.error(f"Tool function does not exist: {mapped_tool_name}, Error: {str(e)}")
|
|
|
|
|
raise ValueError(f"Tool function does not exist: {mapped_tool_name}")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error calling tool: {str(e)}", exc_info=True)
|
|
|
|
|
raise ValueError(f"Error calling tool: {str(e)}")
|
|
|
|
|
|
|
|
|
|
def _process_tool_arguments(self, tool_name, arguments, recent_query):
|
|
|
|
|
"""
|
|
|
|
|
Process tool parameters, supporting special handling logic
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
tool_name: Tool name (MCP internal name, e.g., mcp_doris_...)
|
|
|
|
|
arguments: Original parameters
|
|
|
|
|
recent_query: Recent query content
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Processed parameter dictionary
|
|
|
|
|
"""
|
|
|
|
|
# Copy parameters to avoid modifying the original object
|
|
|
|
|
processed_args = dict(arguments)
|
|
|
|
|
|
|
|
|
|
# Handle potential random_string parameter as fallback
|
|
|
|
|
if "random_string" in processed_args and tool_name.startswith("mcp_doris_"):
|
|
|
|
|
random_string = processed_args.pop("random_string", "")
|
|
|
|
|
logger.debug(f"Processing random_string parameter for tool {tool_name}: '{random_string}'")
|
|
|
|
|
|
|
|
|
|
# 1. For exec_query
|
|
|
|
|
if tool_name == "mcp_doris_exec_query":
|
|
|
|
|
if not processed_args.get("sql"):
|
|
|
|
|
sql_fallback = random_string or recent_query
|
|
|
|
|
if sql_fallback:
|
|
|
|
|
if not random_string and recent_query:
|
|
|
|
|
import re
|
|
|
|
|
sql_matches = re.findall(r'```sql\s*([\s\S]+?)\s*```', recent_query)
|
|
|
|
|
if sql_matches:
|
|
|
|
|
sql_fallback = sql_matches[0].strip()
|
|
|
|
|
|
|
|
|
|
if sql_fallback:
|
|
|
|
|
logger.info(f"Using random_string/recent_query as SQL for exec_query: {sql_fallback[:100]}...")
|
|
|
|
|
processed_args["sql"] = sql_fallback
|
|
|
|
|
else:
|
|
|
|
|
logger.warning(f"exec_query missing sql parameter, and random_string/recent_query is empty or SQL cannot be extracted")
|
|
|
|
|
else:
|
|
|
|
|
logger.warning(f"exec_query missing sql parameter, and both random_string and recent_query are empty")
|
|
|
|
|
|
|
|
|
|
# 2. For tools requiring table_name
|
|
|
|
|
elif tool_name in [
|
|
|
|
|
"mcp_doris_get_table_schema",
|
|
|
|
|
"mcp_doris_get_table_comment",
|
|
|
|
|
"mcp_doris_get_table_column_comments",
|
|
|
|
|
"mcp_doris_get_table_indexes"
|
|
|
|
|
]:
|
|
|
|
|
if not processed_args.get("table_name"):
|
|
|
|
|
table_fallback = random_string
|
|
|
|
|
if table_fallback:
|
|
|
|
|
logger.info(f"Using random_string/recent_query as table_name for {tool_name}: {table_fallback}")
|
|
|
|
|
processed_args["table_name"] = table_fallback
|
|
|
|
|
else:
|
|
|
|
|
logger.warning(f"{tool_name} missing table_name parameter, and random_string/recent_query is empty or table name cannot be extracted")
|
|
|
|
|
|
|
|
|
|
# 3. Other tools
|
|
|
|
|
else:
|
|
|
|
|
logger.debug(f"Tool {tool_name} does not apply random_string fallback logic, or logic is undefined.")
|
|
|
|
|
|
|
|
|
|
# Ensure return is outside the main if block
|
|
|
|
|
return processed_args
|
|
|
|
|
|
|
|
|
|
def _extract_recent_query(self, request: Request) -> Optional[str]:
|
|
|
|
|
"""
|
|
|
|
|
Extract the most recent user query from the request
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
request: Request object
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Optional[str]: The most recent user query, or None if not found
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
# Try to extract message history from request body
|
|
|
|
|
body = None
|
|
|
|
|
body_bytes = getattr(request, "_body", None)
|
|
|
|
|
if body_bytes:
|
|
|
|
|
try:
|
|
|
|
|
body = json.loads(body_bytes)
|
|
|
|
|
except:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
if not body:
|
|
|
|
|
body = getattr(request, "_json", {})
|
|
|
|
|
|
|
|
|
|
# Find the most recent user message from message history
|
|
|
|
|
messages = body.get("params", {}).get("messages", [])
|
|
|
|
|
if messages:
|
|
|
|
|
# Iterate messages in reverse to find the most recent user message
|
|
|
|
|
for msg in reversed(messages):
|
|
|
|
|
if msg.get("role") == "user":
|
|
|
|
|
return msg.get("content", "")
|
|
|
|
|
|
|
|
|
|
# If not found in message history, try extracting from the original message
|
|
|
|
|
message = body.get("params", {}).get("message", {})
|
|
|
|
|
if message and message.get("role") == "user":
|
|
|
|
|
return message.get("content", "")
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error extracting recent query: {str(e)}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def format_tool_result(self, result):
|
|
|
|
|
"""
|
|
|
|
|
Format tool call result into a unified format
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
result: Original tool call result
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Formatted result
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
# If result is already a dictionary, return directly
|
|
|
|
|
if isinstance(result, dict):
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
# If it's a string, try parsing as JSON
|
|
|
|
|
if isinstance(result, str):
|
|
|
|
|
try:
|
|
|
|
|
return json.loads(result)
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
# Pure text result
|
|
|
|
|
return {"content": result}
|
|
|
|
|
|
|
|
|
|
# If it's another type, convert to string
|
|
|
|
|
return {"content": str(result)}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error formatting tool result: {str(e)}")
|
|
|
|
|
return {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
def _get_session_id(self, request: Request) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Get session ID from the request
|
|
|
|
|
|
|
|
|
|
Tries to get session ID from the following locations (in priority order):
|
|
|
|
|
1. Query parameter session_id
|
|
|
|
|
2. session_id field in request body
|
|
|
|
|
3. X-Session-ID header
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
request: Request object
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
str: Session ID, or None if not found
|
|
|
|
|
"""
|
|
|
|
|
# Get from query parameter
|
|
|
|
|
session_id = request.query_params.get("session_id")
|
|
|
|
|
if session_id:
|
|
|
|
|
return session_id
|
|
|
|
|
|
|
|
|
|
# Try getting from request body
|
|
|
|
|
try:
|
|
|
|
|
body = getattr(request, "_json", None)
|
|
|
|
|
if not body:
|
|
|
|
|
body_bytes = getattr(request, "_body", None)
|
|
|
|
|
if body_bytes:
|
|
|
|
|
try:
|
|
|
|
|
body = json.loads(body_bytes)
|
|
|
|
|
except:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
if body and isinstance(body, dict) and "session_id" in body:
|
|
|
|
|
return body["session_id"]
|
|
|
|
|
except:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
# Get from request header
|
|
|
|
|
session_id = request.headers.get("X-Session-ID")
|
|
|
|
|
if session_id:
|
|
|
|
|
return session_id
|
|
|
|
|
|
|
|
|
|
return None
|