init doris mcp 0.2.0

This commit is contained in:
Yijia Su
2025-05-06 12:56:55 +08:00
parent 9dc25be87a
commit c190f19cb5
23 changed files with 6405 additions and 0 deletions

View File

@@ -0,0 +1 @@
# Mark directory as a package

View File

@@ -0,0 +1,33 @@
# doris_mcp_server/config.py
import os
import logging
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv(override=True)
# Get Log Level from environment variable, default to 'info'
LOG_LEVEL_STR = os.getenv('LOG_LEVEL', 'info').upper()
# Map string level to logging level constant
LOG_LEVEL_MAP = {
'DEBUG': logging.DEBUG,
'INFO': logging.INFO,
'WARNING': logging.WARNING,
'ERROR': logging.ERROR,
'CRITICAL': logging.CRITICAL
}
LOG_LEVEL = LOG_LEVEL_MAP.get(LOG_LEVEL_STR, logging.INFO)
# Function to load config (can be expanded later if needed)
def load_config():
"""Loads configuration settings."""
# Currently, configuration is mainly handled by environment variables
# and constants defined in this module.
# This function can be used to perform additional setup if required.
logging.getLogger(__name__).info("Configuration loaded (mainly from environment variables).")
# You can add other configuration constants here if needed
# Example: DB_HOST = os.getenv("DB_HOST", "localhost")
# But often it's better to access os.getenv directly where needed
# or pass config dictionaries around.

196
doris_mcp_server/main.py Normal file
View File

