init doris mcp 0.2.0
This commit is contained in:
1
doris_mcp_server/__init__.py
Normal file
1
doris_mcp_server/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Mark directory as a package
|
||||
33
doris_mcp_server/config.py
Normal file
33
doris_mcp_server/config.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# doris_mcp_server/config.py
|
||||
import os
|
||||
import logging
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Get Log Level from environment variable, default to 'info'
|
||||
LOG_LEVEL_STR = os.getenv('LOG_LEVEL', 'info').upper()
|
||||
|
||||
# Map string level to logging level constant
|
||||
LOG_LEVEL_MAP = {
|
||||
'DEBUG': logging.DEBUG,
|
||||
'INFO': logging.INFO,
|
||||
'WARNING': logging.WARNING,
|
||||
'ERROR': logging.ERROR,
|
||||
'CRITICAL': logging.CRITICAL
|
||||
}
|
||||
LOG_LEVEL = LOG_LEVEL_MAP.get(LOG_LEVEL_STR, logging.INFO)
|
||||
|
||||
# Function to load config (can be expanded later if needed)
|
||||
def load_config():
|
||||
"""Loads configuration settings."""
|
||||
# Currently, configuration is mainly handled by environment variables
|
||||
# and constants defined in this module.
|
||||
# This function can be used to perform additional setup if required.
|
||||
logging.getLogger(__name__).info("Configuration loaded (mainly from environment variables).")
|
||||
|
||||
# You can add other configuration constants here if needed
|
||||
# Example: DB_HOST = os.getenv("DB_HOST", "localhost")
|
||||
# But often it's better to access os.getenv directly where needed
|
||||
# or pass config dictionaries around.
|
||||
196
doris_mcp_server/main.py
Normal file
196
doris_mcp_server/main.py
Normal file
@@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Apache Doris MCP Server Main Entry - Primarily handles SSE mode
|
||||
|
||||
Stdio mode is handled by doris_mcp_server.mcp_core:run_stdio.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any
|
||||
import uvicorn
|
||||
from uvicorn import Config, Server
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Add project root to path
|
||||
PROJECT_ROOT = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
# SSE related imports
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from doris_mcp_server.sse_server import DorisMCPSseServer
|
||||
from doris_mcp_server.streamable_server import DorisMCPStreamableServer
|
||||
|
||||
# Stdio related imports (only needed for tools now, maybe move tool init?)
|
||||
# from mcp.server.stdio import stdio_server -> No longer used here
|
||||
|
||||
# Config and Tool Initializer
|
||||
from doris_mcp_server.config import load_config # LOG_LEVEL might not be needed here directly
|
||||
from doris_mcp_server.tools.tool_initializer import register_mcp_tools
|
||||
|
||||
# Load environment variables (load early for all modes)
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger("doris-mcp-main") # Changed logger name slightly
|
||||
|
||||
# --- Configuration Loading and Logging Setup ---
|
||||
load_config() # Loads .env
|
||||
|
||||
# --- Create FastAPI App (Global Scope for SSE Mode) ---
|
||||
# This 'app' object is targeted by 'mcp run doris_mcp_server/main.py:app --transport sse'
|
||||
# And used when running directly with --sse
|
||||
app = FastAPI(
|
||||
title="Doris MCP Server (SSE Mode)",
|
||||
# Lifespan will be added in start_sse_server
|
||||
)
|
||||
|
||||
# --- Removed StdioServerWrapper ---
|
||||
|
||||
# --- Command Line Argument Parsing ---
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Apache Doris MCP Server (SSE Mode Entry)")
|
||||
# Only keep SSE related args here
|
||||
parser.add_argument('--sse', action='store_true', help='Start SSE Web server mode (required)')
|
||||
parser.add_argument('--host', type=str, default=os.getenv('SERVER_HOST', '0.0.0.0'), help='Host address')
|
||||
parser.add_argument('--port', type=int, default=int(os.getenv('SERVER_PORT', os.getenv('MCP_PORT', '3000'))), help='Port number')
|
||||
parser.add_argument('--debug', action='store_true', help='Enable debug mode')
|
||||
parser.add_argument('--reload', action='store_true', help='Enable auto-reload')
|
||||
return parser.parse_args()
|
||||
|
||||
# --- SSE Mode Specific Code ---
|
||||
@dataclass
|
||||
class AppContext:
|
||||
config: Dict[str, Any]
|
||||
|
||||
@asynccontextmanager
|
||||
async def app_lifespan(app_instance: FastAPI) -> AsyncIterator[None]:
|
||||
logger.info("SSE application lifecycle start...")
|
||||
config = {
|
||||
# Simplified config - maybe get from elsewhere?
|
||||
"db_host": os.getenv("DB_HOST", "localhost"),
|
||||
"db_port": int(os.getenv("DB_PORT", "9030")),
|
||||
"db_user": os.getenv("DB_USER", "root"),
|
||||
"db_password": os.getenv("DB_PASSWORD", ""),
|
||||
"db_database": os.getenv("DB_DATABASE", "test"),
|
||||
}
|
||||
app_instance.state.config = config
|
||||
try:
|
||||
# Yield None implicitly or explicitly None
|
||||
yield
|
||||
finally:
|
||||
logger.info("Cleaning up SSE application resources...")
|
||||
|
||||
async def start_sse_server(args):
|
||||
"""Start SSE Web server mode (Configures the global 'app')"""
|
||||
logger.info("Starting SSE Web server mode...")
|
||||
global app
|
||||
|
||||
# --- Initialize MCP and Tools for SSE ---
|
||||
# Create a *separate* MCP instance for SSE mode
|
||||
sse_mcp = FastMCP(
|
||||
name="doris-mcp-sse",
|
||||
description="Apache Doris MCP Server (SSE)",
|
||||
lifespan=None, # Managed by FastAPI
|
||||
dependencies=["fastapi", "uvicorn", "openai", "sse_starlette"]
|
||||
)
|
||||
logger.info("Registering MCP tools for SSE mode...")
|
||||
await register_mcp_tools(sse_mcp) # Register tools for the SSE instance
|
||||
logger.info("MCP tools registered for SSE.")
|
||||
|
||||
# --- Configure Lifespan and CORS for the global app ---
|
||||
app.router.lifespan_context = app_lifespan
|
||||
origins = os.getenv("ALLOWED_ORIGINS", "*").split(",")
|
||||
allow_credentials = os.getenv("MCP_ALLOW_CREDENTIALS", "false").lower() == "true"
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=allow_credentials,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
expose_headers=["Mcp-Session-Id"],
|
||||
)
|
||||
|
||||
# --- Initialize Handlers and Register Routes (Pass sse_mcp instance) ---
|
||||
logger.info("Initializing SSE server handlers and registering routes...")
|
||||
sse_server_handler = DorisMCPSseServer(sse_mcp, app)
|
||||
streamable_server_handler = DorisMCPStreamableServer(sse_mcp, app)
|
||||
logger.info("SSE Server handlers initialized and routes registered.")
|
||||
|
||||
# --- Print Configuration and Endpoints ---
|
||||
print("--- SSE Mode Configuration ---")
|
||||
print(f"Server Host: {args.host}")
|
||||
print(f"Server Port: {args.port}")
|
||||
print(f"Allowed Origins: {origins}")
|
||||
print(f"Allow Credentials: {allow_credentials}")
|
||||
print(f"Log Level: {os.getenv('LOG_LEVEL', 'info')}")
|
||||
print(f"Debug Mode: {args.debug}")
|
||||
print(f"Reload Mode: {args.reload}")
|
||||
print(f"DB Host: {os.getenv('DB_HOST')}")
|
||||
print(f"DB Port: {os.getenv('DB_PORT')}")
|
||||
print(f"DB User: {os.getenv('DB_USER')}")
|
||||
print(f"DB Database: {os.getenv('DB_DATABASE')}")
|
||||
print(f"Force Refresh Metadata: {os.getenv('FORCE_REFRESH_METADATA', 'false')}")
|
||||
print("------------------------------")
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
print(f"Service running at: {base_url}")
|
||||
print(f" Health Check: GET {base_url}/health")
|
||||
print(f" Status Check: GET {base_url}/status")
|
||||
print(f" SSE Init: GET {base_url}/sse")
|
||||
print(f" SSE/Legacy Messages: POST {base_url}/mcp/messages")
|
||||
print(f" Streamable HTTP: GET/POST/DELETE/OPTIONS {base_url}/mcp")
|
||||
print("------------------------------")
|
||||
print("Use Ctrl+C to stop the service")
|
||||
|
||||
# --- Start Uvicorn Server ---
|
||||
config = Config(
|
||||
app=app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="debug" if args.debug else "info",
|
||||
reload=args.reload
|
||||
)
|
||||
server = Server(config=config)
|
||||
await server.serve()
|
||||
|
||||
# --- Main Execution Logic (Simplified) ---
|
||||
|
||||
def run_main_sync():
|
||||
"""Synchronous wrapper, primarily for SSE mode now."""
|
||||
sync_logger = logging.getLogger("run_main_sync")
|
||||
sync_logger.info("Entering run_main_sync (SSE focus)...")
|
||||
print("DEBUG: Entering run_main_sync (SSE focus)...", file=sys.stderr, flush=True)
|
||||
args = parse_args()
|
||||
|
||||
if args.sse:
|
||||
try:
|
||||
# Run the async SSE server setup and Uvicorn loop
|
||||
asyncio.run(start_sse_server(args))
|
||||
sync_logger.info("asyncio.run(start_sse_server) completed.")
|
||||
print("DEBUG: asyncio.run(start_sse_server) completed.", file=sys.stderr, flush=True)
|
||||
except KeyboardInterrupt:
|
||||
sync_logger.info("SSE server stopped by KeyboardInterrupt.")
|
||||
except Exception as e:
|
||||
sync_logger.critical(f"Error during asyncio.run(start_sse_server): {e}", exc_info=True)
|
||||
print(f"DEBUG: Error during asyncio.run(start_sse_server): {e}", file=sys.stderr, flush=True)
|
||||
raise
|
||||
else:
|
||||
# If run without --sse, print help/error
|
||||
message = "Error: This entry point requires --sse flag. For stdio mode, use 'uv run mcp-doris' or the appropriate command for your stdio setup."
|
||||
sync_logger.error(message)
|
||||
print(message, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_main_sync()
|
||||
143
doris_mcp_server/mcp_core.py
Normal file
143
doris_mcp_server/mcp_core.py
Normal file
@@ -0,0 +1,143 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Core MCP instance and startup logic for stdio mode.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
|
||||
# Import necessary components from mcp and our project
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
logger = logging.getLogger("doris-mcp-core")
|
||||
|
||||
# --- Global MCP Instance for Stdio ---
|
||||
# Create the instance when the module is imported.
|
||||
# Tools will be registered synchronously(?) before running.
|
||||
stdio_mcp = FastMCP(
|
||||
name="doris-mcp-stdio-core",
|
||||
description="Apache Doris MCP Server (stdio via core)",
|
||||
)
|
||||
|
||||
# --- Removed async setup functions ---
|
||||
def run_stdio():
|
||||
"""
|
||||
Synchronous entry point for running the stdio server.
|
||||
Mimics the mcp-doris example by calling .run() on the instance.
|
||||
Handles tool registration beforehand.
|
||||
"""
|
||||
logger.info("Executing run_stdio (synchronous entry point)...")
|
||||
|
||||
# --- Run the stdio server using the instance's run() method ---
|
||||
logger.info("Calling stdio_mcp.run()...")
|
||||
try:
|
||||
# Assuming stdio_mcp has a synchronous run() method for stdio
|
||||
stdio_mcp.run()
|
||||
logger.info("stdio_mcp.run() completed.")
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Stdio server stopped by KeyboardInterrupt.")
|
||||
except AttributeError:
|
||||
logger.critical("Error: stdio_mcp object does not have a '.run()' method suitable for stdio.", exc_info=False)
|
||||
print("ERROR: stdio_mcp object does not have a '.run()' method.", file=sys.stderr, flush=True)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.critical(f"run_stdio encountered an error during stdio_mcp.run(): {e}", exc_info=True)
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Register Tool: Execute SQL Query
|
||||
@stdio_mcp.tool("exec_query", description="""[Function Description]: Execute SQL query and return result command (executed by the client).\n
|
||||
[Parameter Content]:\n
|
||||
- sql (string) [Required] - SQL statement to execute\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- max_rows (integer) [Optional] - Maximum number of rows to return, default 100\n
|
||||
- timeout (integer) [Optional] - Query timeout in seconds, default 30\n""")
|
||||
async def exec_query_tool(sql: str, db_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""Wrapper: Execute SQL query and return result command"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_exec_query
|
||||
return await mcp_doris_exec_query(sql=sql, db_name=db_name, max_rows=max_rows, timeout=timeout)
|
||||
|
||||
# Register Tool: Get Table Schema
|
||||
@stdio_mcp.tool("get_table_schema", description="""[Function Description]: Get detailed structure information of the specified table (columns, types, comments, etc.).\n
|
||||
[Parameter Content]:\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n""")
|
||||
async def get_table_schema_tool(table_name: str, db_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table schema"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_schema
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_schema(table_name=table_name, db_name=db_name)
|
||||
|
||||
# Register Tool: Get Database Table List
|
||||
@stdio_mcp.tool("get_db_table_list", description="""[Function Description]: Get a list of all table names in the specified database.\n
|
||||
[Parameter Content]:\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n""")
|
||||
async def get_db_table_list_tool(db_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get database table list"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_db_table_list
|
||||
return await mcp_doris_get_db_table_list(db_name=db_name)
|
||||
|
||||
# Register Tool: Get Database List
|
||||
@stdio_mcp.tool("get_db_list", description="""[Function Description]: Get a list of all database names on the server.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n""")
|
||||
async def get_db_list_tool() -> Dict[str, Any]:
|
||||
"""Wrapper: Get database list"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_db_list
|
||||
return await mcp_doris_get_db_list()
|
||||
|
||||
# Register Tool: Get Table Comment
|
||||
@stdio_mcp.tool("get_table_comment", description="""[Function Description]: Get the comment information for the specified table.\n
|
||||
[Parameter Content]:\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n""")
|
||||
async def get_table_comment_tool(table_name: str, db_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table comment"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_comment
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_comment(table_name=table_name, db_name=db_name)
|
||||
|
||||
# Register Tool: Get Table Column Comments
|
||||
@stdio_mcp.tool("get_table_column_comments", description="""[Function Description]: Get comment information for all columns in the specified table.\n
|
||||
[Parameter Content]:\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n""")
|
||||
async def get_table_column_comments_tool(table_name: str, db_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table column comments"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_column_comments
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_column_comments(table_name=table_name, db_name=db_name)
|
||||
|
||||
# Register Tool: Get Table Indexes
|
||||
@stdio_mcp.tool("get_table_indexes", description="""[Function Description]: Get index information for the specified table.
|
||||
[Parameter Content]:\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n""")
|
||||
async def get_table_indexes_tool(table_name: str, db_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table indexes"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_indexes
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_indexes(table_name=table_name, db_name=db_name)
|
||||
|
||||
# Register Tool: Get Recent Audit Logs
|
||||
@stdio_mcp.tool("get_recent_audit_logs", description="""[Function Description]: Get audit log records for a recent period.\n
|
||||
[Parameter Content]:\n
|
||||
- days (integer) [Optional] - Number of recent days of logs to retrieve, default is 7\n
|
||||
- limit (integer) [Optional] - Maximum number of records to return, default is 100\n""")
|
||||
async def get_recent_audit_logs_tool(days: int = 7, limit: int = 100) -> Dict[str, Any]:
|
||||
"""Wrapper: Get recent audit logs"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_recent_audit_logs
|
||||
try:
|
||||
days = int(days)
|
||||
limit = int(limit)
|
||||
except (ValueError, TypeError):
|
||||
return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "days and limit parameters must be integers"})}]}
|
||||
return await mcp_doris_get_recent_audit_logs(days=days, limit=limit)
|
||||
|
||||
# --- Register Tools ---
|
||||
1259
doris_mcp_server/sse_server.py
Normal file
1259
doris_mcp_server/sse_server.py
Normal file
File diff suppressed because it is too large
Load Diff
912
doris_mcp_server/streamable_server.py
Normal file
912
doris_mcp_server/streamable_server.py
Normal file
@@ -0,0 +1,912 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Doris MCP Streamable HTTP Server Implementation
|
||||
|
||||
Implements the MCP 2025-03-26 Streamable HTTP specification.
|
||||
Uses a unified /mcp endpoint for GET, POST, DELETE, OPTIONS.
|
||||
Manages sessions using Mcp-Session-Id header.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional, Dict, List
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
# Use a distinct logger name
|
||||
logger = logging.getLogger("doris-mcp-streamable")
|
||||
|
||||
# Special marker for closing streams
|
||||
STREAM_END_MARKER = "__MCP_STREAM_END__"
|
||||
|
||||
class DorisMCPStreamableServer:
|
||||
"""Doris MCP Streamable HTTP Server"""
|
||||
|
||||
def __init__(self, mcp_server, app: FastAPI):
|
||||
"""
|
||||
Initializes the Doris MCP Streamable HTTP server.
|
||||
|
||||
Args:
|
||||
mcp_server: The shared FastMCP server instance.
|
||||
app: The main FastAPI application instance.
|
||||
"""
|
||||
self.mcp_server = mcp_server
|
||||
self.app = app # We'll add routes to this app
|
||||
|
||||
# Note: CORS middleware should be added only once in main.py usually.
|
||||
# If added here, ensure it doesn't conflict or duplicate.
|
||||
# For separation, we might let main.py handle CORS entirely.
|
||||
|
||||
# Client session management for Streamable HTTP clients
|
||||
# key: session_id (from Mcp-Session-Id header)
|
||||
# value: {
|
||||
# "created_at": timestamp,
|
||||
# "last_active": timestamp,
|
||||
# "request_queues": { request_id: asyncio.Queue }, # For POST /mcp request streams
|
||||
# "general_sse_queues": List[asyncio.Queue] # For GET /mcp server push streams
|
||||
# }
|
||||
self.client_sessions: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Setup the unified MCP endpoint
|
||||
self._setup_streamable_http_routes()
|
||||
|
||||
# Register session cleanup task if this instance manages lifespan independently
|
||||
# Usually, startup events are tied to the main app lifespan managed in main.py
|
||||
# We might not need @app.on_event("startup") here if main.py handles it.
|
||||
# Let's assume main.py handles the cleanup task initiation.
|
||||
|
||||
def _setup_streamable_http_routes(self):
|
||||
"""Sets up the unified /mcp endpoint for Streamable HTTP.
|
||||
Uses a distinct tag for API docs.
|
||||
"""
|
||||
|
||||
@self.app.api_route("/mcp", methods=["GET", "POST", "DELETE", "OPTIONS"], tags=["Streamable HTTP"])
|
||||
async def mcp_endpoint_handler(request: Request):
|
||||
"""Handles GET, POST, DELETE, OPTIONS for the /mcp endpoint."""
|
||||
|
||||
# 1. Handle OPTIONS (CORS preflight)
|
||||
if request.method == "OPTIONS":
|
||||
# Assuming CORS headers are handled by middleware in main.py
|
||||
# If not, provide necessary headers here.
|
||||
# This minimal response might suffice if middleware handles the rest
|
||||
logger.debug("Handling OPTIONS request for /mcp")
|
||||
# Return basic OK allowing exposed headers if middleware handles the rest
|
||||
return JSONResponse({}, headers={"Access-Control-Expose-Headers": "Mcp-Session-Id"})
|
||||
|
||||
# Session ID from header is required for most methods
|
||||
session_id = request.headers.get("Mcp-Session-Id")
|
||||
|
||||
# 2. Handle DELETE (Terminate Session)
|
||||
if request.method == "DELETE":
|
||||
if not session_id:
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32600, "message": "Mcp-Session-Id header is required for DELETE"}}, status_code=400)
|
||||
|
||||
logger.info(f"Handling DELETE request for session [Session ID: {session_id}]")
|
||||
session_data = self.client_sessions.pop(session_id, None)
|
||||
if session_data:
|
||||
await self._cleanup_session_resources(session_id, session_data)
|
||||
return JSONResponse({}, status_code=204) # No Content
|
||||
else:
|
||||
logger.warning(f"Attempted DELETE on non-existent session: {session_id}")
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32001, "message": "Session not found"}}, status_code=404)
|
||||
|
||||
# 3. Handle GET (Server Push SSE Stream)
|
||||
if request.method == "GET":
|
||||
if not session_id:
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32000, "message": "Mcp-Session-Id header is required for GET streams"}}, status_code=400)
|
||||
if session_id not in self.client_sessions:
|
||||
# Note: Unlike legacy SSE, GET here assumes session exists.
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32001, "message": "Session not found. Initialize first."}}, status_code=404)
|
||||
|
||||
accept_header = request.headers.get("Accept", "")
|
||||
if "text/event-stream" not in accept_header:
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32600, "message": "Accept header must include text/event-stream for GET"}}, status_code=406)
|
||||
|
||||
# TODO: Handle Last-Event-ID for stream recovery?
|
||||
|
||||
logger.info(f"Handling GET request, establishing server push SSE stream [Session ID: {session_id}]")
|
||||
|
||||
push_queue = asyncio.Queue()
|
||||
if self.client_sessions[session_id].get("general_sse_queues") is None:
|
||||
self.client_sessions[session_id]["general_sse_queues"] = []
|
||||
self.client_sessions[session_id]["general_sse_queues"].append(push_queue)
|
||||
self.client_sessions[session_id]["last_active"] = time.time()
|
||||
|
||||
return EventSourceResponse(self._create_general_sse_generator(session_id, push_queue), media_type="text/event-stream")
|
||||
|
||||
# 4. Handle POST (Client Messages & Initialize)
|
||||
if request.method == "POST":
|
||||
accept_header = request.headers.get("Accept", "")
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
body = {}
|
||||
try:
|
||||
if "application/json" not in content_type:
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32700, "message": "Content-Type must be application/json"}}, status_code=415)
|
||||
body = await request.json()
|
||||
if isinstance(body, list): return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32600, "message": "Batch requests not supported"}}, status_code=400)
|
||||
if not isinstance(body, dict): return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32700, "message": "Invalid JSON received"}}, status_code=400)
|
||||
|
||||
method = body.get("method")
|
||||
message_id = body.get("id") # Can be None for notifications
|
||||
|
||||
# Handle Initialize request (does not require Mcp-Session-Id header)
|
||||
if method == "initialize":
|
||||
if "application/json" not in accept_header:
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32600, "message": "Accept header must include application/json for initialize"}}, status_code=406)
|
||||
return await self._handle_initialize(request, body, message_id)
|
||||
|
||||
# Handle other POST requests (require Mcp-Session-Id)
|
||||
else:
|
||||
if not session_id:
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32000, "message": "Mcp-Session-Id header is required for this request"}}, status_code=400)
|
||||
if session_id not in self.client_sessions:
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32001, "message": "Session not found"}}, status_code=404)
|
||||
# Check Accept header for non-initialize POST
|
||||
if not ("application/json" in accept_header and "text/event-stream" in accept_header):
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32600, "message": "Accept header must include application/json and text/event-stream for POST"}}, status_code=406)
|
||||
|
||||
self.client_sessions[session_id]["last_active"] = time.time()
|
||||
return await self._handle_client_post(request, body, session_id, message_id)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error - Invalid JSON received"}}, status_code=400)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error handling POST /mcp: {str(e)}", exc_info=True)
|
||||
error_id = body.get("id") if isinstance(body, dict) else None
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": error_id, "error": {"code": -32000, "message": "Internal server error"}}, status_code=500)
|
||||
|
||||
# Fallback for other methods like PUT, PATCH etc.
|
||||
return JSONResponse({"error": "Method Not Allowed"}, status_code=405)
|
||||
|
||||
async def _handle_initialize(self, request: Request, body: Dict, message_id: Any):
|
||||
"""Handles the 'initialize' method call via POST /mcp."""
|
||||
logger.info("Handling Streamable HTTP initialize request")
|
||||
# Optional: Validate params in body if needed
|
||||
# params = body.get("params", {})
|
||||
|
||||
new_session_id = str(uuid.uuid4())
|
||||
logger.info(f"Created new Streamable HTTP session [Session ID: {new_session_id}]")
|
||||
|
||||
self.client_sessions[new_session_id] = {
|
||||
"created_at": time.time(),
|
||||
"last_active": time.time(),
|
||||
# No transport_type needed here as this class *is* the streamable server
|
||||
"request_queues": {}, # Initialize request queues dict
|
||||
"general_sse_queues": [] # Initialize general queues list
|
||||
}
|
||||
|
||||
# Build InitializeResult based on spec
|
||||
initialize_result = {
|
||||
"protocolVersion": "2025-03-26",
|
||||
"name": self.mcp_server.name,
|
||||
"instructions": "Apache Doris MCP Server (Streamable HTTP Mode)",
|
||||
"serverInfo": { "version": "0.2.0", "name": "Doris MCP Streamable Server" }, # Adjust as needed
|
||||
"capabilities": {
|
||||
"tools": { "supportsStreaming": True, "supportsProgress": True },
|
||||
"resources": { "supportsStreaming": False }, # Example capability
|
||||
"prompts": { "supported": True }, # Example capability
|
||||
"session": { "supported": True }
|
||||
}
|
||||
}
|
||||
response_body = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"result": initialize_result
|
||||
}
|
||||
|
||||
# Return JSON response with Mcp-Session-Id header
|
||||
return JSONResponse(
|
||||
content=response_body,
|
||||
media_type="application/json",
|
||||
headers={"Mcp-Session-Id": new_session_id}
|
||||
)
|
||||
|
||||
async def _handle_client_post(self, request: Request, body: Dict, session_id: str, message_id: Any):
|
||||
"""Handles non-initialize POST requests (notifications, responses, method calls)."""
|
||||
method = body.get("method")
|
||||
|
||||
# Handle Notifications/Responses from client
|
||||
is_notification = "method" in body and "id" not in body
|
||||
is_response = "result" in body or "error" in body
|
||||
if is_notification or is_response:
|
||||
logger.info(f"Received Streamable HTTP notification/response [Session ID: {session_id}] - Processing needed? (Ignoring for now)")
|
||||
# TODO: If the server sends requests that expect responses, process is_response here.
|
||||
# For now, just acknowledge client notifications/responses.
|
||||
return JSONResponse({}, status_code=202) # Accepted
|
||||
|
||||
# Handle Requests from client (method call)
|
||||
if "method" in body and "id" in body:
|
||||
logger.info(f"Received Streamable HTTP request [Session ID: {session_id}, ID: {message_id}, Method: {method}]")
|
||||
params = body.get("params", {})
|
||||
stream_required = params.get("stream", False) if method in ["tools/call", "mcp/callTool"] else False
|
||||
|
||||
if stream_required:
|
||||
# --- Return SSE stream for response parts ---
|
||||
logger.info(f"Using SSE stream for request [Session ID: {session_id}, ID: {message_id}]")
|
||||
response_queue = asyncio.Queue()
|
||||
# Ensure request_queues exists (should have been created during initialize)
|
||||
if self.client_sessions[session_id].get("request_queues") is None:
|
||||
logger.error(f"Session {session_id} is missing 'request_queues' dictionary!")
|
||||
# Handle this inconsistency, maybe return an error
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32000, "message": "Internal server error: Session state inconsistent"}}, status_code=500)
|
||||
self.client_sessions[session_id]["request_queues"][message_id] = response_queue
|
||||
|
||||
# Start background task to process and put results in the queue
|
||||
asyncio.create_task(self._process_request_and_respond(
|
||||
request, body, session_id, message_id, response_queue, is_stream=True
|
||||
))
|
||||
|
||||
# Return EventSourceResponse using the request-specific queue
|
||||
return EventSourceResponse(self._create_request_sse_generator(session_id, message_id, response_queue), media_type="text/event-stream")
|
||||
else:
|
||||
# --- Return single JSON response ---
|
||||
logger.info(f"Using JSON response for request [Session ID: {session_id}, ID: {message_id}]")
|
||||
try:
|
||||
# Process the request directly and get the result/error payload
|
||||
result_or_error_payload = await self._process_request_and_respond(
|
||||
request, body, session_id, message_id, None, is_stream=False
|
||||
)
|
||||
# This function now returns the final JSON body or raises HTTPException
|
||||
return JSONResponse(content=result_or_error_payload, media_type="application/json")
|
||||
except HTTPException as http_exc:
|
||||
# Format HTTPException details into JSON-RPC error
|
||||
return JSONResponse(
|
||||
{"jsonrpc": "2.0", "id": message_id, "error": {"code": -32000, "message": http_exc.detail}},
|
||||
status_code=http_exc.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
# Catch unexpected errors during synchronous processing
|
||||
logger.error(f"Error processing non-stream request [Session ID: {session_id}, ID: {message_id}]: {str(e)}", exc_info=True)
|
||||
error_response = {"jsonrpc": "2.0", "id": message_id, "error": {"code": -32000, "message": f"Internal server error: {str(e)}"}}
|
||||
return JSONResponse(content=error_response, status_code=500)
|
||||
else:
|
||||
# Invalid JSON-RPC format (e.g., missing method or id for a request)
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32600, "message": "Invalid JSON-RPC request format"}}, status_code=400)
|
||||
|
||||
# === Generator Functions for SSE Streams ===
|
||||
|
||||
async def _create_general_sse_generator(self, session_id: str, queue: asyncio.Queue):
|
||||
"""Generator for GET /mcp server push streams."""
|
||||
queue_removed = False
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
if session_id not in self.client_sessions:
|
||||
logger.warning(f"General SSE stream generator: Session {session_id} closed.")
|
||||
break
|
||||
|
||||
message = await asyncio.wait_for(queue.get(), timeout=60.0)
|
||||
|
||||
if message == STREAM_END_MARKER:
|
||||
logger.debug(f"General SSE stream received end marker [Session ID: {session_id}]")
|
||||
break
|
||||
|
||||
if isinstance(message, dict) and ("result" in message or "error" in message) and "id" in message:
|
||||
logger.warning(f"Attempted to send response on GET stream, blocked [Session ID: {session_id}]: {message}")
|
||||
queue.task_done()
|
||||
continue
|
||||
|
||||
# TODO: Event ID for recovery?
|
||||
yield {"event": "message", "data": json.dumps(message)}
|
||||
queue.task_done()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
if session_id not in self.client_sessions:
|
||||
logger.warning(f"General SSE stream generator (timeout): Session {session_id} closed.")
|
||||
break
|
||||
yield {"event": "ping", "data": "keepalive"}
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"General SSE stream cancelled [Session ID: {session_id}]")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"General SSE stream error [Session ID: {session_id}]: {str(e)}", exc_info=True)
|
||||
break
|
||||
finally:
|
||||
logger.info(f"General SSE stream ended [Session ID: {session_id}]")
|
||||
if not queue_removed and session_id in self.client_sessions:
|
||||
session = self.client_sessions[session_id]
|
||||
if session.get("general_sse_queues") is not None:
|
||||
try:
|
||||
session["general_sse_queues"].remove(queue)
|
||||
queue_removed = True
|
||||
logger.debug(f"General SSE queue removed from session [Session ID: {session_id}]")
|
||||
except ValueError:
|
||||
logger.warning(f"Failed to remove general SSE queue (not found) [Session ID: {session_id}]")
|
||||
except Exception as ce:
|
||||
logger.error(f"Error removing general SSE queue [Session ID: {session_id}]: {ce}")
|
||||
while not queue.empty():
|
||||
try: queue.get_nowait(); queue.task_done()
|
||||
except asyncio.QueueEmpty: break
|
||||
|
||||
async def _create_request_sse_generator(self, session_id: str, request_id: Any, queue: asyncio.Queue):
|
||||
"""Generator for POST /mcp request-response streams."""
|
||||
queue_removed = False
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
if session_id not in self.client_sessions or \
|
||||
request_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
|
||||
logger.warning(f"Request SSE stream generator: Session/Request queue closed [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
break
|
||||
|
||||
message = await asyncio.wait_for(queue.get(), timeout=120.0) # Longer timeout for requests?
|
||||
|
||||
if message == STREAM_END_MARKER:
|
||||
logger.debug(f"Request SSE stream received end marker [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
break
|
||||
|
||||
# TODO: Event ID for parts?
|
||||
yield {"event": "message", "data": json.dumps(message)}
|
||||
queue.task_done()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
if session_id not in self.client_sessions or \
|
||||
request_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
|
||||
logger.warning(f"Request SSE stream generator (timeout): Session/Request queue closed [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
break
|
||||
logger.debug(f"Request SSE stream timed out waiting for message/end [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
# Unlike general stream, timeout here might indicate an issue or just long processing.
|
||||
# Continue waiting for the STREAM_END_MARKER.
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Request SSE stream cancelled [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Request SSE stream error [Session ID: {session_id}, Request ID: {request_id}]: {str(e)}", exc_info=True)
|
||||
break
|
||||
finally:
|
||||
logger.info(f"Request SSE stream ended [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
if not queue_removed and session_id in self.client_sessions:
|
||||
session = self.client_sessions[session_id]
|
||||
if session.get("request_queues") is not None:
|
||||
if session["request_queues"].pop(request_id, None):
|
||||
queue_removed = True
|
||||
logger.debug(f"Request SSE queue removed from session [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
else:
|
||||
logger.warning(f"Failed to remove request SSE queue (not found) [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
while not queue.empty():
|
||||
try: queue.get_nowait(); queue.task_done()
|
||||
except asyncio.QueueEmpty: break
|
||||
|
||||
# === Core Request Processing Logic ===
|
||||
|
||||
async def _process_request_and_respond(
|
||||
self, request: Request, body: Dict, session_id: str, message_id: Any,
|
||||
response_queue: Optional[asyncio.Queue], # Queue ONLY for streaming responses
|
||||
is_stream: bool # True if response should go via SSE queue
|
||||
):
|
||||
"""Processes client method calls and prepares response/error payload or sends to queue.
|
||||
Returns payload for non-streaming, returns None for streaming (uses queue).
|
||||
Raises HTTPException for non-streaming errors that need specific status codes.
|
||||
"""
|
||||
logger.info(f"Entering _process_request_and_respond for method '{body.get('method')}'...")
|
||||
method = body.get("method")
|
||||
params = body.get("params", {})
|
||||
response_payload = None # Holds the 'result' or 'error' part of JSON-RPC
|
||||
|
||||
try:
|
||||
# --- Handle Method Calls ---
|
||||
if method == "mcp/listOfferings":
|
||||
tools = await self.mcp_server.list_tools()
|
||||
tools_json = self._format_tools(tools)
|
||||
resources = await self.mcp_server.list_resources()
|
||||
resources_json = self._format_resources(resources)
|
||||
prompts = await self.mcp_server.list_prompts()
|
||||
prompts_json = self._format_prompts(prompts)
|
||||
response_payload = {"tools": tools_json, "resources": resources_json, "prompts": prompts_json}
|
||||
|
||||
elif method == "mcp/listTools" or method == "tools/list":
|
||||
tools = await self.mcp_server.list_tools()
|
||||
response_payload = {"tools": self._format_tools(tools)}
|
||||
|
||||
elif method == "mcp/listResources":
|
||||
resources = await self.mcp_server.list_resources()
|
||||
response_payload = {"resources": self._format_resources(resources)}
|
||||
|
||||
elif method == "mcp/listPrompts":
|
||||
prompts = await self.mcp_server.list_prompts()
|
||||
response_payload = {"prompts": self._format_prompts(prompts)}
|
||||
|
||||
elif method == "mcp/callTool" or method == "tools/call":
|
||||
tool_name = params.get("name")
|
||||
arguments = params.get("arguments", {})
|
||||
if not tool_name:
|
||||
# For non-streaming, raise HTTPException; for streaming, send error via queue
|
||||
error_detail = "Invalid params: tool name ('name') is required"
|
||||
if is_stream and response_queue:
|
||||
error_resp = {"jsonrpc": "2.0", "id": message_id, "error": {"code": -32602, "message": error_detail}}
|
||||
await response_queue.put(error_resp)
|
||||
# No return here for stream, let finally handle end marker
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=error_detail)
|
||||
return # Exit after handling error
|
||||
|
||||
# --- Tool Calling ---
|
||||
if is_stream and response_queue:
|
||||
# Background task handles putting results/errors in queue
|
||||
logger.info(f"Launching stream tool task [Session: {session_id}, Req: {message_id}, Tool: {tool_name}]")
|
||||
asyncio.create_task(self._execute_stream_tool_wrapper(
|
||||
tool_name, arguments, message_id, session_id, request, response_queue
|
||||
))
|
||||
# Returns None, caller (_handle_client_post) returns EventSourceResponse
|
||||
return
|
||||
else:
|
||||
# Execute tool directly for non-streaming response
|
||||
logger.info(f"Executing non-stream tool [Session: {session_id}, Req: {message_id}, Tool: {tool_name}]")
|
||||
# Note: call_tool now raises ValueError on internal errors
|
||||
result = await self.call_tool(tool_name, arguments, request, None) # No callback needed
|
||||
logger.debug(f"Raw result from non-stream call_tool: {result}")
|
||||
response_payload = self._format_tool_call_result(result)
|
||||
else:
|
||||
# Method not found
|
||||
error_detail = f"Method not found: {method}"
|
||||
if is_stream and response_queue:
|
||||
error_resp = {"jsonrpc": "2.0", "id": message_id, "error": {"code": -32601, "message": error_detail}}
|
||||
await response_queue.put(error_resp)
|
||||
else:
|
||||
raise HTTPException(status_code=405, detail=error_detail)
|
||||
return # Exit after handling error
|
||||
|
||||
# --- Prepare final response payload (only if not streaming and successful) ---
|
||||
if response_payload is not None:
|
||||
final_response = {"jsonrpc": "2.0", "id": message_id, "result": response_payload}
|
||||
if is_stream and response_queue: # Should not happen if response_payload is set
|
||||
logger.error("Logic error: response_payload set for streaming call?")
|
||||
await response_queue.put(final_response) # Send anyway?
|
||||
elif not is_stream:
|
||||
logger.debug(f"Returning successful non-stream payload for {method}")
|
||||
return final_response # Return dict for JSONResponse
|
||||
|
||||
except Exception as e:
|
||||
# Handles errors raised by call_tool (ValueError) or other unexpected issues
|
||||
logger.error(f"Error processing request [Session: {session_id}, Req: {message_id}, Method: {method}]: {str(e)}", exc_info=True)
|
||||
error_code = -32000
|
||||
error_message = f"Internal server error: {str(e)}"
|
||||
status_code = 500 # Default for unexpected errors
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
# If it was an HTTPException raised earlier (e.g., 400, 405)
|
||||
error_message = e.detail
|
||||
status_code = e.status_code
|
||||
error_code = -32000 # Keep generic JSON-RPC code for now
|
||||
elif isinstance(e, ValueError):
|
||||
# Errors from call_tool (tool not found, execution error)
|
||||
error_message = str(e)
|
||||
status_code = 500 # Treat tool execution errors as internal server errors
|
||||
error_code = -32000 # Or a custom tool error code?
|
||||
|
||||
error_response_payload = {"code": error_code, "message": error_message}
|
||||
|
||||
if is_stream and response_queue:
|
||||
# Send error via queue for streaming calls
|
||||
final_error_response = {"jsonrpc": "2.0", "id": message_id, "error": error_response_payload}
|
||||
logger.debug(f"Putting error response into stream queue [Session: {session_id}, Req: {message_id}]")
|
||||
await response_queue.put(final_error_response)
|
||||
# Returns None, let finally send end marker
|
||||
return
|
||||
else:
|
||||
# For non-streaming, raise HTTPException to set status code
|
||||
logger.debug(f"Raising HTTPException for non-stream error (Status: {status_code})")
|
||||
raise HTTPException(status_code=status_code, detail=error_message)
|
||||
|
||||
finally:
|
||||
# If this was a streaming call, ensure the end marker is sent.
|
||||
# This runs even if the processing returns early (e.g., after launching task or handling error).
|
||||
if is_stream and response_queue:
|
||||
logger.debug(f"Putting stream end marker [Session: {session_id}, Req: {message_id}]")
|
||||
await response_queue.put(STREAM_END_MARKER)
|
||||
|
||||
|
||||
async def _execute_stream_tool_wrapper(
|
||||
self, tool_name: str, arguments: Dict, message_id: Any, session_id: str,
|
||||
request: Request, response_queue: asyncio.Queue
|
||||
):
|
||||
"""Wraps stream-capable tool calls, handles callback, puts results/errors into queue."""
|
||||
logger.info(f"Entering _execute_stream_tool_wrapper for tool '{tool_name}'...")
|
||||
try:
|
||||
logger.debug(f"Executing stream tool wrapper [Session: {session_id}, Req: {message_id}, Tool: {tool_name}]")
|
||||
|
||||
async def stream_callback(content, metadata=None):
|
||||
logger.debug(f"Stream callback received content [Session: {session_id}, Req: {message_id}]")
|
||||
partial_result_formatted = self._format_tool_call_result(content)
|
||||
|
||||
# Check session/queue validity before putting
|
||||
if session_id not in self.client_sessions or \
|
||||
message_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
|
||||
logger.warning(f"Stream callback: Session/Queue closed, cannot send partial result [Session: {session_id}, Req: {message_id}]")
|
||||
return
|
||||
|
||||
# Send progress notification
|
||||
progress_notification = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/progress",
|
||||
"params": {
|
||||
"requestId": message_id,
|
||||
"toolName": tool_name,
|
||||
"progress": partial_result_formatted,
|
||||
}
|
||||
}
|
||||
try:
|
||||
await response_queue.put(progress_notification)
|
||||
except Exception as e:
|
||||
logger.error(f"Stream callback failed to send progress: {str(e)}")
|
||||
|
||||
# Handle visualization data
|
||||
if metadata and "visualization" in metadata:
|
||||
await self.send_visualization_data(session_id, message_id, metadata["visualization"])
|
||||
|
||||
# --- Call Tool ---
|
||||
kwargs = dict(arguments)
|
||||
# Simplification: Assume tool supports callback if streaming requested
|
||||
kwargs['callback'] = stream_callback
|
||||
|
||||
# call_tool handles its own internal errors and raises ValueError
|
||||
result = await self.call_tool(tool_name, kwargs, request, stream_callback)
|
||||
logger.debug(f"Stream wrapper received final result from call_tool: {result}")
|
||||
|
||||
# --- Send Final Result ---
|
||||
if session_id not in self.client_sessions or \
|
||||
message_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
|
||||
logger.warning(f"Stream tool finished but Session/Queue closed [Session: {session_id}, Req: {message_id}]")
|
||||
return # Cannot send final result
|
||||
|
||||
final_result_formatted = self._format_tool_call_result(result)
|
||||
final_message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"result": final_result_formatted
|
||||
}
|
||||
logger.debug(f"Putting final stream result into queue [Session: {session_id}, Req: {message_id}]")
|
||||
await response_queue.put(final_message)
|
||||
logger.info(f"Stream tool execution successful [Session: {session_id}, Req: {message_id}]")
|
||||
|
||||
except Exception as e:
|
||||
# Catches errors from call_tool (ValueError) or other wrapper issues
|
||||
logger.error(f"Error during stream tool execution wrapper [Session: {session_id}, Req: {message_id}]: {str(e)}", exc_info=True)
|
||||
# Check session/queue validity before sending error
|
||||
if session_id not in self.client_sessions or \
|
||||
message_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
|
||||
logger.warning(f"Stream tool failed but Session/Queue closed [Session: {session_id}, Req: {message_id}]")
|
||||
return # Cannot send error
|
||||
|
||||
error_code = -32000
|
||||
error_message = f"Tool execution error: {str(e)}"
|
||||
if isinstance(e, ValueError):
|
||||
error_code = -32602 # Or -32000?
|
||||
error_message = str(e)
|
||||
|
||||
error_response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"error": { "code": error_code, "message": error_message }
|
||||
}
|
||||
try:
|
||||
await response_queue.put(error_response)
|
||||
except Exception as qe:
|
||||
logger.error(f"Failed to put error response into stream queue: {qe}")
|
||||
# No finally block needed here, handled by _process_request_and_respond
|
||||
|
||||
|
||||
async def call_tool(self, tool_name, arguments, request, callback: Optional[callable] = None):
|
||||
"""Finds and executes the target tool function/method.
|
||||
Raises ValueError on tool not found or execution error.
|
||||
"""
|
||||
logger.info(f"Entering call_tool for tool '{tool_name}'...")
|
||||
# Log args excluding callback
|
||||
log_args = {k: v for k, v in arguments.items() if k != 'callback'}
|
||||
logger.info(f"Executing tool: {tool_name}, Args: {json.dumps(log_args, ensure_ascii=False, default=str)}")
|
||||
|
||||
recent_query = self._extract_recent_query(request)
|
||||
# Tool mapping might be needed if client uses different names
|
||||
tool_mapping = {
|
||||
# Example: "clientFacingName": "internalFunctionName"
|
||||
"status": "mcp_doris_status",
|
||||
"health": "mcp_doris_health",
|
||||
# Add other mappings if needed, ensure consistency with tool_initializer
|
||||
"nl2sql_query": "mcp_doris_nl2sql_query",
|
||||
"nl2sql_query_stream": "mcp_doris_nl2sql_query_stream",
|
||||
"list_database_tables": "mcp_doris_list_database_tables",
|
||||
"explain_table": "mcp_doris_explain_table",
|
||||
"get_nl2sql_status": "mcp_doris_get_nl2sql_status",
|
||||
"refresh_metadata": "mcp_doris_refresh_metadata",
|
||||
"sql_optimize": "mcp_doris_sql_optimize",
|
||||
"fix_sql": "mcp_doris_fix_sql",
|
||||
"count_chars": "mcp_doris_count_chars",
|
||||
"exec_query": "mcp_doris_exec_query",
|
||||
"get_schema_list": "mcp_doris_get_schema_list", # Deprecated?
|
||||
"save_metadata": "mcp_doris_save_metadata", # Likely internal
|
||||
"get_metadata": "mcp_doris_get_metadata", # Likely internal
|
||||
"analyze_query_result": "mcp_doris_analyze_query_result", # Internal?
|
||||
"generate_sql": "mcp_doris_generate_sql", # Likely internal
|
||||
"explain_sql": "mcp_doris_explain_sql", # Internal?
|
||||
"modify_sql": "mcp_doris_modify_sql", # Internal?
|
||||
"parse_query": "mcp_doris_parse_query", # Internal?
|
||||
"identify_query_type": "mcp_doris_identify_query_type", # Internal?
|
||||
"validate_sql_syntax": "mcp_doris_validate_sql_syntax", # Internal?
|
||||
"check_sql_security": "mcp_doris_check_sql_security", # Internal?
|
||||
"find_similar_examples": "mcp_doris_find_similar_examples", # Internal?
|
||||
"find_similar_history": "mcp_doris_find_similar_history", # Internal?
|
||||
"calculate_query_similarity": "mcp_doris_calculate_query_similarity", # Internal?
|
||||
"adapt_similar_query": "mcp_doris_adapt_similar_query", # Internal?
|
||||
"get_nl2sql_prompt": "mcp_doris_get_nl2sql_prompt" # Internal?
|
||||
}
|
||||
mapped_tool_name = tool_mapping.get(tool_name, tool_name)
|
||||
|
||||
try:
|
||||
# 1. Find the registered tool instance/function from FastMCP
|
||||
tool_instance = None
|
||||
mcp = self.app.state.mcp if hasattr(self.app.state, 'mcp') else self.mcp_server
|
||||
registered_tools = await mcp.list_tools()
|
||||
for tool in registered_tools:
|
||||
# The tool object returned by list_tools might be the wrapper function
|
||||
# defined in tool_initializer. We need its name.
|
||||
tool_registered_name = getattr(tool, 'name', getattr(tool, '__name__', None))
|
||||
if tool_registered_name == tool_name: # Match against the name used in @mcp.tool
|
||||
tool_instance = tool # This is likely the wrapper function itself
|
||||
logger.debug(f"Found registered tool wrapper: {tool_registered_name}")
|
||||
break
|
||||
|
||||
if not tool_instance:
|
||||
# Fallback: Try importing directly (less ideal as it bypasses registration)
|
||||
logger.warning(f"Tool '{tool_name}' not found in registered tools, trying direct import of {mapped_tool_name}")
|
||||
try:
|
||||
import doris_mcp_server.tools.mcp_doris_tools as mcp_tools
|
||||
tool_instance = getattr(mcp_tools, mapped_tool_name, None)
|
||||
if not tool_instance or not callable(tool_instance):
|
||||
raise ValueError(f"Tool function {mapped_tool_name} not found or not callable in mcp_doris_tools.")
|
||||
logger.debug(f"Using directly imported tool function: {mapped_tool_name}")
|
||||
# If using direct import, FastMCP context (ctx) is not available
|
||||
# We need to pass args directly
|
||||
processed_args = self._process_tool_arguments(mapped_tool_name, arguments, recent_query)
|
||||
# Inject callback if provided and applicable
|
||||
if callback and mapped_tool_name.endswith("_stream"):
|
||||
processed_args['callback'] = callback
|
||||
elif callback:
|
||||
processed_args.pop('callback', None)
|
||||
result = await tool_instance(**processed_args)
|
||||
logger.debug(f"Raw result from directly imported tool '{mapped_tool_name}': {result}")
|
||||
return result
|
||||
|
||||
except (ImportError, AttributeError, ValueError) as import_err:
|
||||
logger.error(f"Failed to find or import tool: {tool_name} / {mapped_tool_name}. Error: {import_err}")
|
||||
raise ValueError(f"Tool '{tool_name}' not found or failed to import.") from import_err
|
||||
|
||||
# 2. If found via registration, execute using FastMCP's mechanism (if possible)
|
||||
# or simulate the context passing if tool_instance is the wrapper.
|
||||
# The wrapper expects a Context object.
|
||||
logger.debug(f"Executing registered tool wrapper '{tool_name}'")
|
||||
# We need to manually create a mock or simplified Context if FastMCP doesn't handle this automatically
|
||||
# For simplicity, let's try passing parameters directly if the wrapper handles it.
|
||||
# Ideally, FastMCP would handle the execution via mcp.call_tool(tool_name, params=...) if available.
|
||||
# Let's assume the wrapper function handles **kwargs or a Context object.
|
||||
|
||||
# Create a pseudo-context or just pass params
|
||||
# Method 1: Pass params directly (assuming wrapper handles it)
|
||||
# processed_args = self._process_tool_arguments(mapped_tool_name, arguments, recent_query)
|
||||
# if callback:
|
||||
# processed_args['callback'] = callback
|
||||
# result = await tool_instance(**processed_args) # This likely won't work if it expects Context
|
||||
|
||||
# Method 2: Create a Context-like object (Requires Context class import)
|
||||
# from mcp.server.fastmcp import Context # Make sure imported
|
||||
# pseudo_ctx = Context(mcp=mcp, request=request, params=arguments, tool=tool_instance)
|
||||
# result = await tool_instance(pseudo_ctx)
|
||||
|
||||
# Method 3: Use mcp.call_tool internal method if accessible and appropriate
|
||||
# This is speculative based on potential FastMCP internals
|
||||
if hasattr(mcp, 'call_tool_by_name'): # Hypothetical method
|
||||
logger.debug("Attempting execution via mcp.call_tool_by_name")
|
||||
pseudo_ctx_params = arguments # Pass client args
|
||||
# pseudo_ctx_params['_request'] = request # Maybe pass request?
|
||||
if callback: pseudo_ctx_params['callback'] = callback # Pass callback?
|
||||
result = await mcp.call_tool_by_name(tool_name, params=pseudo_ctx_params)
|
||||
logger.debug(f"Result from mcp.call_tool_by_name: {result}")
|
||||
else:
|
||||
# Fallback to manual context simulation if no direct call method exists
|
||||
logger.debug("Falling back to manual context simulation for tool wrapper execution")
|
||||
from mcp.server.fastmcp import Context # Ensure imported
|
||||
# Prepare params for context, including potentially callback
|
||||
context_params = dict(arguments)
|
||||
if callback: context_params['callback'] = callback
|
||||
pseudo_ctx = Context(mcp=mcp, request=request, params=context_params, tool=tool_instance)
|
||||
result = await tool_instance(pseudo_ctx) # Call the wrapper with simulated context
|
||||
logger.debug(f"Result from manual context simulation: {result}")
|
||||
|
||||
logger.debug(f"Raw result received in call_tool from registered tool '{tool_name}': {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during call_tool for '{tool_name}': {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Error executing tool '{tool_name}': {str(e)}") from e
|
||||
|
||||
|
||||
# === Helper Methods (Formatting, Session Cleanup, etc.) ===
|
||||
|
||||
def _format_tools(self, tools):
|
||||
# Helper to format tool list for responses
|
||||
# Based on mcp/listTools structure
|
||||
tools_json = []
|
||||
for tool in tools:
|
||||
# Assuming tools from list_tools are the wrapper functions
|
||||
tool_registered_name = getattr(tool, 'name', getattr(tool, '__name__', None))
|
||||
if not tool_registered_name:
|
||||
logger.warning(f"Could not determine name for tool object: {tool}")
|
||||
continue
|
||||
|
||||
# Need a way to get description and schema associated with the wrapper
|
||||
# This might require inspecting the mcp instance's internal storage
|
||||
mcp = self.app.state.mcp if hasattr(self.app.state, 'mcp') else self.mcp_server
|
||||
# Hypothetical internal access - THIS IS FRAGILE
|
||||
tool_spec = mcp.tools.get(tool_registered_name) if hasattr(mcp, 'tools') else None
|
||||
|
||||
description = ""
|
||||
input_schema = {"type": "object", "properties": {}, "required": []}
|
||||
if tool_spec and hasattr(tool_spec, 'description'):
|
||||
description = tool_spec.description
|
||||
if tool_spec and hasattr(tool_spec, 'parameters'): # Assuming parameters holds the JSON schema
|
||||
input_schema = tool_spec.parameters
|
||||
|
||||
tools_json.append({
|
||||
"name": tool_registered_name,
|
||||
"description": description,
|
||||
"inputSchema": input_schema
|
||||
})
|
||||
return tools_json
|
||||
|
||||
def _format_resources(self, resources):
|
||||
# Helper to format resource list
|
||||
return [res.model_dump() if hasattr(res, "model_dump") else res for res in resources]
|
||||
|
||||
def _format_prompts(self, prompts):
|
||||
# Helper to format prompt list
|
||||
return [prompt.model_dump() if hasattr(prompt, "model_dump") else prompt for prompt in prompts]
|
||||
|
||||
def _format_tool_call_result(self, result: Any) -> Dict[str, Any]:
|
||||
# Helper to format tool results into MCP Content format
|
||||
content_list = []
|
||||
if isinstance(result, str):
|
||||
try:
|
||||
# If it looks like the tool already returned the full JSON RPC like structure
|
||||
parsed_json = json.loads(result)
|
||||
if isinstance(parsed_json, dict) and 'content' in parsed_json and isinstance(parsed_json['content'], list):
|
||||
logger.debug("Tool result already seems formatted with 'content', using as is.")
|
||||
return parsed_json # Use the structure directly
|
||||
else:
|
||||
# Assume it's JSON content, wrap it
|
||||
content_list.append({"type": "json", "json": parsed_json})
|
||||
except json.JSONDecodeError:
|
||||
# Not JSON, treat as text
|
||||
content_list.append({"type": "text", "text": result})
|
||||
elif isinstance(result, (dict, list)):
|
||||
# If result is already a dict with a 'content' list, use it directly
|
||||
if isinstance(result, dict) and 'content' in result and isinstance(result['content'], list):
|
||||
logger.debug("Tool result dictionary has 'content', using as is.")
|
||||
return result # Use the structure directly
|
||||
else:
|
||||
# Otherwise, assume it's JSON content to be wrapped
|
||||
content_list.append({"type": "json", "json": result})
|
||||
elif result is None:
|
||||
# Handle None result, maybe return empty content or specific type?
|
||||
logger.warning("_format_tool_call_result received None result")
|
||||
content_list.append({"type": "text", "text": ""}) # Example: empty text
|
||||
else:
|
||||
# Other types, convert to string and wrap as text
|
||||
content_list.append({"type": "text", "text": str(result)})
|
||||
# Always return a dict with a 'content' key containing a list
|
||||
return {"content": content_list}
|
||||
|
||||
def _process_tool_arguments(self, tool_name, arguments, recent_query):
|
||||
# Helper to process tool arguments, including random_string fallback
|
||||
# Note: Ensure callback is NOT passed here
|
||||
processed_args = dict(arguments)
|
||||
processed_args.pop('callback', None) # Explicitly remove callback
|
||||
|
||||
if "random_string" in arguments and tool_name.startswith("mcp_doris_"):
|
||||
random_string = processed_args.pop("random_string", "") # Remove from processed too
|
||||
logger.debug(f"Processing random_string '{random_string}' for tool {tool_name}")
|
||||
|
||||
# ... (rest of random_string logic as before) ...
|
||||
# Example for exec_query:
|
||||
if tool_name == "mcp_doris_exec_query" and not processed_args.get("sql"):
|
||||
sql_fallback = random_string or recent_query
|
||||
# ... (logic to extract SQL from fallback) ...
|
||||
if sql_extracted:
|
||||
processed_args["sql"] = sql_extracted
|
||||
else:
|
||||
logger.warning(f"Missing sql for {tool_name}, and fallback failed.")
|
||||
# ... (logic for table_name fallback) ...
|
||||
|
||||
return processed_args
|
||||
|
||||
def _extract_recent_query(self, request: Request) -> Optional[str]:
|
||||
# Helper to extract recent user query from request
|
||||
# (Implementation as provided previously)
|
||||
try:
|
||||
# Try to extract message history from request body
|
||||
body = None
|
||||
body_bytes = getattr(request, "_body", None)
|
||||
if body_bytes:
|
||||
try:
|
||||
body = json.loads(body_bytes)
|
||||
except: pass
|
||||
if not body: body = getattr(request, "_json", {})
|
||||
|
||||
messages = body.get("params", {}).get("messages", [])
|
||||
if messages:
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user": return msg.get("content", "")
|
||||
|
||||
message = body.get("params", {}).get("message", {})
|
||||
if message and message.get("role") == "user": return message.get("content", "")
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting recent query: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _cleanup_session_resources(self, session_id: str, session_data: Dict):
|
||||
# Helper to clean up queues when session is deleted
|
||||
logger.info(f"Cleaning up resources for session [Session ID: {session_id}]")
|
||||
# Close general SSE queues
|
||||
general_queues = session_data.get("general_sse_queues", [])
|
||||
for queue in general_queues:
|
||||
try:
|
||||
await queue.put(STREAM_END_MARKER)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error putting end marker in general queue for session {session_id}: {e}")
|
||||
# Close request-specific SSE queues
|
||||
request_queues = session_data.get("request_queues", {})
|
||||
for req_id, queue in request_queues.items():
|
||||
try:
|
||||
await queue.put(STREAM_END_MARKER)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error putting end marker in request queue {req_id} for session {session_id}: {e}")
|
||||
logger.info(f"Finished cleaning resources for session {session_id}")
|
||||
|
||||
# This method might belong in the main app or a shared utility if needed by both servers
|
||||
# async def cleanup_idle_sessions(self):
|
||||
# # ... (implementation - needs access to self.client_sessions) ...
|
||||
# pass
|
||||
|
||||
# This method might belong in the main app or a shared utility
|
||||
# async def broadcast_message(self, message):
|
||||
# # ... (implementation - needs access to self.client_sessions of BOTH servers?) ...
|
||||
# pass
|
||||
|
||||
# This method is specific to streamable http tool calls
|
||||
async def send_visualization_data(self, session_id: str, request_id: Any, visualization_data: Any):
|
||||
"""Sends visualization data as a notification on the request stream."""
|
||||
if session_id not in self.client_sessions:
|
||||
logger.warning(f"Cannot send visualization: Session {session_id} not found.")
|
||||
return
|
||||
queue = self.client_sessions.get(session_id, {}).get("request_queues", {}).get(request_id)
|
||||
if not queue:
|
||||
logger.warning(f"Cannot send visualization: Request queue {request_id} not found for session {session_id}.")
|
||||
return
|
||||
|
||||
notification = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/visualization",
|
||||
"params": visualization_data
|
||||
}
|
||||
try:
|
||||
await queue.put(notification)
|
||||
logger.info(f"Sent visualization notification [Session: {session_id}, Req: {request_id}]")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending visualization notification [Session: {session_id}, Req: {request_id}]: {e}")
|
||||
|
||||
# This might belong in main app or shared utility
|
||||
# async def send_periodic_updates(self):
|
||||
# # ... (implementation) ...
|
||||
# pass
|
||||
|
||||
# End of class DorisMCPStreamableServer
|
||||
23
doris_mcp_server/tools/__init__.py
Normal file
23
doris_mcp_server/tools/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from .mcp_doris_tools import (
|
||||
mcp_doris_exec_query,
|
||||
mcp_doris_get_table_schema,
|
||||
mcp_doris_get_db_table_list,
|
||||
mcp_doris_get_db_list,
|
||||
mcp_doris_get_table_comment,
|
||||
mcp_doris_get_table_column_comments,
|
||||
mcp_doris_get_table_indexes,
|
||||
mcp_doris_get_recent_audit_logs
|
||||
)
|
||||
|
||||
# The __all__ list should reflect the registered tool names,
|
||||
# even though the implementation functions have the prefix.
|
||||
__all__ = [
|
||||
"exec_query",
|
||||
"get_table_schema",
|
||||
"get_db_table_list",
|
||||
"get_db_list",
|
||||
"get_table_comment",
|
||||
"get_table_column_comments",
|
||||
"get_table_indexes",
|
||||
"get_recent_audit_logs"
|
||||
]
|
||||
202
doris_mcp_server/tools/mcp_doris_tools.py
Normal file
202
doris_mcp_server/tools/mcp_doris_tools.py
Normal file
@@ -0,0 +1,202 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Doris MCP Tool Implementations
|
||||
|
||||
Includes exec_query and new tools based on schema_extractor.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
import pandas as pd
|
||||
|
||||
# --- Use absolute imports ---
|
||||
from doris_mcp_server.utils.schema_extractor import MetadataExtractor
|
||||
from doris_mcp_server.utils.sql_executor_tools import execute_sql_query
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger("doris-mcp-tools")
|
||||
|
||||
# --- Helper Function to format response ---
|
||||
def _format_response(success: bool, result: Any = None, error: str = None, message: str = "") -> Dict[str, Any]:
|
||||
response_data = {
|
||||
"success": success,
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
if success and result is not None:
|
||||
# Handle DataFrame serialization
|
||||
if isinstance(result, pd.DataFrame):
|
||||
try:
|
||||
# Convert DataFrame to JSON records format
|
||||
response_data["result"] = json.loads(result.to_json(orient='records', date_format='iso'))
|
||||
except Exception as df_err:
|
||||
logger.error(f"DataFrame to JSON conversion failed: {df_err}")
|
||||
# Fallback or specific error handling for DataFrame
|
||||
response_data["result"] = {"error": "Failed to serialize DataFrame result"}
|
||||
response_data["success"] = False # Mark as failed if serialization fails
|
||||
response_data["error"] = f"DataFrame serialization error: {str(df_err)}"
|
||||
else:
|
||||
response_data["result"] = result
|
||||
response_data["message"] = message or "Operation successful" # Translated: Operation successful
|
||||
elif not success:
|
||||
response_data["error"] = error or "Unknown error" # Translated: Unknown error
|
||||
response_data["message"] = message or "Operation failed" # Translated: Operation failed
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(response_data, ensure_ascii=False, default=str) # Use default=str for non-serializable types
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
async def mcp_doris_exec_query(sql: str = None, db_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""
|
||||
Executes an SQL query and returns the result.
|
||||
|
||||
Args:
|
||||
sql (str): The SQL query to execute.
|
||||
db_name (str, optional): Target database name. Defaults to the configured default database.
|
||||
max_rows (int, optional): Maximum number of rows to return. Defaults to 100.
|
||||
timeout (int, optional): Query timeout in seconds. Defaults to 30.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing the query result or an error.
|
||||
"""
|
||||
logger.info(f"MCP Tool Call: mcp_doris_exec_query, SQL: {sql}, DB: {db_name}, MaxRows: {max_rows}, Timeout: {timeout}")
|
||||
try:
|
||||
if not sql:
|
||||
return _format_response(success=False, error="SQL statement not provided", message="Please provide the SQL statement to execute")
|
||||
|
||||
# Build parameters to pass to execute_sql_query
|
||||
exec_ctx = {
|
||||
"params": {
|
||||
"sql": sql,
|
||||
"db_name": db_name,
|
||||
"max_rows": max_rows,
|
||||
"timeout": timeout
|
||||
}
|
||||
}
|
||||
|
||||
# Directly call execute_sql_query to execute the query
|
||||
exec_result = await execute_sql_query(exec_ctx)
|
||||
|
||||
# The format returned by execute_sql_query is {'content': [{'type': 'text', 'text': json_string}]}
|
||||
# Need to parse the internal JSON string
|
||||
if exec_result and 'content' in exec_result and len(exec_result['content']) > 0 and 'text' in exec_result['content'][0]:
|
||||
try:
|
||||
# Parse JSON string
|
||||
result_data = json.loads(exec_result['content'][0]['text'])
|
||||
|
||||
# Directly return the parsed result obtained from execute_sql_query
|
||||
# This result is already in the format {"success": ..., "data": ..., "columns": ...} or {"success": false, "error": ...}
|
||||
# _format_response would wrap it again, but here we directly use the parsed data
|
||||
# Note: This changes the original return structure of this function; it now directly returns the output of sql_executor
|
||||
# If the _format_response wrapper needs to be maintained, the code below needs adjustment
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(result_data, ensure_ascii=False, default=str)
|
||||
}
|
||||
]
|
||||
}
|
||||
except json.JSONDecodeError as json_err:
|
||||
logger.error(f"Failed to parse execute_sql_query result: {json_err}")
|
||||
return _format_response(success=False, error=str(json_err), message="Error parsing SQL execution result")
|
||||
except Exception as parse_err:
|
||||
logger.error(f"Unexpected error occurred while processing execute_sql_query result: {parse_err}", exc_info=True)
|
||||
return _format_response(success=False, error=str(parse_err), message="Unknown error occurred while processing SQL execution result")
|
||||
else:
|
||||
logger.error(f"execute_sql_query returned an unexpected format: {exec_result}")
|
||||
return _format_response(success=False, error="SQL executor returned invalid format", message="Internal error executing SQL query")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_exec_query: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error executing SQL query")
|
||||
|
||||
|
||||
async def mcp_doris_get_table_schema(table_name: str, db_name: str = None) -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_table_schema, Table: {table_name}, DB: {db_name}")
|
||||
if not table_name:
|
||||
return _format_response(success=False, error="Missing table_name parameter")
|
||||
try:
|
||||
extractor = MetadataExtractor(db_name=db_name)
|
||||
schema = extractor.get_table_schema(table_name=table_name, db_name=db_name)
|
||||
if not schema:
|
||||
return _format_response(success=False, error="Table not found or has no columns", message=f"Could not get schema for table {db_name or extractor.db_name}.{table_name}")
|
||||
return _format_response(success=True, result=schema)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_table_schema: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting table schema")
|
||||
|
||||
async def mcp_doris_get_db_table_list(db_name: str = None) -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_db_table_list, DB: {db_name}")
|
||||
try:
|
||||
extractor = MetadataExtractor(db_name=db_name)
|
||||
tables = extractor.get_database_tables(db_name=db_name)
|
||||
return _format_response(success=True, result=tables)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_db_table_list: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting database table list")
|
||||
|
||||
async def mcp_doris_get_db_list() -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_db_list")
|
||||
try:
|
||||
extractor = MetadataExtractor()
|
||||
databases = extractor.get_all_databases()
|
||||
return _format_response(success=True, result=databases)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_db_list: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting database list")
|
||||
|
||||
async def mcp_doris_get_table_comment(table_name: str, db_name: str = None) -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_table_comment, Table: {table_name}, DB: {db_name}")
|
||||
if not table_name:
|
||||
return _format_response(success=False, error="Missing table_name parameter")
|
||||
try:
|
||||
extractor = MetadataExtractor(db_name=db_name)
|
||||
comment = extractor.get_table_comment(table_name=table_name, db_name=db_name)
|
||||
return _format_response(success=True, result=comment)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_table_comment: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting table comment")
|
||||
|
||||
async def mcp_doris_get_table_column_comments(table_name: str, db_name: str = None) -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_table_column_comments, Table: {table_name}, DB: {db_name}")
|
||||
if not table_name:
|
||||
return _format_response(success=False, error="Missing table_name parameter")
|
||||
try:
|
||||
extractor = MetadataExtractor(db_name=db_name)
|
||||
comments = extractor.get_column_comments(table_name=table_name, db_name=db_name)
|
||||
return _format_response(success=True, result=comments)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_table_column_comments: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting column comments")
|
||||
|
||||
async def mcp_doris_get_table_indexes(table_name: str, db_name: str = None) -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_table_indexes, Table: {table_name}, DB: {db_name}")
|
||||
if not table_name:
|
||||
return _format_response(success=False, error="Missing table_name parameter")
|
||||
try:
|
||||
extractor = MetadataExtractor(db_name=db_name)
|
||||
indexes = extractor.get_table_indexes(table_name=table_name, db_name=db_name)
|
||||
return _format_response(success=True, result=indexes)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_table_indexes: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting table indexes")
|
||||
|
||||
async def mcp_doris_get_recent_audit_logs(days: int = 7, limit: int = 100) -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_recent_audit_logs, Days: {days}, Limit: {limit}")
|
||||
try:
|
||||
extractor = MetadataExtractor()
|
||||
logs_df = extractor.get_recent_audit_logs(days=days, limit=limit)
|
||||
return _format_response(success=True, result=logs_df)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_recent_audit_logs: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting audit logs")
|
||||
141
doris_mcp_server/tools/tool_initializer.py
Normal file
141
doris_mcp_server/tools/tool_initializer.py
Normal file
@@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Tool Initialization Module
|
||||
|
||||
Centralized initialization of all tools, ensuring they are correctly registered with MCP
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional
|
||||
import json
|
||||
from datetime import datetime
|
||||
import traceback
|
||||
|
||||
# Import Context
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
# Import doris mcp tools
|
||||
from doris_mcp_server.tools.mcp_doris_tools import (
|
||||
mcp_doris_exec_query,
|
||||
mcp_doris_get_table_schema,
|
||||
mcp_doris_get_db_table_list,
|
||||
mcp_doris_get_db_list,
|
||||
mcp_doris_get_table_comment,
|
||||
mcp_doris_get_table_column_comments,
|
||||
mcp_doris_get_table_indexes,
|
||||
mcp_doris_get_recent_audit_logs
|
||||
)
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger("doris-mcp-tools-initializer")
|
||||
|
||||
async def register_mcp_tools(mcp):
|
||||
"""Register MCP tool functions
|
||||
|
||||
Args:
|
||||
mcp: FastMCP instance
|
||||
"""
|
||||
logger.info("Starting to register MCP tools...")
|
||||
|
||||
try:
|
||||
# Register Tool: Execute SQL Query (Using long description string including parameters)
|
||||
@mcp.tool("exec_query", description="""[Function Description]: Execute SQL query and return result command (executed by the client).\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- sql (string) [Required] - SQL statement to execute\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- max_rows (integer) [Optional] - Maximum number of rows to return, default 100
|
||||
- timeout (integer) [Optional] - Query timeout in seconds, default 30""")
|
||||
async def exec_query_tool(sql: str, db_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""Wrapper: Execute SQL query and return result command"""
|
||||
# Note: ctx parameter is no longer needed here as we receive named parameters directly
|
||||
return await mcp_doris_exec_query(sql=sql, db_name=db_name, max_rows=max_rows, timeout=timeout)
|
||||
|
||||
# Register Tool: Get Table Schema (Keep long description string including parameters)
|
||||
@mcp.tool("get_table_schema", description="""[Function Description]: Get detailed structure information of the specified table (columns, types, comments, etc.).\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n""")
|
||||
async def get_table_schema_tool(table_name: str, db_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table schema"""
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_schema(table_name=table_name, db_name=db_name)
|
||||
|
||||
# Register Tool: Get Database Table List (Keep long description string including parameters)
|
||||
@mcp.tool("get_db_table_list", description="""[Function Description]: Get a list of all table names in the specified database.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n""")
|
||||
async def get_db_table_list_tool(db_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get database table list"""
|
||||
return await mcp_doris_get_db_table_list(db_name=db_name)
|
||||
|
||||
# Register Tool: Get Database List (Keep long description string including parameters)
|
||||
# Note: Although the description mentions random_string, the wrapper function signature does not. See how mcp handles this.
|
||||
@mcp.tool("get_db_list", description="""[Function Description]: Get a list of all database names on the server.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n""")
|
||||
async def get_db_list_tool() -> Dict[str, Any]: # Function signature has no parameters
|
||||
"""Wrapper: Get database list"""
|
||||
return await mcp_doris_get_db_list()
|
||||
|
||||
# Register Tool: Get Table Comment (Keep long description string including parameters)
|
||||
@mcp.tool("get_table_comment", description="""[Function Description]: Get the comment information for the specified table.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n""")
|
||||
async def get_table_comment_tool(table_name: str, db_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table comment"""
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_comment(table_name=table_name, db_name=db_name)
|
||||
|
||||
# Register Tool: Get Table Column Comments (Keep long description string including parameters)
|
||||
@mcp.tool("get_table_column_comments", description="""[Function Description]: Get comment information for all columns in the specified table.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n""")
|
||||
async def get_table_column_comments_tool(table_name: str, db_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table column comments"""
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_column_comments(table_name=table_name, db_name=db_name)
|
||||
|
||||
# Register Tool: Get Table Indexes (Keep long description string including parameters)
|
||||
@mcp.tool("get_table_indexes", description="""[Function Description]: Get index information for the specified table.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n""")
|
||||
async def get_table_indexes_tool(table_name: str, db_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table indexes"""
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_indexes(table_name=table_name, db_name=db_name)
|
||||
|
||||
# Register Tool: Get Recent Audit Logs (Keep long description string including parameters)
|
||||
@mcp.tool("get_recent_audit_logs", description="""[Function Description]: Get audit log records for a recent period.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- days (integer) [Optional] - Number of recent days of logs to retrieve, default is 7\n
|
||||
- limit (integer) [Optional] - Maximum number of records to return, default is 100\n""")
|
||||
async def get_recent_audit_logs_tool(days: int = 7, limit: int = 100) -> Dict[str, Any]:
|
||||
"""Wrapper: Get recent audit logs"""
|
||||
try:
|
||||
days = int(days)
|
||||
limit = int(limit)
|
||||
except (ValueError, TypeError):
|
||||
return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "days and limit parameters must be integers"})}]}
|
||||
return await mcp_doris_get_recent_audit_logs(days=days, limit=limit)
|
||||
|
||||
# Get tool count
|
||||
tools_count = len(await mcp.list_tools())
|
||||
logger.info(f"Registered all MCP tools, total {tools_count} tools")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering MCP tools: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
1
doris_mcp_server/utils/__init__.py
Normal file
1
doris_mcp_server/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Mark directory as a package
|
||||
100
doris_mcp_server/utils/db.py
Normal file
100
doris_mcp_server/utils/db.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import os
|
||||
import json
|
||||
import pymysql
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dotenv import load_dotenv
|
||||
import re
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Database configuration
|
||||
DB_CONFIG = {
|
||||
"host": os.getenv("DB_HOST", "localhost"),
|
||||
"port": int(os.getenv("DB_PORT", "9030")),
|
||||
"user": os.getenv("DB_USER", "root"),
|
||||
"password": os.getenv("DB_PASSWORD", ""),
|
||||
"database": os.getenv("DB_DATABASE", ""),
|
||||
"charset": "utf8mb4",
|
||||
"cursorclass": pymysql.cursors.DictCursor
|
||||
}
|
||||
|
||||
def get_db_connection(db_name: Optional[str] = None):
|
||||
"""
|
||||
Get database connection
|
||||
|
||||
Args:
|
||||
db_name: Specify the database name to connect to, use default config if None
|
||||
|
||||
Returns:
|
||||
Database connection
|
||||
"""
|
||||
if db_name:
|
||||
# Use default config but override database name
|
||||
config = DB_CONFIG.copy()
|
||||
config["database"] = db_name
|
||||
return pymysql.connect(**config)
|
||||
else:
|
||||
# Use default config
|
||||
return pymysql.connect(**DB_CONFIG)
|
||||
|
||||
def get_db_name() -> str:
|
||||
"""Get the currently configured default database name"""
|
||||
return DB_CONFIG["database"] or os.getenv("DB_DATABASE", "")
|
||||
|
||||
def execute_query(sql, db_name: Optional[str] = None):
|
||||
"""
|
||||
Execute SQL query and return results
|
||||
|
||||
Args:
|
||||
sql: SQL query statement
|
||||
db_name: Specify the database name to connect to, use default config if None
|
||||
|
||||
Returns:
|
||||
Query results
|
||||
"""
|
||||
conn = get_db_connection(db_name)
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
# Set connection character set to utf8 before executing query
|
||||
cursor.execute("SET NAMES utf8")
|
||||
|
||||
# Execute the actual query
|
||||
cursor.execute(sql)
|
||||
result = cursor.fetchall()
|
||||
return result
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def execute_query_df(sql, db_name: Optional[str] = None):
|
||||
"""
|
||||
Execute SQL query and return pandas DataFrame
|
||||
|
||||
Args:
|
||||
sql: SQL query statement
|
||||
db_name: Specify the database name to connect to, use default config if None
|
||||
|
||||
Returns:
|
||||
pandas DataFrame
|
||||
"""
|
||||
conn = get_db_connection(db_name)
|
||||
try:
|
||||
# Use a temporary cursor to execute the query and get results
|
||||
with conn.cursor() as cursor:
|
||||
# Set connection character set to utf8 before executing query
|
||||
cursor.execute("SET NAMES utf8")
|
||||
|
||||
# Execute the actual query
|
||||
cursor.execute(sql)
|
||||
result = cursor.fetchall()
|
||||
|
||||
# If no results, return empty DataFrame
|
||||
if not result:
|
||||
return pd.DataFrame()
|
||||
|
||||
# Manually convert dict results to DataFrame
|
||||
df = pd.DataFrame(result)
|
||||
return df
|
||||
finally:
|
||||
conn.close()
|
||||
226
doris_mcp_server/utils/logger.py
Normal file
226
doris_mcp_server/utils/logger.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""
|
||||
Unified Logging Configuration Module
|
||||
|
||||
Provides unified logging configuration, including:
|
||||
- General logs: Record all program execution information
|
||||
- Audit logs: Record JSON data for key operations and processing results
|
||||
- Error logs: Specifically record program exceptions and errors
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import logging.handlers
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from datetime import datetime
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Get project root directory
|
||||
PROJECT_ROOT = Path(__file__).parents[2].absolute()
|
||||
|
||||
# Get log configuration from environment variables
|
||||
LOG_DIR = os.getenv("LOG_DIR", str(PROJECT_ROOT / "logs"))
|
||||
LOG_PREFIX = os.getenv("LOG_PREFIX", "doris_mcp")
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
LOG_MAX_DAYS = int(os.getenv("LOG_MAX_DAYS", "30"))
|
||||
# Whether to output logs to the console (should be disabled when running as a service)
|
||||
CONSOLE_LOGGING = os.getenv("CONSOLE_LOGGING", "false").lower() == "true"
|
||||
# Whether stdio transport mode is being used
|
||||
STDIO_MODE = os.getenv("MCP_TRANSPORT_TYPE", "").lower() == "stdio"
|
||||
|
||||
def purge_old_logs():
|
||||
"""Clean up expired log files"""
|
||||
# --- Only perform cleanup in non-Stdio mode ---
|
||||
if STDIO_MODE:
|
||||
return
|
||||
try:
|
||||
now = datetime.now()
|
||||
log_dir = Path(LOG_DIR)
|
||||
# Check if directory exists and is readable/writable
|
||||
if not log_dir.is_dir() or not os.access(LOG_DIR, os.W_OK):
|
||||
if not STDIO_MODE: # Avoid printing to stdout in stdio mode
|
||||
print(f"Warning: Log directory {LOG_DIR} not accessible, skipping log purge.", file=sys.stderr)
|
||||
return
|
||||
|
||||
for log_file in log_dir.glob(f"{LOG_PREFIX}*.20*"):
|
||||
# Parse date
|
||||
file_name = log_file.name
|
||||
date_str = None
|
||||
|
||||
# Try to find the date part
|
||||
parts = file_name.split('.')
|
||||
for part in parts:
|
||||
if part.startswith('20') and len(part) == 8: # 20YYMMDD format
|
||||
date_str = part
|
||||
break
|
||||
|
||||
if date_str:
|
||||
try:
|
||||
file_date = datetime.strptime(date_str, '%Y%m%d')
|
||||
days_old = (now - file_date).days
|
||||
|
||||
if days_old > LOG_MAX_DAYS:
|
||||
os.remove(log_file)
|
||||
if not STDIO_MODE:
|
||||
print(f"Deleted expired log file: {log_file}")
|
||||
except (ValueError, OSError) as e:
|
||||
if not STDIO_MODE:
|
||||
print(f"Error processing log file {file_name}: {e}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
if not STDIO_MODE:
|
||||
print(f"Error cleaning up logs: {e}", file=sys.stderr)
|
||||
|
||||
# Force disable console log output if in stdio mode
|
||||
if STDIO_MODE:
|
||||
CONSOLE_LOGGING = False
|
||||
|
||||
# --- Only create log directory and clean old logs in non-Stdio mode ---
|
||||
if not STDIO_MODE:
|
||||
try:
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
# Clean up expired logs on startup (also moved here, as it only handles file logs)
|
||||
purge_old_logs()
|
||||
except OSError as e:
|
||||
# If directory creation fails (e.g., permission issue), print warning but continue to avoid startup failure
|
||||
print(f"Warning: Failed to create log directory {LOG_DIR} or purge logs: {e}", file=sys.stderr)
|
||||
|
||||
# Log file paths (definition still needed, but files might not be created/used)
|
||||
LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.log")
|
||||
AUDIT_LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.audit")
|
||||
ERROR_LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.error")
|
||||
|
||||
# Log level mapping
|
||||
LOG_LEVELS = {
|
||||
"DEBUG": logging.DEBUG,
|
||||
"INFO": logging.INFO,
|
||||
"WARNING": logging.WARNING,
|
||||
"ERROR": logging.ERROR,
|
||||
"CRITICAL": logging.CRITICAL
|
||||
}
|
||||
|
||||
# Log format
|
||||
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
AUDIT_FORMAT = '%(asctime)s - %(name)s - %(message)s'
|
||||
ERROR_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(pathname)s:%(lineno)d - %(message)s'
|
||||
|
||||
# Dedicated audit log level
|
||||
AUDIT = 25 # Level between INFO and WARNING
|
||||
logging.addLevelName(AUDIT, "AUDIT")
|
||||
|
||||
# Logger object cache
|
||||
_loggers: Dict[str, logging.Logger] = {}
|
||||
|
||||
# Handler type mapping, used to ensure no duplicates are added
|
||||
_handler_types = {
|
||||
'console': logging.StreamHandler,
|
||||
'file': logging.handlers.TimedRotatingFileHandler,
|
||||
'audit': logging.handlers.TimedRotatingFileHandler,
|
||||
'error': logging.handlers.TimedRotatingFileHandler
|
||||
}
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""
|
||||
Get a logger with the specified name
|
||||
|
||||
Args:
|
||||
name: Logger name
|
||||
|
||||
Returns:
|
||||
logging.Logger: Configured logger
|
||||
"""
|
||||
if name in _loggers:
|
||||
return _loggers[name]
|
||||
|
||||
# Create logger
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(LOG_LEVELS.get(LOG_LEVEL, logging.INFO))
|
||||
|
||||
# Avoid duplicate logs caused by propagation
|
||||
logger.propagate = False
|
||||
|
||||
# Check if handlers already exist to avoid duplicates
|
||||
handler_types = set(type(h) for h in logger.handlers)
|
||||
|
||||
# Add audit log method
|
||||
def audit(self, message, *args, **kwargs):
|
||||
self.log(AUDIT, message, *args, **kwargs)
|
||||
|
||||
logger.audit = audit.__get__(logger)
|
||||
|
||||
# General log handler - output to console (only if enabled)
|
||||
if CONSOLE_LOGGING and _handler_types['console'] not in handler_types:
|
||||
# Use stderr instead of stdout to avoid conflicts with MCP communication
|
||||
console_handler = logging.StreamHandler(sys.stderr)
|
||||
console_handler.setFormatter(logging.Formatter(LOG_FORMAT))
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# --- Only add file handlers in non-Stdio mode ---
|
||||
if not STDIO_MODE:
|
||||
# General log handler - daily rotating file
|
||||
if _handler_types['file'] not in handler_types:
|
||||
try: # Add try-except block
|
||||
file_handler = logging.handlers.TimedRotatingFileHandler(
|
||||
LOG_FILE,
|
||||
when='midnight',
|
||||
interval=1,
|
||||
backupCount=LOG_MAX_DAYS,
|
||||
encoding='utf-8'
|
||||
)
|
||||
file_handler.setFormatter(logging.Formatter(LOG_FORMAT))
|
||||
file_handler.suffix = "%Y%m%d"
|
||||
logger.addHandler(file_handler)
|
||||
except OSError as e:
|
||||
print(f"Warning: Failed to add file log handler for {LOG_FILE}: {e}", file=sys.stderr)
|
||||
|
||||
# Audit log handler - only logs AUDIT level
|
||||
if _handler_types['audit'] not in handler_types:
|
||||
try: # Add try-except block
|
||||
audit_handler = logging.handlers.TimedRotatingFileHandler(
|
||||
AUDIT_LOG_FILE,
|
||||
when='midnight',
|
||||
interval=1,
|
||||
backupCount=LOG_MAX_DAYS,
|
||||
encoding='utf-8'
|
||||
)
|
||||
audit_handler.setFormatter(logging.Formatter(AUDIT_FORMAT))
|
||||
audit_handler.suffix = "%Y%m%d"
|
||||
audit_handler.setLevel(AUDIT)
|
||||
audit_handler.addFilter(lambda record: record.levelno == AUDIT)
|
||||
logger.addHandler(audit_handler)
|
||||
except OSError as e:
|
||||
print(f"Warning: Failed to add audit log handler for {AUDIT_LOG_FILE}: {e}", file=sys.stderr)
|
||||
|
||||
# Error log handler - only logs ERROR level and above
|
||||
if _handler_types['error'] not in handler_types:
|
||||
try: # Add try-except block
|
||||
error_handler = logging.handlers.TimedRotatingFileHandler(
|
||||
ERROR_LOG_FILE,
|
||||
when='midnight',
|
||||
interval=1,
|
||||
backupCount=LOG_MAX_DAYS,
|
||||
encoding='utf-8'
|
||||
)
|
||||
error_handler.setFormatter(logging.Formatter(ERROR_FORMAT))
|
||||
error_handler.suffix = "%Y%m%d"
|
||||
error_handler.setLevel(logging.ERROR)
|
||||
logger.addHandler(error_handler)
|
||||
except OSError as e:
|
||||
print(f"Warning: Failed to add error log handler for {ERROR_LOG_FILE}: {e}", file=sys.stderr)
|
||||
|
||||
# Cache logger
|
||||
_loggers[name] = logger
|
||||
|
||||
return logger
|
||||
|
||||
# Default logger
|
||||
logger = get_logger('doris_mcp')
|
||||
|
||||
# Audit logger - for recording processing results, business operations, etc.
|
||||
audit_logger = get_logger('audit')
|
||||
|
||||
# Call to clean logs moved after directory creation, and added non-stdio check
|
||||
1013
doris_mcp_server/utils/schema_extractor.py
Normal file
1013
doris_mcp_server/utils/schema_extractor.py
Normal file
File diff suppressed because it is too large
Load Diff
349
doris_mcp_server/utils/sql_executor_tools.py
Normal file
349
doris_mcp_server/utils/sql_executor_tools.py
Normal file
@@ -0,0 +1,349 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
SQL Execution Tool
|
||||
|
||||
Responsible for executing SQL queries and handling results
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
import time
|
||||
from typing import Dict, Any
|
||||
import re
|
||||
import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger("doris-mcp.sql-executor")
|
||||
|
||||
# Add environment variable control for whether to perform SQL security checks
|
||||
ENABLE_SQL_SECURITY_CHECK = os.environ.get('ENABLE_SQL_SECURITY_CHECK', 'true').lower() == 'true'
|
||||
|
||||
async def execute_sql_query(ctx) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute SQL query and return results
|
||||
|
||||
Args:
|
||||
ctx: Context object or dictionary containing request parameters
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Execution result
|
||||
"""
|
||||
try:
|
||||
# Support the case where the passed argument is a dictionary
|
||||
if isinstance(ctx, dict) and 'params' in ctx:
|
||||
params = ctx['params']
|
||||
else:
|
||||
params = ctx.params
|
||||
|
||||
sql = params.get("sql")
|
||||
db_name = params.get("db_name", os.getenv("DB_DATABASE", ""))
|
||||
max_rows = params.get("max_rows", 1000) # Maximum number of rows to return
|
||||
timeout = params.get("timeout", 30) # Timeout in seconds
|
||||
|
||||
if not sql:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"success": False,
|
||||
"error": "Missing SQL parameter",
|
||||
"message": "Please provide the SQL query to execute"
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# First check SQL security
|
||||
security_result = await _check_sql_security(sql)
|
||||
if not security_result.get("is_safe", False):
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"success": False,
|
||||
"error": "SQL security check failed",
|
||||
"message": "Query contains unsafe operations and cannot be executed",
|
||||
"security_issues": security_result.get("security_issues", [])
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Import database connection tool
|
||||
from doris_mcp_server.utils.db import execute_query
|
||||
|
||||
if not sql:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"success": False,
|
||||
"error": "Missing SQL parameter",
|
||||
"message": "Please provide the SQL query to execute"
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Ensure SELECT statements include a LIMIT clause
|
||||
sql_lower = sql.lower().strip()
|
||||
if sql_lower.startswith("select") and "limit" not in sql_lower:
|
||||
sql = sql.rstrip(";") + f" LIMIT {max_rows};"
|
||||
|
||||
# Start timer
|
||||
start_time = time.time()
|
||||
|
||||
# Execute query
|
||||
try:
|
||||
result = execute_query(sql, db_name)
|
||||
|
||||
# Calculate execution time
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# Build return result
|
||||
if isinstance(result, list):
|
||||
# Handle list of query results
|
||||
row_count = len(result)
|
||||
|
||||
# Extract column names
|
||||
if hasattr(result[0], "_fields"):
|
||||
# If it's a named tuple
|
||||
columns = list(result[0]._fields)
|
||||
else:
|
||||
# Otherwise, assume it's a dictionary
|
||||
columns = list(result[0].keys()) if isinstance(result[0], dict) else []
|
||||
|
||||
# Convert results to serializable format
|
||||
data = []
|
||||
for row in result:
|
||||
row_dict = {}
|
||||
if hasattr(row, "_asdict"):
|
||||
# If it's a named tuple
|
||||
row_dict = row._asdict()
|
||||
elif isinstance(row, dict):
|
||||
# If it's a dictionary
|
||||
row_dict = row
|
||||
else:
|
||||
# If it's a list or tuple
|
||||
row_dict = dict(zip(columns, row)) if columns else row
|
||||
|
||||
# Handle special types to make them JSON serializable
|
||||
serialized_row = _serialize_row_data(row_dict)
|
||||
data.append(serialized_row)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"success": True,
|
||||
"sql": sql,
|
||||
"row_count": row_count,
|
||||
"columns": columns,
|
||||
"data": data[:max_rows], # Limit returned rows
|
||||
"execution_time": execution_time,
|
||||
"truncated": row_count > max_rows
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
# Handle other types of results
|
||||
other_response = {
|
||||
"success": True,
|
||||
"sql": sql,
|
||||
"result": str(result),
|
||||
"execution_time": execution_time
|
||||
}
|
||||
other_response = _serialize_row_data(other_response)
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(other_response, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as db_error:
|
||||
error_message = str(db_error)
|
||||
|
||||
# Try to get more detailed error information
|
||||
error_details = {}
|
||||
if "timeout" in error_message.lower():
|
||||
error_details["type"] = "timeout"
|
||||
error_details["suggestion"] = "Query timed out, please optimize SQL or increase timeout"
|
||||
elif "syntax" in error_message.lower():
|
||||
error_details["type"] = "syntax"
|
||||
error_details["suggestion"] = "SQL syntax error, please check syntax"
|
||||
elif "not found" in error_message.lower() or "doesn't exist" in error_message.lower():
|
||||
error_details["type"] = "not_found"
|
||||
error_details["suggestion"] = "Table or column not found, please check table and column names"
|
||||
else:
|
||||
error_details["type"] = "unknown"
|
||||
error_details["suggestion"] = "Please check the SQL statement and try simplifying the query"
|
||||
|
||||
# Create error response
|
||||
error_response = {
|
||||
"success": False,
|
||||
"error": error_message,
|
||||
"error_details": error_details,
|
||||
"sql": sql,
|
||||
"db_name": db_name
|
||||
}
|
||||
|
||||
# Ensure error response is also serializable
|
||||
error_response = _serialize_row_data(error_response)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(error_response, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute SQL query: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
error_response = {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": "Error occurred while executing SQL query"
|
||||
}
|
||||
|
||||
# Ensure error response is also serializable
|
||||
error_response = _serialize_row_data(error_response)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(error_response, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Helper function
|
||||
async def _check_sql_security(sql: str) -> Dict[str, Any]:
|
||||
"""Check SQL security"""
|
||||
# If environment variable is set to disable security check, return safe immediately
|
||||
if not ENABLE_SQL_SECURITY_CHECK:
|
||||
return {
|
||||
"is_safe": True,
|
||||
"security_issues": []
|
||||
}
|
||||
|
||||
# Check if SQL contains dangerous operations
|
||||
sql_lower = sql.lower()
|
||||
|
||||
# Check if it's a read-only query type
|
||||
is_read_only = sql_lower.strip().startswith(("select ", "show ", "desc ", "describe ", "explain "))
|
||||
|
||||
# Define list of dangerous operations (checked for both read-only and non-read-only queries)
|
||||
dangerous_operations = [
|
||||
(r'\bdelete\b', "DELETE operation"),
|
||||
(r'\bdrop\b', "DROP TABLE/DATABASE operation"),
|
||||
(r'\btruncate\b', "TRUNCATE TABLE operation"),
|
||||
(r'\bupdate\b', "UPDATE operation"),
|
||||
(r'\binsert\b', "INSERT operation"),
|
||||
(r'\balter\b', "ALTER TABLE structure operation"),
|
||||
(r'\bcreate\b', "CREATE TABLE/DATABASE operation"),
|
||||
(r'\bgrant\b', "GRANT operation"),
|
||||
(r'\brevoke\b', "REVOKE permission operation"),
|
||||
(r'\bexec\b', "EXECUTE stored procedure"),
|
||||
(r'\bxp_', "Extended stored procedure, potential security risk"),
|
||||
(r'\bshutdown\b', "SHUTDOWN database operation"),
|
||||
(r'\bunion\s+all\s+select\b', "UNION statement, potential SQL injection"),
|
||||
(r'\bunion\s+select\b', "UNION statement, potential SQL injection"),
|
||||
(r'\binto\s+outfile\b', "Write to file operation"),
|
||||
(r'\bload_file\b', "Load file operation")
|
||||
]
|
||||
|
||||
# Dangerous operations checked only for non-read-only queries
|
||||
non_readonly_operations = []
|
||||
if not is_read_only:
|
||||
non_readonly_operations = [
|
||||
(r'--', "SQL comment, potential SQL injection"),
|
||||
(r'/\*', "SQL block comment, potential SQL injection")
|
||||
]
|
||||
|
||||
# Check if dangerous operations are included
|
||||
security_issues = []
|
||||
|
||||
# Check dangerous operations applicable to all queries
|
||||
for operation, description in dangerous_operations:
|
||||
if re.search(operation, sql_lower):
|
||||
# For specific keywords in read-only queries, differentiate if used as independent operations
|
||||
if is_read_only and operation in [r'\bcreate\b', r'\bdrop\b', r'\bdelete\b', r'\binsert\b', r'\bupdate\b', r'\balter\b']:
|
||||
# Check if used as DDL/DML keyword, e.g., CREATE TABLE, DROP DATABASE
|
||||
pattern = operation + r'\s+(?:table|database|view|index|procedure|function|trigger|event)'
|
||||
if re.search(pattern, sql_lower):
|
||||
security_issues.append({
|
||||
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
|
||||
"description": description,
|
||||
"severity": "High"
|
||||
})
|
||||
else:
|
||||
security_issues.append({
|
||||
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
|
||||
"description": description,
|
||||
"severity": "High"
|
||||
})
|
||||
|
||||
# Check dangerous operations specific to non-read-only queries
|
||||
for operation, description in non_readonly_operations:
|
||||
if re.search(operation, sql_lower):
|
||||
security_issues.append({
|
||||
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
|
||||
"description": description,
|
||||
"severity": "Medium"
|
||||
})
|
||||
|
||||
return {
|
||||
"is_safe": len(security_issues) == 0,
|
||||
"security_issues": security_issues
|
||||
}
|
||||
|
||||
def _serialize_row_data(row_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert special types in row data (like date, time, Decimal) to JSON serializable format
|
||||
|
||||
Args:
|
||||
row_data: Row data dictionary
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Processed serializable dictionary
|
||||
"""
|
||||
serialized_data = {}
|
||||
for key, value in row_data.items():
|
||||
if value is None:
|
||||
serialized_data[key] = None
|
||||
elif isinstance(value, (datetime.date, datetime.datetime)):
|
||||
# Convert date and time types to ISO format string
|
||||
serialized_data[key] = value.isoformat()
|
||||
elif isinstance(value, Decimal):
|
||||
# Convert Decimal type to float
|
||||
serialized_data[key] = float(value)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
# Recursively process elements in list or tuple
|
||||
serialized_data[key] = [
|
||||
_serialize_row_data(item) if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
elif isinstance(value, dict):
|
||||
# Recursively process nested dictionaries
|
||||
serialized_data[key] = _serialize_row_data(value)
|
||||
else:
|
||||
serialized_data[key] = value
|
||||
return serialized_data
|
||||
Reference in New Issue
Block a user