Files
doris-mcp-server/doris_mcp_server/sse_server.py
2025-06-06 14:35:53 +08:00

1261 lines
56 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Doris MCP SSE Server Implementation
Implements a standard MCP SSE server based on MCP's SseServerTransport,
supports bidirectional communication with clients, and integrates with the existing Doris-MCP-Server.
"""
import asyncio
import json
import uuid
import logging
import time
from typing import Any, Optional
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette.sse import EventSourceResponse
# Get logger
logger = logging.getLogger("doris-mcp-sse")
class DorisMCPSseServer:
"""Doris MCP SSE Server Implementation"""
def __init__(self, mcp_server, app: FastAPI):
"""
Initialize the Doris MCP SSE server
Args:
mcp_server: FastMCP server instance
app: FastAPI application instance
"""
self.mcp_server = mcp_server
# Ensure app is a FastAPI instance
if not isinstance(app, FastAPI):
logger.warning("Passed application is not a FastAPI instance, will use the existing FastAPI instance")
self.app = app
# Add CORS middleware
self.app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3100"], # Specify frontend domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"]
)
# Client session management
self.client_sessions = {}
# Set up SSE routes
self.setup_sse_routes()
# Register startup event
@self.app.on_event("startup")
async def startup_event():
# Start session cleanup task
asyncio.create_task(self.cleanup_idle_sessions())
# Start task to send periodic status updates
asyncio.create_task(self.send_periodic_updates())
def setup_sse_routes(self):
"""Set up SSE related routes"""
@self.app.get("/health")
async def health_check():
"""Health check endpoint"""
try:
# Use direct health check logic
return {
"status": "healthy",
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"server": "Doris MCP Server"
}
except Exception as e:
return {
"status": "error",
"error": str(e),
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
@self.app.get("/status")
async def status():
"""Get server status"""
try:
# Get tool list
tools = await self.mcp_server.list_tools()
tool_names = [tool.name if hasattr(tool, 'name') else str(tool) for tool in tools]
logger.info(f"Getting tool list, currently registered tools: {tool_names}")
# Get resource list
resources = await self.mcp_server.list_resources()
resource_names = [res.name if hasattr(res, 'name') else str(res) for res in resources]
# Get prompt template list
prompts = await self.mcp_server.list_prompts()
prompt_names = [prompt.name if hasattr(prompt, 'name') else str(prompt) for prompt in prompts]
return {
"status": "running",
"name": self.mcp_server.name,
"mode": "mcp_sse",
"clients": len(self.client_sessions),
"tools": tool_names,
"resources": resource_names,
"prompts": prompt_names
}
except Exception as e:
logger.error(f"Error getting status: {str(e)}")
return {
"status": "error",
"error": str(e)
}
@self.app.get("/sse")
async def mcp_sse_init(request: Request):
"""SSE service entry point, establishes client connection (New Endpoint)"""
# Generate session ID
session_id = str(uuid.uuid4())
logger.info(f"New SSE connection [Session ID: {session_id}] at /sse")
# Create client session
self.client_sessions[session_id] = {
"client_id": request.headers.get("X-Client-ID", f"client_{str(uuid.uuid4())[:8]}"),
"created_at": time.time(),
"last_active": time.time(),
"queue": asyncio.Queue()
}
# Immediately put endpoint information into the queue
endpoint_data = f"/mcp/messages?session_id={session_id}"
await self.client_sessions[session_id]["queue"].put({
"event": "endpoint",
"data": endpoint_data
})
# Create event generator
async def event_generator():
try:
while True:
# Use timeout to get new messages, to detect client disconnect
try:
message = await asyncio.wait_for(
self.client_sessions[session_id]["queue"].get(),
timeout=30
)
# Check if it's a close command
if isinstance(message, dict) and message.get("event") == "close":
logger.info(f"Received close command [Session ID: {session_id}]")
break
# Return message
if isinstance(message, dict):
if "event" in message:
# If event field exists, it's a system event
event_type = message["event"]
event_data = message["data"]
yield {
"event": event_type,
"data": event_data
}
else:
# Otherwise it's a normal message, use message event
yield {
"event": "message",
"data": json.dumps(message)
}
elif isinstance(message, str):
# If it's a string, send directly
yield {
"event": "message",
"data": message
}
else:
# Other types, convert to JSON
yield {
"event": "message",
"data": json.dumps(message)
}
except asyncio.TimeoutError:
# Send ping to keep connection alive
yield {
"event": "ping",
"data": "keepalive"
}
continue
except asyncio.CancelledError:
# Connection cancelled
logger.info(f"SSE connection cancelled [Session ID: {session_id}]")
except Exception as e:
# Other error occurred
logger.error(f"SSE event generator error [Session ID: {session_id}]: {str(e)}")
finally:
# Clean up session
if session_id in self.client_sessions:
logger.info(f"Cleaning up session [Session ID: {session_id}]")
del self.client_sessions[session_id]
# Return standard SSE response
return EventSourceResponse(
event_generator(),
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Expose-Headers": "*",
"Cache-Control": "no-cache",
"Connection": "keep-alive"
}
)
@self.app.options("/mcp/messages")
async def mcp_messages_options(request: Request):
"""Handle preflight requests"""
return JSONResponse(
{},
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Expose-Headers": "*"
}
)
@self.app.post("/mcp/messages")
async def mcp_messages_handler(request: Request):
"""Handle client message requests, using class method"""
return await self.mcp_message(request)
async def cleanup_idle_sessions(self):
"""Clean up idle client sessions"""
while True:
await asyncio.sleep(60) # Check every minute
current_time = time.time()
# Find sessions idle for over 5 minutes
idle_sessions = []
for session_id, session in self.client_sessions.items():
if current_time - session["last_active"] > 300: # 5 minutes
idle_sessions.append(session_id)
# Close and remove idle sessions
for session_id in idle_sessions:
try:
# Send close message
await self.client_sessions[session_id]["queue"].put({"event": "close"})
# Clean up session
logger.info(f"Cleaned up idle session: {session_id}")
except Exception as e:
logger.error(f"Error cleaning up session: {str(e)}")
finally:
# Ensure session is removed
if session_id in self.client_sessions:
del self.client_sessions[session_id]
async def send_periodic_updates(self):
"""Periodically send status updates to all clients"""
while True:
try:
# Send status update every 5 seconds
await asyncio.sleep(5)
# If no clients are connected, skip this update
if not self.client_sessions:
continue
# Get current status
status_data = {
"timestamp": time.time(),
"clients_count": len(self.client_sessions),
"server_status": "running"
}
# Send status update to all clients
await self.broadcast_status_update(status_data)
except Exception as e:
logger.error(f"Error sending periodic updates: {str(e)}")
# Wait a bit after error before continuing
await asyncio.sleep(1)
async def broadcast_status_update(self, status_data):
"""Broadcast status update to all clients
Args:
status_data: Status data
"""
logger.debug(f"Broadcasting status update: {status_data}")
message = {
"jsonrpc": "2.0",
"method": "notifications/status",
"params": {
"type": "status_update",
"data": status_data
}
}
await self.broadcast_message(message)
async def broadcast_visualization_data(self, visualization_data):
"""Broadcast visualization data to all clients
Args:
visualization_data: Visualization data, should include type field
"""
if not visualization_data or not isinstance(visualization_data, dict) or "type" not in visualization_data:
logger.warning(f"Invalid visualization data: {visualization_data}")
return
logger.info(f"Broadcasting visualization data: {visualization_data['type']}")
message = {
"jsonrpc": "2.0",
"method": "notifications/visualization",
"params": {
"type": "visualization",
"data": visualization_data
}
}
await self.broadcast_message(message)
async def send_visualization_data(self, session_id, visualization_data):
"""Send visualization data to a specific client
Args:
session_id: Session ID
visualization_data: Visualization data, should include type field
"""
if not visualization_data or not isinstance(visualization_data, dict) or "type" not in visualization_data:
logger.warning(f"Invalid visualization data: {visualization_data}")
return
if session_id not in self.client_sessions:
logger.warning(f"Session does not exist: {session_id}")
return
logger.info(f"Sending visualization data to session {session_id}: {visualization_data['type']}")
message = {
"jsonrpc": "2.0",
"method": "notifications/visualization",
"params": {
"type": "visualization",
"data": visualization_data
}
}
await self.client_sessions[session_id]["queue"].put(message)
async def send_tool_result(self, session_id, tool_name, result_data, is_final=True):
"""Send tool execution result to the client
Args:
session_id: Session ID
tool_name: Tool name
result_data: Result data
is_final: Whether it is the final result
"""
if session_id not in self.client_sessions:
logger.warning(f"Session does not exist: {session_id}")
return
logger.info(f"Sending tool result to session {session_id}: {tool_name}")
message = {
"jsonrpc": "2.0",
"method": "notifications/tool_result",
"params": {
"type": "tool_result",
"tool": tool_name,
"result": result_data,
"is_final": is_final
}
}
await self.client_sessions[session_id]["queue"].put(message)
async def broadcast_message(self, message):
"""Broadcast a message to all active sessions
Args:
message: Message to broadcast
"""
# If no clients are connected, return immediately
if not self.client_sessions:
return
# Create a copy of the session ID list so the original dictionary can be safely modified during iteration
session_ids = list(self.client_sessions.keys())
# Send message to all sessions
for session_id in session_ids:
try:
if session_id in self.client_sessions: # Check again, as session might have been removed during iteration
await self.client_sessions[session_id]["queue"].put(message)
except Exception as e:
logger.error(f"Error sending message to session {session_id}: {str(e)}")
async def get_status(self):
"""Get server status"""
return {
"status": "running",
"name": self.mcp_server.name,
"mode": "mcp_sse",
"clients": len(self.client_sessions)
}
async def mcp_message(self, request: Request):
"""Endpoint to receive client messages"""
try:
# Parse request parameters
session_id = self._get_session_id(request)
# Check if session exists
if not session_id or session_id not in self.client_sessions:
logger.warning(f"Session does not exist: {session_id}")
return JSONResponse(
{"jsonrpc": "2.0", "error": {"code": -32000, "message": "Session does not exist or has expired"}},
status_code=401
)
# Update session last active time
self.client_sessions[session_id]["last_active"] = time.time()
# Get request body
try:
body = await request.json()
logger.info(f"Received message [Session ID: {session_id}]: {json.dumps(body)}")
# Process message
message_id = body.get("id", str(uuid.uuid4()))
# Handle JSON-RPC 2.0 formatted commands
if "jsonrpc" not in body or body.get("jsonrpc") != "2.0" or "method" not in body:
return JSONResponse(
{"jsonrpc": "2.0", "id": message_id, "error": {"code": -32600, "message": "Invalid request, must be JSON-RPC 2.0 format"}},
status_code=400
)
# Get method and parameters
method = body.get("method")
params = body.get("params", {})
# Special handling for JSON-RPC formatted commands
if method == "initialize":
# Initialization request
logger.info(f"Processing initialize command [Session ID: {session_id}]")
response = {
"jsonrpc": "2.0",
"id": message_id,
"result": {
"protocolVersion": "2024-11-05",
"name": self.mcp_server.name,
"instructions": "This is an MCP server for Apache Doris database",
"serverInfo": {
"version": "0.1.0",
"name": "Doris MCP Server"
},
"capabilities": {
"tools": {
"supportsStreaming": True,
"supportsProgress": True
},
"resources": {
"supportsStreaming": False
},
"prompts": {
"supported": True
}
}
}
}
await self.client_sessions[session_id]["queue"].put(response)
return JSONResponse(
{"status": "success"},
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Expose-Headers": "*"
}
)
elif method == "mcp/listOfferings":
# List all available features
logger.info(f"Processing listOfferings command [Session ID: {session_id}]")
# Get tool list
tools = await self.mcp_server.list_tools()
tools_json = [
{
"name": tool.name if hasattr(tool, "name") else str(tool),
"description": tool.description if hasattr(tool, "description") else "",
"inputSchema": tool.parameters if hasattr(tool, "parameters") else {
"type": "object",
"properties": {},
"required": []
}
}
for tool in tools
]
# Get resource list
resources = await self.mcp_server.list_resources()
resources_json = [res.model_dump() if hasattr(res, "model_dump") else res for res in resources]
# Get prompt template list
prompts = await self.mcp_server.list_prompts()
prompts_json = [prompt.model_dump() if hasattr(prompt, "model_dump") else prompt for prompt in prompts]
# Build response
response = {
"jsonrpc": "2.0",
"id": message_id,
"result": {
"tools": tools_json,
"resources": resources_json,
"prompts": prompts_json
}
}
await self.client_sessions[session_id]["queue"].put(response)
return JSONResponse(
{"status": "success"},
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Expose-Headers": "*"
}
)
elif method == "mcp/listTools" or method == "tools/list":
# List all tools
logger.info(f"Processing listTools command [Session ID: {session_id}]")
tools = await self.mcp_server.list_tools()
tools_json = [
{
"name": tool.name if hasattr(tool, "name") else str(tool),
"description": tool.description if hasattr(tool, "description") else "",
"inputSchema": tool.parameters if hasattr(tool, "parameters") else {
"type": "object",
"properties": {},
"required": []
}
}
for tool in tools
]
response = {
"jsonrpc": "2.0",
"id": message_id,
"result": {
"tools": tools_json
}
}
await self.client_sessions[session_id]["queue"].put(response)
return JSONResponse(
{"status": "success"},
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Expose-Headers": "*"
}
)
elif method == "mcp/listResources":
# List all resources
logger.info(f"Processing listResources command [Session ID: {session_id}]")
resources = await self.mcp_server.list_resources()
resources_json = [res.model_dump() if hasattr(res, "model_dump") else res for res in resources]
response = {
"jsonrpc": "2.0",
"id": message_id,
"result": {
"resources": resources_json
}
}
await self.client_sessions[session_id]["queue"].put(response)
return JSONResponse(
{"status": "success"},
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Expose-Headers": "*"
}
)
elif method == "mcp/callTool" or method == "tools/call":
# Call tool - special handling
tool_name = params.get("name")
arguments = params.get("arguments", {})
if not tool_name:
error_response = {
"jsonrpc": "2.0",
"id": message_id,
"error": {
"code": -32602,
"message": "Invalid params: tool name is required"
}
}
await self.client_sessions[session_id]["queue"].put(error_response)
return JSONResponse(
{"status": "error", "message": "Tool name is required"},
status_code=400,
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Expose-Headers": "*"
}
)
# Get MCP instance
mcp = self.mcp_server
# Check if it's a streaming tool call
stream_mode = "stream" in request.query_params or params.get("stream", False)
logger.info(f"Calling tool [Session ID: {session_id}, Tool: {tool_name}, Args: {arguments}, Stream mode: {stream_mode}]")
if stream_mode:
# Streaming tool call
logger.info(f"Using streaming response to handle tool call [Session ID: {session_id}, Tool: {tool_name}]")
try:
# Find tool
tool_instance = None
for tool in await mcp.list_tools():
if getattr(tool, 'name', '') == tool_name:
tool_instance = tool
break
if not tool_instance:
raise Exception(f"Tool {tool_name} does not exist")
# Define callback function
async def callback(content, metadata):
# Send partial result
partial_message = {
"jsonrpc": "2.0",
"id": message_id,
"partial": True,
"result": {
"content": content,
"metadata": metadata
}
}
# Put message into queue
await self.client_sessions[session_id]["queue"].put(partial_message)
# If visualization data is included, broadcast to all clients
if metadata and "visualization" in metadata:
await self.broadcast_visualization_data(metadata["visualization"])
# Build argument dictionary
kwargs = dict(arguments)
kwargs['callback'] = callback
# Execute tool call
# Fix: Do not call tool object directly, instead create async task to call call_tool method
# func = tool_instance.func if hasattr(tool_instance, 'func') else tool_instance
# Start async task to execute streaming tool
# asyncio.create_task(self._execute_stream_tool(func, kwargs, message_id, session_id))
# Modified to use call_tool method
asyncio.create_task(self._execute_stream_tool_wrapper(tool_name, kwargs, message_id, session_id, request))
# Return received confirmation
return JSONResponse(
{"status": "processing"},
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Expose-Headers": "*"
}
)
except Exception as e:
logger.error(f"Streaming tool processing error: {str(e)}")
error_message = {
"jsonrpc": "2.0",
"id": message_id,
"success": False,
"error": str(e)
}
# Put error message into queue
await self.client_sessions[session_id]["queue"].put(error_message)
return JSONResponse(
{"status": "error", "message": str(e)},
status_code=500,
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Expose-Headers": "*"
}
)
else:
# Non-streaming tool call
logger.info(f"Using standard response to handle tool call [Session ID: {session_id}, Tool: {tool_name}]")
try:
# Find tool
tool_instance = None
for tool in await mcp.list_tools():
if getattr(tool, 'name', '') == tool_name:
tool_instance = tool
break
if not tool_instance:
error_response = {
"jsonrpc": "2.0",
"id": message_id,
"error": {
"code": -32601,
"message": f"Tool '{tool_name}' not found"
}
}
await self.client_sessions[session_id]["queue"].put(error_response)
return JSONResponse(
{"status": "error", "message": f"Tool '{tool_name}' not found"},
status_code=404,
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Expose-Headers": "*"
}
)
# Execute tool call
# Fix: Do not call tool object directly, instead call custom call_tool method
result = await self.call_tool(tool_name, arguments, request)
# Special formatting for result
# If result is already in the correct format, use directly
if isinstance(result, dict) and "content" in result and isinstance(result["content"], list):
formatted_result = result
else:
# Otherwise, format into standard format
formatted_result = {
"content": [
{
"type": "json",
"text": result if isinstance(result, str) else json.dumps(result, ensure_ascii=False)
}
]
}
# Build response
response = {
"jsonrpc": "2.0",
"id": message_id,
"result": formatted_result
}
# Put response into queue
await self.client_sessions[session_id]["queue"].put(response)
# Return received confirmation
return JSONResponse(
{"status": "success"},
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Expose-Headers": "*"
}
)
except Exception as e:
logger.error(f"Tool call error: {str(e)}", exc_info=True)
# Build error response
if str(e).startswith('{"code":'):
# If it's a JSON formatted error, use directly
try:
error_obj = json.loads(str(e))
error_response = {
"jsonrpc": "2.0",
"id": message_id,
"error": error_obj
}
except:
# If parsing fails, use standard format
error_response = {
"jsonrpc": "2.0",
"id": message_id,
"error": {
"code": -32000,
"message": str(e)
}
}
else:
# Normal error string
error_response = {
"jsonrpc": "2.0",
"id": message_id,
"error": {
"code": -32000,
"message": str(e)
}
}
# Put error response into queue
await self.client_sessions[session_id]["queue"].put(error_response)
# Return error status
return JSONResponse(
{"status": "error", "message": str(e)},
status_code=500,
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Expose-Headers": "*"
}
)
else:
# Other message types, forward directly to MCP for handling
logger.info(f"Processing general message [Session ID: {session_id}]")
try:
# Process message
# FastMCP object doesn't have process_message method, build response directly instead
# result = await mcp.process_message(body)
# Build response
response = {
"jsonrpc": "2.0",
"id": message_id,
"result": {
"status": "ok",
"message": "Message received, but unable to process unrecognized message type"
}
}
except Exception as e:
logger.error(f"Error processing message [Session ID: {session_id}]: {str(e)}")
response = {
"jsonrpc": "2.0",
"id": message_id,
"error": {
"code": -32000,
"message": f"Error processing message: {str(e)}"
}
}
# Put response into queue
if response:
await self.client_sessions[session_id]["queue"].put(response)
# Return received confirmation to HTTP request
return JSONResponse(
{"status": "received"},
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Expose-Headers": "*"
}
)
except Exception as e:
logger.error(f"Error processing message: {str(e)}")
# Send error response
error_response = {
"jsonrpc": "2.0",
"id": message_id if 'message_id' in locals() else str(uuid.uuid4()),
"error": {
"code": -32000,
"message": str(e)
}
}
await self.client_sessions[session_id]["queue"].put(error_response)
return JSONResponse(
{"status": "error", "message": str(e)},
status_code=500,
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Expose-Headers": "*"
}
)
except Exception as e:
logger.error(f"Error processing request: {str(e)}", exc_info=True)
return JSONResponse(
{
"jsonrpc": "2.0",
"id": "unknown",
"error": {
"code": -32000,
"message": f"Error processing request: {str(e)}"
}
},
status_code=500,
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Expose-Headers": "*"
}
)
async def _execute_stream_tool_wrapper(self, tool_name, kwargs, message_id, session_id, request):
"""Wrapper for streaming tool calls
Args:
tool_name: Tool name
kwargs: Function parameters
message_id: Message ID
session_id: Session ID
request: Request object
"""
try:
# Execute by calling the standard tool method
result = await self.call_tool(tool_name, kwargs, request)
# Send completion message
final_message = {
"jsonrpc": "2.0",
"id": message_id,
"success": True,
"result": result
}
# Check if session still exists
if session_id in self.client_sessions:
await self.client_sessions[session_id]["queue"].put(final_message)
logger.info(f"Streaming tool execution completed [Session ID: {session_id}, Message ID: {message_id}]")
else:
logger.warning(f"Streaming tool execution completed but session closed [Session ID: {session_id}]")
except Exception as e:
logger.error(f"Streaming tool execution failed [Session ID: {session_id}]: {str(e)}")
# Send error message
error_message = {
"jsonrpc": "2.0",
"id": message_id,
"success": False,
"error": str(e)
}
# Check if session still exists
if session_id in self.client_sessions:
await self.client_sessions[session_id]["queue"].put(error_message)
else:
logger.warning(f"Streaming tool execution failed but session closed [Session ID: {session_id}]")
async def broadcast_tool_result(self, tool_name, result_data):
"""Broadcast tool call result to all clients
Args:
tool_name: Tool name
result_data: Result data
"""
logger.info(f"Broadcasting tool result: {tool_name}")
message = {
"jsonrpc": "2.0",
"method": "notifications/tool_result",
"params": {
"type": "tool_result",
"tool": tool_name,
"result": result_data
}
}
await self.broadcast_message(message)
async def call_tool(self, tool_name, arguments, request):
"""
Calls the specified tool and returns the result
Args:
tool_name: Tool name
arguments: Tool parameters
request: Original request object
Returns:
Tool call result
"""
logger.info(f"Calling tool: {tool_name}, Parameters: {json.dumps(arguments, ensure_ascii=False)}")
# Get recent query content, used to handle random_string parameter
recent_query = self._extract_recent_query(request)
# Handle tool name mapping - Add support for standard name tools
tool_mapping = {
# --- Retained/New Tools ---
"exec_query": "mcp_doris_exec_query",
"get_table_schema": "mcp_doris_get_table_schema",
"get_db_table_list": "mcp_doris_get_db_table_list",
"get_db_list": "mcp_doris_get_db_list",
"get_table_comment": "mcp_doris_get_table_comment",
"get_table_column_comments": "mcp_doris_get_table_column_comments",
"get_table_indexes": "mcp_doris_get_table_indexes",
"get_recent_audit_logs": "mcp_doris_get_recent_audit_logs",
"get_catalog_list": "mcp_doris_get_catalog_list"
}
# If it's a standard name, convert to MCP name
mapped_tool_name = tool_mapping.get(tool_name, tool_name)
# Import tool functions from mcp_doris_tools
try:
# Import tool module
import doris_mcp_server.tools.mcp_doris_tools as mcp_tools
# Get the corresponding tool function
tool_function = getattr(mcp_tools, mapped_tool_name, None)
if not tool_function:
# If it doesn't exist in mcp_tools, try getting the tool using the MCP server instance
mcp = self.mcp_server
# Find the corresponding tool
for tool in await mcp.list_tools():
if getattr(tool, 'name', '') == mapped_tool_name:
tool_function = tool
break
if not tool_function:
raise ValueError(f"Tool not found: {tool_name} / {mapped_tool_name}")
# Process common input parameter conversions
processed_args = self._process_tool_arguments(mapped_tool_name, arguments, recent_query)
# Call the tool function
try:
# Log tool type and attribute information to help debugging
logger.debug(f"Tool function type: {type(tool_function)}")
logger.debug(f"Tool function attributes: {dir(tool_function)}")
if callable(tool_function):
logger.debug("Tool function is callable, calling directly")
result = await tool_function(**processed_args)
elif hasattr(tool_function, 'run'):
logger.debug("Tool function has run method, calling run method")
result = await tool_function.run(**processed_args)
elif hasattr(tool_function, 'execute'):
logger.debug("Tool function has execute method, calling execute method")
result = await tool_function.execute(**processed_args)
elif hasattr(tool_function, 'call'):
logger.debug("Tool function has call method, calling call method")
result = await tool_function.call(**processed_args)
elif hasattr(tool_function, '__call__'):
logger.debug("Tool function has __call__ method, calling __call__ method")
result = await tool_function.__call__(**processed_args)
else:
# If it's a dict type, try getting the function from it
if isinstance(tool_function, dict) and 'function' in tool_function:
logger.debug("Tool is a dictionary type, trying to get 'function' key")
actual_func = tool_function['function']
if callable(actual_func):
result = await actual_func(**processed_args)
else:
raise ValueError(f"Function in dictionary is not callable: {type(actual_func)}")
else:
raise ValueError(f"Unsupported tool type: {type(tool_function)}, Attributes: {dir(tool_function)}")
except Exception as e:
logger.error(f"Failed to call tool function: {str(e)}", exc_info=True)
raise ValueError(f"Error calling tool: {str(e)}")
# Return tool execution result
return result
except AttributeError as e:
logger.error(f"Tool function does not exist: {mapped_tool_name}, Error: {str(e)}")
raise ValueError(f"Tool function does not exist: {mapped_tool_name}")
except Exception as e:
logger.error(f"Error calling tool: {str(e)}", exc_info=True)
raise ValueError(f"Error calling tool: {str(e)}")
def _process_tool_arguments(self, tool_name, arguments, recent_query):
"""
Process tool parameters, supporting special handling logic
Args:
tool_name: Tool name (MCP internal name, e.g., mcp_doris_...)
arguments: Original parameters
recent_query: Recent query content
Returns:
Processed parameter dictionary
"""
# Copy parameters to avoid modifying the original object
processed_args = dict(arguments)
# Handle potential random_string parameter as fallback
if "random_string" in processed_args and tool_name.startswith("mcp_doris_"):
random_string = processed_args.pop("random_string", "")
logger.debug(f"Processing random_string parameter for tool {tool_name}: '{random_string}'")
# 1. For exec_query
if tool_name == "mcp_doris_exec_query":
if not processed_args.get("sql"):
sql_fallback = random_string or recent_query
if sql_fallback:
if not random_string and recent_query:
import re
sql_matches = re.findall(r'```sql\s*([\s\S]+?)\s*```', recent_query)
if sql_matches:
sql_fallback = sql_matches[0].strip()
if sql_fallback:
logger.info(f"Using random_string/recent_query as SQL for exec_query: {sql_fallback[:100]}...")
processed_args["sql"] = sql_fallback
else:
logger.warning(f"exec_query missing sql parameter, and random_string/recent_query is empty or SQL cannot be extracted")
else:
logger.warning(f"exec_query missing sql parameter, and both random_string and recent_query are empty")
# 2. For tools requiring table_name
elif tool_name in [
"mcp_doris_get_table_schema",
"mcp_doris_get_table_comment",
"mcp_doris_get_table_column_comments",
"mcp_doris_get_table_indexes"
]:
if not processed_args.get("table_name"):
table_fallback = random_string
if table_fallback:
logger.info(f"Using random_string/recent_query as table_name for {tool_name}: {table_fallback}")
processed_args["table_name"] = table_fallback
else:
logger.warning(f"{tool_name} missing table_name parameter, and random_string/recent_query is empty or table name cannot be extracted")
# 3. Other tools
else:
logger.debug(f"Tool {tool_name} does not apply random_string fallback logic, or logic is undefined.")
# Ensure return is outside the main if block
return processed_args
def _extract_recent_query(self, request: Request) -> Optional[str]:
"""
Extract the most recent user query from the request
Args:
request: Request object
Returns:
Optional[str]: The most recent user query, or None if not found
"""
try:
# Try to extract message history from request body
body = None
body_bytes = getattr(request, "_body", None)
if body_bytes:
try:
body = json.loads(body_bytes)
except:
pass
if not body:
body = getattr(request, "_json", {})
# Find the most recent user message from message history
messages = body.get("params", {}).get("messages", [])
if messages:
# Iterate messages in reverse to find the most recent user message
for msg in reversed(messages):
if msg.get("role") == "user":
return msg.get("content", "")
# If not found in message history, try extracting from the original message
message = body.get("params", {}).get("message", {})
if message and message.get("role") == "user":
return message.get("content", "")
return None
except Exception as e:
logger.error(f"Error extracting recent query: {str(e)}")
return None
def format_tool_result(self, result):
"""
Format tool call result into a unified format
Args:
result: Original tool call result
Returns:
Formatted result
"""
try:
# If result is already a dictionary, return directly
if isinstance(result, dict):
return result
# If it's a string, try parsing as JSON
if isinstance(result, str):
try:
return json.loads(result)
except json.JSONDecodeError:
# Pure text result
return {"content": result}
# If it's another type, convert to string
return {"content": str(result)}
except Exception as e:
logger.error(f"Error formatting tool result: {str(e)}")
return {"error": str(e)}
def _get_session_id(self, request: Request) -> str:
"""
Get session ID from the request
Tries to get session ID from the following locations (in priority order):
1. Query parameter session_id
2. session_id field in request body
3. X-Session-ID header
Args:
request: Request object
Returns:
str: Session ID, or None if not found
"""
# Get from query parameter
session_id = request.query_params.get("session_id")
if session_id:
return session_id
# Try getting from request body
try:
body = getattr(request, "_json", None)
if not body:
body_bytes = getattr(request, "_body", None)
if body_bytes:
try:
body = json.loads(body_bytes)
except:
pass
if body and isinstance(body, dict) and "session_id" in body:
return body["session_id"]
except:
pass
# Get from request header
session_id = request.headers.get("X-Session-ID")
if session_id:
return session_id
return None