@@ -0,0 +1,196 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Apache Doris MCP Server Main Entry - Primarily handles SSE mode
Stdio mode is handled by doris_mcp_server.mcp_core:run_stdio.
"""
import os
import sys
import argparse
import asyncio
import logging
from contextlib import asynccontextmanager
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Dict, Any
import uvicorn
from uvicorn import Config, Server
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
# Add project root to path
PROJECT_ROOT = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
sys.path.insert(0, PROJECT_ROOT)
# SSE related imports
from mcp.server.fastmcp import FastMCP
from doris_mcp_server.sse_server import DorisMCPSseServer
from doris_mcp_server.streamable_server import DorisMCPStreamableServer
# Stdio related imports (only needed for tools now, maybe move tool init?)
# from mcp.server.stdio import stdio_server -> No longer used here
# Config and Tool Initializer
from doris_mcp_server.config import load_config # LOG_LEVEL might not be needed here directly
from doris_mcp_server.tools.tool_initializer import register_mcp_tools
# Load environment variables (load early for all modes)
load_dotenv(override=True)
# Get logger
logger = logging.getLogger("doris-mcp-main") # Changed logger name slightly
# --- Configuration Loading and Logging Setup ---
load_config() # Loads .env
# --- Create FastAPI App (Global Scope for SSE Mode) ---
# This 'app' object is targeted by 'mcp run doris_mcp_server/main.py:app --transport sse'
# And used when running directly with --sse
app = FastAPI(
title="Doris MCP Server (SSE Mode)",
# Lifespan will be added in start_sse_server
)
# --- Removed StdioServerWrapper ---
# --- Command Line Argument Parsing ---
def parse_args():
parser = argparse.ArgumentParser(description="Apache Doris MCP Server (SSE Mode Entry)")
# Only keep SSE related args here
parser.add_argument('--sse', action='store_true', help='Start SSE Web server mode (required)')
parser.add_argument('--host', type=str, default=os.getenv('SERVER_HOST', '0.0.0.0'), help='Host address')
parser.add_argument('--port', type=int, default=int(os.getenv('SERVER_PORT', os.getenv('MCP_PORT', '3000'))), help='Port number')
parser.add_argument('--debug', action='store_true', help='Enable debug mode')
parser.add_argument('--reload', action='store_true', help='Enable auto-reload')
return parser.parse_args()
# --- SSE Mode Specific Code ---
@dataclass
class AppContext:
config: Dict[str, Any]
@asynccontextmanager
async def app_lifespan(app_instance: FastAPI) -> AsyncIterator[None]:
logger.info("SSE application lifecycle start...")
config = {
# Simplified config - maybe get from elsewhere?
"db_host": os.getenv("DB_HOST", "localhost"),
"db_port": int(os.getenv("DB_PORT", "9030")),
"db_user": os.getenv("DB_USER", "root"),
"db_password": os.getenv("DB_PASSWORD", ""),
"db_database": os.getenv("DB_DATABASE", "test"),
}
app_instance.state.config = config
try:
# Yield None implicitly or explicitly None
yield
finally:
logger.info("Cleaning up SSE application resources...")
async def start_sse_server(args):
"""Start SSE Web server mode (Configures the global 'app')"""
logger.info("Starting SSE Web server mode...")
global app
# --- Initialize MCP and Tools for SSE ---
# Create a *separate* MCP instance for SSE mode
sse_mcp = FastMCP(
name="doris-mcp-sse",
description="Apache Doris MCP Server (SSE)",
lifespan=None, # Managed by FastAPI
dependencies=["fastapi", "uvicorn", "openai", "sse_starlette"]
)
logger.info("Registering MCP tools for SSE mode...")
await register_mcp_tools(sse_mcp) # Register tools for the SSE instance
logger.info("MCP tools registered for SSE.")
# --- Configure Lifespan and CORS for the global app ---
app.router.lifespan_context = app_lifespan
origins = os.getenv("ALLOWED_ORIGINS", "*").split(",")
allow_credentials = os.getenv("MCP_ALLOW_CREDENTIALS", "false").lower() == "true"
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=allow_credentials,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["Mcp-Session-Id"],
)
# --- Initialize Handlers and Register Routes (Pass sse_mcp instance) ---
logger.info("Initializing SSE server handlers and registering routes...")
sse_server_handler = DorisMCPSseServer(sse_mcp, app)
streamable_server_handler = DorisMCPStreamableServer(sse_mcp, app)
logger.info("SSE Server handlers initialized and routes registered.")
# --- Print Configuration and Endpoints ---
print("--- SSE Mode Configuration ---")
print(f"Server Host: {args.host}")
print(f"Server Port: {args.port}")
print(f"Allowed Origins: {origins}")
print(f"Allow Credentials: {allow_credentials}")
print(f"Log Level: {os.getenv('LOG_LEVEL', 'info')}")
print(f"Debug Mode: {args.debug}")
print(f"Reload Mode: {args.reload}")
print(f"DB Host: {os.getenv('DB_HOST')}")
print(f"DB Port: {os.getenv('DB_PORT')}")
print(f"DB User: {os.getenv('DB_USER')}")
print(f"DB Database: {os.getenv('DB_DATABASE')}")
print(f"Force Refresh Metadata: {os.getenv('FORCE_REFRESH_METADATA', 'false')}")
print("------------------------------")
base_url = f"http://{args.host}:{args.port}"
print(f"Service running at: {base_url}")
print(f" Health Check: GET {base_url}/health")
print(f" Status Check: GET {base_url}/status")
print(f" SSE Init: GET {base_url}/sse")
print(f" SSE/Legacy Messages: POST {base_url}/mcp/messages")
print(f" Streamable HTTP: GET/POST/DELETE/OPTIONS {base_url}/mcp")
print("------------------------------")
print("Use Ctrl+C to stop the service")
# --- Start Uvicorn Server ---
config = Config(
app=app,
host=args.host,
port=args.port,
log_level="debug" if args.debug else "info",
reload=args.reload
)
server = Server(config=config)
await server.serve()
# --- Main Execution Logic (Simplified) ---
def run_main_sync():
"""Synchronous wrapper, primarily for SSE mode now."""
sync_logger = logging.getLogger("run_main_sync")
sync_logger.info("Entering run_main_sync (SSE focus)...")
print("DEBUG: Entering run_main_sync (SSE focus)...", file=sys.stderr, flush=True)
args = parse_args()
if args.sse:
try:
# Run the async SSE server setup and Uvicorn loop
asyncio.run(start_sse_server(args))
sync_logger.info("asyncio.run(start_sse_server) completed.")
print("DEBUG: asyncio.run(start_sse_server) completed.", file=sys.stderr, flush=True)
except KeyboardInterrupt:
sync_logger.info("SSE server stopped by KeyboardInterrupt.")
except Exception as e:
sync_logger.critical(f"Error during asyncio.run(start_sse_server): {e}", exc_info=True)
print(f"DEBUG: Error during asyncio.run(start_sse_server): {e}", file=sys.stderr, flush=True)
raise
else:
# If run without --sse, print help/error
message = "Error: This entry point requires --sse flag. For stdio mode, use 'uv run mcp-doris' or the appropriate command for your stdio setup."
sync_logger.error(message)
print(message, file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
run_main_sync()

View File

@@ -0,0 +1,143 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Core MCP instance and startup logic for stdio mode.
"""
import asyncio
import logging
import sys
import traceback
import json
from typing import Dict, Any
# Import necessary components from mcp and our project
from mcp.server.fastmcp import FastMCP
logger = logging.getLogger("doris-mcp-core")
# --- Global MCP Instance for Stdio ---
# Create the instance when the module is imported.
# Tools will be registered synchronously(?) before running.
stdio_mcp = FastMCP(
name="doris-mcp-stdio-core",
description="Apache Doris MCP Server (stdio via core)",
)
# --- Removed async setup functions ---
def run_stdio():
"""
Synchronous entry point for running the stdio server.
Mimics the mcp-doris example by calling .run() on the instance.
Handles tool registration beforehand.
"""
logger.info("Executing run_stdio (synchronous entry point)...")
# --- Run the stdio server using the instance's run() method ---
logger.info("Calling stdio_mcp.run()...")
try:
# Assuming stdio_mcp has a synchronous run() method for stdio
stdio_mcp.run()
logger.info("stdio_mcp.run() completed.")
except KeyboardInterrupt:
logger.info("Stdio server stopped by KeyboardInterrupt.")
except AttributeError:
logger.critical("Error: stdio_mcp object does not have a '.run()' method suitable for stdio.", exc_info=False)
print("ERROR: stdio_mcp object does not have a '.run()' method.", file=sys.stderr, flush=True)
sys.exit(1)
except Exception as e:
logger.critical(f"run_stdio encountered an error during stdio_mcp.run(): {e}", exc_info=True)
traceback.print_exc(file=sys.stderr)
sys.exit(1)
# Register Tool: Execute SQL Query
@stdio_mcp.tool("exec_query", description="""[Function Description]: Execute SQL query and return result command (executed by the client).\n
[Parameter Content]:\n
- sql (string) [Required] - SQL statement to execute\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n
- max_rows (integer) [Optional] - Maximum number of rows to return, default 100\n
- timeout (integer) [Optional] - Query timeout in seconds, default 30\n""")
async def exec_query_tool(sql: str, db_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]:
"""Wrapper: Execute SQL query and return result command"""
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_exec_query
return await mcp_doris_exec_query(sql=sql, db_name=db_name, max_rows=max_rows, timeout=timeout)
# Register Tool: Get Table Schema
@stdio_mcp.tool("get_table_schema", description="""[Function Description]: Get detailed structure information of the specified table (columns, types, comments, etc.).\n
[Parameter Content]:\n
- table_name (string) [Required] - Name of the table to query\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n""")
async def get_table_schema_tool(table_name: str, db_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get table schema"""
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_schema
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
return await mcp_doris_get_table_schema(table_name=table_name, db_name=db_name)
# Register Tool: Get Database Table List
@stdio_mcp.tool("get_db_table_list", description="""[Function Description]: Get a list of all table names in the specified database.\n
[Parameter Content]:\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n""")
async def get_db_table_list_tool(db_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get database table list"""
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_db_table_list
return await mcp_doris_get_db_table_list(db_name=db_name)
# Register Tool: Get Database List
@stdio_mcp.tool("get_db_list", description="""[Function Description]: Get a list of all database names on the server.\n
[Parameter Content]:\n
- random_string (string) [Required] - Unique identifier for the tool call\n""")
async def get_db_list_tool() -> Dict[str, Any]:
"""Wrapper: Get database list"""
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_db_list
return await mcp_doris_get_db_list()
# Register Tool: Get Table Comment
@stdio_mcp.tool("get_table_comment", description="""[Function Description]: Get the comment information for the specified table.\n
[Parameter Content]:\n
- table_name (string) [Required] - Name of the table to query\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n""")
async def get_table_comment_tool(table_name: str, db_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get table comment"""
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_comment
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
return await mcp_doris_get_table_comment(table_name=table_name, db_name=db_name)
# Register Tool: Get Table Column Comments
@stdio_mcp.tool("get_table_column_comments", description="""[Function Description]: Get comment information for all columns in the specified table.\n
[Parameter Content]:\n
- table_name (string) [Required] - Name of the table to query\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n""")
async def get_table_column_comments_tool(table_name: str, db_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get table column comments"""
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_column_comments
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
return await mcp_doris_get_table_column_comments(table_name=table_name, db_name=db_name)
# Register Tool: Get Table Indexes
@stdio_mcp.tool("get_table_indexes", description="""[Function Description]: Get index information for the specified table.
[Parameter Content]:\n
- table_name (string) [Required] - Name of the table to query\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n""")
async def get_table_indexes_tool(table_name: str, db_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get table indexes"""
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_indexes
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
return await mcp_doris_get_table_indexes(table_name=table_name, db_name=db_name)
# Register Tool: Get Recent Audit Logs
@stdio_mcp.tool("get_recent_audit_logs", description="""[Function Description]: Get audit log records for a recent period.\n
[Parameter Content]:\n
- days (integer) [Optional] - Number of recent days of logs to retrieve, default is 7\n
- limit (integer) [Optional] - Maximum number of records to return, default is 100\n""")
async def get_recent_audit_logs_tool(days: int = 7, limit: int = 100) -> Dict[str, Any]:
"""Wrapper: Get recent audit logs"""
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_recent_audit_logs
try:
days = int(days)
limit = int(limit)
except (ValueError, TypeError):
return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "days and limit parameters must be integers"})}]}
return await mcp_doris_get_recent_audit_logs(days=days, limit=limit)
# --- Register Tools ---

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,912 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Doris MCP Streamable HTTP Server Implementation
Implements the MCP 2025-03-26 Streamable HTTP specification.
Uses a unified /mcp endpoint for GET, POST, DELETE, OPTIONS.
Manages sessions using Mcp-Session-Id header.
"""
import asyncio
import json
import uuid
import logging
import time
from typing import Any, Optional, Dict, List
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette.sse import EventSourceResponse
# Use a distinct logger name
logger = logging.getLogger("doris-mcp-streamable")
# Special marker for closing streams
STREAM_END_MARKER = "__MCP_STREAM_END__"
class DorisMCPStreamableServer:
"""Doris MCP Streamable HTTP Server"""
def __init__(self, mcp_server, app: FastAPI):
"""
Initializes the Doris MCP Streamable HTTP server.
Args:
mcp_server: The shared FastMCP server instance.
app: The main FastAPI application instance.
"""
self.mcp_server = mcp_server
self.app = app # We'll add routes to this app
# Note: CORS middleware should be added only once in main.py usually.
# If added here, ensure it doesn't conflict or duplicate.
# For separation, we might let main.py handle CORS entirely.
# Client session management for Streamable HTTP clients
# key: session_id (from Mcp-Session-Id header)
# value: {
# "created_at": timestamp,
# "last_active": timestamp,
# "request_queues": { request_id: asyncio.Queue }, # For POST /mcp request streams
# "general_sse_queues": List[asyncio.Queue] # For GET /mcp server push streams
# }
self.client_sessions: Dict[str, Dict[str, Any]] = {}
# Setup the unified MCP endpoint
self._setup_streamable_http_routes()
# Register session cleanup task if this instance manages lifespan independently
# Usually, startup events are tied to the main app lifespan managed in main.py
# We might not need @app.on_event("startup") here if main.py handles it.
# Let's assume main.py handles the cleanup task initiation.
def _setup_streamable_http_routes(self):
"""Sets up the unified /mcp endpoint for Streamable HTTP.
Uses a distinct tag for API docs.
"""
@self.app.api_route("/mcp", methods=["GET", "POST", "DELETE", "OPTIONS"], tags=["Streamable HTTP"])
async def mcp_endpoint_handler(request: Request):
"""Handles GET, POST, DELETE, OPTIONS for the /mcp endpoint."""
# 1. Handle OPTIONS (CORS preflight)
if request.method == "OPTIONS":
# Assuming CORS headers are handled by middleware in main.py
# If not, provide necessary headers here.
# This minimal response might suffice if middleware handles the rest
logger.debug("Handling OPTIONS request for /mcp")
# Return basic OK allowing exposed headers if middleware handles the rest
return JSONResponse({}, headers={"Access-Control-Expose-Headers": "Mcp-Session-Id"})
# Session ID from header is required for most methods
session_id = request.headers.get("Mcp-Session-Id")
# 2. Handle DELETE (Terminate Session)
if request.method == "DELETE":
if not session_id:
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32600, "message": "Mcp-Session-Id header is required for DELETE"}}, status_code=400)
logger.info(f"Handling DELETE request for session [Session ID: {session_id}]")
session_data = self.client_sessions.pop(session_id, None)
if session_data:
await self._cleanup_session_resources(session_id, session_data)
return JSONResponse({}, status_code=204) # No Content
else:
logger.warning(f"Attempted DELETE on non-existent session: {session_id}")
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32001, "message": "Session not found"}}, status_code=404)
# 3. Handle GET (Server Push SSE Stream)
if request.method == "GET":
if not session_id:
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32000, "message": "Mcp-Session-Id header is required for GET streams"}}, status_code=400)
if session_id not in self.client_sessions:
# Note: Unlike legacy SSE, GET here assumes session exists.
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32001, "message": "Session not found. Initialize first."}}, status_code=404)
accept_header = request.headers.get("Accept", "")
if "text/event-stream" not in accept_header:
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32600, "message": "Accept header must include text/event-stream for GET"}}, status_code=406)
# TODO: Handle Last-Event-ID for stream recovery?
logger.info(f"Handling GET request, establishing server push SSE stream [Session ID: {session_id}]")
push_queue = asyncio.Queue()
if self.client_sessions[session_id].get("general_sse_queues") is None:
self.client_sessions[session_id]["general_sse_queues"] = []
self.client_sessions[session_id]["general_sse_queues"].append(push_queue)
self.client_sessions[session_id]["last_active"] = time.time()
return EventSourceResponse(self._create_general_sse_generator(session_id, push_queue), media_type="text/event-stream")
# 4. Handle POST (Client Messages & Initialize)
if request.method == "POST":
accept_header = request.headers.get("Accept", "")
content_type = request.headers.get("Content-Type", "")
body = {}
try:
if "application/json" not in content_type:
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32700, "message": "Content-Type must be application/json"}}, status_code=415)
body = await request.json()
if isinstance(body, list): return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32600, "message": "Batch requests not supported"}}, status_code=400)
if not isinstance(body, dict): return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32700, "message": "Invalid JSON received"}}, status_code=400)
method = body.get("method")
message_id = body.get("id") # Can be None for notifications
# Handle Initialize request (does not require Mcp-Session-Id header)
if method == "initialize":
if "application/json" not in accept_header:
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32600, "message": "Accept header must include application/json for initialize"}}, status_code=406)
return await self._handle_initialize(request, body, message_id)
# Handle other POST requests (require Mcp-Session-Id)
else:
if not session_id:
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32000, "message": "Mcp-Session-Id header is required for this request"}}, status_code=400)
if session_id not in self.client_sessions:
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32001, "message": "Session not found"}}, status_code=404)
# Check Accept header for non-initialize POST
if not ("application/json" in accept_header and "text/event-stream" in accept_header):
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32600, "message": "Accept header must include application/json and text/event-stream for POST"}}, status_code=406)
self.client_sessions[session_id]["last_active"] = time.time()
return await self._handle_client_post(request, body, session_id, message_id)
except json.JSONDecodeError:
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error - Invalid JSON received"}}, status_code=400)
except Exception as e:
logger.error(f"Unexpected error handling POST /mcp: {str(e)}", exc_info=True)
error_id = body.get("id") if isinstance(body, dict) else None
return JSONResponse({"jsonrpc": "2.0", "id": error_id, "error": {"code": -32000, "message": "Internal server error"}}, status_code=500)
# Fallback for other methods like PUT, PATCH etc.
return JSONResponse({"error": "Method Not Allowed"}, status_code=405)
async def _handle_initialize(self, request: Request, body: Dict, message_id: Any):
"""Handles the 'initialize' method call via POST /mcp."""
logger.info("Handling Streamable HTTP initialize request")
# Optional: Validate params in body if needed
# params = body.get("params", {})
new_session_id = str(uuid.uuid4())
logger.info(f"Created new Streamable HTTP session [Session ID: {new_session_id}]")
self.client_sessions[new_session_id] = {
"created_at": time.time(),
"last_active": time.time(),
# No transport_type needed here as this class *is* the streamable server
"request_queues": {}, # Initialize request queues dict
"general_sse_queues": [] # Initialize general queues list
}
# Build InitializeResult based on spec
initialize_result = {
"protocolVersion": "2025-03-26",
"name": self.mcp_server.name,
"instructions": "Apache Doris MCP Server (Streamable HTTP Mode)",
"serverInfo": { "version": "0.2.0", "name": "Doris MCP Streamable Server" }, # Adjust as needed
"capabilities": {
"tools": { "supportsStreaming": True, "supportsProgress": True },
"resources": { "supportsStreaming": False }, # Example capability
"prompts": { "supported": True }, # Example capability
"session": { "supported": True }
}
}
response_body = {
"jsonrpc": "2.0",
"id": message_id,
"result": initialize_result
}
# Return JSON response with Mcp-Session-Id header
return JSONResponse(
content=response_body,
media_type="application/json",
headers={"Mcp-Session-Id": new_session_id}
)
async def _handle_client_post(self, request: Request, body: Dict, session_id: str, message_id: Any):
"""Handles non-initialize POST requests (notifications, responses, method calls)."""
method = body.get("method")
# Handle Notifications/Responses from client
is_notification = "method" in body and "id" not in body
is_response = "result" in body or "error" in body
if is_notification or is_response:
logger.info(f"Received Streamable HTTP notification/response [Session ID: {session_id}] - Processing needed? (Ignoring for now)")
# TODO: If the server sends requests that expect responses, process is_response here.
# For now, just acknowledge client notifications/responses.
return JSONResponse({}, status_code=202) # Accepted
# Handle Requests from client (method call)
if "method" in body and "id" in body:
logger.info(f"Received Streamable HTTP request [Session ID: {session_id}, ID: {message_id}, Method: {method}]")
params = body.get("params", {})
stream_required = params.get("stream", False) if method in ["tools/call", "mcp/callTool"] else False
if stream_required:
# --- Return SSE stream for response parts ---
logger.info(f"Using SSE stream for request [Session ID: {session_id}, ID: {message_id}]")
response_queue = asyncio.Queue()
# Ensure request_queues exists (should have been created during initialize)
if self.client_sessions[session_id].get("request_queues") is None:
logger.error(f"Session {session_id} is missing 'request_queues' dictionary!")
# Handle this inconsistency, maybe return an error
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32000, "message": "Internal server error: Session state inconsistent"}}, status_code=500)
self.client_sessions[session_id]["request_queues"][message_id] = response_queue
# Start background task to process and put results in the queue
asyncio.create_task(self._process_request_and_respond(
request, body, session_id, message_id, response_queue, is_stream=True
))
# Return EventSourceResponse using the request-specific queue
return EventSourceResponse(self._create_request_sse_generator(session_id, message_id, response_queue), media_type="text/event-stream")
else:
# --- Return single JSON response ---
logger.info(f"Using JSON response for request [Session ID: {session_id}, ID: {message_id}]")
try:
# Process the request directly and get the result/error payload
result_or_error_payload = await self._process_request_and_respond(
request, body, session_id, message_id, None, is_stream=False
)
# This function now returns the final JSON body or raises HTTPException
return JSONResponse(content=result_or_error_payload, media_type="application/json")
except HTTPException as http_exc:
# Format HTTPException details into JSON-RPC error
return JSONResponse(
{"jsonrpc": "2.0", "id": message_id, "error": {"code": -32000, "message": http_exc.detail}},
status_code=http_exc.status_code
)
except Exception as e:
# Catch unexpected errors during synchronous processing
logger.error(f"Error processing non-stream request [Session ID: {session_id}, ID: {message_id}]: {str(e)}", exc_info=True)
error_response = {"jsonrpc": "2.0", "id": message_id, "error": {"code": -32000, "message": f"Internal server error: {str(e)}"}}
return JSONResponse(content=error_response, status_code=500)
else:
# Invalid JSON-RPC format (e.g., missing method or id for a request)
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32600, "message": "Invalid JSON-RPC request format"}}, status_code=400)
# === Generator Functions for SSE Streams ===
async def _create_general_sse_generator(self, session_id: str, queue: asyncio.Queue):
"""Generator for GET /mcp server push streams."""
queue_removed = False
try:
while True:
try:
if session_id not in self.client_sessions:
logger.warning(f"General SSE stream generator: Session {session_id} closed.")
break
message = await asyncio.wait_for(queue.get(), timeout=60.0)
if message == STREAM_END_MARKER:
logger.debug(f"General SSE stream received end marker [Session ID: {session_id}]")
break
if isinstance(message, dict) and ("result" in message or "error" in message) and "id" in message:
logger.warning(f"Attempted to send response on GET stream, blocked [Session ID: {session_id}]: {message}")
queue.task_done()
continue
# TODO: Event ID for recovery?
yield {"event": "message", "data": json.dumps(message)}
queue.task_done()
except asyncio.TimeoutError:
if session_id not in self.client_sessions:
logger.warning(f"General SSE stream generator (timeout): Session {session_id} closed.")
break
yield {"event": "ping", "data": "keepalive"}
continue
except asyncio.CancelledError:
logger.info(f"General SSE stream cancelled [Session ID: {session_id}]")
raise
except Exception as e:
logger.error(f"General SSE stream error [Session ID: {session_id}]: {str(e)}", exc_info=True)
break
finally:
logger.info(f"General SSE stream ended [Session ID: {session_id}]")
if not queue_removed and session_id in self.client_sessions:
session = self.client_sessions[session_id]
if session.get("general_sse_queues") is not None:
try:
session["general_sse_queues"].remove(queue)
queue_removed = True
logger.debug(f"General SSE queue removed from session [Session ID: {session_id}]")
except ValueError:
logger.warning(f"Failed to remove general SSE queue (not found) [Session ID: {session_id}]")
except Exception as ce:
logger.error(f"Error removing general SSE queue [Session ID: {session_id}]: {ce}")
while not queue.empty():
try: queue.get_nowait(); queue.task_done()
except asyncio.QueueEmpty: break
async def _create_request_sse_generator(self, session_id: str, request_id: Any, queue: asyncio.Queue):
"""Generator for POST /mcp request-response streams."""
queue_removed = False
try:
while True:
try:
if session_id not in self.client_sessions or \
request_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
logger.warning(f"Request SSE stream generator: Session/Request queue closed [Session ID: {session_id}, Request ID: {request_id}]")
break
message = await asyncio.wait_for(queue.get(), timeout=120.0) # Longer timeout for requests?
if message == STREAM_END_MARKER:
logger.debug(f"Request SSE stream received end marker [Session ID: {session_id}, Request ID: {request_id}]")
break
# TODO: Event ID for parts?
yield {"event": "message", "data": json.dumps(message)}
queue.task_done()
except asyncio.TimeoutError:
if session_id not in self.client_sessions or \
request_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
logger.warning(f"Request SSE stream generator (timeout): Session/Request queue closed [Session ID: {session_id}, Request ID: {request_id}]")
break
logger.debug(f"Request SSE stream timed out waiting for message/end [Session ID: {session_id}, Request ID: {request_id}]")
# Unlike general stream, timeout here might indicate an issue or just long processing.
# Continue waiting for the STREAM_END_MARKER.
continue
except asyncio.CancelledError:
logger.info(f"Request SSE stream cancelled [Session ID: {session_id}, Request ID: {request_id}]")
raise
except Exception as e:
logger.error(f"Request SSE stream error [Session ID: {session_id}, Request ID: {request_id}]: {str(e)}", exc_info=True)
break
finally:
logger.info(f"Request SSE stream ended [Session ID: {session_id}, Request ID: {request_id}]")
if not queue_removed and session_id in self.client_sessions:
session = self.client_sessions[session_id]
if session.get("request_queues") is not None:
if session["request_queues"].pop(request_id, None):
queue_removed = True
logger.debug(f"Request SSE queue removed from session [Session ID: {session_id}, Request ID: {request_id}]")
else:
logger.warning(f"Failed to remove request SSE queue (not found) [Session ID: {session_id}, Request ID: {request_id}]")
while not queue.empty():
try: queue.get_nowait(); queue.task_done()
except asyncio.QueueEmpty: break
# === Core Request Processing Logic ===
async def _process_request_and_respond(
self, request: Request, body: Dict, session_id: str, message_id: Any,
response_queue: Optional[asyncio.Queue], # Queue ONLY for streaming responses
is_stream: bool # True if response should go via SSE queue
):
"""Processes client method calls and prepares response/error payload or sends to queue.
Returns payload for non-streaming, returns None for streaming (uses queue).
Raises HTTPException for non-streaming errors that need specific status codes.
"""
logger.info(f"Entering _process_request_and_respond for method '{body.get('method')}'...")
method = body.get("method")
params = body.get("params", {})
response_payload = None # Holds the 'result' or 'error' part of JSON-RPC
try:
# --- Handle Method Calls ---
if method == "mcp/listOfferings":
tools = await self.mcp_server.list_tools()
tools_json = self._format_tools(tools)
resources = await self.mcp_server.list_resources()
resources_json = self._format_resources(resources)
prompts = await self.mcp_server.list_prompts()
prompts_json = self._format_prompts(prompts)
response_payload = {"tools": tools_json, "resources": resources_json, "prompts": prompts_json}
elif method == "mcp/listTools" or method == "tools/list":
tools = await self.mcp_server.list_tools()
response_payload = {"tools": self._format_tools(tools)}
elif method == "mcp/listResources":
resources = await self.mcp_server.list_resources()
response_payload = {"resources": self._format_resources(resources)}
elif method == "mcp/listPrompts":
prompts = await self.mcp_server.list_prompts()
response_payload = {"prompts": self._format_prompts(prompts)}
elif method == "mcp/callTool" or method == "tools/call":
tool_name = params.get("name")
arguments = params.get("arguments", {})
if not tool_name:
# For non-streaming, raise HTTPException; for streaming, send error via queue
error_detail = "Invalid params: tool name ('name') is required"
if is_stream and response_queue:
error_resp = {"jsonrpc": "2.0", "id": message_id, "error": {"code": -32602, "message": error_detail}}
await response_queue.put(error_resp)
# No return here for stream, let finally handle end marker
else:
raise HTTPException(status_code=400, detail=error_detail)
return # Exit after handling error
# --- Tool Calling ---
if is_stream and response_queue:
# Background task handles putting results/errors in queue
logger.info(f"Launching stream tool task [Session: {session_id}, Req: {message_id}, Tool: {tool_name}]")
asyncio.create_task(self._execute_stream_tool_wrapper(
tool_name, arguments, message_id, session_id, request, response_queue
))
# Returns None, caller (_handle_client_post) returns EventSourceResponse
return
else:
# Execute tool directly for non-streaming response
logger.info(f"Executing non-stream tool [Session: {session_id}, Req: {message_id}, Tool: {tool_name}]")
# Note: call_tool now raises ValueError on internal errors
result = await self.call_tool(tool_name, arguments, request, None) # No callback needed
logger.debug(f"Raw result from non-stream call_tool: {result}")
response_payload = self._format_tool_call_result(result)
else:
# Method not found
error_detail = f"Method not found: {method}"
if is_stream and response_queue:
error_resp = {"jsonrpc": "2.0", "id": message_id, "error": {"code": -32601, "message": error_detail}}
await response_queue.put(error_resp)
else:
raise HTTPException(status_code=405, detail=error_detail)
return # Exit after handling error
# --- Prepare final response payload (only if not streaming and successful) ---
if response_payload is not None:
final_response = {"jsonrpc": "2.0", "id": message_id, "result": response_payload}
if is_stream and response_queue: # Should not happen if response_payload is set
logger.error("Logic error: response_payload set for streaming call?")
await response_queue.put(final_response) # Send anyway?
elif not is_stream:
logger.debug(f"Returning successful non-stream payload for {method}")
return final_response # Return dict for JSONResponse
except Exception as e:
# Handles errors raised by call_tool (ValueError) or other unexpected issues
logger.error(f"Error processing request [Session: {session_id}, Req: {message_id}, Method: {method}]: {str(e)}", exc_info=True)
error_code = -32000
error_message = f"Internal server error: {str(e)}"
status_code = 500 # Default for unexpected errors
if isinstance(e, HTTPException):
# If it was an HTTPException raised earlier (e.g., 400, 405)
error_message = e.detail
status_code = e.status_code
error_code = -32000 # Keep generic JSON-RPC code for now
elif isinstance(e, ValueError):
# Errors from call_tool (tool not found, execution error)
error_message = str(e)
status_code = 500 # Treat tool execution errors as internal server errors
error_code = -32000 # Or a custom tool error code?
error_response_payload = {"code": error_code, "message": error_message}
if is_stream and response_queue:
# Send error via queue for streaming calls
final_error_response = {"jsonrpc": "2.0", "id": message_id, "error": error_response_payload}
logger.debug(f"Putting error response into stream queue [Session: {session_id}, Req: {message_id}]")
await response_queue.put(final_error_response)
# Returns None, let finally send end marker
return
else:
# For non-streaming, raise HTTPException to set status code
logger.debug(f"Raising HTTPException for non-stream error (Status: {status_code})")
raise HTTPException(status_code=status_code, detail=error_message)
finally:
# If this was a streaming call, ensure the end marker is sent.
# This runs even if the processing returns early (e.g., after launching task or handling error).
if is_stream and response_queue:
logger.debug(f"Putting stream end marker [Session: {session_id}, Req: {message_id}]")
await response_queue.put(STREAM_END_MARKER)
async def _execute_stream_tool_wrapper(
self, tool_name: str, arguments: Dict, message_id: Any, session_id: str,
request: Request, response_queue: asyncio.Queue
):
"""Wraps stream-capable tool calls, handles callback, puts results/errors into queue."""
logger.info(f"Entering _execute_stream_tool_wrapper for tool '{tool_name}'...")
try:
logger.debug(f"Executing stream tool wrapper [Session: {session_id}, Req: {message_id}, Tool: {tool_name}]")
async def stream_callback(content, metadata=None):
logger.debug(f"Stream callback received content [Session: {session_id}, Req: {message_id}]")
partial_result_formatted = self._format_tool_call_result(content)
# Check session/queue validity before putting
if session_id not in self.client_sessions or \
message_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
logger.warning(f"Stream callback: Session/Queue closed, cannot send partial result [Session: {session_id}, Req: {message_id}]")
return
# Send progress notification
progress_notification = {
"jsonrpc": "2.0",
"method": "tools/progress",
"params": {
"requestId": message_id,
"toolName": tool_name,
"progress": partial_result_formatted,
}
}
try:
await response_queue.put(progress_notification)
except Exception as e:
logger.error(f"Stream callback failed to send progress: {str(e)}")
# Handle visualization data
if metadata and "visualization" in metadata:
await self.send_visualization_data(session_id, message_id, metadata["visualization"])
# --- Call Tool ---
kwargs = dict(arguments)
# Simplification: Assume tool supports callback if streaming requested
kwargs['callback'] = stream_callback
# call_tool handles its own internal errors and raises ValueError
result = await self.call_tool(tool_name, kwargs, request, stream_callback)
logger.debug(f"Stream wrapper received final result from call_tool: {result}")
# --- Send Final Result ---
if session_id not in self.client_sessions or \
message_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
logger.warning(f"Stream tool finished but Session/Queue closed [Session: {session_id}, Req: {message_id}]")
return # Cannot send final result
final_result_formatted = self._format_tool_call_result(result)
final_message = {
"jsonrpc": "2.0",
"id": message_id,
"result": final_result_formatted
}
logger.debug(f"Putting final stream result into queue [Session: {session_id}, Req: {message_id}]")
await response_queue.put(final_message)
logger.info(f"Stream tool execution successful [Session: {session_id}, Req: {message_id}]")
except Exception as e:
# Catches errors from call_tool (ValueError) or other wrapper issues
logger.error(f"Error during stream tool execution wrapper [Session: {session_id}, Req: {message_id}]: {str(e)}", exc_info=True)
# Check session/queue validity before sending error
if session_id not in self.client_sessions or \
message_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
logger.warning(f"Stream tool failed but Session/Queue closed [Session: {session_id}, Req: {message_id}]")
return # Cannot send error
error_code = -32000
error_message = f"Tool execution error: {str(e)}"
if isinstance(e, ValueError):
error_code = -32602 # Or -32000?
error_message = str(e)
error_response = {
"jsonrpc": "2.0",
"id": message_id,
"error": { "code": error_code, "message": error_message }
}
try:
await response_queue.put(error_response)
except Exception as qe:
logger.error(f"Failed to put error response into stream queue: {qe}")
# No finally block needed here, handled by _process_request_and_respond
async def call_tool(self, tool_name, arguments, request, callback: Optional[callable] = None):
"""Finds and executes the target tool function/method.
Raises ValueError on tool not found or execution error.
"""
logger.info(f"Entering call_tool for tool '{tool_name}'...")
# Log args excluding callback
log_args = {k: v for k, v in arguments.items() if k != 'callback'}
logger.info(f"Executing tool: {tool_name}, Args: {json.dumps(log_args, ensure_ascii=False, default=str)}")
recent_query = self._extract_recent_query(request)
# Tool mapping might be needed if client uses different names
tool_mapping = {
# Example: "clientFacingName": "internalFunctionName"
"status": "mcp_doris_status",
"health": "mcp_doris_health",
# Add other mappings if needed, ensure consistency with tool_initializer
"nl2sql_query": "mcp_doris_nl2sql_query",
"nl2sql_query_stream": "mcp_doris_nl2sql_query_stream",
"list_database_tables": "mcp_doris_list_database_tables",
"explain_table": "mcp_doris_explain_table",
"get_nl2sql_status": "mcp_doris_get_nl2sql_status",
"refresh_metadata": "mcp_doris_refresh_metadata",
"sql_optimize": "mcp_doris_sql_optimize",
"fix_sql": "mcp_doris_fix_sql",
"count_chars": "mcp_doris_count_chars",
"exec_query": "mcp_doris_exec_query",
"get_schema_list": "mcp_doris_get_schema_list", # Deprecated?
"save_metadata": "mcp_doris_save_metadata", # Likely internal
"get_metadata": "mcp_doris_get_metadata", # Likely internal
"analyze_query_result": "mcp_doris_analyze_query_result", # Internal?
"generate_sql": "mcp_doris_generate_sql", # Likely internal
"explain_sql": "mcp_doris_explain_sql", # Internal?
"modify_sql": "mcp_doris_modify_sql", # Internal?
"parse_query": "mcp_doris_parse_query", # Internal?
"identify_query_type": "mcp_doris_identify_query_type", # Internal?
"validate_sql_syntax": "mcp_doris_validate_sql_syntax", # Internal?
"check_sql_security": "mcp_doris_check_sql_security", # Internal?
"find_similar_examples": "mcp_doris_find_similar_examples", # Internal?
"find_similar_history": "mcp_doris_find_similar_history", # Internal?
"calculate_query_similarity": "mcp_doris_calculate_query_similarity", # Internal?
"adapt_similar_query": "mcp_doris_adapt_similar_query", # Internal?
"get_nl2sql_prompt": "mcp_doris_get_nl2sql_prompt" # Internal?
}
mapped_tool_name = tool_mapping.get(tool_name, tool_name)
try:
# 1. Find the registered tool instance/function from FastMCP
tool_instance = None
mcp = self.app.state.mcp if hasattr(self.app.state, 'mcp') else self.mcp_server
registered_tools = await mcp.list_tools()
for tool in registered_tools:
# The tool object returned by list_tools might be the wrapper function
# defined in tool_initializer. We need its name.
tool_registered_name = getattr(tool, 'name', getattr(tool, '__name__', None))
if tool_registered_name == tool_name: # Match against the name used in @mcp.tool
tool_instance = tool # This is likely the wrapper function itself
logger.debug(f"Found registered tool wrapper: {tool_registered_name}")
break
if not tool_instance:
# Fallback: Try importing directly (less ideal as it bypasses registration)
logger.warning(f"Tool '{tool_name}' not found in registered tools, trying direct import of {mapped_tool_name}")
try:
import doris_mcp_server.tools.mcp_doris_tools as mcp_tools
tool_instance = getattr(mcp_tools, mapped_tool_name, None)
if not tool_instance or not callable(tool_instance):
raise ValueError(f"Tool function {mapped_tool_name} not found or not callable in mcp_doris_tools.")
logger.debug(f"Using directly imported tool function: {mapped_tool_name}")
# If using direct import, FastMCP context (ctx) is not available
# We need to pass args directly
processed_args = self._process_tool_arguments(mapped_tool_name, arguments, recent_query)
# Inject callback if provided and applicable
if callback and mapped_tool_name.endswith("_stream"):
processed_args['callback'] = callback
elif callback:
processed_args.pop('callback', None)
result = await tool_instance(**processed_args)
logger.debug(f"Raw result from directly imported tool '{mapped_tool_name}': {result}")
return result
except (ImportError, AttributeError, ValueError) as import_err:
logger.error(f"Failed to find or import tool: {tool_name} / {mapped_tool_name}. Error: {import_err}")
raise ValueError(f"Tool '{tool_name}' not found or failed to import.") from import_err
# 2. If found via registration, execute using FastMCP's mechanism (if possible)
# or simulate the context passing if tool_instance is the wrapper.
# The wrapper expects a Context object.
logger.debug(f"Executing registered tool wrapper '{tool_name}'")
# We need to manually create a mock or simplified Context if FastMCP doesn't handle this automatically
# For simplicity, let's try passing parameters directly if the wrapper handles it.
# Ideally, FastMCP would handle the execution via mcp.call_tool(tool_name, params=...) if available.
# Let's assume the wrapper function handles **kwargs or a Context object.
# Create a pseudo-context or just pass params
# Method 1: Pass params directly (assuming wrapper handles it)
# processed_args = self._process_tool_arguments(mapped_tool_name, arguments, recent_query)
# if callback:
# processed_args['callback'] = callback
# result = await tool_instance(**processed_args) # This likely won't work if it expects Context
# Method 2: Create a Context-like object (Requires Context class import)
# from mcp.server.fastmcp import Context # Make sure imported
# pseudo_ctx = Context(mcp=mcp, request=request, params=arguments, tool=tool_instance)
# result = await tool_instance(pseudo_ctx)
# Method 3: Use mcp.call_tool internal method if accessible and appropriate
# This is speculative based on potential FastMCP internals
if hasattr(mcp, 'call_tool_by_name'): # Hypothetical method
logger.debug("Attempting execution via mcp.call_tool_by_name")
pseudo_ctx_params = arguments # Pass client args
# pseudo_ctx_params['_request'] = request # Maybe pass request?
if callback: pseudo_ctx_params['callback'] = callback # Pass callback?
result = await mcp.call_tool_by_name(tool_name, params=pseudo_ctx_params)
logger.debug(f"Result from mcp.call_tool_by_name: {result}")
else:
# Fallback to manual context simulation if no direct call method exists
logger.debug("Falling back to manual context simulation for tool wrapper execution")
from mcp.server.fastmcp import Context # Ensure imported
# Prepare params for context, including potentially callback
context_params = dict(arguments)
if callback: context_params['callback'] = callback
pseudo_ctx = Context(mcp=mcp, request=request, params=context_params, tool=tool_instance)
result = await tool_instance(pseudo_ctx) # Call the wrapper with simulated context
logger.debug(f"Result from manual context simulation: {result}")
logger.debug(f"Raw result received in call_tool from registered tool '{tool_name}': {result}")
return result
except Exception as e:
logger.error(f"Exception during call_tool for '{tool_name}': {str(e)}", exc_info=True)
raise ValueError(f"Error executing tool '{tool_name}': {str(e)}") from e
# === Helper Methods (Formatting, Session Cleanup, etc.) ===
def _format_tools(self, tools):
# Helper to format tool list for responses
# Based on mcp/listTools structure
tools_json = []
for tool in tools:
# Assuming tools from list_tools are the wrapper functions
tool_registered_name = getattr(tool, 'name', getattr(tool, '__name__', None))
if not tool_registered_name:
logger.warning(f"Could not determine name for tool object: {tool}")
continue
# Need a way to get description and schema associated with the wrapper
# This might require inspecting the mcp instance's internal storage
mcp = self.app.state.mcp if hasattr(self.app.state, 'mcp') else self.mcp_server
# Hypothetical internal access - THIS IS FRAGILE
tool_spec = mcp.tools.get(tool_registered_name) if hasattr(mcp, 'tools') else None
description = ""
input_schema = {"type": "object", "properties": {}, "required": []}
if tool_spec and hasattr(tool_spec, 'description'):
description = tool_spec.description
if tool_spec and hasattr(tool_spec, 'parameters'): # Assuming parameters holds the JSON schema
input_schema = tool_spec.parameters
tools_json.append({
"name": tool_registered_name,
"description": description,
"inputSchema": input_schema
})
return tools_json
def _format_resources(self, resources):
# Helper to format resource list
return [res.model_dump() if hasattr(res, "model_dump") else res for res in resources]
def _format_prompts(self, prompts):
# Helper to format prompt list
return [prompt.model_dump() if hasattr(prompt, "model_dump") else prompt for prompt in prompts]
def _format_tool_call_result(self, result: Any) -> Dict[str, Any]:
# Helper to format tool results into MCP Content format
content_list = []
if isinstance(result, str):
try:
# If it looks like the tool already returned the full JSON RPC like structure
parsed_json = json.loads(result)
if isinstance(parsed_json, dict) and 'content' in parsed_json and isinstance(parsed_json['content'], list):
logger.debug("Tool result already seems formatted with 'content', using as is.")
return parsed_json # Use the structure directly
else:
# Assume it's JSON content, wrap it
content_list.append({"type": "json", "json": parsed_json})
except json.JSONDecodeError:
# Not JSON, treat as text
content_list.append({"type": "text", "text": result})
elif isinstance(result, (dict, list)):
# If result is already a dict with a 'content' list, use it directly
if isinstance(result, dict) and 'content' in result and isinstance(result['content'], list):
logger.debug("Tool result dictionary has 'content', using as is.")
return result # Use the structure directly
else:
# Otherwise, assume it's JSON content to be wrapped
content_list.append({"type": "json", "json": result})
elif result is None:
# Handle None result, maybe return empty content or specific type?
logger.warning("_format_tool_call_result received None result")
content_list.append({"type": "text", "text": ""}) # Example: empty text
else:
# Other types, convert to string and wrap as text
content_list.append({"type": "text", "text": str(result)})
# Always return a dict with a 'content' key containing a list
return {"content": content_list}
def _process_tool_arguments(self, tool_name, arguments, recent_query):
# Helper to process tool arguments, including random_string fallback
# Note: Ensure callback is NOT passed here
processed_args = dict(arguments)
processed_args.pop('callback', None) # Explicitly remove callback
if "random_string" in arguments and tool_name.startswith("mcp_doris_"):
random_string = processed_args.pop("random_string", "") # Remove from processed too
logger.debug(f"Processing random_string '{random_string}' for tool {tool_name}")
# ... (rest of random_string logic as before) ...
# Example for exec_query:
if tool_name == "mcp_doris_exec_query" and not processed_args.get("sql"):
sql_fallback = random_string or recent_query
# ... (logic to extract SQL from fallback) ...
if sql_extracted:
processed_args["sql"] = sql_extracted
else:
logger.warning(f"Missing sql for {tool_name}, and fallback failed.")
# ... (logic for table_name fallback) ...
return processed_args
def _extract_recent_query(self, request: Request) -> Optional[str]:
# Helper to extract recent user query from request
# (Implementation as provided previously)
try:
# Try to extract message history from request body
body = None
body_bytes = getattr(request, "_body", None)
if body_bytes:
try:
body = json.loads(body_bytes)
except: pass
if not body: body = getattr(request, "_json", {})
messages = body.get("params", {}).get("messages", [])
if messages:
for msg in reversed(messages):
if msg.get("role") == "user": return msg.get("content", "")
message = body.get("params", {}).get("message", {})
if message and message.get("role") == "user": return message.get("content", "")
return None
except Exception as e:
logger.error(f"Error extracting recent query: {str(e)}")
return None
async def _cleanup_session_resources(self, session_id: str, session_data: Dict):
# Helper to clean up queues when session is deleted
logger.info(f"Cleaning up resources for session [Session ID: {session_id}]")
# Close general SSE queues
general_queues = session_data.get("general_sse_queues", [])
for queue in general_queues:
try:
await queue.put(STREAM_END_MARKER)
except Exception as e:
logger.warning(f"Error putting end marker in general queue for session {session_id}: {e}")
# Close request-specific SSE queues
request_queues = session_data.get("request_queues", {})
for req_id, queue in request_queues.items():
try:
await queue.put(STREAM_END_MARKER)
except Exception as e:
logger.warning(f"Error putting end marker in request queue {req_id} for session {session_id}: {e}")
logger.info(f"Finished cleaning resources for session {session_id}")
# This method might belong in the main app or a shared utility if needed by both servers
# async def cleanup_idle_sessions(self):
# # ... (implementation - needs access to self.client_sessions) ...
# pass
# This method might belong in the main app or a shared utility
# async def broadcast_message(self, message):
# # ... (implementation - needs access to self.client_sessions of BOTH servers?) ...
# pass
# This method is specific to streamable http tool calls
async def send_visualization_data(self, session_id: str, request_id: Any, visualization_data: Any):
"""Sends visualization data as a notification on the request stream."""
if session_id not in self.client_sessions:
logger.warning(f"Cannot send visualization: Session {session_id} not found.")
return
queue = self.client_sessions.get(session_id, {}).get("request_queues", {}).get(request_id)
if not queue:
logger.warning(f"Cannot send visualization: Request queue {request_id} not found for session {session_id}.")
return
notification = {
"jsonrpc": "2.0",
"method": "tools/visualization",
"params": visualization_data
}
try:
await queue.put(notification)
logger.info(f"Sent visualization notification [Session: {session_id}, Req: {request_id}]")
except Exception as e:
logger.error(f"Error sending visualization notification [Session: {session_id}, Req: {request_id}]: {e}")
# This might belong in main app or shared utility
# async def send_periodic_updates(self):
# # ... (implementation) ...
# pass
# End of class DorisMCPStreamableServer

View File

@@ -0,0 +1,23 @@
from .mcp_doris_tools import (
mcp_doris_exec_query,
mcp_doris_get_table_schema,
mcp_doris_get_db_table_list,
mcp_doris_get_db_list,
mcp_doris_get_table_comment,
mcp_doris_get_table_column_comments,
mcp_doris_get_table_indexes,
mcp_doris_get_recent_audit_logs
)
# The __all__ list should reflect the registered tool names,
# even though the implementation functions have the prefix.
__all__ = [
"exec_query",
"get_table_schema",
"get_db_table_list",
"get_db_list",
"get_table_comment",
"get_table_column_comments",
"get_table_indexes",
"get_recent_audit_logs"
]

View File

@@ -0,0 +1,202 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Doris MCP Tool Implementations
Includes exec_query and new tools based on schema_extractor.
"""
import os
import time
import json
import logging
from typing import Dict, Any
import pandas as pd
# --- Use absolute imports ---
from doris_mcp_server.utils.schema_extractor import MetadataExtractor
from doris_mcp_server.utils.sql_executor_tools import execute_sql_query
# Get logger
logger = logging.getLogger("doris-mcp-tools")
# --- Helper Function to format response ---
def _format_response(success: bool, result: Any = None, error: str = None, message: str = "") -> Dict[str, Any]:
response_data = {
"success": success,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
if success and result is not None:
# Handle DataFrame serialization
if isinstance(result, pd.DataFrame):
try:
# Convert DataFrame to JSON records format
response_data["result"] = json.loads(result.to_json(orient='records', date_format='iso'))
except Exception as df_err:
logger.error(f"DataFrame to JSON conversion failed: {df_err}")
# Fallback or specific error handling for DataFrame
response_data["result"] = {"error": "Failed to serialize DataFrame result"}
response_data["success"] = False # Mark as failed if serialization fails
response_data["error"] = f"DataFrame serialization error: {str(df_err)}"
else:
response_data["result"] = result
response_data["message"] = message or "Operation successful" # Translated: Operation successful
elif not success:
response_data["error"] = error or "Unknown error" # Translated: Unknown error
response_data["message"] = message or "Operation failed" # Translated: Operation failed
return {
"content": [
{
"type": "text",
"text": json.dumps(response_data, ensure_ascii=False, default=str) # Use default=str for non-serializable types
}
]
}
async def mcp_doris_exec_query(sql: str = None, db_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]:
"""
Executes an SQL query and returns the result.
Args:
sql (str): The SQL query to execute.
db_name (str, optional): Target database name. Defaults to the configured default database.
max_rows (int, optional): Maximum number of rows to return. Defaults to 100.
timeout (int, optional): Query timeout in seconds. Defaults to 30.
Returns:
Dict[str, Any]: A dictionary containing the query result or an error.
"""
logger.info(f"MCP Tool Call: mcp_doris_exec_query, SQL: {sql}, DB: {db_name}, MaxRows: {max_rows}, Timeout: {timeout}")
try:
if not sql:
return _format_response(success=False, error="SQL statement not provided", message="Please provide the SQL statement to execute")
# Build parameters to pass to execute_sql_query
exec_ctx = {
"params": {
"sql": sql,
"db_name": db_name,
"max_rows": max_rows,
"timeout": timeout
}
}
# Directly call execute_sql_query to execute the query
exec_result = await execute_sql_query(exec_ctx)
# The format returned by execute_sql_query is {'content': [{'type': 'text', 'text': json_string}]}
# Need to parse the internal JSON string
if exec_result and 'content' in exec_result and len(exec_result['content']) > 0 and 'text' in exec_result['content'][0]:
try:
# Parse JSON string
result_data = json.loads(exec_result['content'][0]['text'])
# Directly return the parsed result obtained from execute_sql_query
# This result is already in the format {"success": ..., "data": ..., "columns": ...} or {"success": false, "error": ...}
# _format_response would wrap it again, but here we directly use the parsed data
# Note: This changes the original return structure of this function; it now directly returns the output of sql_executor
# If the _format_response wrapper needs to be maintained, the code below needs adjustment
return {
"content": [
{
"type": "text",
"text": json.dumps(result_data, ensure_ascii=False, default=str)
}
]
}
except json.JSONDecodeError as json_err:
logger.error(f"Failed to parse execute_sql_query result: {json_err}")
return _format_response(success=False, error=str(json_err), message="Error parsing SQL execution result")
except Exception as parse_err:
logger.error(f"Unexpected error occurred while processing execute_sql_query result: {parse_err}", exc_info=True)
return _format_response(success=False, error=str(parse_err), message="Unknown error occurred while processing SQL execution result")
else:
logger.error(f"execute_sql_query returned an unexpected format: {exec_result}")
return _format_response(success=False, error="SQL executor returned invalid format", message="Internal error executing SQL query")
except Exception as e:
logger.error(f"MCP tool execution failed mcp_doris_exec_query: {str(e)}", exc_info=True)
return _format_response(success=False, error=str(e), message="Error executing SQL query")
async def mcp_doris_get_table_schema(table_name: str, db_name: str = None) -> Dict[str, Any]:
logger.info(f"MCP Tool Call: mcp_doris_get_table_schema, Table: {table_name}, DB: {db_name}")
if not table_name:
return _format_response(success=False, error="Missing table_name parameter")
try:
extractor = MetadataExtractor(db_name=db_name)
schema = extractor.get_table_schema(table_name=table_name, db_name=db_name)
if not schema:
return _format_response(success=False, error="Table not found or has no columns", message=f"Could not get schema for table {db_name or extractor.db_name}.{table_name}")
return _format_response(success=True, result=schema)
except Exception as e:
logger.error(f"MCP tool execution failed mcp_doris_get_table_schema: {str(e)}", exc_info=True)
return _format_response(success=False, error=str(e), message="Error getting table schema")
async def mcp_doris_get_db_table_list(db_name: str = None) -> Dict[str, Any]:
logger.info(f"MCP Tool Call: mcp_doris_get_db_table_list, DB: {db_name}")
try:
extractor = MetadataExtractor(db_name=db_name)
tables = extractor.get_database_tables(db_name=db_name)
return _format_response(success=True, result=tables)
except Exception as e:
logger.error(f"MCP tool execution failed mcp_doris_get_db_table_list: {str(e)}", exc_info=True)
return _format_response(success=False, error=str(e), message="Error getting database table list")
async def mcp_doris_get_db_list() -> Dict[str, Any]:
logger.info(f"MCP Tool Call: mcp_doris_get_db_list")
try:
extractor = MetadataExtractor()
databases = extractor.get_all_databases()
return _format_response(success=True, result=databases)
except Exception as e:
logger.error(f"MCP tool execution failed mcp_doris_get_db_list: {str(e)}", exc_info=True)
return _format_response(success=False, error=str(e), message="Error getting database list")
async def mcp_doris_get_table_comment(table_name: str, db_name: str = None) -> Dict[str, Any]:
logger.info(f"MCP Tool Call: mcp_doris_get_table_comment, Table: {table_name}, DB: {db_name}")
if not table_name:
return _format_response(success=False, error="Missing table_name parameter")
try:
extractor = MetadataExtractor(db_name=db_name)
comment = extractor.get_table_comment(table_name=table_name, db_name=db_name)
return _format_response(success=True, result=comment)
except Exception as e:
logger.error(f"MCP tool execution failed mcp_doris_get_table_comment: {str(e)}", exc_info=True)
return _format_response(success=False, error=str(e), message="Error getting table comment")
async def mcp_doris_get_table_column_comments(table_name: str, db_name: str = None) -> Dict[str, Any]:
logger.info(f"MCP Tool Call: mcp_doris_get_table_column_comments, Table: {table_name}, DB: {db_name}")
if not table_name:
return _format_response(success=False, error="Missing table_name parameter")
try:
extractor = MetadataExtractor(db_name=db_name)
comments = extractor.get_column_comments(table_name=table_name, db_name=db_name)
return _format_response(success=True, result=comments)
except Exception as e:
logger.error(f"MCP tool execution failed mcp_doris_get_table_column_comments: {str(e)}", exc_info=True)
return _format_response(success=False, error=str(e), message="Error getting column comments")
async def mcp_doris_get_table_indexes(table_name: str, db_name: str = None) -> Dict[str, Any]:
logger.info(f"MCP Tool Call: mcp_doris_get_table_indexes, Table: {table_name}, DB: {db_name}")
if not table_name:
return _format_response(success=False, error="Missing table_name parameter")
try:
extractor = MetadataExtractor(db_name=db_name)
indexes = extractor.get_table_indexes(table_name=table_name, db_name=db_name)
return _format_response(success=True, result=indexes)
except Exception as e:
logger.error(f"MCP tool execution failed mcp_doris_get_table_indexes: {str(e)}", exc_info=True)
return _format_response(success=False, error=str(e), message="Error getting table indexes")
async def mcp_doris_get_recent_audit_logs(days: int = 7, limit: int = 100) -> Dict[str, Any]:
logger.info(f"MCP Tool Call: mcp_doris_get_recent_audit_logs, Days: {days}, Limit: {limit}")
try:
extractor = MetadataExtractor()
logs_df = extractor.get_recent_audit_logs(days=days, limit=limit)
return _format_response(success=True, result=logs_df)
except Exception as e:
logger.error(f"MCP tool execution failed mcp_doris_get_recent_audit_logs: {str(e)}", exc_info=True)
return _format_response(success=False, error=str(e), message="Error getting audit logs")

View File

@@ -0,0 +1,141 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Tool Initialization Module
Centralized initialization of all tools, ensuring they are correctly registered with MCP
"""
import logging
import os
from typing import List, Dict, Any, Optional
import json
from datetime import datetime
import traceback
# Import Context
from mcp.server.fastmcp import Context
# Import doris mcp tools
from doris_mcp_server.tools.mcp_doris_tools import (
mcp_doris_exec_query,
mcp_doris_get_table_schema,
mcp_doris_get_db_table_list,
mcp_doris_get_db_list,
mcp_doris_get_table_comment,
mcp_doris_get_table_column_comments,
mcp_doris_get_table_indexes,
mcp_doris_get_recent_audit_logs
)
# Get logger
logger = logging.getLogger("doris-mcp-tools-initializer")
async def register_mcp_tools(mcp):
"""Register MCP tool functions
Args:
mcp: FastMCP instance
"""
logger.info("Starting to register MCP tools...")
try:
# Register Tool: Execute SQL Query (Using long description string including parameters)
@mcp.tool("exec_query", description="""[Function Description]: Execute SQL query and return result command (executed by the client).\n
[Parameter Content]:\n
- random_string (string) [Required] - Unique identifier for the tool call\n
- sql (string) [Required] - SQL statement to execute\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n
- max_rows (integer) [Optional] - Maximum number of rows to return, default 100
- timeout (integer) [Optional] - Query timeout in seconds, default 30""")
async def exec_query_tool(sql: str, db_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]:
"""Wrapper: Execute SQL query and return result command"""
# Note: ctx parameter is no longer needed here as we receive named parameters directly
return await mcp_doris_exec_query(sql=sql, db_name=db_name, max_rows=max_rows, timeout=timeout)
# Register Tool: Get Table Schema (Keep long description string including parameters)
@mcp.tool("get_table_schema", description="""[Function Description]: Get detailed structure information of the specified table (columns, types, comments, etc.).\n
[Parameter Content]:\n
- random_string (string) [Required] - Unique identifier for the tool call\n
- table_name (string) [Required] - Name of the table to query\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n""")
async def get_table_schema_tool(table_name: str, db_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get table schema"""
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
return await mcp_doris_get_table_schema(table_name=table_name, db_name=db_name)
# Register Tool: Get Database Table List (Keep long description string including parameters)
@mcp.tool("get_db_table_list", description="""[Function Description]: Get a list of all table names in the specified database.\n
[Parameter Content]:\n
- random_string (string) [Required] - Unique identifier for the tool call\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n""")
async def get_db_table_list_tool(db_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get database table list"""
return await mcp_doris_get_db_table_list(db_name=db_name)
# Register Tool: Get Database List (Keep long description string including parameters)
# Note: Although the description mentions random_string, the wrapper function signature does not. See how mcp handles this.
@mcp.tool("get_db_list", description="""[Function Description]: Get a list of all database names on the server.\n
[Parameter Content]:\n
- random_string (string) [Required] - Unique identifier for the tool call\n""")
async def get_db_list_tool() -> Dict[str, Any]: # Function signature has no parameters
"""Wrapper: Get database list"""
return await mcp_doris_get_db_list()
# Register Tool: Get Table Comment (Keep long description string including parameters)
@mcp.tool("get_table_comment", description="""[Function Description]: Get the comment information for the specified table.\n
[Parameter Content]:\n
- random_string (string) [Required] - Unique identifier for the tool call\n
- table_name (string) [Required] - Name of the table to query\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n""")
async def get_table_comment_tool(table_name: str, db_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get table comment"""
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
return await mcp_doris_get_table_comment(table_name=table_name, db_name=db_name)
# Register Tool: Get Table Column Comments (Keep long description string including parameters)
@mcp.tool("get_table_column_comments", description="""[Function Description]: Get comment information for all columns in the specified table.\n
[Parameter Content]:\n
- random_string (string) [Required] - Unique identifier for the tool call\n
- table_name (string) [Required] - Name of the table to query\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n""")
async def get_table_column_comments_tool(table_name: str, db_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get table column comments"""
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
return await mcp_doris_get_table_column_comments(table_name=table_name, db_name=db_name)
# Register Tool: Get Table Indexes (Keep long description string including parameters)
@mcp.tool("get_table_indexes", description="""[Function Description]: Get index information for the specified table.\n
[Parameter Content]:\n
- random_string (string) [Required] - Unique identifier for the tool call\n
- table_name (string) [Required] - Name of the table to query\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n""")
async def get_table_indexes_tool(table_name: str, db_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get table indexes"""
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
return await mcp_doris_get_table_indexes(table_name=table_name, db_name=db_name)
# Register Tool: Get Recent Audit Logs (Keep long description string including parameters)
@mcp.tool("get_recent_audit_logs", description="""[Function Description]: Get audit log records for a recent period.\n
[Parameter Content]:\n
- random_string (string) [Required] - Unique identifier for the tool call\n
- days (integer) [Optional] - Number of recent days of logs to retrieve, default is 7\n
- limit (integer) [Optional] - Maximum number of records to return, default is 100\n""")
async def get_recent_audit_logs_tool(days: int = 7, limit: int = 100) -> Dict[str, Any]:
"""Wrapper: Get recent audit logs"""
try:
days = int(days)
limit = int(limit)
except (ValueError, TypeError):
return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "days and limit parameters must be integers"})}]}
return await mcp_doris_get_recent_audit_logs(days=days, limit=limit)
# Get tool count
tools_count = len(await mcp.list_tools())
logger.info(f"Registered all MCP tools, total {tools_count} tools")
return True
except Exception as e:
logger.error(f"Error registering MCP tools: {str(e)}")
logger.error(traceback.format_exc())
return False

View File

@@ -0,0 +1 @@
# Mark directory as a package

View File

@@ -0,0 +1,100 @@
import os
import json
import pymysql
import pandas as pd
from typing import Dict, List, Optional, Any
from dotenv import load_dotenv
import re
# Load environment variables
load_dotenv(override=True)
# Database configuration
DB_CONFIG = {
"host": os.getenv("DB_HOST", "localhost"),
"port": int(os.getenv("DB_PORT", "9030")),
"user": os.getenv("DB_USER", "root"),
"password": os.getenv("DB_PASSWORD", ""),
"database": os.getenv("DB_DATABASE", ""),
"charset": "utf8mb4",
"cursorclass": pymysql.cursors.DictCursor
}
def get_db_connection(db_name: Optional[str] = None):
"""
Get database connection
Args:
db_name: Specify the database name to connect to, use default config if None
Returns:
Database connection
"""
if db_name:
# Use default config but override database name
config = DB_CONFIG.copy()
config["database"] = db_name
return pymysql.connect(**config)
else:
# Use default config
return pymysql.connect(**DB_CONFIG)
def get_db_name() -> str:
"""Get the currently configured default database name"""
return DB_CONFIG["database"] or os.getenv("DB_DATABASE", "")
def execute_query(sql, db_name: Optional[str] = None):
"""
Execute SQL query and return results
Args:
sql: SQL query statement
db_name: Specify the database name to connect to, use default config if None
Returns:
Query results
"""
conn = get_db_connection(db_name)
try:
with conn.cursor() as cursor:
# Set connection character set to utf8 before executing query
cursor.execute("SET NAMES utf8")
# Execute the actual query
cursor.execute(sql)
result = cursor.fetchall()
return result
finally:
conn.close()
def execute_query_df(sql, db_name: Optional[str] = None):
"""
Execute SQL query and return pandas DataFrame
Args:
sql: SQL query statement
db_name: Specify the database name to connect to, use default config if None
Returns:
pandas DataFrame
"""
conn = get_db_connection(db_name)
try:
# Use a temporary cursor to execute the query and get results
with conn.cursor() as cursor:
# Set connection character set to utf8 before executing query
cursor.execute("SET NAMES utf8")
# Execute the actual query
cursor.execute(sql)
result = cursor.fetchall()
# If no results, return empty DataFrame
if not result:
return pd.DataFrame()
# Manually convert dict results to DataFrame
df = pd.DataFrame(result)
return df
finally:
conn.close()

View File

@@ -0,0 +1,226 @@
"""
Unified Logging Configuration Module
Provides unified logging configuration, including:
- General logs: Record all program execution information
- Audit logs: Record JSON data for key operations and processing results
- Error logs: Specifically record program exceptions and errors
"""
import os
import sys
import logging
import logging.handlers
from pathlib import Path
from typing import Dict
from datetime import datetime
from dotenv import load_dotenv
# Load environment variables
load_dotenv(override=True)
# Get project root directory
PROJECT_ROOT = Path(__file__).parents[2].absolute()
# Get log configuration from environment variables
LOG_DIR = os.getenv("LOG_DIR", str(PROJECT_ROOT / "logs"))
LOG_PREFIX = os.getenv("LOG_PREFIX", "doris_mcp")
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
LOG_MAX_DAYS = int(os.getenv("LOG_MAX_DAYS", "30"))
# Whether to output logs to the console (should be disabled when running as a service)
CONSOLE_LOGGING = os.getenv("CONSOLE_LOGGING", "false").lower() == "true"
# Whether stdio transport mode is being used
STDIO_MODE = os.getenv("MCP_TRANSPORT_TYPE", "").lower() == "stdio"
def purge_old_logs():
"""Clean up expired log files"""
# --- Only perform cleanup in non-Stdio mode ---
if STDIO_MODE:
return
try:
now = datetime.now()
log_dir = Path(LOG_DIR)
# Check if directory exists and is readable/writable
if not log_dir.is_dir() or not os.access(LOG_DIR, os.W_OK):
if not STDIO_MODE: # Avoid printing to stdout in stdio mode
print(f"Warning: Log directory {LOG_DIR} not accessible, skipping log purge.", file=sys.stderr)
return
for log_file in log_dir.glob(f"{LOG_PREFIX}*.20*"):
# Parse date
file_name = log_file.name
date_str = None
# Try to find the date part
parts = file_name.split('.')
for part in parts:
if part.startswith('20') and len(part) == 8: # 20YYMMDD format
date_str = part
break
if date_str:
try:
file_date = datetime.strptime(date_str, '%Y%m%d')
days_old = (now - file_date).days
if days_old > LOG_MAX_DAYS:
os.remove(log_file)
if not STDIO_MODE:
print(f"Deleted expired log file: {log_file}")
except (ValueError, OSError) as e:
if not STDIO_MODE:
print(f"Error processing log file {file_name}: {e}", file=sys.stderr)
except Exception as e:
if not STDIO_MODE:
print(f"Error cleaning up logs: {e}", file=sys.stderr)
# Force disable console log output if in stdio mode
if STDIO_MODE:
CONSOLE_LOGGING = False
# --- Only create log directory and clean old logs in non-Stdio mode ---
if not STDIO_MODE:
try:
os.makedirs(LOG_DIR, exist_ok=True)
# Clean up expired logs on startup (also moved here, as it only handles file logs)
purge_old_logs()
except OSError as e:
# If directory creation fails (e.g., permission issue), print warning but continue to avoid startup failure
print(f"Warning: Failed to create log directory {LOG_DIR} or purge logs: {e}", file=sys.stderr)
# Log file paths (definition still needed, but files might not be created/used)
LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.log")
AUDIT_LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.audit")
ERROR_LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.error")
# Log level mapping
LOG_LEVELS = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL
}
# Log format
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
AUDIT_FORMAT = '%(asctime)s - %(name)s - %(message)s'
ERROR_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(pathname)s:%(lineno)d - %(message)s'
# Dedicated audit log level
AUDIT = 25 # Level between INFO and WARNING
logging.addLevelName(AUDIT, "AUDIT")
# Logger object cache
_loggers: Dict[str, logging.Logger] = {}
# Handler type mapping, used to ensure no duplicates are added
_handler_types = {
'console': logging.StreamHandler,
'file': logging.handlers.TimedRotatingFileHandler,
'audit': logging.handlers.TimedRotatingFileHandler,
'error': logging.handlers.TimedRotatingFileHandler
}
def get_logger(name: str) -> logging.Logger:
"""
Get a logger with the specified name
Args:
name: Logger name
Returns:
logging.Logger: Configured logger
"""
if name in _loggers:
return _loggers[name]
# Create logger
logger = logging.getLogger(name)
logger.setLevel(LOG_LEVELS.get(LOG_LEVEL, logging.INFO))
# Avoid duplicate logs caused by propagation
logger.propagate = False
# Check if handlers already exist to avoid duplicates
handler_types = set(type(h) for h in logger.handlers)
# Add audit log method
def audit(self, message, *args, **kwargs):
self.log(AUDIT, message, *args, **kwargs)
logger.audit = audit.__get__(logger)
# General log handler - output to console (only if enabled)
if CONSOLE_LOGGING and _handler_types['console'] not in handler_types:
# Use stderr instead of stdout to avoid conflicts with MCP communication
console_handler = logging.StreamHandler(sys.stderr)
console_handler.setFormatter(logging.Formatter(LOG_FORMAT))
logger.addHandler(console_handler)
# --- Only add file handlers in non-Stdio mode ---
if not STDIO_MODE:
# General log handler - daily rotating file
if _handler_types['file'] not in handler_types:
try: # Add try-except block
file_handler = logging.handlers.TimedRotatingFileHandler(
LOG_FILE,
when='midnight',
interval=1,
backupCount=LOG_MAX_DAYS,
encoding='utf-8'
)
file_handler.setFormatter(logging.Formatter(LOG_FORMAT))
file_handler.suffix = "%Y%m%d"
logger.addHandler(file_handler)
except OSError as e:
print(f"Warning: Failed to add file log handler for {LOG_FILE}: {e}", file=sys.stderr)
# Audit log handler - only logs AUDIT level
if _handler_types['audit'] not in handler_types:
try: # Add try-except block
audit_handler = logging.handlers.TimedRotatingFileHandler(
AUDIT_LOG_FILE,
when='midnight',
interval=1,
backupCount=LOG_MAX_DAYS,
encoding='utf-8'
)
audit_handler.setFormatter(logging.Formatter(AUDIT_FORMAT))
audit_handler.suffix = "%Y%m%d"
audit_handler.setLevel(AUDIT)
audit_handler.addFilter(lambda record: record.levelno == AUDIT)
logger.addHandler(audit_handler)
except OSError as e:
print(f"Warning: Failed to add audit log handler for {AUDIT_LOG_FILE}: {e}", file=sys.stderr)
# Error log handler - only logs ERROR level and above
if _handler_types['error'] not in handler_types:
try: # Add try-except block
error_handler = logging.handlers.TimedRotatingFileHandler(
ERROR_LOG_FILE,
when='midnight',
interval=1,
backupCount=LOG_MAX_DAYS,
encoding='utf-8'
)
error_handler.setFormatter(logging.Formatter(ERROR_FORMAT))
error_handler.suffix = "%Y%m%d"
error_handler.setLevel(logging.ERROR)
logger.addHandler(error_handler)
except OSError as e:
print(f"Warning: Failed to add error log handler for {ERROR_LOG_FILE}: {e}", file=sys.stderr)
# Cache logger
_loggers[name] = logger
return logger
# Default logger
logger = get_logger('doris_mcp')
# Audit logger - for recording processing results, business operations, etc.
audit_logger = get_logger('audit')
# Call to clean logs moved after directory creation, and added non-stdio check

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,349 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
SQL Execution Tool
Responsible for executing SQL queries and handling results
"""
import os
import json
import logging
import traceback
import time
from typing import Dict, Any
import re
import datetime
from decimal import Decimal
# Get logger
logger = logging.getLogger("doris-mcp.sql-executor")
# Add environment variable control for whether to perform SQL security checks
ENABLE_SQL_SECURITY_CHECK = os.environ.get('ENABLE_SQL_SECURITY_CHECK', 'true').lower() == 'true'
async def execute_sql_query(ctx) -> Dict[str, Any]:
"""
Execute SQL query and return results
Args:
ctx: Context object or dictionary containing request parameters
Returns:
Dict[str, Any]: Execution result
"""
try:
# Support the case where the passed argument is a dictionary
if isinstance(ctx, dict) and 'params' in ctx:
params = ctx['params']
else:
params = ctx.params
sql = params.get("sql")
db_name = params.get("db_name", os.getenv("DB_DATABASE", ""))
max_rows = params.get("max_rows", 1000) # Maximum number of rows to return
timeout = params.get("timeout", 30) # Timeout in seconds
if not sql:
return {
"content": [
{
"type": "text",
"text": json.dumps({
"success": False,
"error": "Missing SQL parameter",
"message": "Please provide the SQL query to execute"
}, ensure_ascii=False)
}
]
}
# First check SQL security
security_result = await _check_sql_security(sql)
if not security_result.get("is_safe", False):
return {
"content": [
{
"type": "text",
"text": json.dumps({
"success": False,
"error": "SQL security check failed",
"message": "Query contains unsafe operations and cannot be executed",
"security_issues": security_result.get("security_issues", [])
}, ensure_ascii=False)
}
]
}
# Import database connection tool
from doris_mcp_server.utils.db import execute_query
if not sql:
return {
"content": [
{
"type": "text",
"text": json.dumps({
"success": False,
"error": "Missing SQL parameter",
"message": "Please provide the SQL query to execute"
}, ensure_ascii=False)
}
]
}
# Ensure SELECT statements include a LIMIT clause
sql_lower = sql.lower().strip()
if sql_lower.startswith("select") and "limit" not in sql_lower:
sql = sql.rstrip(";") + f" LIMIT {max_rows};"
# Start timer
start_time = time.time()
# Execute query
try:
result = execute_query(sql, db_name)
# Calculate execution time
execution_time = time.time() - start_time
# Build return result
if isinstance(result, list):
# Handle list of query results
row_count = len(result)
# Extract column names
if hasattr(result[0], "_fields"):
# If it's a named tuple
columns = list(result[0]._fields)
else:
# Otherwise, assume it's a dictionary
columns = list(result[0].keys()) if isinstance(result[0], dict) else []
# Convert results to serializable format
data = []
for row in result:
row_dict = {}
if hasattr(row, "_asdict"):
# If it's a named tuple
row_dict = row._asdict()
elif isinstance(row, dict):
# If it's a dictionary
row_dict = row
else:
# If it's a list or tuple
row_dict = dict(zip(columns, row)) if columns else row
# Handle special types to make them JSON serializable
serialized_row = _serialize_row_data(row_dict)
data.append(serialized_row)
return {
"content": [
{
"type": "text",
"text": json.dumps({
"success": True,
"sql": sql,
"row_count": row_count,
"columns": columns,
"data": data[:max_rows], # Limit returned rows
"execution_time": execution_time,
"truncated": row_count > max_rows
}, ensure_ascii=False)
}
]
}
else:
# Handle other types of results
other_response = {
"success": True,
"sql": sql,
"result": str(result),
"execution_time": execution_time
}
other_response = _serialize_row_data(other_response)
return {
"content": [
{
"type": "text",
"text": json.dumps(other_response, ensure_ascii=False)
}
]
}
except Exception as db_error:
error_message = str(db_error)
# Try to get more detailed error information
error_details = {}
if "timeout" in error_message.lower():
error_details["type"] = "timeout"
error_details["suggestion"] = "Query timed out, please optimize SQL or increase timeout"
elif "syntax" in error_message.lower():
error_details["type"] = "syntax"
error_details["suggestion"] = "SQL syntax error, please check syntax"
elif "not found" in error_message.lower() or "doesn't exist" in error_message.lower():
error_details["type"] = "not_found"
error_details["suggestion"] = "Table or column not found, please check table and column names"
else:
error_details["type"] = "unknown"
error_details["suggestion"] = "Please check the SQL statement and try simplifying the query"
# Create error response
error_response = {
"success": False,
"error": error_message,
"error_details": error_details,
"sql": sql,
"db_name": db_name
}
# Ensure error response is also serializable
error_response = _serialize_row_data(error_response)
return {
"content": [
{
"type": "text",
"text": json.dumps(error_response, ensure_ascii=False)
}
]
}
except Exception as e:
logger.error(f"Failed to execute SQL query: {str(e)}")
logger.error(traceback.format_exc())
error_response = {
"success": False,
"error": str(e),
"message": "Error occurred while executing SQL query"
}
# Ensure error response is also serializable
error_response = _serialize_row_data(error_response)
return {
"content": [
{
"type": "text",
"text": json.dumps(error_response, ensure_ascii=False)
}
]
}
# Helper function
async def _check_sql_security(sql: str) -> Dict[str, Any]:
"""Check SQL security"""
# If environment variable is set to disable security check, return safe immediately
if not ENABLE_SQL_SECURITY_CHECK:
return {
"is_safe": True,
"security_issues": []
}
# Check if SQL contains dangerous operations
sql_lower = sql.lower()
# Check if it's a read-only query type
is_read_only = sql_lower.strip().startswith(("select ", "show ", "desc ", "describe ", "explain "))
# Define list of dangerous operations (checked for both read-only and non-read-only queries)
dangerous_operations = [
(r'\bdelete\b', "DELETE operation"),
(r'\bdrop\b', "DROP TABLE/DATABASE operation"),
(r'\btruncate\b', "TRUNCATE TABLE operation"),
(r'\bupdate\b', "UPDATE operation"),
(r'\binsert\b', "INSERT operation"),
(r'\balter\b', "ALTER TABLE structure operation"),
(r'\bcreate\b', "CREATE TABLE/DATABASE operation"),
(r'\bgrant\b', "GRANT operation"),
(r'\brevoke\b', "REVOKE permission operation"),
(r'\bexec\b', "EXECUTE stored procedure"),
(r'\bxp_', "Extended stored procedure, potential security risk"),
(r'\bshutdown\b', "SHUTDOWN database operation"),
(r'\bunion\s+all\s+select\b', "UNION statement, potential SQL injection"),
(r'\bunion\s+select\b', "UNION statement, potential SQL injection"),
(r'\binto\s+outfile\b', "Write to file operation"),
(r'\bload_file\b', "Load file operation")
]
# Dangerous operations checked only for non-read-only queries
non_readonly_operations = []
if not is_read_only:
non_readonly_operations = [
(r'--', "SQL comment, potential SQL injection"),
(r'/\*', "SQL block comment, potential SQL injection")
]
# Check if dangerous operations are included
security_issues = []
# Check dangerous operations applicable to all queries
for operation, description in dangerous_operations:
if re.search(operation, sql_lower):
# For specific keywords in read-only queries, differentiate if used as independent operations
if is_read_only and operation in [r'\bcreate\b', r'\bdrop\b', r'\bdelete\b', r'\binsert\b', r'\bupdate\b', r'\balter\b']:
# Check if used as DDL/DML keyword, e.g., CREATE TABLE, DROP DATABASE
pattern = operation + r'\s+(?:table|database|view|index|procedure|function|trigger|event)'
if re.search(pattern, sql_lower):
security_issues.append({
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
"description": description,
"severity": "High"
})
else:
security_issues.append({
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
"description": description,
"severity": "High"
})
# Check dangerous operations specific to non-read-only queries
for operation, description in non_readonly_operations:
if re.search(operation, sql_lower):
security_issues.append({
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
"description": description,
"severity": "Medium"
})
return {
"is_safe": len(security_issues) == 0,
"security_issues": security_issues
}
def _serialize_row_data(row_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert special types in row data (like date, time, Decimal) to JSON serializable format
Args:
row_data: Row data dictionary
Returns:
Dict[str, Any]: Processed serializable dictionary
"""
serialized_data = {}
for key, value in row_data.items():
if value is None:
serialized_data[key] = None
elif isinstance(value, (datetime.date, datetime.datetime)):
# Convert date and time types to ISO format string
serialized_data[key] = value.isoformat()
elif isinstance(value, Decimal):
# Convert Decimal type to float
serialized_data[key] = float(value)
elif isinstance(value, (list, tuple)):
# Recursively process elements in list or tuple
serialized_data[key] = [
_serialize_row_data(item) if isinstance(item, dict) else item
for item in value
]
elif isinstance(value, dict):
# Recursively process nested dictionaries
serialized_data[key] = _serialize_row_data(value)
else:
serialized_data[key] = value
return serialized_data