0.3.0 Release Version

This commit is contained in:
FreeOnePlus
2025-06-08 18:44:40 +08:00
parent d9fed06c92
commit 4c913743c7
54 changed files with 12649 additions and 4667 deletions

View File

@@ -1 +1,13 @@
# Mark directory as a package
"""
Doris MCP Server - A Model Context Protocol server for Apache Doris database integration.
This package provides:
- MCP protocol implementation for Apache Doris
- Multi-transport support (stdio, SSE, streamable HTTP)
- Comprehensive database tools and resources
- Enterprise-grade security and monitoring
"""
__version__ = "1.0.0"
__author__ = "Doris MCP Team"
__description__ = "Apache Doris MCP Server Implementation"

View File

@@ -0,0 +1,8 @@
"""
Entry point for running doris_mcp_server as a module
"""
from .main import main_sync
if __name__ == "__main__":
main_sync()

View File

@@ -1,33 +0,0 @@
# doris_mcp_server/config.py
import os
import logging
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv(override=True)
# Get Log Level from environment variable, default to 'info'
LOG_LEVEL_STR = os.getenv('LOG_LEVEL', 'info').upper()
# Map string level to logging level constant
LOG_LEVEL_MAP = {
'DEBUG': logging.DEBUG,
'INFO': logging.INFO,
'WARNING': logging.WARNING,
'ERROR': logging.ERROR,
'CRITICAL': logging.CRITICAL
}
LOG_LEVEL = LOG_LEVEL_MAP.get(LOG_LEVEL_STR, logging.INFO)
# Function to load config (can be expanded later if needed)
def load_config():
"""Loads configuration settings."""
# Currently, configuration is mainly handled by environment variables
# and constants defined in this module.
# This function can be used to perform additional setup if required.
logging.getLogger(__name__).info("Configuration loaded (mainly from environment variables).")
# You can add other configuration constants here if needed
# Example: DB_HOST = os.getenv("DB_HOST", "localhost")
# But often it's better to access os.getenv directly where needed
# or pass config dictionaries around.

View File

@@ -1,196 +1,515 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Apache Doris MCP Server Main Entry - Primarily handles SSE mode
Apache Doris MCP Server - Enterprise Database Service Implementation
Stdio mode is handled by doris_mcp_server.mcp_core:run_stdio.
Based on Apache Doris official MCP Server architecture design, providing complete MCP protocol support
Supports independent encapsulation implementation of Resources, Tools, and Prompts
Supports both stdio and streamable HTTP startup modes
"""
import os
import sys
import argparse
import asyncio
import json
import logging
from contextlib import asynccontextmanager
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Dict, Any
import uvicorn
from uvicorn import Config, Server
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
from typing import Any
# Add project root to path
PROJECT_ROOT = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
sys.path.insert(0, PROJECT_ROOT)
from mcp.server import Server
from mcp.server.models import InitializationOptions
# SSE related imports
from mcp.server.fastmcp import FastMCP
from doris_mcp_server.sse_server import DorisMCPSseServer
from doris_mcp_server.streamable_server import DorisMCPStreamableServer
# Stdio related imports (only needed for tools now, maybe move tool init?)
# from mcp.server.stdio import stdio_server -> No longer used here
# Config and Tool Initializer
from doris_mcp_server.config import load_config # LOG_LEVEL might not be needed here directly
from doris_mcp_server.tools.tool_initializer import register_mcp_tools
# Load environment variables (load early for all modes)
load_dotenv(override=True)
# Get logger
logger = logging.getLogger("doris-mcp-main") # Changed logger name slightly
# --- Configuration Loading and Logging Setup ---
load_config() # Loads .env
# --- Create FastAPI App (Global Scope for SSE Mode) ---
# This 'app' object is targeted by 'mcp run doris_mcp_server/main.py:app --transport sse'
# And used when running directly with --sse
app = FastAPI(
title="Doris MCP Server (SSE Mode)",
# Lifespan will be added in start_sse_server
from mcp.types import (
Prompt,
Resource,
TextContent,
Tool,
)
# --- Removed StdioServerWrapper ---
from .tools.tools_manager import DorisToolsManager
from .tools.prompts_manager import DorisPromptsManager
from .tools.resources_manager import DorisResourcesManager
from .utils.config import DorisConfig
from .utils.db import DorisConnectionManager
from .utils.security import DorisSecurityManager
# --- Command Line Argument Parsing ---
def parse_args():
parser = argparse.ArgumentParser(description="Apache Doris MCP Server (SSE Mode Entry)")
# Only keep SSE related args here
parser.add_argument('--sse', action='store_true', help='Start SSE Web server mode (required)')
parser.add_argument('--host', type=str, default=os.getenv('SERVER_HOST', '0.0.0.0'), help='Host address')
parser.add_argument('--port', type=int, default=int(os.getenv('SERVER_PORT', os.getenv('MCP_PORT', '3000'))), help='Port number')
parser.add_argument('--debug', action='store_true', help='Enable debug mode')
parser.add_argument('--reload', action='store_true', help='Enable auto-reload')
return parser.parse_args()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- SSE Mode Specific Code ---
@dataclass
class AppContext:
config: Dict[str, Any]
@asynccontextmanager
async def app_lifespan(app_instance: FastAPI) -> AsyncIterator[None]:
logger.info("SSE application lifecycle start...")
config = {
# Simplified config - maybe get from elsewhere?
"db_host": os.getenv("DB_HOST", "localhost"),
"db_port": int(os.getenv("DB_PORT", "9030")),
"db_user": os.getenv("DB_USER", "root"),
"db_password": os.getenv("DB_PASSWORD", ""),
"db_database": os.getenv("DB_DATABASE", "test"),
}
app_instance.state.config = config
try:
# Yield None implicitly or explicitly None
yield
finally:
logger.info("Cleaning up SSE application resources...")
class DorisServer:
"""Apache Doris MCP Server main class"""
async def start_sse_server(args):
"""Start SSE Web server mode (Configures the global 'app')"""
logger.info("Starting SSE Web server mode...")
global app
def __init__(self, config: DorisConfig):
self.config = config
self.server = Server("doris-mcp-server")
# --- Initialize MCP and Tools for SSE ---
# Create a *separate* MCP instance for SSE mode
sse_mcp = FastMCP(
name="doris-mcp-sse",
description="Apache Doris MCP Server (SSE)",
lifespan=None, # Managed by FastAPI
dependencies=["fastapi", "uvicorn", "openai", "sse_starlette"]
)
logger.info("Registering MCP tools for SSE mode...")
await register_mcp_tools(sse_mcp) # Register tools for the SSE instance
logger.info("MCP tools registered for SSE.")
# Initialize security manager
self.security_manager = DorisSecurityManager(config)
# --- Configure Lifespan and CORS for the global app ---
app.router.lifespan_context = app_lifespan
origins = os.getenv("ALLOWED_ORIGINS", "*").split(",")
allow_credentials = os.getenv("MCP_ALLOW_CREDENTIALS", "false").lower() == "true"
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=allow_credentials,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["Mcp-Session-Id"],
)
# Initialize connection manager, pass in security manager
self.connection_manager = DorisConnectionManager(config, self.security_manager)
# --- Initialize Handlers and Register Routes (Pass sse_mcp instance) ---
logger.info("Initializing SSE server handlers and registering routes...")
sse_server_handler = DorisMCPSseServer(sse_mcp, app)
streamable_server_handler = DorisMCPStreamableServer(sse_mcp, app)
logger.info("SSE Server handlers initialized and routes registered.")
# Initialize independent managers
self.resources_manager = DorisResourcesManager(self.connection_manager)
self.tools_manager = DorisToolsManager(self.connection_manager)
self.prompts_manager = DorisPromptsManager(self.connection_manager)
# --- Print Configuration and Endpoints ---
print("--- SSE Mode Configuration ---")
print(f"Server Host: {args.host}")
print(f"Server Port: {args.port}")
print(f"Allowed Origins: {origins}")
print(f"Allow Credentials: {allow_credentials}")
print(f"Log Level: {os.getenv('LOG_LEVEL', 'info')}")
print(f"Debug Mode: {args.debug}")
print(f"Reload Mode: {args.reload}")
print(f"DB Host: {os.getenv('DB_HOST')}")
print(f"DB Port: {os.getenv('DB_PORT')}")
print(f"DB User: {os.getenv('DB_USER')}")
print(f"DB Database: {os.getenv('DB_DATABASE')}")
print(f"Force Refresh Metadata: {os.getenv('FORCE_REFRESH_METADATA', 'false')}")
print("------------------------------")
base_url = f"http://{args.host}:{args.port}"
print(f"Service running at: {base_url}")
print(f" Health Check: GET {base_url}/health")
print(f" Status Check: GET {base_url}/status")
print(f" SSE Init: GET {base_url}/sse")
print(f" SSE/Legacy Messages: POST {base_url}/mcp/messages")
print(f" Streamable HTTP: GET/POST/DELETE/OPTIONS {base_url}/mcp")
print("------------------------------")
print("Use Ctrl+C to stop the service")
self.logger = logging.getLogger(f"{__name__}.DorisServer")
self._setup_handlers()
# --- Start Uvicorn Server ---
config = Config(
app=app,
host=args.host,
port=args.port,
log_level="debug" if args.debug else "info",
reload=args.reload
)
server = Server(config=config)
await server.serve()
def _setup_handlers(self):
"""Setup MCP protocol handlers"""
# --- Main Execution Logic (Simplified) ---
@self.server.list_resources()
async def handle_list_resources() -> list[Resource]:
"""Handle resource list request"""
try:
self.logger.info("Handling resource list request")
resources = await self.resources_manager.list_resources()
self.logger.info(f"Returning {len(resources)} resources")
return resources
except Exception as e:
self.logger.error(f"Failed to handle resource list request: {e}")
return []
def run_main_sync():
"""Synchronous wrapper, primarily for SSE mode now."""
sync_logger = logging.getLogger("run_main_sync")
sync_logger.info("Entering run_main_sync (SSE focus)...")
print("DEBUG: Entering run_main_sync (SSE focus)...", file=sys.stderr, flush=True)
args = parse_args()
@self.server.read_resource()
async def handle_read_resource(uri: str) -> str:
"""Handle resource read request"""
try:
self.logger.info(f"Handling resource read request: {uri}")
content = await self.resources_manager.read_resource(uri)
return content
except Exception as e:
self.logger.error(f"Failed to handle resource read request: {e}")
return json.dumps(
{"error": f"Failed to read resource: {str(e)}", "uri": uri},
ensure_ascii=False,
indent=2,
)
@self.server.list_tools()
async def handle_list_tools() -> list[Tool]:
"""Handle tool list request"""
try:
self.logger.info("Handling tool list request")
tools = await self.tools_manager.list_tools()
self.logger.info(f"Returning {len(tools)} tools")
return tools
except Exception as e:
self.logger.error(f"Failed to handle tool list request: {e}")
return []
@self.server.call_tool()
async def handle_call_tool(
name: str, arguments: dict[str, Any]
) -> list[TextContent]:
"""Handle tool call request"""
try:
self.logger.info(f"Handling tool call request: {name}")
result = await self.tools_manager.call_tool(name, arguments)
return [TextContent(type="text", text=result)]
except Exception as e:
self.logger.error(f"Failed to handle tool call request: {e}")
error_result = json.dumps(
{
"error": f"Tool call failed: {str(e)}",
"tool_name": name,
"arguments": arguments,
},
ensure_ascii=False,
indent=2,
)
return [TextContent(type="text", text=error_result)]
@self.server.list_prompts()
async def handle_list_prompts() -> list[Prompt]:
"""Handle prompt list request"""
try:
self.logger.info("Handling prompt list request")
prompts = await self.prompts_manager.list_prompts()
self.logger.info(f"Returning {len(prompts)} prompts")
return prompts
except Exception as e:
self.logger.error(f"Failed to handle prompt list request: {e}")
return []
@self.server.get_prompt()
async def handle_get_prompt(name: str, arguments: dict[str, Any]) -> str:
"""Handle prompt get request"""
try:
self.logger.info(f"Handling prompt get request: {name}")
result = await self.prompts_manager.get_prompt(name, arguments)
return result
except Exception as e:
self.logger.error(f"Failed to handle prompt get request: {e}")
error_result = json.dumps(
{
"error": f"Failed to get prompt: {str(e)}",
"prompt_name": name,
"arguments": arguments,
},
ensure_ascii=False,
indent=2,
)
return error_result
async def start_stdio(self):
"""Start stdio transport mode"""
self.logger.info("Starting Doris MCP Server (stdio mode)")
if args.sse:
try:
# Run the async SSE server setup and Uvicorn loop
asyncio.run(start_sse_server(args))
sync_logger.info("asyncio.run(start_sse_server) completed.")
print("DEBUG: asyncio.run(start_sse_server) completed.", file=sys.stderr, flush=True)
except KeyboardInterrupt:
sync_logger.info("SSE server stopped by KeyboardInterrupt.")
# Ensure connection manager is initialized
await self.connection_manager.initialize()
self.logger.info("Connection manager initialization completed")
# Start stdio server - using simpler approach
from mcp.server.stdio import stdio_server
self.logger.info("Creating stdio_server transport...")
# Try different startup approaches
try:
async with stdio_server() as streams:
read_stream, write_stream = streams
self.logger.info("stdio_server streams created successfully")
# Create initialization options
# MCP 1.8.0 requires parameters for get_capabilities
from mcp.server.lowlevel.server import NotificationOptions
capabilities = self.server.get_capabilities(
notification_options=NotificationOptions(
prompts_changed=True,
resources_changed=True,
tools_changed=True
),
experimental_capabilities={}
)
init_options = InitializationOptions(
server_name="doris-mcp-server",
server_version="1.0.0",
capabilities=capabilities,
)
self.logger.info("Initialization options created successfully")
# Run server
self.logger.info("Starting to run MCP server...")
await self.server.run(read_stream, write_stream, init_options)
except Exception as inner_e:
self.logger.error(f"stdio_server internal error: {inner_e}")
self.logger.error(f"Error type: {type(inner_e)}")
# Try to get more error information
import traceback
self.logger.error("Complete error stack:")
self.logger.error(traceback.format_exc())
# If it's ExceptionGroup, try to parse
if hasattr(inner_e, 'exceptions'):
self.logger.error(f"ExceptionGroup contains {len(inner_e.exceptions)} exceptions:")
for i, exc in enumerate(inner_e.exceptions):
self.logger.error(f" Exception {i+1}: {type(exc).__name__}: {exc}")
raise inner_e
except Exception as e:
sync_logger.critical(f"Error during asyncio.run(start_sse_server): {e}", exc_info=True)
print(f"DEBUG: Error during asyncio.run(start_sse_server): {e}", file=sys.stderr, flush=True)
self.logger.error(f"stdio server startup failed: {e}")
self.logger.error(f"Error type: {type(e)}")
raise
else:
# If run without --sse, print help/error
message = "Error: This entry point requires --sse flag. For stdio mode, use 'uv run mcp-doris' or the appropriate command for your stdio setup."
sync_logger.error(message)
print(message, file=sys.stderr)
sys.exit(1)
async def start_http(self, host: str = "localhost", port: int = 3000):
"""Start Streamable HTTP transport mode"""
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}")
try:
# Ensure connection manager is initialized
await self.connection_manager.initialize()
# Use Starlette and StreamableHTTPSessionManager according to official example
import uvicorn
import contextlib
from collections.abc import AsyncIterator
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from starlette.applications import Starlette
from starlette.routing import Mount, Route
from starlette.responses import JSONResponse, Response
from starlette.types import Receive, Scope, Send
# Create session manager
session_manager = StreamableHTTPSessionManager(
app=self.server,
json_response=True, # Enable JSON response
stateless=False # Maintain session state
)
self.logger.info(f"StreamableHTTP session manager created, will start at http://{host}:{port}")
# Health check endpoint
async def health_check(request):
return JSONResponse({"status": "healthy", "service": "doris-mcp-server"})
# Lifecycle manager - simplified since we manage session_manager externally
@contextlib.asynccontextmanager
async def lifespan(app: Starlette) -> AsyncIterator[None]:
"""Context manager for managing application lifecycle"""
self.logger.info("Application started!")
try:
yield
finally:
self.logger.info("Application is shutting down...")
# Create ASGI application - use direct session manager as ASGI app
starlette_app = Starlette(
debug=True,
routes=[
Route("/health", health_check, methods=["GET"]),
],
lifespan=lifespan,
)
# Custom ASGI app that handles both /mcp and /mcp/ without redirects
async def mcp_app(scope, receive, send):
# Handle lifespan events
if scope["type"] == "lifespan":
await starlette_app(scope, receive, send)
return
# Handle HTTP requests
if scope["type"] == "http":
path = scope.get("path", "")
self.logger.info(f"Received request for path: {path}")
try:
# Handle health check
if path.startswith("/health"):
await starlette_app(scope, receive, send)
return
# Handle MCP requests - both /mcp and /mcp/ go to session manager
if path == "/mcp" or path.startswith("/mcp/"):
self.logger.info(f"Handling MCP request for path: {path}")
# Log request details for debugging
method = scope.get("method", "UNKNOWN")
headers = dict(scope.get("headers", []))
self.logger.info(f"MCP Request - Method: {method}")
self.logger.info(f"MCP Request - Headers: {headers}")
# Handle Dify compatibility for GET requests
if method == "GET":
accept_header = headers.get(b'accept', b'').decode('utf-8')
user_agent = headers.get(b'user-agent', b'').decode('utf-8')
# For other GET requests, try to add application/json to Accept header
if 'text/event-stream' in accept_header and 'application/json' not in accept_header:
self.logger.info("Adding application/json to Accept header for GET request")
# Modify headers to include both content types
new_headers = []
for name, value in scope.get("headers", []):
if name == b'accept':
# Add application/json to the accept header
new_value = value.decode('utf-8') + ', application/json'
new_headers.append((name, new_value.encode('utf-8')))
else:
new_headers.append((name, value))
# Update scope with modified headers
scope = dict(scope)
scope["headers"] = new_headers
self.logger.info(f"Modified Accept header to: {new_value}")
await session_manager.handle_request(scope, receive, send)
return
# 404 for other paths
self.logger.info(f"Path not found: {path}")
response = Response("Not Found", status_code=404)
await response(scope, receive, send)
except Exception as e:
self.logger.error(f"Error handling request for {path}: {e}")
import traceback
self.logger.error(traceback.format_exc())
response = Response("Internal Server Error", status_code=500)
await response(scope, receive, send)
else:
# For other scope types, just return
self.logger.warning(f"Unsupported scope type: {scope['type']}")
return
# Start uvicorn server with session manager lifecycle
config = uvicorn.Config(
app=mcp_app,
host=host,
port=port,
log_level="info"
)
server = uvicorn.Server(config)
# Run session manager and server together
async with session_manager.run():
self.logger.info("Session manager started, now starting HTTP server")
await server.serve()
except Exception as e:
self.logger.error(f"Streamable HTTP server startup failed: {e}")
import traceback
self.logger.error("Complete error stack:")
self.logger.error(traceback.format_exc())
# If it's ExceptionGroup, try to parse
if hasattr(e, 'exceptions'):
self.logger.error(f"ExceptionGroup contains {len(e.exceptions)} exceptions:")
for i, exc in enumerate(e.exceptions):
self.logger.error(f" Exception {i+1}: {type(exc).__name__}: {exc}")
raise
async def shutdown(self):
"""Shutdown server"""
self.logger.info("Shutting down Doris MCP Server")
try:
await self.connection_manager.close()
self.logger.info("Doris MCP Server has been shut down")
except Exception as e:
self.logger.error(f"Error occurred while shutting down server: {e}")
def create_arg_parser():
"""Create command line argument parser"""
parser = argparse.ArgumentParser(
description="Apache Doris MCP Server - Enterprise Database Service",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Transport Modes:
stdio - Standard input/output (for local process communication)
http - Streamable HTTP mode (MCP 2025-03-26 protocol)
Examples:
python -m doris_mcp_server --transport stdio
python -m doris_mcp_server --transport http --host 0.0.0.0 --port 3000
"""
)
parser.add_argument(
"--transport",
type=str,
choices=["stdio", "http"],
default="stdio",
help="Transport protocol type: stdio (local), http (Streamable HTTP)",
)
parser.add_argument(
"--host",
type=str,
default="localhost",
help="Host address for HTTP mode (default: localhost)",
)
parser.add_argument(
"--port", type=int, default=3000, help="Port number for HTTP mode (default: 3000)"
)
parser.add_argument(
"--db-host",
type=str,
default="localhost",
help="Doris database host address (default: localhost)",
)
parser.add_argument(
"--db-port", type=int, default=9030, help="Doris database port number (default: 9030)"
)
parser.add_argument(
"--db-user", type=str, default="root", help="Doris database username (default: root)"
)
parser.add_argument("--db-password", type=str, default="", help="Doris database password")
parser.add_argument(
"--db-database",
type=str,
default="information_schema",
help="Doris database name (default: information_schema)",
)
parser.add_argument(
"--log-level",
type=str,
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
default="INFO",
help="Log level (default: INFO)",
)
return parser
async def main():
"""Main function"""
parser = create_arg_parser()
args = parser.parse_args()
# Set log level
logging.getLogger().setLevel(getattr(logging, args.log_level))
# Create configuration - priority: command line arguments > .env file > default values
config = DorisConfig.from_env() # First load from .env file and environment variables
# Command line arguments override configuration (if provided)
if args.db_host != "localhost": # If not default value, use command line argument
config.database.host = args.db_host
if args.db_port != 9030:
config.database.port = args.db_port
if args.db_user != "root":
config.database.user = args.db_user
if args.db_password: # Use password if provided
config.database.password = args.db_password
if args.db_database != "information_schema":
config.database.database = args.db_database
if args.log_level != "INFO":
config.logging.level = args.log_level
# Create server instance
server = DorisServer(config)
try:
if args.transport == "stdio":
await server.start_stdio()
elif args.transport == "http":
await server.start_http(args.host, args.port)
else:
logger.error(f"Unsupported transport protocol: {args.transport}")
await server.shutdown()
return 1
except KeyboardInterrupt:
logger.info("Received interrupt signal, shutting down server...")
except Exception as e:
logger.error(f"Server runtime error: {e}")
# Clean up resources even in case of exception
try:
await server.shutdown()
except Exception as shutdown_error:
logger.error(f"Error occurred while shutting down server: {shutdown_error}")
return 1
finally:
# Cleanup in case of normal shutdown
try:
await server.shutdown()
except Exception as shutdown_error:
logger.error(f"Error occurred while shutting down server: {shutdown_error}")
return 0
def main_sync():
"""Synchronous main function for entry point"""
exit_code = asyncio.run(main())
exit(exit_code)
if __name__ == "__main__":
run_main_sync()
main_sync()

View File

@@ -1,159 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Core MCP instance and startup logic for stdio mode.
"""
import asyncio
import logging
import sys
import traceback
import json
from typing import Dict, Any
# Import necessary components from mcp and our project
from mcp.server.fastmcp import FastMCP
logger = logging.getLogger("doris-mcp-core")
# --- Global MCP Instance for Stdio ---
# Create the instance when the module is imported.
# Tools will be registered synchronously(?) before running.
stdio_mcp = FastMCP(
name="doris-mcp-stdio-core",
description="Apache Doris MCP Server (stdio via core)",
)
# --- Removed async setup functions ---
def run_stdio():
"""
Synchronous entry point for running the stdio server.
Mimics the mcp-doris example by calling .run() on the instance.
Handles tool registration beforehand.
"""
logger.info("Executing run_stdio (synchronous entry point)...")
# --- Run the stdio server using the instance's run() method ---
logger.info("Calling stdio_mcp.run()...")
try:
# Assuming stdio_mcp has a synchronous run() method for stdio
stdio_mcp.run()
logger.info("stdio_mcp.run() completed.")
except KeyboardInterrupt:
logger.info("Stdio server stopped by KeyboardInterrupt.")
except AttributeError:
logger.critical("Error: stdio_mcp object does not have a '.run()' method suitable for stdio.", exc_info=False)
print("ERROR: stdio_mcp object does not have a '.run()' method.", file=sys.stderr, flush=True)
sys.exit(1)
except Exception as e:
logger.critical(f"run_stdio encountered an error during stdio_mcp.run(): {e}", exc_info=True)
traceback.print_exc(file=sys.stderr)
sys.exit(1)
# Register Tool: Execute SQL Query
@stdio_mcp.tool("exec_query", description="""[Function Description]: Execute SQL query and return result command with catalog federation support.\n
[Parameter Content]:\n
- sql (string) [Required] - SQL statement to execute. MUST use three-part naming for all table references: 'catalog_name.db_name.table_name'. For internal tables use 'internal.db_name.table_name', for external tables use 'catalog_name.db_name.table_name'\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n
- catalog_name (string) [Optional] - Reference catalog name for context, defaults to current catalog\n
- max_rows (integer) [Optional] - Maximum number of rows to return, default 100\n
- timeout (integer) [Optional] - Query timeout in seconds, default 30\n""")
async def exec_query_tool(sql: str, db_name: str = None, catalog_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]:
"""Wrapper: Execute SQL query and return result command"""
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_exec_query
return await mcp_doris_exec_query(sql=sql, db_name=db_name, catalog_name=catalog_name, max_rows=max_rows, timeout=timeout)
# Register Tool: Get Table Schema
@stdio_mcp.tool("get_table_schema", description="""[Function Description]: Get detailed structure information of the specified table (columns, types, comments, etc.).\n
[Parameter Content]:\n
- table_name (string) [Required] - Name of the table to query\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
async def get_table_schema_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get table schema"""
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_schema
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
return await mcp_doris_get_table_schema(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
# Register Tool: Get Database Table List
@stdio_mcp.tool("get_db_table_list", description="""[Function Description]: Get a list of all table names in the specified database.\n
[Parameter Content]:\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
async def get_db_table_list_tool(db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get database table list"""
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_db_table_list
return await mcp_doris_get_db_table_list(db_name=db_name, catalog_name=catalog_name)
# Register Tool: Get Database List
@stdio_mcp.tool("get_db_list", description="""[Function Description]: Get a list of all database names on the server.\n
[Parameter Content]:\n
- random_string (string) [Required] - Unique identifier for the tool call\n
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
async def get_db_list_tool(catalog_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get database list"""
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_db_list
return await mcp_doris_get_db_list(catalog_name=catalog_name)
# Register Tool: Get Table Comment
@stdio_mcp.tool("get_table_comment", description="""[Function Description]: Get the comment information for the specified table.\n
[Parameter Content]:\n
- table_name (string) [Required] - Name of the table to query\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
async def get_table_comment_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get table comment"""
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_comment
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
return await mcp_doris_get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
# Register Tool: Get Table Column Comments
@stdio_mcp.tool("get_table_column_comments", description="""[Function Description]: Get comment information for all columns in the specified table.\n
[Parameter Content]:\n
- table_name (string) [Required] - Name of the table to query\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
async def get_table_column_comments_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get table column comments"""
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_column_comments
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
return await mcp_doris_get_table_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
# Register Tool: Get Table Indexes
@stdio_mcp.tool("get_table_indexes", description="""[Function Description]: Get index information for the specified table.
[Parameter Content]:\n
- table_name (string) [Required] - Name of the table to query\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
async def get_table_indexes_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get table indexes"""
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_indexes
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
return await mcp_doris_get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
# Register Tool: Get Recent Audit Logs
@stdio_mcp.tool("get_recent_audit_logs", description="""[Function Description]: Get audit log records for a recent period.\n
[Parameter Content]:\n
- days (integer) [Optional] - Number of recent days of logs to retrieve, default is 7\n
- limit (integer) [Optional] - Maximum number of records to return, default is 100\n""")
async def get_recent_audit_logs_tool(days: int = 7, limit: int = 100) -> Dict[str, Any]:
"""Wrapper: Get recent audit logs"""
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_recent_audit_logs
try:
days = int(days)
limit = int(limit)
except (ValueError, TypeError):
return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "days and limit parameters must be integers"})}]}
return await mcp_doris_get_recent_audit_logs(days=days, limit=limit)
# Register Tool: Get Catalog List
@stdio_mcp.tool("get_catalog_list", description="""[Function Description]: Get a list of all catalog names on the server.\n
[Parameter Content]:\n
- random_string (string) [Required] - Unique identifier for the tool call\n""")
async def get_catalog_list_tool() -> Dict[str, Any]:
"""Wrapper: Get catalog list"""
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_catalog_list
return await mcp_doris_get_catalog_list()
# --- Register Tools ---

File diff suppressed because it is too large Load Diff

View File

@@ -1,912 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Doris MCP Streamable HTTP Server Implementation
Implements the MCP 2025-03-26 Streamable HTTP specification.
Uses a unified /mcp endpoint for GET, POST, DELETE, OPTIONS.
Manages sessions using Mcp-Session-Id header.
"""
import asyncio
import json
import uuid
import logging
import time
from typing import Any, Optional, Dict, List
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette.sse import EventSourceResponse
# Use a distinct logger name
logger = logging.getLogger("doris-mcp-streamable")
# Special marker for closing streams
STREAM_END_MARKER = "__MCP_STREAM_END__"
class DorisMCPStreamableServer:
"""Doris MCP Streamable HTTP Server"""
def __init__(self, mcp_server, app: FastAPI):
"""
Initializes the Doris MCP Streamable HTTP server.
Args:
mcp_server: The shared FastMCP server instance.
app: The main FastAPI application instance.
"""
self.mcp_server = mcp_server
self.app = app # We'll add routes to this app
# Note: CORS middleware should be added only once in main.py usually.
# If added here, ensure it doesn't conflict or duplicate.
# For separation, we might let main.py handle CORS entirely.
# Client session management for Streamable HTTP clients
# key: session_id (from Mcp-Session-Id header)
# value: {
# "created_at": timestamp,
# "last_active": timestamp,
# "request_queues": { request_id: asyncio.Queue }, # For POST /mcp request streams
# "general_sse_queues": List[asyncio.Queue] # For GET /mcp server push streams
# }
self.client_sessions: Dict[str, Dict[str, Any]] = {}
# Setup the unified MCP endpoint
self._setup_streamable_http_routes()
# Register session cleanup task if this instance manages lifespan independently
# Usually, startup events are tied to the main app lifespan managed in main.py
# We might not need @app.on_event("startup") here if main.py handles it.
# Let's assume main.py handles the cleanup task initiation.
def _setup_streamable_http_routes(self):
"""Sets up the unified /mcp endpoint for Streamable HTTP.
Uses a distinct tag for API docs.
"""
@self.app.api_route("/mcp", methods=["GET", "POST", "DELETE", "OPTIONS"], tags=["Streamable HTTP"])
async def mcp_endpoint_handler(request: Request):
"""Handles GET, POST, DELETE, OPTIONS for the /mcp endpoint."""
# 1. Handle OPTIONS (CORS preflight)
if request.method == "OPTIONS":
# Assuming CORS headers are handled by middleware in main.py
# If not, provide necessary headers here.
# This minimal response might suffice if middleware handles the rest
logger.debug("Handling OPTIONS request for /mcp")
# Return basic OK allowing exposed headers if middleware handles the rest
return JSONResponse({}, headers={"Access-Control-Expose-Headers": "Mcp-Session-Id"})
# Session ID from header is required for most methods
session_id = request.headers.get("Mcp-Session-Id")
# 2. Handle DELETE (Terminate Session)
if request.method == "DELETE":
if not session_id:
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32600, "message": "Mcp-Session-Id header is required for DELETE"}}, status_code=400)
logger.info(f"Handling DELETE request for session [Session ID: {session_id}]")
session_data = self.client_sessions.pop(session_id, None)
if session_data:
await self._cleanup_session_resources(session_id, session_data)
return JSONResponse({}, status_code=204) # No Content
else:
logger.warning(f"Attempted DELETE on non-existent session: {session_id}")
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32001, "message": "Session not found"}}, status_code=404)
# 3. Handle GET (Server Push SSE Stream)
if request.method == "GET":
if not session_id:
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32000, "message": "Mcp-Session-Id header is required for GET streams"}}, status_code=400)
if session_id not in self.client_sessions:
# Note: Unlike legacy SSE, GET here assumes session exists.
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32001, "message": "Session not found. Initialize first."}}, status_code=404)
accept_header = request.headers.get("Accept", "")
if "text/event-stream" not in accept_header:
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32600, "message": "Accept header must include text/event-stream for GET"}}, status_code=406)
# TODO: Handle Last-Event-ID for stream recovery?
logger.info(f"Handling GET request, establishing server push SSE stream [Session ID: {session_id}]")
push_queue = asyncio.Queue()
if self.client_sessions[session_id].get("general_sse_queues") is None:
self.client_sessions[session_id]["general_sse_queues"] = []
self.client_sessions[session_id]["general_sse_queues"].append(push_queue)
self.client_sessions[session_id]["last_active"] = time.time()
return EventSourceResponse(self._create_general_sse_generator(session_id, push_queue), media_type="text/event-stream")
# 4. Handle POST (Client Messages & Initialize)
if request.method == "POST":
accept_header = request.headers.get("Accept", "")
content_type = request.headers.get("Content-Type", "")
body = {}
try:
if "application/json" not in content_type:
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32700, "message": "Content-Type must be application/json"}}, status_code=415)
body = await request.json()
if isinstance(body, list): return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32600, "message": "Batch requests not supported"}}, status_code=400)
if not isinstance(body, dict): return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32700, "message": "Invalid JSON received"}}, status_code=400)
method = body.get("method")
message_id = body.get("id") # Can be None for notifications
# Handle Initialize request (does not require Mcp-Session-Id header)
if method == "initialize":
if "application/json" not in accept_header:
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32600, "message": "Accept header must include application/json for initialize"}}, status_code=406)
return await self._handle_initialize(request, body, message_id)
# Handle other POST requests (require Mcp-Session-Id)
else:
if not session_id:
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32000, "message": "Mcp-Session-Id header is required for this request"}}, status_code=400)
if session_id not in self.client_sessions:
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32001, "message": "Session not found"}}, status_code=404)
# Check Accept header for non-initialize POST
if not ("application/json" in accept_header and "text/event-stream" in accept_header):
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32600, "message": "Accept header must include application/json and text/event-stream for POST"}}, status_code=406)
self.client_sessions[session_id]["last_active"] = time.time()
return await self._handle_client_post(request, body, session_id, message_id)
except json.JSONDecodeError:
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error - Invalid JSON received"}}, status_code=400)
except Exception as e:
logger.error(f"Unexpected error handling POST /mcp: {str(e)}", exc_info=True)
error_id = body.get("id") if isinstance(body, dict) else None
return JSONResponse({"jsonrpc": "2.0", "id": error_id, "error": {"code": -32000, "message": "Internal server error"}}, status_code=500)
# Fallback for other methods like PUT, PATCH etc.
return JSONResponse({"error": "Method Not Allowed"}, status_code=405)
async def _handle_initialize(self, request: Request, body: Dict, message_id: Any):
"""Handles the 'initialize' method call via POST /mcp."""
logger.info("Handling Streamable HTTP initialize request")
# Optional: Validate params in body if needed
# params = body.get("params", {})
new_session_id = str(uuid.uuid4())
logger.info(f"Created new Streamable HTTP session [Session ID: {new_session_id}]")
self.client_sessions[new_session_id] = {
"created_at": time.time(),
"last_active": time.time(),
# No transport_type needed here as this class *is* the streamable server
"request_queues": {}, # Initialize request queues dict
"general_sse_queues": [] # Initialize general queues list
}
# Build InitializeResult based on spec
initialize_result = {
"protocolVersion": "2025-03-26",
"name": self.mcp_server.name,
"instructions": "Apache Doris MCP Server (Streamable HTTP Mode)",
"serverInfo": { "version": "0.2.0", "name": "Doris MCP Streamable Server" }, # Adjust as needed
"capabilities": {
"tools": { "supportsStreaming": True, "supportsProgress": True },
"resources": { "supportsStreaming": False }, # Example capability
"prompts": { "supported": True }, # Example capability
"session": { "supported": True }
}
}
response_body = {
"jsonrpc": "2.0",
"id": message_id,
"result": initialize_result
}
# Return JSON response with Mcp-Session-Id header
return JSONResponse(
content=response_body,
media_type="application/json",
headers={"Mcp-Session-Id": new_session_id}
)
async def _handle_client_post(self, request: Request, body: Dict, session_id: str, message_id: Any):
"""Handles non-initialize POST requests (notifications, responses, method calls)."""
method = body.get("method")
# Handle Notifications/Responses from client
is_notification = "method" in body and "id" not in body
is_response = "result" in body or "error" in body
if is_notification or is_response:
logger.info(f"Received Streamable HTTP notification/response [Session ID: {session_id}] - Processing needed? (Ignoring for now)")
# TODO: If the server sends requests that expect responses, process is_response here.
# For now, just acknowledge client notifications/responses.
return JSONResponse({}, status_code=202) # Accepted
# Handle Requests from client (method call)
if "method" in body and "id" in body:
logger.info(f"Received Streamable HTTP request [Session ID: {session_id}, ID: {message_id}, Method: {method}]")
params = body.get("params", {})
stream_required = params.get("stream", False) if method in ["tools/call", "mcp/callTool"] else False
if stream_required:
# --- Return SSE stream for response parts ---
logger.info(f"Using SSE stream for request [Session ID: {session_id}, ID: {message_id}]")
response_queue = asyncio.Queue()
# Ensure request_queues exists (should have been created during initialize)
if self.client_sessions[session_id].get("request_queues") is None:
logger.error(f"Session {session_id} is missing 'request_queues' dictionary!")
# Handle this inconsistency, maybe return an error
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32000, "message": "Internal server error: Session state inconsistent"}}, status_code=500)
self.client_sessions[session_id]["request_queues"][message_id] = response_queue
# Start background task to process and put results in the queue
asyncio.create_task(self._process_request_and_respond(
request, body, session_id, message_id, response_queue, is_stream=True
))
# Return EventSourceResponse using the request-specific queue
return EventSourceResponse(self._create_request_sse_generator(session_id, message_id, response_queue), media_type="text/event-stream")
else:
# --- Return single JSON response ---
logger.info(f"Using JSON response for request [Session ID: {session_id}, ID: {message_id}]")
try:
# Process the request directly and get the result/error payload
result_or_error_payload = await self._process_request_and_respond(
request, body, session_id, message_id, None, is_stream=False
)
# This function now returns the final JSON body or raises HTTPException
return JSONResponse(content=result_or_error_payload, media_type="application/json")
except HTTPException as http_exc:
# Format HTTPException details into JSON-RPC error
return JSONResponse(
{"jsonrpc": "2.0", "id": message_id, "error": {"code": -32000, "message": http_exc.detail}},
status_code=http_exc.status_code
)
except Exception as e:
# Catch unexpected errors during synchronous processing
logger.error(f"Error processing non-stream request [Session ID: {session_id}, ID: {message_id}]: {str(e)}", exc_info=True)
error_response = {"jsonrpc": "2.0", "id": message_id, "error": {"code": -32000, "message": f"Internal server error: {str(e)}"}}
return JSONResponse(content=error_response, status_code=500)
else:
# Invalid JSON-RPC format (e.g., missing method or id for a request)
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32600, "message": "Invalid JSON-RPC request format"}}, status_code=400)
# === Generator Functions for SSE Streams ===
async def _create_general_sse_generator(self, session_id: str, queue: asyncio.Queue):
"""Generator for GET /mcp server push streams."""
queue_removed = False
try:
while True:
try:
if session_id not in self.client_sessions:
logger.warning(f"General SSE stream generator: Session {session_id} closed.")
break
message = await asyncio.wait_for(queue.get(), timeout=60.0)
if message == STREAM_END_MARKER:
logger.debug(f"General SSE stream received end marker [Session ID: {session_id}]")
break
if isinstance(message, dict) and ("result" in message or "error" in message) and "id" in message:
logger.warning(f"Attempted to send response on GET stream, blocked [Session ID: {session_id}]: {message}")
queue.task_done()
continue
# TODO: Event ID for recovery?
yield {"event": "message", "data": json.dumps(message)}
queue.task_done()
except asyncio.TimeoutError:
if session_id not in self.client_sessions:
logger.warning(f"General SSE stream generator (timeout): Session {session_id} closed.")
break
yield {"event": "ping", "data": "keepalive"}
continue
except asyncio.CancelledError:
logger.info(f"General SSE stream cancelled [Session ID: {session_id}]")
raise
except Exception as e:
logger.error(f"General SSE stream error [Session ID: {session_id}]: {str(e)}", exc_info=True)
break
finally:
logger.info(f"General SSE stream ended [Session ID: {session_id}]")
if not queue_removed and session_id in self.client_sessions:
session = self.client_sessions[session_id]
if session.get("general_sse_queues") is not None:
try:
session["general_sse_queues"].remove(queue)
queue_removed = True
logger.debug(f"General SSE queue removed from session [Session ID: {session_id}]")
except ValueError:
logger.warning(f"Failed to remove general SSE queue (not found) [Session ID: {session_id}]")
except Exception as ce:
logger.error(f"Error removing general SSE queue [Session ID: {session_id}]: {ce}")
while not queue.empty():
try: queue.get_nowait(); queue.task_done()
except asyncio.QueueEmpty: break
async def _create_request_sse_generator(self, session_id: str, request_id: Any, queue: asyncio.Queue):
"""Generator for POST /mcp request-response streams."""
queue_removed = False
try:
while True:
try:
if session_id not in self.client_sessions or \
request_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
logger.warning(f"Request SSE stream generator: Session/Request queue closed [Session ID: {session_id}, Request ID: {request_id}]")
break
message = await asyncio.wait_for(queue.get(), timeout=120.0) # Longer timeout for requests?
if message == STREAM_END_MARKER:
logger.debug(f"Request SSE stream received end marker [Session ID: {session_id}, Request ID: {request_id}]")
break
# TODO: Event ID for parts?
yield {"event": "message", "data": json.dumps(message)}
queue.task_done()
except asyncio.TimeoutError:
if session_id not in self.client_sessions or \
request_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
logger.warning(f"Request SSE stream generator (timeout): Session/Request queue closed [Session ID: {session_id}, Request ID: {request_id}]")
break
logger.debug(f"Request SSE stream timed out waiting for message/end [Session ID: {session_id}, Request ID: {request_id}]")
# Unlike general stream, timeout here might indicate an issue or just long processing.
# Continue waiting for the STREAM_END_MARKER.
continue
except asyncio.CancelledError:
logger.info(f"Request SSE stream cancelled [Session ID: {session_id}, Request ID: {request_id}]")
raise
except Exception as e:
logger.error(f"Request SSE stream error [Session ID: {session_id}, Request ID: {request_id}]: {str(e)}", exc_info=True)
break
finally:
logger.info(f"Request SSE stream ended [Session ID: {session_id}, Request ID: {request_id}]")
if not queue_removed and session_id in self.client_sessions:
session = self.client_sessions[session_id]
if session.get("request_queues") is not None:
if session["request_queues"].pop(request_id, None):
queue_removed = True
logger.debug(f"Request SSE queue removed from session [Session ID: {session_id}, Request ID: {request_id}]")
else:
logger.warning(f"Failed to remove request SSE queue (not found) [Session ID: {session_id}, Request ID: {request_id}]")
while not queue.empty():
try: queue.get_nowait(); queue.task_done()
except asyncio.QueueEmpty: break
# === Core Request Processing Logic ===
async def _process_request_and_respond(
self, request: Request, body: Dict, session_id: str, message_id: Any,
response_queue: Optional[asyncio.Queue], # Queue ONLY for streaming responses
is_stream: bool # True if response should go via SSE queue
):
"""Processes client method calls and prepares response/error payload or sends to queue.
Returns payload for non-streaming, returns None for streaming (uses queue).
Raises HTTPException for non-streaming errors that need specific status codes.
"""
logger.info(f"Entering _process_request_and_respond for method '{body.get('method')}'...")
method = body.get("method")
params = body.get("params", {})
response_payload = None # Holds the 'result' or 'error' part of JSON-RPC
try:
# --- Handle Method Calls ---
if method == "mcp/listOfferings":
tools = await self.mcp_server.list_tools()
tools_json = self._format_tools(tools)
resources = await self.mcp_server.list_resources()
resources_json = self._format_resources(resources)
prompts = await self.mcp_server.list_prompts()
prompts_json = self._format_prompts(prompts)
response_payload = {"tools": tools_json, "resources": resources_json, "prompts": prompts_json}
elif method == "mcp/listTools" or method == "tools/list":
tools = await self.mcp_server.list_tools()
response_payload = {"tools": self._format_tools(tools)}
elif method == "mcp/listResources":
resources = await self.mcp_server.list_resources()
response_payload = {"resources": self._format_resources(resources)}
elif method == "mcp/listPrompts":
prompts = await self.mcp_server.list_prompts()
response_payload = {"prompts": self._format_prompts(prompts)}
elif method == "mcp/callTool" or method == "tools/call":
tool_name = params.get("name")
arguments = params.get("arguments", {})
if not tool_name:
# For non-streaming, raise HTTPException; for streaming, send error via queue
error_detail = "Invalid params: tool name ('name') is required"
if is_stream and response_queue:
error_resp = {"jsonrpc": "2.0", "id": message_id, "error": {"code": -32602, "message": error_detail}}
await response_queue.put(error_resp)
# No return here for stream, let finally handle end marker
else:
raise HTTPException(status_code=400, detail=error_detail)
return # Exit after handling error
# --- Tool Calling ---
if is_stream and response_queue:
# Background task handles putting results/errors in queue
logger.info(f"Launching stream tool task [Session: {session_id}, Req: {message_id}, Tool: {tool_name}]")
asyncio.create_task(self._execute_stream_tool_wrapper(
tool_name, arguments, message_id, session_id, request, response_queue
))
# Returns None, caller (_handle_client_post) returns EventSourceResponse
return
else:
# Execute tool directly for non-streaming response
logger.info(f"Executing non-stream tool [Session: {session_id}, Req: {message_id}, Tool: {tool_name}]")
# Note: call_tool now raises ValueError on internal errors
result = await self.call_tool(tool_name, arguments, request, None) # No callback needed
logger.debug(f"Raw result from non-stream call_tool: {result}")
response_payload = self._format_tool_call_result(result)
else:
# Method not found
error_detail = f"Method not found: {method}"
if is_stream and response_queue:
error_resp = {"jsonrpc": "2.0", "id": message_id, "error": {"code": -32601, "message": error_detail}}
await response_queue.put(error_resp)
else:
raise HTTPException(status_code=405, detail=error_detail)
return # Exit after handling error
# --- Prepare final response payload (only if not streaming and successful) ---
if response_payload is not None:
final_response = {"jsonrpc": "2.0", "id": message_id, "result": response_payload}
if is_stream and response_queue: # Should not happen if response_payload is set
logger.error("Logic error: response_payload set for streaming call?")
await response_queue.put(final_response) # Send anyway?
elif not is_stream:
logger.debug(f"Returning successful non-stream payload for {method}")
return final_response # Return dict for JSONResponse
except Exception as e:
# Handles errors raised by call_tool (ValueError) or other unexpected issues
logger.error(f"Error processing request [Session: {session_id}, Req: {message_id}, Method: {method}]: {str(e)}", exc_info=True)
error_code = -32000
error_message = f"Internal server error: {str(e)}"
status_code = 500 # Default for unexpected errors
if isinstance(e, HTTPException):
# If it was an HTTPException raised earlier (e.g., 400, 405)
error_message = e.detail
status_code = e.status_code
error_code = -32000 # Keep generic JSON-RPC code for now
elif isinstance(e, ValueError):
# Errors from call_tool (tool not found, execution error)
error_message = str(e)
status_code = 500 # Treat tool execution errors as internal server errors
error_code = -32000 # Or a custom tool error code?
error_response_payload = {"code": error_code, "message": error_message}
if is_stream and response_queue:
# Send error via queue for streaming calls
final_error_response = {"jsonrpc": "2.0", "id": message_id, "error": error_response_payload}
logger.debug(f"Putting error response into stream queue [Session: {session_id}, Req: {message_id}]")
await response_queue.put(final_error_response)
# Returns None, let finally send end marker
return
else:
# For non-streaming, raise HTTPException to set status code
logger.debug(f"Raising HTTPException for non-stream error (Status: {status_code})")
raise HTTPException(status_code=status_code, detail=error_message)
finally:
# If this was a streaming call, ensure the end marker is sent.
# This runs even if the processing returns early (e.g., after launching task or handling error).
if is_stream and response_queue:
logger.debug(f"Putting stream end marker [Session: {session_id}, Req: {message_id}]")
await response_queue.put(STREAM_END_MARKER)
async def _execute_stream_tool_wrapper(
self, tool_name: str, arguments: Dict, message_id: Any, session_id: str,
request: Request, response_queue: asyncio.Queue
):
"""Wraps stream-capable tool calls, handles callback, puts results/errors into queue."""
logger.info(f"Entering _execute_stream_tool_wrapper for tool '{tool_name}'...")
try:
logger.debug(f"Executing stream tool wrapper [Session: {session_id}, Req: {message_id}, Tool: {tool_name}]")
async def stream_callback(content, metadata=None):
logger.debug(f"Stream callback received content [Session: {session_id}, Req: {message_id}]")
partial_result_formatted = self._format_tool_call_result(content)
# Check session/queue validity before putting
if session_id not in self.client_sessions or \
message_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
logger.warning(f"Stream callback: Session/Queue closed, cannot send partial result [Session: {session_id}, Req: {message_id}]")
return
# Send progress notification
progress_notification = {
"jsonrpc": "2.0",
"method": "tools/progress",
"params": {
"requestId": message_id,
"toolName": tool_name,
"progress": partial_result_formatted,
}
}
try:
await response_queue.put(progress_notification)
except Exception as e:
logger.error(f"Stream callback failed to send progress: {str(e)}")
# Handle visualization data
if metadata and "visualization" in metadata:
await self.send_visualization_data(session_id, message_id, metadata["visualization"])
# --- Call Tool ---
kwargs = dict(arguments)
# Simplification: Assume tool supports callback if streaming requested
kwargs['callback'] = stream_callback
# call_tool handles its own internal errors and raises ValueError
result = await self.call_tool(tool_name, kwargs, request, stream_callback)
logger.debug(f"Stream wrapper received final result from call_tool: {result}")
# --- Send Final Result ---
if session_id not in self.client_sessions or \
message_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
logger.warning(f"Stream tool finished but Session/Queue closed [Session: {session_id}, Req: {message_id}]")
return # Cannot send final result
final_result_formatted = self._format_tool_call_result(result)
final_message = {
"jsonrpc": "2.0",
"id": message_id,
"result": final_result_formatted
}
logger.debug(f"Putting final stream result into queue [Session: {session_id}, Req: {message_id}]")
await response_queue.put(final_message)
logger.info(f"Stream tool execution successful [Session: {session_id}, Req: {message_id}]")
except Exception as e:
# Catches errors from call_tool (ValueError) or other wrapper issues
logger.error(f"Error during stream tool execution wrapper [Session: {session_id}, Req: {message_id}]: {str(e)}", exc_info=True)
# Check session/queue validity before sending error
if session_id not in self.client_sessions or \
message_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
logger.warning(f"Stream tool failed but Session/Queue closed [Session: {session_id}, Req: {message_id}]")
return # Cannot send error
error_code = -32000
error_message = f"Tool execution error: {str(e)}"
if isinstance(e, ValueError):
error_code = -32602 # Or -32000?
error_message = str(e)
error_response = {
"jsonrpc": "2.0",
"id": message_id,
"error": { "code": error_code, "message": error_message }
}
try:
await response_queue.put(error_response)
except Exception as qe:
logger.error(f"Failed to put error response into stream queue: {qe}")
# No finally block needed here, handled by _process_request_and_respond
async def call_tool(self, tool_name, arguments, request, callback: Optional[callable] = None):
"""Finds and executes the target tool function/method.
Raises ValueError on tool not found or execution error.
"""
logger.info(f"Entering call_tool for tool '{tool_name}'...")
# Log args excluding callback
log_args = {k: v for k, v in arguments.items() if k != 'callback'}
logger.info(f"Executing tool: {tool_name}, Args: {json.dumps(log_args, ensure_ascii=False, default=str)}")
recent_query = self._extract_recent_query(request)
# Tool mapping might be needed if client uses different names
tool_mapping = {
# Example: "clientFacingName": "internalFunctionName"
"status": "mcp_doris_status",
"health": "mcp_doris_health",
# Add other mappings if needed, ensure consistency with tool_initializer
"nl2sql_query": "mcp_doris_nl2sql_query",
"nl2sql_query_stream": "mcp_doris_nl2sql_query_stream",
"list_database_tables": "mcp_doris_list_database_tables",
"explain_table": "mcp_doris_explain_table",
"get_nl2sql_status": "mcp_doris_get_nl2sql_status",
"refresh_metadata": "mcp_doris_refresh_metadata",
"sql_optimize": "mcp_doris_sql_optimize",
"fix_sql": "mcp_doris_fix_sql",
"count_chars": "mcp_doris_count_chars",
"exec_query": "mcp_doris_exec_query",
"get_schema_list": "mcp_doris_get_schema_list", # Deprecated?
"save_metadata": "mcp_doris_save_metadata", # Likely internal
"get_metadata": "mcp_doris_get_metadata", # Likely internal
"analyze_query_result": "mcp_doris_analyze_query_result", # Internal?
"generate_sql": "mcp_doris_generate_sql", # Likely internal
"explain_sql": "mcp_doris_explain_sql", # Internal?
"modify_sql": "mcp_doris_modify_sql", # Internal?
"parse_query": "mcp_doris_parse_query", # Internal?
"identify_query_type": "mcp_doris_identify_query_type", # Internal?
"validate_sql_syntax": "mcp_doris_validate_sql_syntax", # Internal?
"check_sql_security": "mcp_doris_check_sql_security", # Internal?
"find_similar_examples": "mcp_doris_find_similar_examples", # Internal?
"find_similar_history": "mcp_doris_find_similar_history", # Internal?
"calculate_query_similarity": "mcp_doris_calculate_query_similarity", # Internal?
"adapt_similar_query": "mcp_doris_adapt_similar_query", # Internal?
"get_nl2sql_prompt": "mcp_doris_get_nl2sql_prompt" # Internal?
}
mapped_tool_name = tool_mapping.get(tool_name, tool_name)
try:
# 1. Find the registered tool instance/function from FastMCP
tool_instance = None
mcp = self.app.state.mcp if hasattr(self.app.state, 'mcp') else self.mcp_server
registered_tools = await mcp.list_tools()
for tool in registered_tools:
# The tool object returned by list_tools might be the wrapper function
# defined in tool_initializer. We need its name.
tool_registered_name = getattr(tool, 'name', getattr(tool, '__name__', None))
if tool_registered_name == tool_name: # Match against the name used in @mcp.tool
tool_instance = tool # This is likely the wrapper function itself
logger.debug(f"Found registered tool wrapper: {tool_registered_name}")
break
if not tool_instance:
# Fallback: Try importing directly (less ideal as it bypasses registration)
logger.warning(f"Tool '{tool_name}' not found in registered tools, trying direct import of {mapped_tool_name}")
try:
import doris_mcp_server.tools.mcp_doris_tools as mcp_tools
tool_instance = getattr(mcp_tools, mapped_tool_name, None)
if not tool_instance or not callable(tool_instance):
raise ValueError(f"Tool function {mapped_tool_name} not found or not callable in mcp_doris_tools.")
logger.debug(f"Using directly imported tool function: {mapped_tool_name}")
# If using direct import, FastMCP context (ctx) is not available
# We need to pass args directly
processed_args = self._process_tool_arguments(mapped_tool_name, arguments, recent_query)
# Inject callback if provided and applicable
if callback and mapped_tool_name.endswith("_stream"):
processed_args['callback'] = callback
elif callback:
processed_args.pop('callback', None)
result = await tool_instance(**processed_args)
logger.debug(f"Raw result from directly imported tool '{mapped_tool_name}': {result}")
return result
except (ImportError, AttributeError, ValueError) as import_err:
logger.error(f"Failed to find or import tool: {tool_name} / {mapped_tool_name}. Error: {import_err}")
raise ValueError(f"Tool '{tool_name}' not found or failed to import.") from import_err
# 2. If found via registration, execute using FastMCP's mechanism (if possible)
# or simulate the context passing if tool_instance is the wrapper.
# The wrapper expects a Context object.
logger.debug(f"Executing registered tool wrapper '{tool_name}'")
# We need to manually create a mock or simplified Context if FastMCP doesn't handle this automatically
# For simplicity, let's try passing parameters directly if the wrapper handles it.
# Ideally, FastMCP would handle the execution via mcp.call_tool(tool_name, params=...) if available.
# Let's assume the wrapper function handles **kwargs or a Context object.
# Create a pseudo-context or just pass params
# Method 1: Pass params directly (assuming wrapper handles it)
# processed_args = self._process_tool_arguments(mapped_tool_name, arguments, recent_query)
# if callback:
# processed_args['callback'] = callback
# result = await tool_instance(**processed_args) # This likely won't work if it expects Context
# Method 2: Create a Context-like object (Requires Context class import)
# from mcp.server.fastmcp import Context # Make sure imported
# pseudo_ctx = Context(mcp=mcp, request=request, params=arguments, tool=tool_instance)
# result = await tool_instance(pseudo_ctx)
# Method 3: Use mcp.call_tool internal method if accessible and appropriate
# This is speculative based on potential FastMCP internals
if hasattr(mcp, 'call_tool_by_name'): # Hypothetical method
logger.debug("Attempting execution via mcp.call_tool_by_name")
pseudo_ctx_params = arguments # Pass client args
# pseudo_ctx_params['_request'] = request # Maybe pass request?
if callback: pseudo_ctx_params['callback'] = callback # Pass callback?
result = await mcp.call_tool_by_name(tool_name, params=pseudo_ctx_params)
logger.debug(f"Result from mcp.call_tool_by_name: {result}")
else:
# Fallback to manual context simulation if no direct call method exists
logger.debug("Falling back to manual context simulation for tool wrapper execution")
from mcp.server.fastmcp import Context # Ensure imported
# Prepare params for context, including potentially callback
context_params = dict(arguments)
if callback: context_params['callback'] = callback
pseudo_ctx = Context(mcp=mcp, request=request, params=context_params, tool=tool_instance)
result = await tool_instance(pseudo_ctx) # Call the wrapper with simulated context
logger.debug(f"Result from manual context simulation: {result}")
logger.debug(f"Raw result received in call_tool from registered tool '{tool_name}': {result}")
return result
except Exception as e:
logger.error(f"Exception during call_tool for '{tool_name}': {str(e)}", exc_info=True)
raise ValueError(f"Error executing tool '{tool_name}': {str(e)}") from e
# === Helper Methods (Formatting, Session Cleanup, etc.) ===
def _format_tools(self, tools):
# Helper to format tool list for responses
# Based on mcp/listTools structure
tools_json = []
for tool in tools:
# Assuming tools from list_tools are the wrapper functions
tool_registered_name = getattr(tool, 'name', getattr(tool, '__name__', None))
if not tool_registered_name:
logger.warning(f"Could not determine name for tool object: {tool}")
continue
# Need a way to get description and schema associated with the wrapper
# This might require inspecting the mcp instance's internal storage
mcp = self.app.state.mcp if hasattr(self.app.state, 'mcp') else self.mcp_server
# Hypothetical internal access - THIS IS FRAGILE
tool_spec = mcp.tools.get(tool_registered_name) if hasattr(mcp, 'tools') else None
description = ""
input_schema = {"type": "object", "properties": {}, "required": []}
if tool_spec and hasattr(tool_spec, 'description'):
description = tool_spec.description
if tool_spec and hasattr(tool_spec, 'parameters'): # Assuming parameters holds the JSON schema
input_schema = tool_spec.parameters
tools_json.append({
"name": tool_registered_name,
"description": description,
"inputSchema": input_schema
})
return tools_json
def _format_resources(self, resources):
# Helper to format resource list
return [res.model_dump() if hasattr(res, "model_dump") else res for res in resources]
def _format_prompts(self, prompts):
# Helper to format prompt list
return [prompt.model_dump() if hasattr(prompt, "model_dump") else prompt for prompt in prompts]
def _format_tool_call_result(self, result: Any) -> Dict[str, Any]:
# Helper to format tool results into MCP Content format
content_list = []
if isinstance(result, str):
try:
# If it looks like the tool already returned the full JSON RPC like structure
parsed_json = json.loads(result)
if isinstance(parsed_json, dict) and 'content' in parsed_json and isinstance(parsed_json['content'], list):
logger.debug("Tool result already seems formatted with 'content', using as is.")
return parsed_json # Use the structure directly
else:
# Assume it's JSON content, wrap it
content_list.append({"type": "json", "json": parsed_json})
except json.JSONDecodeError:
# Not JSON, treat as text
content_list.append({"type": "text", "text": result})
elif isinstance(result, (dict, list)):
# If result is already a dict with a 'content' list, use it directly
if isinstance(result, dict) and 'content' in result and isinstance(result['content'], list):
logger.debug("Tool result dictionary has 'content', using as is.")
return result # Use the structure directly
else:
# Otherwise, assume it's JSON content to be wrapped
content_list.append({"type": "json", "json": result})
elif result is None:
# Handle None result, maybe return empty content or specific type?
logger.warning("_format_tool_call_result received None result")
content_list.append({"type": "text", "text": ""}) # Example: empty text
else:
# Other types, convert to string and wrap as text
content_list.append({"type": "text", "text": str(result)})
# Always return a dict with a 'content' key containing a list
return {"content": content_list}
def _process_tool_arguments(self, tool_name, arguments, recent_query):
# Helper to process tool arguments, including random_string fallback
# Note: Ensure callback is NOT passed here
processed_args = dict(arguments)
processed_args.pop('callback', None) # Explicitly remove callback
if "random_string" in arguments and tool_name.startswith("mcp_doris_"):
random_string = processed_args.pop("random_string", "") # Remove from processed too
logger.debug(f"Processing random_string '{random_string}' for tool {tool_name}")
# ... (rest of random_string logic as before) ...
# Example for exec_query:
if tool_name == "mcp_doris_exec_query" and not processed_args.get("sql"):
sql_fallback = random_string or recent_query
# ... (logic to extract SQL from fallback) ...
if sql_extracted:
processed_args["sql"] = sql_extracted
else:
logger.warning(f"Missing sql for {tool_name}, and fallback failed.")
# ... (logic for table_name fallback) ...
return processed_args
def _extract_recent_query(self, request: Request) -> Optional[str]:
# Helper to extract recent user query from request
# (Implementation as provided previously)
try:
# Try to extract message history from request body
body = None
body_bytes = getattr(request, "_body", None)
if body_bytes:
try:
body = json.loads(body_bytes)
except: pass
if not body: body = getattr(request, "_json", {})
messages = body.get("params", {}).get("messages", [])
if messages:
for msg in reversed(messages):
if msg.get("role") == "user": return msg.get("content", "")
message = body.get("params", {}).get("message", {})
if message and message.get("role") == "user": return message.get("content", "")
return None
except Exception as e:
logger.error(f"Error extracting recent query: {str(e)}")
return None
async def _cleanup_session_resources(self, session_id: str, session_data: Dict):
# Helper to clean up queues when session is deleted
logger.info(f"Cleaning up resources for session [Session ID: {session_id}]")
# Close general SSE queues
general_queues = session_data.get("general_sse_queues", [])
for queue in general_queues:
try:
await queue.put(STREAM_END_MARKER)
except Exception as e:
logger.warning(f"Error putting end marker in general queue for session {session_id}: {e}")
# Close request-specific SSE queues
request_queues = session_data.get("request_queues", {})
for req_id, queue in request_queues.items():
try:
await queue.put(STREAM_END_MARKER)
except Exception as e:
logger.warning(f"Error putting end marker in request queue {req_id} for session {session_id}: {e}")
logger.info(f"Finished cleaning resources for session {session_id}")
# This method might belong in the main app or a shared utility if needed by both servers
# async def cleanup_idle_sessions(self):
# # ... (implementation - needs access to self.client_sessions) ...
# pass
# This method might belong in the main app or a shared utility
# async def broadcast_message(self, message):
# # ... (implementation - needs access to self.client_sessions of BOTH servers?) ...
# pass
# This method is specific to streamable http tool calls
async def send_visualization_data(self, session_id: str, request_id: Any, visualization_data: Any):
"""Sends visualization data as a notification on the request stream."""
if session_id not in self.client_sessions:
logger.warning(f"Cannot send visualization: Session {session_id} not found.")
return
queue = self.client_sessions.get(session_id, {}).get("request_queues", {}).get(request_id)
if not queue:
logger.warning(f"Cannot send visualization: Request queue {request_id} not found for session {session_id}.")
return
notification = {
"jsonrpc": "2.0",
"method": "tools/visualization",
"params": visualization_data
}
try:
await queue.put(notification)
logger.info(f"Sent visualization notification [Session: {session_id}, Req: {request_id}]")
except Exception as e:
logger.error(f"Error sending visualization notification [Session: {session_id}, Req: {request_id}]: {e}")
# This might belong in main app or shared utility
# async def send_periodic_updates(self):
# # ... (implementation) ...
# pass
# End of class DorisMCPStreamableServer

View File

@@ -1,25 +1,9 @@
from .mcp_doris_tools import (
mcp_doris_exec_query,
mcp_doris_get_table_schema,
mcp_doris_get_db_table_list,
mcp_doris_get_db_list,
mcp_doris_get_table_comment,
mcp_doris_get_table_column_comments,
mcp_doris_get_table_indexes,
mcp_doris_get_recent_audit_logs,
mcp_doris_get_catalog_list
)
"""
MCP Tools Package - Contains all MCP tool implementations.
# The __all__ list should reflect the registered tool names,
# even though the implementation functions have the prefix.
__all__ = [
"exec_query",
"get_table_schema",
"get_db_table_list",
"get_db_list",
"get_table_comment",
"get_table_column_comments",
"get_table_indexes",
"get_recent_audit_logs",
"get_catalog_list"
]
This package includes:
- Doris database tools
- Resource managers
- Prompt managers
- Tool registration and initialization
"""

View File

@@ -1,230 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Doris MCP Tool Implementations
Includes exec_query and new tools based on schema_extractor.
"""
import os
import time
import json
import logging
from typing import Dict, Any
import pandas as pd
# --- Use absolute imports ---
from doris_mcp_server.utils.schema_extractor import MetadataExtractor
from doris_mcp_server.utils.sql_executor_tools import execute_sql_query
# Get logger
logger = logging.getLogger("doris-mcp-tools")
# --- Helper Function to format response ---
def _format_response(success: bool, result: Any = None, error: str = None, message: str = "") -> Dict[str, Any]:
response_data = {
"success": success,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
if success and result is not None:
# Handle DataFrame serialization
if isinstance(result, pd.DataFrame):
try:
# Convert DataFrame to JSON records format
response_data["result"] = json.loads(result.to_json(orient='records', date_format='iso'))
except Exception as df_err:
logger.error(f"DataFrame to JSON conversion failed: {df_err}")
# Fallback or specific error handling for DataFrame
response_data["result"] = {"error": "Failed to serialize DataFrame result"}
response_data["success"] = False # Mark as failed if serialization fails
response_data["error"] = f"DataFrame serialization error: {str(df_err)}"
else:
response_data["result"] = result
response_data["message"] = message or "Operation successful" # Translated: Operation successful
elif not success:
response_data["error"] = error or "Unknown error" # Translated: Unknown error
response_data["message"] = message or "Operation failed" # Translated: Operation failed
return {
"content": [
{
"type": "text",
"text": json.dumps(response_data, ensure_ascii=False, default=str) # Use default=str for non-serializable types
}
]
}
async def mcp_doris_exec_query(sql: str = None, db_name: str = None, catalog_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]:
"""
Executes an SQL query and returns the result with catalog federation support.
Args:
sql (str): The SQL query to execute. MUST use three-part naming for table references:
- Internal tables: internal.db_name.table_name (e.g., "SELECT * FROM internal.ssb.customer")
- External tables: catalog_name.db_name.table_name (e.g., "SELECT * FROM mysql.ssb.customer")
- Cross-catalog queries: "SELECT * FROM mysql.ssb.customer m JOIN internal.ssb.orders o ON m.id = o.customer_id"
Examples:
- Query internal catalog: "SELECT COUNT(*) FROM internal.ssb.customer"
- Query MySQL catalog: "SELECT COUNT(*) FROM mysql.ssb.customer"
- Cross-catalog join: "SELECT * FROM internal.ssb.customer c JOIN mysql.test.user_info u ON c.id = u.customer_id"
db_name (str, optional): Target database name. Only used for connection context, table names in SQL must be fully qualified.
catalog_name (str, optional): Reference catalog name for context. Does not affect SQL execution - table names in SQL must be fully qualified.
Available catalogs can be found using get_catalog_list tool.
max_rows (int, optional): Maximum number of rows to return. Defaults to 100.
timeout (int, optional): Query timeout in seconds. Defaults to 30.
Returns:
Dict[str, Any]: A dictionary containing the query result or an error.
"""
logger.info(f"MCP Tool Call: mcp_doris_exec_query, SQL: {sql}, DB: {db_name}, Catalog: {catalog_name}, MaxRows: {max_rows}, Timeout: {timeout}")
try:
if not sql:
return _format_response(success=False, error="SQL statement not provided", message="Please provide the SQL statement to execute")
# Build parameters to pass to execute_sql_query
exec_ctx = {
"params": {
"sql": sql,
"db_name": db_name,
"catalog_name": catalog_name,
"max_rows": max_rows,
"timeout": timeout
}
}
# Directly call execute_sql_query to execute the query
exec_result = await execute_sql_query(exec_ctx)
# The format returned by execute_sql_query is {'content': [{'type': 'text', 'text': json_string}]}
# Need to parse the internal JSON string
if exec_result and 'content' in exec_result and len(exec_result['content']) > 0 and 'text' in exec_result['content'][0]:
try:
# Parse JSON string
result_data = json.loads(exec_result['content'][0]['text'])
# Directly return the parsed result obtained from execute_sql_query
# This result is already in the format {"success": ..., "data": ..., "columns": ...} or {"success": false, "error": ...}
# _format_response would wrap it again, but here we directly use the parsed data
# Note: This changes the original return structure of this function; it now directly returns the output of sql_executor
# If the _format_response wrapper needs to be maintained, the code below needs adjustment
return {
"content": [
{
"type": "text",
"text": json.dumps(result_data, ensure_ascii=False, default=str)
}
]
}
except json.JSONDecodeError as json_err:
logger.error(f"Failed to parse execute_sql_query result: {json_err}")
return _format_response(success=False, error=str(json_err), message="Error parsing SQL execution result")
except Exception as parse_err:
logger.error(f"Unexpected error occurred while processing execute_sql_query result: {parse_err}", exc_info=True)
return _format_response(success=False, error=str(parse_err), message="Unknown error occurred while processing SQL execution result")
else:
logger.error(f"execute_sql_query returned an unexpected format: {exec_result}")
return _format_response(success=False, error="SQL executor returned invalid format", message="Internal error executing SQL query")
except Exception as e:
logger.error(f"MCP tool execution failed mcp_doris_exec_query: {str(e)}", exc_info=True)
return _format_response(success=False, error=str(e), message="Error executing SQL query")
async def mcp_doris_get_table_schema(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
logger.info(f"MCP Tool Call: mcp_doris_get_table_schema, Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
if not table_name:
return _format_response(success=False, error="Missing table_name parameter")
try:
extractor = MetadataExtractor(db_name=db_name, catalog_name=catalog_name)
schema = extractor.get_table_schema(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
if not schema:
return _format_response(success=False, error="Table not found or has no columns", message=f"Could not get schema for table {catalog_name or 'default'}.{db_name or extractor.db_name}.{table_name}")
return _format_response(success=True, result=schema)
except Exception as e:
logger.error(f"MCP tool execution failed mcp_doris_get_table_schema: {str(e)}", exc_info=True)
return _format_response(success=False, error=str(e), message="Error getting table schema")
async def mcp_doris_get_db_table_list(db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
logger.info(f"MCP Tool Call: mcp_doris_get_db_table_list, DB: {db_name}, Catalog: {catalog_name}")
try:
extractor = MetadataExtractor(db_name=db_name, catalog_name=catalog_name)
tables = extractor.get_database_tables(db_name=db_name, catalog_name=catalog_name)
return _format_response(success=True, result=tables)
except Exception as e:
logger.error(f"MCP tool execution failed mcp_doris_get_db_table_list: {str(e)}", exc_info=True)
return _format_response(success=False, error=str(e), message="Error getting database table list")
async def mcp_doris_get_db_list(catalog_name: str = None) -> Dict[str, Any]:
logger.info(f"MCP Tool Call: mcp_doris_get_db_list, Catalog: {catalog_name}")
try:
extractor = MetadataExtractor(catalog_name=catalog_name)
databases = extractor.get_all_databases(catalog_name=catalog_name)
return _format_response(success=True, result=databases)
except Exception as e:
logger.error(f"MCP tool execution failed mcp_doris_get_db_list: {str(e)}", exc_info=True)
return _format_response(success=False, error=str(e), message="Error getting database list")
async def mcp_doris_get_table_comment(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
logger.info(f"MCP Tool Call: mcp_doris_get_table_comment, Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
if not table_name:
return _format_response(success=False, error="Missing table_name parameter")
try:
extractor = MetadataExtractor(db_name=db_name, catalog_name=catalog_name)
comment = extractor.get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
return _format_response(success=True, result=comment)
except Exception as e:
logger.error(f"MCP tool execution failed mcp_doris_get_table_comment: {str(e)}", exc_info=True)
return _format_response(success=False, error=str(e), message="Error getting table comment")
async def mcp_doris_get_table_column_comments(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
logger.info(f"MCP Tool Call: mcp_doris_get_table_column_comments, Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
if not table_name:
return _format_response(success=False, error="Missing table_name parameter")
try:
extractor = MetadataExtractor(db_name=db_name, catalog_name=catalog_name)
comments = extractor.get_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
return _format_response(success=True, result=comments)
except Exception as e:
logger.error(f"MCP tool execution failed mcp_doris_get_table_column_comments: {str(e)}", exc_info=True)
return _format_response(success=False, error=str(e), message="Error getting column comments")
async def mcp_doris_get_table_indexes(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
logger.info(f"MCP Tool Call: mcp_doris_get_table_indexes, Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
if not table_name:
return _format_response(success=False, error="Missing table_name parameter")
try:
extractor = MetadataExtractor(db_name=db_name, catalog_name=catalog_name)
indexes = extractor.get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
return _format_response(success=True, result=indexes)
except Exception as e:
logger.error(f"MCP tool execution failed mcp_doris_get_table_indexes: {str(e)}", exc_info=True)
return _format_response(success=False, error=str(e), message="Error getting table indexes")
async def mcp_doris_get_recent_audit_logs(days: int = 7, limit: int = 100) -> Dict[str, Any]:
logger.info(f"MCP Tool Call: mcp_doris_get_recent_audit_logs, Days: {days}, Limit: {limit}")
try:
extractor = MetadataExtractor()
logs_df = extractor.get_recent_audit_logs(days=days, limit=limit)
return _format_response(success=True, result=logs_df)
except Exception as e:
logger.error(f"MCP tool execution failed mcp_doris_get_recent_audit_logs: {str(e)}", exc_info=True)
return _format_response(success=False, error=str(e), message="Error getting audit logs")
async def mcp_doris_get_catalog_list() -> Dict[str, Any]:
"""
Get Doris catalog list
Returns:
Dict[str, Any]: Dictionary containing catalog list or error information
"""
logger.info(f"MCP Tool Call: mcp_doris_get_catalog_list")
try:
extractor = MetadataExtractor()
catalogs = extractor.get_catalog_list()
return _format_response(success=True, result=catalogs, message="Successfully retrieved catalog list")
except Exception as e:
logger.error(f"MCP tool execution failed mcp_doris_get_catalog_list: {str(e)}", exc_info=True)
return _format_response(success=False, error=str(e), message="Error getting catalog list")

View File

@@ -0,0 +1,455 @@
"""
Apache Doris MCP Prompts Manager
Provides standardized management of query templates and intelligent prompts
"""
from datetime import datetime
from typing import Any
from mcp.types import (
GetPromptResult,
Prompt,
PromptArgument,
PromptMessage,
TextContent,
)
from ..utils.db import DorisConnectionManager
class PromptTemplate:
"""Prompt template"""
def __init__(
self,
name: str,
description: str,
template: str,
arguments: list[PromptArgument] = None,
category: str = "general",
):
self.name = name
self.description = description
self.template = template
self.arguments = arguments or []
self.category = category
self.created_at = datetime.now()
def render(self, arguments: dict[str, Any]) -> str:
"""Render template content"""
content = self.template
for key, value in arguments.items():
placeholder = f"{{{key}}}"
content = content.replace(placeholder, str(value))
return content
class DorisPromptsManager:
"""Apache Doris Prompts Manager"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
self.templates = self._init_prompt_templates()
def _init_prompt_templates(self) -> dict[str, PromptTemplate]:
"""Initialize prompt templates"""
templates = {}
# Sales data analysis template
templates["sales_analysis"] = PromptTemplate(
name="sales_analysis",
description="Sales data analysis query template for generating sales statistics and trend analysis queries",
template="""Please help me analyze sales data with the following requirements:
Analysis time range: {date_range}
{product_filter}
{region_filter}
Please generate SQL queries to analyze the following dimensions:
1. Total sales amount and order quantity
2. Sales trends by time dimension
3. Top-selling product rankings
4. Sales personnel performance statistics
Data table structure reference:
- Order table: Contains order ID, customer ID, salesperson ID, order amount, order time and other fields
- Product table: Contains product ID, product name, product category, price and other fields
- Customer table: Contains customer ID, customer name, region and other fields
Please ensure query results are easy to understand and analyze.""",
arguments=[
PromptArgument(
name="date_range",
description="Date range for analysis, such as 'Q1 2024' or 'last 30 days'",
required=True,
),
PromptArgument(
name="product_category",
description="Product category filter condition, such as 'electronics'",
required=False,
),
PromptArgument(
name="region",
description="Sales region filter condition, such as 'East China'",
required=False,
),
],
category="business_analysis",
)
# User behavior analysis template
templates["user_behavior_analysis"] = PromptTemplate(
name="user_behavior_analysis",
description="User behavior analysis query template for analyzing user activity patterns and preferences",
template="""Please help me analyze user behavior data, analysis objectives:
User segment: {user_segment}
{behavior_filter}
Analysis period: {time_period}
Please generate SQL queries to analyze the following aspects:
1. User activity statistics (DAU, MAU)
2. User behavior path analysis
3. Feature usage preference statistics
4. User retention rate analysis
Data table structure reference:
- User table: Contains user ID, registration time, user type, region and other fields
- Behavior log table: Contains user ID, behavior type, behavior time, page path and other fields
- Session table: Contains session ID, user ID, session start time, session duration and other fields
Please provide easy-to-understand statistical results and visualization suggestions.""",
arguments=[
PromptArgument(
name="user_segment",
description="User segment conditions, such as 'new users', 'active users'",
required=True,
),
PromptArgument(
name="behavior_type",
description="Behavior type filter, such as 'login', 'purchase', 'browse'",
required=False,
),
PromptArgument(
name="time_period",
description="Analysis time period, such as 'last 7 days', 'this month'",
required=False,
),
],
category="user_analysis",
)
# Performance optimization analysis template
templates["performance_optimization"] = PromptTemplate(
name="performance_optimization",
description="Database performance optimization analysis template for identifying performance bottlenecks and optimization opportunities",
template="""Please help me with database performance analysis and optimization recommendations:
Focus area: {focus_area}
{table_scope}
Performance metrics: {metrics}
Please generate SQL queries to analyze the following content:
1. Table and query performance statistics
2. Index usage efficiency analysis
3. Slow query identification and analysis
4. Storage space usage
Analysis objectives:
- Identify performance bottlenecks
- Provide optimization recommendations
- Evaluate optimization effects
Please provide specific optimization recommendations and implementation steps.""",
arguments=[
PromptArgument(
name="focus_area",
description="Performance area of focus, such as 'query performance', 'storage optimization'",
required=True,
),
PromptArgument(
name="table_name",
description="Specific table name (optional), if analyzing specific table performance",
required=False,
),
PromptArgument(
name="metrics",
description="Performance metrics of interest, such as 'response time', 'throughput'",
required=False,
),
],
category="performance",
)
# Data quality check template
templates["data_quality_check"] = PromptTemplate(
name="data_quality_check",
description="Data quality check template for detecting data integrity and consistency issues",
template="""Please help me perform data quality checks:
Check target: {target_table}
{quality_dimensions}
Check level: {check_level}
Please generate SQL queries to check the following data quality issues:
1. Data integrity (null values, duplicate values)
2. Data consistency (format, range)
3. Data accuracy (business rule validation)
4. Data timeliness (update frequency)
Check items:
- Required field null value checks
- Primary key and unique constraint validation
- Data format and type checks
- Business logic consistency validation
- Data distribution anomaly detection
Please provide detailed problem reports and fix recommendations.""",
arguments=[
PromptArgument(
name="target_table", description="Target table name to check", required=True
),
PromptArgument(
name="quality_dimensions",
description="Quality check dimensions, such as 'integrity', 'consistency', 'accuracy'",
required=False,
),
PromptArgument(
name="check_level",
description="Check level, such as 'basic check', 'deep check'",
required=False,
),
],
category="data_quality",
)
# Report generation template
templates["report_generation"] = PromptTemplate(
name="report_generation",
description="Business report generation template for creating standardized business reports",
template="""Please help me generate business reports:
Report type: {report_type}
Report period: {report_period}
{business_scope}
Please generate SQL queries to build the following report content:
1. Key business indicator summary
2. Trend analysis and year-over-year/month-over-month comparison
3. Anomaly data identification and explanation
4. Business insights and recommendations
Report requirements:
- Data accuracy and timeliness
- Clear hierarchical structure
- Easy-to-understand data presentation
- Decision-supporting analytical perspective
Please provide complete report structure and data acquisition logic.""",
arguments=[
PromptArgument(
name="report_type",
description="Report type, such as 'sales report', 'operations report', 'financial report'",
required=True,
),
PromptArgument(
name="report_period",
description="Report period, such as 'daily report', 'weekly report', 'monthly report'",
required=True,
),
PromptArgument(
name="business_unit",
description="Business unit scope, such as 'East China region', 'Product line A'",
required=False,
),
],
category="reporting",
)
# Real-time monitoring template
templates["real_time_monitoring"] = PromptTemplate(
name="real_time_monitoring",
description="Real-time monitoring query template for building real-time data monitoring and alerting",
template="""Please help me design real-time monitoring queries:
Monitoring target: {monitoring_target}
Alert threshold: {alert_threshold}
Monitoring frequency: {monitoring_frequency}
Please generate SQL queries to implement the following monitoring functions:
1. Real-time statistics of key indicators
2. Anomaly detection and alerting
3. Trend change monitoring
4. System health status checks
Monitoring dimensions:
- Business indicator monitoring (transaction volume, user activity, etc.)
- Technical indicator monitoring (performance, error rate, etc.)
- Data quality monitoring (integrity, consistency, etc.)
Please provide complete monitoring solution and implementation recommendations.""",
arguments=[
PromptArgument(
name="monitoring_target",
description="Monitoring target, such as 'transaction system', 'user activity'",
required=True,
),
PromptArgument(
name="alert_threshold",
description="Alert threshold setting, such as 'error rate > 5%'",
required=False,
),
PromptArgument(
name="monitoring_frequency",
description="Monitoring frequency, such as 'real-time', 'every minute', 'every 5 minutes'",
required=False,
),
],
category="monitoring",
)
return templates
async def list_prompts(self) -> list[Prompt]:
"""List all available prompt templates"""
prompts = []
for template in self.templates.values():
prompt = Prompt(
name=template.name,
description=template.description,
arguments=template.arguments,
)
prompts.append(prompt)
return prompts
async def get_prompt(self, name: str, arguments: dict[str, Any]) -> GetPromptResult:
"""Get content of specific prompt template"""
if name not in self.templates:
raise ValueError(f"Prompt template named '{name}' not found")
template = self.templates[name]
# Process optional arguments
processed_args = await self._process_arguments(template, arguments)
# Render template content
rendered_content = template.render(processed_args)
# Add database context information
context_info = await self._get_database_context()
full_content = f"""{rendered_content}
Database context information:
{context_info}
Please generate accurate and efficient SQL queries based on the above requirements and database structure."""
return GetPromptResult(
description=template.description,
messages=[
PromptMessage(
role="user", content=TextContent(type="text", text=full_content)
)
],
)
async def _process_arguments(
self, template: PromptTemplate, arguments: dict[str, Any]
) -> dict[str, Any]:
"""Process template arguments"""
processed = {}
for arg in template.arguments:
if arg.name in arguments:
processed[arg.name] = arguments[arg.name]
elif arg.required:
raise ValueError(f"Missing required parameter: {arg.name}")
else:
# Provide default handling for optional parameters
processed[arg.name] = self._get_default_argument_text(arg.name)
return processed
def _get_default_argument_text(self, arg_name: str) -> str:
"""Get default text for optional parameters"""
defaults = {
"product_category": "",
"region": "",
"behavior_type": "",
"time_period": "No time range restriction",
"table_name": "",
"metrics": "All performance metrics",
"quality_dimensions": "All quality dimensions",
"check_level": "Standard check",
"business_unit": "Full business scope",
"alert_threshold": "Use default threshold",
"monitoring_frequency": "Real-time monitoring",
}
return defaults.get(arg_name, "")
async def _get_database_context(self) -> str:
"""Get database context information"""
try:
connection = await self.connection_manager.get_connection("system")
# Get basic database information
db_info_sql = """
SELECT
COUNT(*) as table_count,
SUM(table_rows) as total_rows
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_type = 'BASE TABLE'
"""
db_result = await connection.execute(db_info_sql)
db_info = db_result.data[0] if db_result.data else {}
# Get main table list
tables_sql = """
SELECT
table_name,
table_comment,
table_rows
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_type = 'BASE TABLE'
ORDER BY table_rows DESC
LIMIT 10
"""
tables_result = await connection.execute(tables_sql)
context = f"""Current database statistics:
- Total number of tables: {db_info.get("table_count", 0)}
- Total data rows: {db_info.get("total_rows", 0):,}
Main data tables:"""
for table in tables_result.data:
context += f"\n- {table['table_name']}"
if table.get("table_comment"):
context += f": {table['table_comment']}"
context += f" ({table.get('table_rows', 0):,} rows)"
return context
except Exception as e:
return f"Unable to get database context information: {str(e)}"
def get_templates_by_category(self, category: str) -> list[PromptTemplate]:
"""Get templates by category"""
return [
template
for template in self.templates.values()
if template.category == category
]
def get_all_categories(self) -> list[str]:
"""Get all template categories"""
categories = {template.category for template in self.templates.values()}
return sorted(categories)

View File

@@ -0,0 +1,361 @@
"""
Apache Doris MCP Resources Manager
Provides standardized abstraction and access interface for database metadata
"""
import json
from datetime import datetime
from typing import Any
from mcp.types import Resource
from ..utils.db import DorisConnectionManager
class TableMetadata:
"""Data table metadata"""
def __init__(
self,
name: str,
comment: str = None,
row_count: int = 0,
columns: list[dict] = None,
create_time: datetime = None,
):
self.name = name
self.comment = comment
self.row_count = row_count
self.columns = columns or []
self.create_time = create_time
class ViewMetadata:
"""Data view metadata"""
def __init__(self, name: str, comment: str = None, definition: str = None):
self.name = name
self.comment = comment
self.definition = definition
class MetadataCache:
"""Metadata cache manager"""
def __init__(self, ttl_seconds: int = 300):
self.cache = {}
self.ttl = ttl_seconds
async def get(self, key: str) -> Any | None:
if key in self.cache:
data, timestamp = self.cache[key]
if datetime.now().timestamp() - timestamp < self.ttl:
return data
else:
del self.cache[key]
return None
async def set(self, key: str, value: Any):
self.cache[key] = (value, datetime.now().timestamp())
class DorisResourcesManager:
"""Apache Doris Resources Manager"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
self.metadata_cache = MetadataCache()
async def list_resources(self) -> list[Resource]:
"""List all available database resources"""
resources = []
try:
# Get metadata for all tables
tables = await self._get_table_metadata()
for table in tables:
resources.append(
Resource(
uri=f"doris://table/{table.name}",
name=f"Data Table: {table.name}",
description=f"{table.comment or 'Data table'} (rows: {table.row_count:,})",
mimeType="application/json",
)
)
# Get metadata for all views
views = await self._get_view_metadata()
for view in views:
resources.append(
Resource(
uri=f"doris://view/{view.name}",
name=f"Data View: {view.name}",
description=f"{view.comment or 'Data view'}",
mimeType="application/json",
)
)
# Add database statistics resource
resources.append(
Resource(
uri="doris://stats/database",
name="Database Statistics",
description="Overall database statistics and performance metrics",
mimeType="application/json",
)
)
except Exception as e:
print(f"Failed to get resource list: {e}")
return resources
async def read_resource(self, uri: str) -> str:
"""Read detailed information of specific resource"""
try:
resource_type, resource_name = self._parse_resource_uri(uri)
if resource_type == "table":
return await self._get_table_schema(resource_name)
elif resource_type == "view":
return await self._get_view_definition(resource_name)
elif resource_type == "stats" and resource_name == "database":
return await self._get_database_stats()
else:
raise ValueError(f"Unsupported resource type: {resource_type}")
except Exception as e:
return json.dumps(
{"error": f"Failed to read resource: {str(e)}", "uri": uri},
ensure_ascii=False,
indent=2,
)
async def _get_table_metadata(self) -> list[TableMetadata]:
"""Get metadata for all tables"""
cache_key = "table_metadata"
cached = await self.metadata_cache.get(cache_key)
if cached:
return cached
connection = await self.connection_manager.get_connection("system")
# Query basic table information
tables_query = """
SELECT
table_name,
table_comment,
table_rows as row_count,
create_time
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_type = 'BASE TABLE'
ORDER BY table_name
"""
result = await connection.execute(tables_query)
tables = []
for row in result.data:
# Get column information for the table
columns = await self._get_table_columns(connection, row["table_name"])
table = TableMetadata(
name=row["table_name"],
comment=row.get("table_comment"),
row_count=row.get("row_count", 0),
columns=columns,
create_time=row.get("create_time"),
)
tables.append(table)
await self.metadata_cache.set(cache_key, tables)
return tables
async def _get_table_columns(self, connection, table_name: str) -> list[dict]:
"""Get column information for table"""
columns_query = """
SELECT
column_name,
data_type,
is_nullable,
column_default,
column_comment,
column_key
FROM information_schema.columns
WHERE table_schema = DATABASE()
AND table_name = %s
ORDER BY ordinal_position
"""
result = await connection.execute(columns_query, (table_name,))
return [dict(row) for row in result.data]
async def _get_view_metadata(self) -> list[ViewMetadata]:
"""Get metadata for all views"""
cache_key = "view_metadata"
cached = await self.metadata_cache.get(cache_key)
if cached:
return cached
connection = await self.connection_manager.get_connection("system")
views_query = """
SELECT
table_name,
table_comment,
view_definition
FROM information_schema.views
WHERE table_schema = DATABASE()
ORDER BY table_name
"""
result = await connection.execute(views_query)
views = []
for row in result.data:
view = ViewMetadata(
name=row["table_name"],
comment=row.get("table_comment"),
definition=row.get("view_definition"),
)
views.append(view)
await self.metadata_cache.set(cache_key, views)
return views
async def _get_table_schema(self, table_name: str) -> str:
"""Get detailed structure information of table"""
connection = await self.connection_manager.get_connection("system")
# Get basic table information
table_info_query = """
SELECT
table_name,
table_comment,
table_rows,
create_time,
engine
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = %s
"""
table_result = await connection.execute(table_info_query, (table_name,))
if not table_result.data:
raise ValueError(f"Table {table_name} does not exist")
table_info = table_result.data[0]
# Get column information
columns = await self._get_table_columns(connection, table_name)
# Get index information
indexes = await self._get_table_indexes(connection, table_name)
schema_info = {
"table_name": table_info["table_name"],
"comment": table_info.get("table_comment"),
"row_count": table_info.get("table_rows", 0),
"create_time": str(table_info.get("create_time")),
"engine": table_info.get("engine"),
"columns": columns,
"indexes": indexes,
}
return json.dumps(schema_info, ensure_ascii=False, indent=2)
async def _get_table_indexes(self, connection, table_name: str) -> list[dict]:
"""Get index information for table"""
indexes_query = """
SELECT
index_name,
column_name,
index_type,
non_unique
FROM information_schema.statistics
WHERE table_schema = DATABASE()
AND table_name = %s
ORDER BY index_name, seq_in_index
"""
result = await connection.execute(indexes_query, (table_name,))
return [dict(row) for row in result.data]
async def _get_view_definition(self, view_name: str) -> str:
"""Get definition information of view"""
connection = await self.connection_manager.get_connection("system")
view_query = """
SELECT
table_name,
table_comment,
view_definition
FROM information_schema.views
WHERE table_schema = DATABASE()
AND table_name = %s
"""
result = await connection.execute(view_query, (view_name,))
if not result.data:
raise ValueError(f"View {view_name} does not exist")
view_info = result.data[0]
schema_info = {
"view_name": view_info["table_name"],
"comment": view_info.get("table_comment"),
"definition": view_info.get("view_definition"),
}
return json.dumps(schema_info, ensure_ascii=False, indent=2)
async def _get_database_stats(self) -> str:
"""Get database statistics"""
connection = await self.connection_manager.get_connection("system")
# Get table statistics
table_stats_query = """
SELECT
COUNT(*) as table_count,
SUM(table_rows) as total_rows
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_type = 'BASE TABLE'
"""
table_result = await connection.execute(table_stats_query)
table_stats = table_result.data[0] if table_result.data else {}
# Get view statistics
view_stats_query = """
SELECT COUNT(*) as view_count
FROM information_schema.views
WHERE table_schema = DATABASE()
"""
view_result = await connection.execute(view_stats_query)
view_stats = view_result.data[0] if view_result.data else {}
stats_info = {
"database_name": "current_database",
"table_count": table_stats.get("table_count", 0),
"view_count": view_stats.get("view_count", 0),
"total_rows": table_stats.get("total_rows", 0),
"last_updated": datetime.now().isoformat(),
}
return json.dumps(stats_info, ensure_ascii=False, indent=2)
def _parse_resource_uri(self, uri: str) -> tuple:
"""Parse resource URI"""
if not uri.startswith("doris://"):
raise ValueError("Invalid resource URI format")
path = uri[8:] # Remove "doris://" prefix
parts = path.split("/")
if len(parts) < 2:
raise ValueError("Incomplete resource URI format")
return parts[0], parts[1]

View File

@@ -1,157 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Tool Initialization Module
Centralized initialization of all tools, ensuring they are correctly registered with MCP
"""
import logging
import os
from typing import List, Dict, Any, Optional
import json
from datetime import datetime
import traceback
# Import Context
from mcp.server.fastmcp import Context
# Import doris mcp tools
from doris_mcp_server.tools.mcp_doris_tools import (
mcp_doris_exec_query,
mcp_doris_get_table_schema,
mcp_doris_get_db_table_list,
mcp_doris_get_db_list,
mcp_doris_get_table_comment,
mcp_doris_get_table_column_comments,
mcp_doris_get_table_indexes,
mcp_doris_get_recent_audit_logs,
mcp_doris_get_catalog_list
)
# Get logger
logger = logging.getLogger("doris-mcp-tools-initializer")
async def register_mcp_tools(mcp):
"""Register MCP tool functions
Args:
mcp: FastMCP instance
"""
logger.info("Starting to register MCP tools...")
try:
# Register Tool: Execute SQL Query (Using long description string including parameters)
@mcp.tool("exec_query", description="""[Function Description]: Execute SQL query and return result command with catalog federation support.\n
[Parameter Content]:\n
- random_string (string) [Required] - Unique identifier for the tool call\n
- sql (string) [Required] - SQL statement to execute. MUST use three-part naming for all table references: 'catalog_name.db_name.table_name'. For internal tables use 'internal.db_name.table_name', for external tables use 'catalog_name.db_name.table_name'\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n
- catalog_name (string) [Optional] - Reference catalog name for context, defaults to current catalog\n
- max_rows (integer) [Optional] - Maximum number of rows to return, default 100
- timeout (integer) [Optional] - Query timeout in seconds, default 30""")
async def exec_query_tool(sql: str, db_name: str = None, catalog_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]:
"""Wrapper: Execute SQL query and return result command"""
# Note: ctx parameter is no longer needed here as we receive named parameters directly
return await mcp_doris_exec_query(sql=sql, db_name=db_name, catalog_name=catalog_name, max_rows=max_rows, timeout=timeout)
# Register Tool: Get Table Schema (Keep long description string including parameters)
@mcp.tool("get_table_schema", description="""[Function Description]: Get detailed structure information of the specified table (columns, types, comments, etc.).\n
[Parameter Content]:\n
- random_string (string) [Required] - Unique identifier for the tool call\n
- table_name (string) [Required] - Name of the table to query\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
async def get_table_schema_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get table schema"""
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
return await mcp_doris_get_table_schema(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
# Register Tool: Get Database Table List (Keep long description string including parameters)
@mcp.tool("get_db_table_list", description="""[Function Description]: Get a list of all table names in the specified database.\n
[Parameter Content]:\n
- random_string (string) [Required] - Unique identifier for the tool call\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
async def get_db_table_list_tool(db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get database table list"""
return await mcp_doris_get_db_table_list(db_name=db_name, catalog_name=catalog_name)
# Register Tool: Get Database List (Keep long description string including parameters)
# Note: Although the description mentions random_string, the wrapper function signature does not. See how mcp handles this.
@mcp.tool("get_db_list", description="""[Function Description]: Get a list of all database names on the server.\n
[Parameter Content]:\n
- random_string (string) [Required] - Unique identifier for the tool call\n
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
async def get_db_list_tool(catalog_name: str = None) -> Dict[str, Any]: # Function signature has no parameters
"""Wrapper: Get database list"""
return await mcp_doris_get_db_list(catalog_name=catalog_name)
# Register Tool: Get Table Comment (Keep long description string including parameters)
@mcp.tool("get_table_comment", description="""[Function Description]: Get the comment information for the specified table.\n
[Parameter Content]:\n
- random_string (string) [Required] - Unique identifier for the tool call\n
- table_name (string) [Required] - Name of the table to query\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
async def get_table_comment_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get table comment"""
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
return await mcp_doris_get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
# Register Tool: Get Table Column Comments (Keep long description string including parameters)
@mcp.tool("get_table_column_comments", description="""[Function Description]: Get comment information for all columns in the specified table.\n
[Parameter Content]:\n
- random_string (string) [Required] - Unique identifier for the tool call\n
- table_name (string) [Required] - Name of the table to query\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
async def get_table_column_comments_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get table column comments"""
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
return await mcp_doris_get_table_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
# Register Tool: Get Table Indexes (Keep long description string including parameters)
@mcp.tool("get_table_indexes", description="""[Function Description]: Get index information for the specified table.\n
[Parameter Content]:\n
- random_string (string) [Required] - Unique identifier for the tool call\n
- table_name (string) [Required] - Name of the table to query\n
- db_name (string) [Optional] - Target database name, defaults to the current database\n
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
async def get_table_indexes_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
"""Wrapper: Get table indexes"""
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
return await mcp_doris_get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
# Register Tool: Get Recent Audit Logs (Keep long description string including parameters)
@mcp.tool("get_recent_audit_logs", description="""[Function Description]: Get audit log records for a recent period.\n
[Parameter Content]:\n
- random_string (string) [Required] - Unique identifier for the tool call\n
- days (integer) [Optional] - Number of recent days of logs to retrieve, default is 7\n
- limit (integer) [Optional] - Maximum number of records to return, default is 100\n""")
async def get_recent_audit_logs_tool(days: int = 7, limit: int = 100) -> Dict[str, Any]:
"""Wrapper: Get recent audit logs"""
try:
days = int(days)
limit = int(limit)
except (ValueError, TypeError):
return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "days and limit parameters must be integers"})}]}
return await mcp_doris_get_recent_audit_logs(days=days, limit=limit)
# Register Tool: Get Catalog List (Keep long description string including parameters)
@mcp.tool("get_catalog_list", description="""[Function Description]: Get a list of all catalog names on the server.\n
[Parameter Content]:\n
- random_string (string) [Required] - Unique identifier for the tool call\n""")
async def get_catalog_list_tool() -> Dict[str, Any]:
"""Wrapper: Get catalog list"""
return await mcp_doris_get_catalog_list()
# Get tool count
tools_count = len(await mcp.list_tools())
logger.info(f"Registered all MCP tools, total {tools_count} tools")
return True
except Exception as e:
logger.error(f"Error registering MCP tools: {str(e)}")
logger.error(traceback.format_exc())
return False

View File

@@ -0,0 +1,766 @@
"""
Apache Doris MCP Tools Manager
Responsible for tool registration, management, scheduling and routing, does not contain specific business logic implementation
"""
import json
import time
from datetime import datetime
from typing import Any, Dict, List
from mcp.types import Tool
from ..utils.db import DorisConnectionManager
from ..utils.query_executor import DorisQueryExecutor
from ..utils.analysis_tools import TableAnalyzer, PerformanceMonitor
from ..utils.schema_extractor import MetadataExtractor
from ..utils.logger import get_logger
logger = get_logger(__name__)
class DorisToolsManager:
"""Apache Doris Tools Manager"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
# Initialize business logic processors
self.query_executor = DorisQueryExecutor(connection_manager)
self.table_analyzer = TableAnalyzer(connection_manager)
self.performance_monitor = PerformanceMonitor(connection_manager)
self.metadata_extractor = MetadataExtractor(connection_manager=connection_manager)
logger.info("DorisToolsManager initialized with business logic processors")
async def register_tools_with_mcp(self, mcp):
"""Register all tools to MCP server"""
logger.info("Starting to register MCP tools")
# Column statistical analysis tool
@mcp.tool(
"column_analysis",
description="""[Function Description]: Analyze statistical information and data distribution of the specified column.
[Parameter Content]:
- table_name (string) [Required] - Name of the table to analyze
- column_name (string) [Required] - Name of the column to analyze
- analysis_type (string) [Optional] - Type of analysis to perform, default is "basic"
* "basic": Basic statistics (count, null values, distinct values)
* "distribution": Data distribution analysis (frequency, percentiles)
* "detailed": Comprehensive analysis including all above plus patterns and outliers
""",
inputSchema={
"type": "object",
"properties": {
"table_name": {"type": "string", "description": "Table name"},
"column_name": {
"type": "string",
"description": "Column name to analyze",
},
"analysis_type": {
"type": "string",
"enum": ["basic", "distribution", "detailed"],
"description": "Analysis type",
"default": "basic",
},
},
"required": ["table_name", "column_name"],
}
)
async def column_analysis_tool(
table_name: str,
column_name: str,
analysis_type: str = "basic"
) -> str:
"""Column statistical analysis tool"""
return await self.call_tool("column_analysis", {
"table_name": table_name,
"column_name": column_name,
"analysis_type": analysis_type
})
# Database performance monitoring tool
@mcp.tool(
"performance_stats[Experimental]",
description="""[Important]: This tool is experimental and may not be fully functional!
[Function Description]: Get database performance statistics information.
[Parameter Content]:
- metric_type (string) [Optional] - Type of performance metrics to retrieve, default is "queries"
* "queries": Query performance metrics (execution time, frequency, etc.)
* "connections": Connection statistics (active connections, connection pool status)
* "tables": Table-level statistics (size, row count, access patterns)
* "system": System-level metrics (CPU, memory, disk usage)
- time_range (string) [Optional] - Time range for statistics, default is "1h"
* "1h": Last 1 hour
* "6h": Last 6 hours
* "24h": Last 24 hours
* "7d": Last 7 days
""",
inputSchema={
"type": "object",
"properties": {
"metric_type": {
"type": "string",
"enum": ["queries", "connections", "tables", "system"],
"description": "Performance metric type",
"default": "queries",
},
"time_range": {
"type": "string",
"enum": ["1h", "6h", "24h", "7d"],
"description": "Time range",
"default": "1h",
},
},
}
)
async def performance_stats_tool(
metric_type: str = "queries",
time_range: str = "1h"
) -> str:
"""Database performance monitoring tool"""
return await self.call_tool("performance_stats", {
"metric_type": metric_type,
"time_range": time_range
})
# SQL query execution tool (supports catalog federation queries)
@mcp.tool(
"exec_query",
description="""[Function Description]: Execute SQL query and return result command with catalog federation support.
[Parameter Content]:
- sql (string) [Required] - SQL statement to execute. MUST use three-part naming for all table references: 'catalog_name.db_name.table_name'. For internal tables use 'internal.db_name.table_name', for external tables use 'catalog_name.db_name.table_name'
- db_name (string) [Optional] - Target database name, defaults to the current database
- catalog_name (string) [Optional] - Reference catalog name for context, defaults to current catalog
- max_rows (integer) [Optional] - Maximum number of rows to return, default 100
- timeout (integer) [Optional] - Query timeout in seconds, default 30
""",
)
async def exec_query_tool(
sql: str,
db_name: str = None,
catalog_name: str = None,
max_rows: int = 100,
timeout: int = 30,
) -> str:
"""Execute SQL query (supports federation queries)"""
return await self.call_tool("exec_query", {
"sql": sql,
"db_name": db_name,
"catalog_name": catalog_name,
"max_rows": max_rows,
"timeout": timeout
})
# Get table schema tool
@mcp.tool(
"get_table_schema",
description="""[Function Description]: Get detailed structure information of the specified table (columns, types, comments, etc.).
[Parameter Content]:
- table_name (string) [Required] - Name of the table to query
- db_name (string) [Optional] - Target database name, defaults to the current database
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
""",
)
async def get_table_schema_tool(
table_name: str, db_name: str = None, catalog_name: str = None
) -> str:
"""Get table schema information"""
return await self.call_tool("get_table_schema", {
"table_name": table_name,
"db_name": db_name,
"catalog_name": catalog_name
})
# Get database table list tool
@mcp.tool(
"get_db_table_list",
description="""[Function Description]: Get a list of all table names in the specified database.
[Parameter Content]:
- db_name (string) [Optional] - Target database name, defaults to the current database
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
""",
)
async def get_db_table_list_tool(
db_name: str = None, catalog_name: str = None
) -> str:
"""Get database table list"""
return await self.call_tool("get_db_table_list", {
"db_name": db_name,
"catalog_name": catalog_name
})
# Get database list tool
@mcp.tool(
"get_db_list",
description="""[Function Description]: Get a list of all database names on the server.
[Parameter Content]:
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
""",
)
async def get_db_list_tool(catalog_name: str = None) -> str:
"""Get database list"""
return await self.call_tool("get_db_list", {
"catalog_name": catalog_name
})
# Get table comment tool
@mcp.tool(
"get_table_comment",
description="""[Function Description]: Get the comment information for the specified table.
[Parameter Content]:
- table_name (string) [Required] - Name of the table to query
- db_name (string) [Optional] - Target database name, defaults to the current database
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
""",
)
async def get_table_comment_tool(
table_name: str, db_name: str = None, catalog_name: str = None
) -> str:
"""Get table comment"""
return await self.call_tool("get_table_comment", {
"table_name": table_name,
"db_name": db_name,
"catalog_name": catalog_name
})
# Get table column comments tool
@mcp.tool(
"get_table_column_comments",
description="""[Function Description]: Get comment information for all columns in the specified table.
[Parameter Content]:
- table_name (string) [Required] - Name of the table to query
- db_name (string) [Optional] - Target database name, defaults to the current database
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
""",
)
async def get_table_column_comments_tool(
table_name: str, db_name: str = None, catalog_name: str = None
) -> str:
"""Get table column comments"""
return await self.call_tool("get_table_column_comments", {
"table_name": table_name,
"db_name": db_name,
"catalog_name": catalog_name
})
# Get table indexes tool
@mcp.tool(
"get_table_indexes",
description="""[Function Description]: Get index information for the specified table.
[Parameter Content]:
- table_name (string) [Required] - Name of the table to query
- db_name (string) [Optional] - Target database name, defaults to the current database
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
""",
)
async def get_table_indexes_tool(
table_name: str, db_name: str = None, catalog_name: str = None
) -> str:
"""Get table indexes"""
return await self.call_tool("get_table_indexes", {
"table_name": table_name,
"db_name": db_name,
"catalog_name": catalog_name
})
# Get audit logs tool
@mcp.tool(
"get_recent_audit_logs",
description="""[Function Description]: Get audit log records for a recent period.
[Parameter Content]:
- days (integer) [Optional] - Number of recent days of logs to retrieve, default is 7
- limit (integer) [Optional] - Maximum number of records to return, default is 100
""",
)
async def get_recent_audit_logs_tool(
days: int = 7, limit: int = 100
) -> str:
"""Get audit logs"""
return await self.call_tool("get_recent_audit_logs", {
"days": days,
"limit": limit
})
# Get catalog list tool
@mcp.tool(
"get_catalog_list",
description="""[Function Description]: Get a list of all catalog names on the server.
[Parameter Content]:
- random_string (string) [Required] - Unique identifier for the tool call
""",
)
async def get_catalog_list_tool(random_string: str) -> str:
"""Get catalog list"""
return await self.call_tool("get_catalog_list", {
"random_string": random_string
})
logger.info("Successfully registered 11 tools to MCP server (2 core tools + 9 migrated tools)")
async def list_tools(self) -> List[Tool]:
"""List all available query tools (for stdio mode)"""
tools = [
Tool(
name="column_analysis[Experimental]",
description="""[Important]: This tool is experimental and may not be fully functional!
[Function Description]: Analyze statistical information and data distribution of the specified column.
[Parameter Content]:
- table_name (string) [Required] - Name of the table to analyze
- column_name (string) [Required] - Name of the column to analyze
- analysis_type (string) [Optional] - Type of analysis to perform, default is "basic"
* "basic": Basic statistics (count, null values, distinct values)
* "distribution": Data distribution analysis (frequency, percentiles)
* "detailed": Comprehensive analysis including all above plus patterns and outliers
""",
inputSchema={
"type": "object",
"properties": {
"table_name": {"type": "string", "description": "Table name"},
"column_name": {
"type": "string",
"description": "Column name to analyze",
},
"analysis_type": {
"type": "string",
"enum": ["basic", "distribution", "detailed"],
"description": "Analysis type",
"default": "basic",
},
},
"required": ["table_name", "column_name"],
},
),
Tool(
name="performance_stats",
description="""[Function Description]: Get database performance statistics information.
[Parameter Content]:
- metric_type (string) [Optional] - Type of performance metrics to retrieve, default is "queries"
* "queries": Query performance metrics (execution time, frequency, etc.)
* "connections": Connection statistics (active connections, connection pool status)
* "tables": Table-level statistics (size, row count, access patterns)
* "system": System-level metrics (CPU, memory, disk usage)
- time_range (string) [Optional] - Time range for statistics, default is "1h"
* "1h": Last 1 hour
* "6h": Last 6 hours
* "24h": Last 24 hours
* "7d": Last 7 days
""",
inputSchema={
"type": "object",
"properties": {
"metric_type": {
"type": "string",
"enum": ["queries", "connections", "tables", "system"],
"description": "Performance metric type",
"default": "queries",
},
"time_range": {
"type": "string",
"enum": ["1h", "6h", "24h", "7d"],
"description": "Time range",
"default": "1h",
},
},
},
),
Tool(
name="exec_query",
description="""[Function Description]: Execute SQL query and return result command with catalog federation support.
[Parameter Content]:
- sql (string) [Required] - SQL statement to execute. MUST use three-part naming for all table references: 'catalog_name.db_name.table_name'. For internal tables use 'internal.db_name.table_name', for external tables use 'catalog_name.db_name.table_name'
- db_name (string) [Optional] - Target database name, defaults to the current database
- catalog_name (string) [Optional] - Reference catalog name for context, defaults to current catalog
- max_rows (integer) [Optional] - Maximum number of rows to return, default 100
- timeout (integer) [Optional] - Query timeout in seconds, default 30
""",
inputSchema={
"type": "object",
"properties": {
"sql": {"type": "string", "description": "SQL statement to execute, must use three-part naming"},
"db_name": {"type": "string", "description": "Target database name"},
"catalog_name": {"type": "string", "description": "Catalog name"},
"max_rows": {"type": "integer", "description": "Maximum number of rows to return", "default": 100},
"timeout": {"type": "integer", "description": "Timeout in seconds", "default": 30},
},
"required": ["sql"],
},
),
Tool(
name="get_table_schema",
description="""[Function Description]: Get detailed structure information of the specified table (columns, types, comments, etc.).
[Parameter Content]:
- table_name (string) [Required] - Name of the table to query
- db_name (string) [Optional] - Target database name, defaults to the current database
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
""",
inputSchema={
"type": "object",
"properties": {
"table_name": {"type": "string", "description": "Table name"},
"db_name": {"type": "string", "description": "Database name"},
"catalog_name": {"type": "string", "description": "Catalog name"},
},
"required": ["table_name"],
},
),
Tool(
name="get_db_table_list",
description="""[Function Description]: Get a list of all table names in the specified database.
[Parameter Content]:
- db_name (string) [Optional] - Target database name, defaults to the current database
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
""",
inputSchema={
"type": "object",
"properties": {
"db_name": {"type": "string", "description": "Database name"},
"catalog_name": {"type": "string", "description": "Catalog name"},
},
},
),
Tool(
name="get_db_list",
description="""[Function Description]: Get a list of all database names on the server.
[Parameter Content]:
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
""",
inputSchema={
"type": "object",
"properties": {
"catalog_name": {"type": "string", "description": "Catalog name"},
},
},
),
Tool(
name="get_table_comment",
description="""[Function Description]: Get the comment information for the specified table.
[Parameter Content]:
- table_name (string) [Required] - Name of the table to query
- db_name (string) [Optional] - Target database name, defaults to the current database
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
""",
inputSchema={
"type": "object",
"properties": {
"table_name": {"type": "string", "description": "Table name"},
"db_name": {"type": "string", "description": "Database name"},
"catalog_name": {"type": "string", "description": "Catalog name"},
},
"required": ["table_name"],
},
),
Tool(
name="get_table_column_comments",
description="""[Function Description]: Get comment information for all columns in the specified table.
[Parameter Content]:
- table_name (string) [Required] - Name of the table to query
- db_name (string) [Optional] - Target database name, defaults to the current database
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
""",
inputSchema={
"type": "object",
"properties": {
"table_name": {"type": "string", "description": "Table name"},
"db_name": {"type": "string", "description": "Database name"},
"catalog_name": {"type": "string", "description": "Catalog name"},
},
"required": ["table_name"],
},
),
Tool(
name="get_table_indexes",
description="""[Function Description]: Get index information for the specified table.
[Parameter Content]:
- table_name (string) [Required] - Name of the table to query
- db_name (string) [Optional] - Target database name, defaults to the current database
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
""",
inputSchema={
"type": "object",
"properties": {
"table_name": {"type": "string", "description": "Table name"},
"db_name": {"type": "string", "description": "Database name"},
"catalog_name": {"type": "string", "description": "Catalog name"},
},
"required": ["table_name"],
},
),
Tool(
name="get_recent_audit_logs",
description="""[Function Description]: Get audit log records for a recent period.
[Parameter Content]:
- days (integer) [Optional] - Number of recent days of logs to retrieve, default is 7
- limit (integer) [Optional] - Maximum number of records to return, default is 100
""",
inputSchema={
"type": "object",
"properties": {
"days": {"type": "integer", "description": "Number of recent days", "default": 7},
"limit": {"type": "integer", "description": "Maximum number of records", "default": 100},
},
},
),
Tool(
name="get_catalog_list",
description="""[Function Description]: Get a list of all catalog names on the server.
[Parameter Content]:
- random_string (string) [Required] - Unique identifier for the tool call
""",
inputSchema={
"type": "object",
"properties": {
"random_string": {"type": "string", "description": "Unique identifier"},
},
"required": ["random_string"],
},
),
]
return tools
async def call_tool(self, name: str, arguments: Dict[str, Any]) -> str:
"""
Call the specified query tool (tool routing and scheduling center)
"""
try:
start_time = time.time()
# Tool routing - dispatch requests to corresponding business logic processors
if name == "column_analysis":
result = await self._column_analysis_tool(arguments)
elif name == "performance_stats":
result = await self._performance_stats_tool(arguments)
# ===== 9 tool routes migrated from source project =====
elif name == "exec_query":
result = await self._exec_query_tool(arguments)
elif name == "get_table_schema":
result = await self._get_table_schema_tool(arguments)
elif name == "get_db_table_list":
result = await self._get_db_table_list_tool(arguments)
elif name == "get_db_list":
result = await self._get_db_list_tool(arguments)
elif name == "get_table_comment":
result = await self._get_table_comment_tool(arguments)
elif name == "get_table_column_comments":
result = await self._get_table_column_comments_tool(arguments)
elif name == "get_table_indexes":
result = await self._get_table_indexes_tool(arguments)
elif name == "get_recent_audit_logs":
result = await self._get_recent_audit_logs_tool(arguments)
elif name == "get_catalog_list":
result = await self._get_catalog_list_tool(arguments)
else:
raise ValueError(f"Unknown tool: {name}")
execution_time = time.time() - start_time
# Add execution information
if isinstance(result, dict):
result["_execution_info"] = {
"tool_name": name,
"execution_time": round(execution_time, 3),
"timestamp": datetime.now().isoformat(),
}
return json.dumps(result, ensure_ascii=False, indent=2)
except Exception as e:
logger.error(f"Tool call failed {name}: {str(e)}")
error_result = {
"error": str(e),
"tool_name": name,
"arguments": arguments,
"timestamp": datetime.now().isoformat(),
}
return json.dumps(error_result, ensure_ascii=False, indent=2)
# The following are tool routing methods, responsible for calling corresponding business logic processors
async def _column_analysis_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Column statistical analysis tool routing"""
table_name = arguments.get("table_name")
column_name = arguments.get("column_name")
analysis_type = arguments.get("analysis_type", "basic")
# Delegate to table analyzer for processing
return await self.table_analyzer.analyze_column(
table_name, column_name, analysis_type
)
async def _performance_stats_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Database performance statistics tool routing"""
metric_type = arguments.get("metric_type", "queries")
time_range = arguments.get("time_range", "1h")
# Delegate to performance monitor for processing
return await self.performance_monitor.get_performance_stats(
metric_type, time_range
)
async def _exec_query_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""SQL query execution tool routing (supports federation queries)"""
sql = arguments.get("sql")
db_name = arguments.get("db_name")
catalog_name = arguments.get("catalog_name")
max_rows = arguments.get("max_rows", 100)
timeout = arguments.get("timeout", 30)
# Delegate to metadata extractor for processing
return await self.metadata_extractor.exec_query_for_mcp(
sql, db_name, catalog_name, max_rows, timeout
)
async def _get_table_schema_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Get table schema tool routing"""
table_name = arguments.get("table_name")
db_name = arguments.get("db_name")
catalog_name = arguments.get("catalog_name")
# Delegate to metadata extractor for processing
return await self.metadata_extractor.get_table_schema_for_mcp(
table_name, db_name, catalog_name
)
async def _get_db_table_list_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Get database table list tool routing"""
db_name = arguments.get("db_name")
catalog_name = arguments.get("catalog_name")
# Delegate to metadata extractor for processing
return await self.metadata_extractor.get_db_table_list_for_mcp(db_name, catalog_name)
async def _get_db_list_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Get database list tool routing"""
catalog_name = arguments.get("catalog_name")
# Delegate to metadata extractor for processing
return await self.metadata_extractor.get_db_list_for_mcp(catalog_name)
async def _get_table_comment_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Get table comment tool routing"""
table_name = arguments.get("table_name")
db_name = arguments.get("db_name")
catalog_name = arguments.get("catalog_name")
# Delegate to metadata extractor for processing
return await self.metadata_extractor.get_table_comment_for_mcp(
table_name, db_name, catalog_name
)
async def _get_table_column_comments_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Get table column comments tool routing"""
table_name = arguments.get("table_name")
db_name = arguments.get("db_name")
catalog_name = arguments.get("catalog_name")
# Delegate to metadata extractor for processing
return await self.metadata_extractor.get_table_column_comments_for_mcp(
table_name, db_name, catalog_name
)
async def _get_table_indexes_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Get table indexes tool routing"""
table_name = arguments.get("table_name")
db_name = arguments.get("db_name")
catalog_name = arguments.get("catalog_name")
# Delegate to metadata extractor for processing
return await self.metadata_extractor.get_table_indexes_for_mcp(
table_name, db_name, catalog_name
)
async def _get_recent_audit_logs_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Get audit logs tool routing"""
days = arguments.get("days", 7)
limit = arguments.get("limit", 100)
# Delegate to metadata extractor for processing
return await self.metadata_extractor.get_recent_audit_logs_for_mcp(days, limit)
async def _get_catalog_list_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Get catalog list tool routing"""
# random_string parameter is required in the source project, but not actually used in business logic
# Here we ignore it and directly call business logic
# Delegate to metadata extractor for processing
return await self.metadata_extractor.get_catalog_list_for_mcp()

View File

@@ -1 +1,10 @@
# Mark directory as a package
"""
Utilities Package - Contains utility classes and helper functions.
This package includes:
- Database connection and operations
- Configuration management
- Security utilities
- Query execution helpers
- Logging configuration
"""

View File

@@ -0,0 +1,318 @@
"""
Data Analysis Tools Module
Provides data analysis functions including table analysis, column statistics, performance monitoring, etc.
"""
import time
from datetime import datetime
from typing import Any, Dict, List
from .db import DorisConnectionManager
from .logger import get_logger
logger = get_logger(__name__)
class TableAnalyzer:
"""Table analyzer"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
async def get_table_summary(
self,
table_name: str,
include_sample: bool = True,
sample_size: int = 10
) -> Dict[str, Any]:
"""Get table summary information"""
connection = await self.connection_manager.get_connection("query")
# Get table basic information
table_info_sql = f"""
SELECT
table_name,
table_comment,
table_rows,
create_time,
engine
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = '{table_name}'
"""
table_info_result = await connection.execute(table_info_sql)
if not table_info_result.data:
raise ValueError(f"Table {table_name} does not exist")
table_info = table_info_result.data[0]
# Get column information
columns_sql = f"""
SELECT
column_name,
data_type,
is_nullable,
column_comment
FROM information_schema.columns
WHERE table_schema = DATABASE()
AND table_name = '{table_name}'
ORDER BY ordinal_position
"""
columns_result = await connection.execute(columns_sql)
summary = {
"table_name": table_info["table_name"],
"comment": table_info.get("table_comment"),
"row_count": table_info.get("table_rows", 0),
"create_time": str(table_info.get("create_time")),
"engine": table_info.get("engine"),
"column_count": len(columns_result.data),
"columns": columns_result.data,
}
# Get sample data
if include_sample and sample_size > 0:
sample_sql = f"SELECT * FROM {table_name} LIMIT {sample_size}"
sample_result = await connection.execute(sample_sql)
summary["sample_data"] = sample_result.data
return summary
async def analyze_column(
self,
table_name: str,
column_name: str,
analysis_type: str = "basic"
) -> Dict[str, Any]:
"""Analyze column statistics"""
try:
connection = await self.connection_manager.get_connection("query")
# Basic statistics
basic_stats_sql = f"""
SELECT
'{column_name}' as column_name,
COUNT(*) as total_count,
COUNT({column_name}) as non_null_count,
COUNT(DISTINCT {column_name}) as distinct_count
FROM {table_name}
"""
basic_result = await connection.execute(basic_stats_sql)
if not basic_result.data:
return {
"success": False,
"error": f"Unable to get statistics for table {table_name} column {column_name}"
}
analysis = basic_result.data[0].copy()
analysis["success"] = True
analysis["analysis_type"] = analysis_type
if analysis_type in ["distribution", "detailed"]:
# Data distribution analysis
distribution_sql = f"""
SELECT
{column_name} as value,
COUNT(*) as frequency
FROM {table_name}
WHERE {column_name} IS NOT NULL
GROUP BY {column_name}
ORDER BY frequency DESC
LIMIT 20
"""
distribution_result = await connection.execute(distribution_sql)
analysis["value_distribution"] = distribution_result.data
if analysis_type == "detailed":
# Detailed statistics (for numeric types)
try:
numeric_stats_sql = f"""
SELECT
MIN({column_name}) as min_value,
MAX({column_name}) as max_value,
AVG({column_name}) as avg_value
FROM {table_name}
WHERE {column_name} IS NOT NULL
"""
numeric_result = await connection.execute(numeric_stats_sql)
if numeric_result.data:
analysis.update(numeric_result.data[0])
except Exception:
# Non-numeric columns don't support numeric statistics
pass
return analysis
except Exception as e:
logger.error(f"Column analysis failed: {e}")
return {
"success": False,
"error": str(e),
"column_name": column_name,
"table_name": table_name
}
async def analyze_table_relationships(
self,
table_name: str,
depth: int = 2
) -> Dict[str, Any]:
"""Analyze table relationships"""
connection = await self.connection_manager.get_connection("system")
# Get table basic information
table_info_sql = f"""
SELECT
table_name,
table_comment,
table_rows
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = '{table_name}'
"""
table_result = await connection.execute(table_info_sql)
if not table_result.data:
raise ValueError(f"Table {table_name} does not exist")
# Get all tables list (for analyzing potential relationships)
all_tables_sql = """
SELECT
table_name,
table_comment
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_type = 'BASE TABLE'
AND table_name != %s
"""
all_tables_result = await connection.execute(all_tables_sql, (table_name,))
return {
"center_table": table_result.data[0],
"related_tables": all_tables_result.data,
"depth": depth,
"note": "Table relationship analysis based on column name similarity and business logic inference",
}
class PerformanceMonitor:
"""Performance monitor"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
async def get_performance_stats(
self,
metric_type: str = "queries",
time_range: str = "1h"
) -> Dict[str, Any]:
"""Get performance statistics"""
connection = await self.connection_manager.get_connection("system")
# Convert time range to seconds
time_mapping = {
"1h": 3600,
"6h": 21600,
"24h": 86400,
"7d": 604800
}
seconds = time_mapping.get(time_range, 3600)
if metric_type == "queries":
# Query performance metrics
stats = {
"metric_type": "queries",
"time_range": time_range,
"timestamp": datetime.now().isoformat(),
"total_queries": 0,
"avg_execution_time": 0.0,
"slow_queries": 0,
"error_queries": 0,
"note": "Query performance statistics (simulated data)"
}
elif metric_type == "connections":
# Connection statistics
connection_metrics = await self.connection_manager.get_metrics()
stats = {
"metric_type": "connections",
"time_range": time_range,
"timestamp": datetime.now().isoformat(),
"total_connections": connection_metrics.total_connections,
"active_connections": connection_metrics.active_connections,
"idle_connections": connection_metrics.idle_connections,
"failed_connections": connection_metrics.failed_connections,
"connection_errors": connection_metrics.connection_errors,
"avg_connection_time": connection_metrics.avg_connection_time,
"last_health_check": connection_metrics.last_health_check.isoformat() if connection_metrics.last_health_check else None
}
elif metric_type == "tables":
# Table-level statistics
tables_sql = """
SELECT
table_name,
table_rows,
data_length,
index_length,
create_time,
update_time
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_type = 'BASE TABLE'
ORDER BY table_rows DESC
LIMIT 20
"""
tables_result = await connection.execute(tables_sql)
stats = {
"metric_type": "tables",
"time_range": time_range,
"timestamp": datetime.now().isoformat(),
"table_count": len(tables_result.data),
"tables": tables_result.data
}
elif metric_type == "system":
# System-level metrics (simulated)
stats = {
"metric_type": "system",
"time_range": time_range,
"timestamp": datetime.now().isoformat(),
"cpu_usage": 45.2,
"memory_usage": 68.5,
"disk_usage": 72.1,
"network_io": {
"bytes_sent": 1024000,
"bytes_received": 2048000
},
"note": "System metrics (simulated data)"
}
else:
raise ValueError(f"Unsupported metric type: {metric_type}")
return stats
async def get_query_history(
self,
limit: int = 50,
order_by: str = "time"
) -> Dict[str, Any]:
"""Get query history"""
# Since Doris doesn't have a built-in query history table,
# we return simulated data
return {
"total_queries": 0,
"queries": [],
"limit": limit,
"order_by": order_by,
"note": "Query history feature requires audit log configuration"
}

View File

@@ -0,0 +1,608 @@
#!/usr/bin/env python3
"""
Doris Configuration Management Module
Implements configuration loading, validation and management functionality
"""
import json
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
try:
from dotenv import load_dotenv
except ImportError:
load_dotenv = None
@dataclass
class DatabaseConfig:
"""Database connection configuration"""
host: str = "localhost"
port: int = 9030
user: str = "root"
password: str = ""
database: str = "test"
charset: str = "utf8mb4"
# Connection pool configuration
min_connections: int = 5
max_connections: int = 20
connection_timeout: int = 30
health_check_interval: int = 60
max_connection_age: int = 3600
@dataclass
class SecurityConfig:
"""Security configuration"""
# Authentication configuration
auth_type: str = "token" # token, basic, oauth
token_secret: str = "default_secret"
token_expiry: int = 3600
# SQL security configuration
blocked_keywords: list[str] = field(
default_factory=lambda: [
"DROP",
"DELETE",
"TRUNCATE",
"ALTER",
"CREATE",
"INSERT",
"UPDATE",
"GRANT",
"REVOKE",
]
)
max_query_complexity: int = 100
max_result_rows: int = 10000
# Sensitive table configuration
sensitive_tables: dict[str, str] = field(default_factory=dict)
# Data masking configuration
enable_masking: bool = True
masking_rules: list[dict[str, Any]] = field(default_factory=list)
@dataclass
class PerformanceConfig:
"""Performance configuration"""
# Query cache configuration
enable_query_cache: bool = True
cache_ttl: int = 300
max_cache_size: int = 1000
# Concurrency control configuration
max_concurrent_queries: int = 50
query_timeout: int = 300
# Connection pool optimization configuration
connection_pool_size: int = 20
idle_timeout: int = 1800
@dataclass
class LoggingConfig:
"""Logging configuration"""
level: str = "INFO"
format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
file_path: str | None = None
max_file_size: int = 10 * 1024 * 1024 # 10MB
backup_count: int = 5
# Audit log configuration
enable_audit: bool = True
audit_file_path: str | None = None
@dataclass
class MonitoringConfig:
"""Monitoring configuration"""
# Metrics collection configuration
enable_metrics: bool = True
metrics_port: int = 8081
metrics_path: str = "/metrics"
# Health check configuration
health_check_port: int = 8082
health_check_path: str = "/health"
# Alert configuration
enable_alerts: bool = False
alert_webhook_url: str | None = None
@dataclass
class DorisConfig:
"""Doris MCP Server complete configuration"""
# Basic configuration
server_name: str = "doris-mcp-server"
server_version: str = "1.0.0"
server_port: int = 8080
# Sub-configuration modules
database: DatabaseConfig = field(default_factory=DatabaseConfig)
security: SecurityConfig = field(default_factory=SecurityConfig)
performance: PerformanceConfig = field(default_factory=PerformanceConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
monitoring: MonitoringConfig = field(default_factory=MonitoringConfig)
# Custom configuration
custom_config: dict[str, Any] = field(default_factory=dict)
@classmethod
def from_file(cls, config_path: str) -> "DorisConfig":
"""Load configuration from file"""
config_file = Path(config_path)
if not config_file.exists():
raise FileNotFoundError(f"Configuration file does not exist: {config_path}")
try:
with open(config_file, encoding="utf-8") as f:
if config_file.suffix.lower() == ".json":
config_data = json.load(f)
else:
# Support other formats (like YAML)
raise ValueError(f"Unsupported configuration file format: {config_file.suffix}")
return cls._from_dict(config_data)
except Exception as e:
raise ValueError(f"Failed to load configuration file: {e}")
@classmethod
def from_env(cls, env_file: str | None = None) -> "DorisConfig":
"""Load configuration from environment variables
Args:
env_file: .env file path, if None, search in the following order:
.env, .env.local, .env.production, .env.development
"""
# Load .env file
if load_dotenv is not None:
if env_file:
# Load specified .env file
if Path(env_file).exists():
load_dotenv(env_file)
logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_file}")
else:
logging.getLogger(__name__).warning(f"Environment configuration file does not exist: {env_file}")
else:
# Load .env files in priority order
env_files = [".env", ".env.local", ".env.production", ".env.development"]
for env_path in env_files:
if Path(env_path).exists():
load_dotenv(env_path)
logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_path}")
break
else:
logging.getLogger(__name__).info("No .env configuration file found, using system environment variables")
else:
logging.getLogger(__name__).warning("python-dotenv not installed, cannot load .env files")
config = cls()
# Database configuration
config.database.host = os.getenv("DORIS_HOST", config.database.host)
config.database.port = int(os.getenv("DORIS_PORT", str(config.database.port)))
config.database.user = os.getenv("DORIS_USER", config.database.user)
config.database.password = os.getenv("DORIS_PASSWORD", config.database.password)
config.database.database = os.getenv("DORIS_DATABASE", config.database.database)
# Connection pool configuration
config.database.min_connections = int(
os.getenv("DORIS_MIN_CONNECTIONS", str(config.database.min_connections))
)
config.database.max_connections = int(
os.getenv("DORIS_MAX_CONNECTIONS", str(config.database.max_connections))
)
config.database.connection_timeout = int(
os.getenv("DORIS_CONNECTION_TIMEOUT", str(config.database.connection_timeout))
)
config.database.health_check_interval = int(
os.getenv("DORIS_HEALTH_CHECK_INTERVAL", str(config.database.health_check_interval))
)
config.database.max_connection_age = int(
os.getenv("DORIS_MAX_CONNECTION_AGE", str(config.database.max_connection_age))
)
# Security configuration
config.security.auth_type = os.getenv("AUTH_TYPE", config.security.auth_type)
config.security.token_secret = os.getenv("TOKEN_SECRET", config.security.token_secret)
config.security.token_expiry = int(
os.getenv("TOKEN_EXPIRY", str(config.security.token_expiry))
)
config.security.max_result_rows = int(
os.getenv("MAX_RESULT_ROWS", str(config.security.max_result_rows))
)
config.security.max_query_complexity = int(
os.getenv("MAX_QUERY_COMPLEXITY", str(config.security.max_query_complexity))
)
config.security.enable_masking = (
os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true"
)
# Performance configuration
config.performance.enable_query_cache = (
os.getenv("ENABLE_QUERY_CACHE", "true").lower() == "true"
)
config.performance.cache_ttl = int(
os.getenv("CACHE_TTL", str(config.performance.cache_ttl))
)
config.performance.max_cache_size = int(
os.getenv("MAX_CACHE_SIZE", str(config.performance.max_cache_size))
)
config.performance.max_concurrent_queries = int(
os.getenv("MAX_CONCURRENT_QUERIES", str(config.performance.max_concurrent_queries))
)
config.performance.query_timeout = int(
os.getenv("QUERY_TIMEOUT", str(config.performance.query_timeout))
)
# Logging configuration
config.logging.level = os.getenv("LOG_LEVEL", config.logging.level)
config.logging.file_path = os.getenv("LOG_FILE_PATH", config.logging.file_path)
config.logging.enable_audit = (
os.getenv("ENABLE_AUDIT", str(config.logging.enable_audit).lower()).lower() == "true"
)
config.logging.audit_file_path = os.getenv("AUDIT_FILE_PATH", config.logging.audit_file_path)
# Monitoring configuration
config.monitoring.enable_metrics = (
os.getenv("ENABLE_METRICS", "true").lower() == "true"
)
config.monitoring.metrics_port = int(
os.getenv("METRICS_PORT", str(config.monitoring.metrics_port))
)
config.monitoring.health_check_port = int(
os.getenv("HEALTH_CHECK_PORT", str(config.monitoring.health_check_port))
)
config.monitoring.enable_alerts = (
os.getenv("ENABLE_ALERTS", str(config.monitoring.enable_alerts).lower()).lower() == "true"
)
config.monitoring.alert_webhook_url = os.getenv("ALERT_WEBHOOK_URL", config.monitoring.alert_webhook_url)
# Server configuration
config.server_name = os.getenv("SERVER_NAME", config.server_name)
config.server_version = os.getenv("SERVER_VERSION", config.server_version)
config.server_port = int(os.getenv("SERVER_PORT", str(config.server_port)))
return config
@classmethod
def _from_dict(cls, config_data: dict[str, Any]) -> "DorisConfig":
"""Create configuration object from dictionary"""
config = cls()
# Update basic configuration
for key in ["server_name", "server_version", "server_port"]:
if key in config_data:
setattr(config, key, config_data[key])
# Update database configuration
if "database" in config_data:
db_config = config_data["database"]
for key, value in db_config.items():
if hasattr(config.database, key):
setattr(config.database, key, value)
# Update security configuration
if "security" in config_data:
sec_config = config_data["security"]
for key, value in sec_config.items():
if hasattr(config.security, key):
setattr(config.security, key, value)
# Update performance configuration
if "performance" in config_data:
perf_config = config_data["performance"]
for key, value in perf_config.items():
if hasattr(config.performance, key):
setattr(config.performance, key, value)
# Update logging configuration
if "logging" in config_data:
log_config = config_data["logging"]
for key, value in log_config.items():
if hasattr(config.logging, key):
setattr(config.logging, key, value)
# Update monitoring configuration
if "monitoring" in config_data:
mon_config = config_data["monitoring"]
for key, value in mon_config.items():
if hasattr(config.monitoring, key):
setattr(config.monitoring, key, value)
# Custom configuration
config.custom_config = config_data.get("custom", {})
return config
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary format"""
return {
"server_name": self.server_name,
"server_version": self.server_version,
"server_port": self.server_port,
"database": {
"host": self.database.host,
"port": self.database.port,
"user": self.database.user,
"password": "***", # Hide password
"database": self.database.database,
"charset": self.database.charset,
"min_connections": self.database.min_connections,
"max_connections": self.database.max_connections,
"connection_timeout": self.database.connection_timeout,
"health_check_interval": self.database.health_check_interval,
"max_connection_age": self.database.max_connection_age,
},
"security": {
"auth_type": self.security.auth_type,
"token_secret": "***", # Hide secret key
"token_expiry": self.security.token_expiry,
"blocked_keywords": self.security.blocked_keywords,
"max_query_complexity": self.security.max_query_complexity,
"max_result_rows": self.security.max_result_rows,
"sensitive_tables": self.security.sensitive_tables,
"enable_masking": self.security.enable_masking,
"masking_rules": len(self.security.masking_rules),
},
"performance": {
"enable_query_cache": self.performance.enable_query_cache,
"cache_ttl": self.performance.cache_ttl,
"max_cache_size": self.performance.max_cache_size,
"max_concurrent_queries": self.performance.max_concurrent_queries,
"query_timeout": self.performance.query_timeout,
"connection_pool_size": self.performance.connection_pool_size,
"idle_timeout": self.performance.idle_timeout,
},
"logging": {
"level": self.logging.level,
"format": self.logging.format,
"file_path": self.logging.file_path,
"max_file_size": self.logging.max_file_size,
"backup_count": self.logging.backup_count,
"enable_audit": self.logging.enable_audit,
"audit_file_path": self.logging.audit_file_path,
},
"monitoring": {
"enable_metrics": self.monitoring.enable_metrics,
"metrics_port": self.monitoring.metrics_port,
"metrics_path": self.monitoring.metrics_path,
"health_check_port": self.monitoring.health_check_port,
"health_check_path": self.monitoring.health_check_path,
"enable_alerts": self.monitoring.enable_alerts,
"alert_webhook_url": self.monitoring.alert_webhook_url,
},
"custom": self.custom_config,
}
def save_to_file(self, config_path: str):
"""Save configuration to file"""
config_file = Path(config_path)
config_file.parent.mkdir(parents=True, exist_ok=True)
try:
with open(config_file, "w", encoding="utf-8") as f:
if config_file.suffix.lower() == ".json":
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
else:
raise ValueError(f"Unsupported configuration file format: {config_file.suffix}")
except Exception as e:
raise ValueError(f"Failed to save configuration file: {e}")
def validate(self) -> list[str]:
"""Validate configuration validity"""
errors = []
# Validate database configuration
if not self.database.host:
errors.append("Database host address cannot be empty")
if not (1 <= self.database.port <= 65535):
errors.append("Database port must be in the range 1-65535")
if not self.database.user:
errors.append("Database username cannot be empty")
if self.database.min_connections <= 0:
errors.append("Minimum connections must be greater than 0")
if self.database.max_connections <= self.database.min_connections:
errors.append("Maximum connections must be greater than minimum connections")
# Validate security configuration
if self.security.auth_type not in ["token", "basic", "oauth"]:
errors.append("Authentication type must be one of token, basic, or oauth")
if self.security.token_expiry <= 0:
errors.append("Token expiry time must be greater than 0")
if self.security.max_query_complexity <= 0:
errors.append("Maximum query complexity must be greater than 0")
if self.security.max_result_rows <= 0:
errors.append("Maximum result rows must be greater than 0")
# Validate performance configuration
if self.performance.cache_ttl <= 0:
errors.append("Cache TTL must be greater than 0")
if self.performance.max_concurrent_queries <= 0:
errors.append("Maximum concurrent queries must be greater than 0")
if self.performance.query_timeout <= 0:
errors.append("Query timeout must be greater than 0")
# Validate logging configuration
if self.logging.level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
errors.append("Log level must be one of DEBUG, INFO, WARNING, ERROR, or CRITICAL")
if self.logging.max_file_size <= 0:
errors.append("Maximum log file size must be greater than 0")
if self.logging.backup_count < 0:
errors.append("Log backup count cannot be negative")
# Validate monitoring configuration
if not (1 <= self.monitoring.metrics_port <= 65535):
errors.append("Monitoring port must be in the range 1-65535")
if not (1 <= self.monitoring.health_check_port <= 65535):
errors.append("Health check port must be in the range 1-65535")
return errors
def get_connection_string(self) -> str:
"""Get database connection string (hide password)"""
return f"mysql://{self.database.user}:***@{self.database.host}:{self.database.port}/{self.database.database}"
def get_config_summary(self) -> dict[str, Any]:
"""Get configuration summary information"""
return {
"server": f"{self.server_name} v{self.server_version}",
"database": f"{self.database.host}:{self.database.port}/{self.database.database}",
"connection_pool": f"{self.database.min_connections}-{self.database.max_connections}",
"security": {
"auth_type": self.security.auth_type,
"masking_enabled": self.security.enable_masking,
"blocked_keywords_count": len(self.security.blocked_keywords),
},
"performance": {
"cache_enabled": self.performance.enable_query_cache,
"max_concurrent": self.performance.max_concurrent_queries,
"query_timeout": self.performance.query_timeout,
},
"monitoring": {
"metrics_enabled": self.monitoring.enable_metrics,
"alerts_enabled": self.monitoring.enable_alerts,
},
}
class ConfigManager:
"""Configuration manager class"""
def __init__(self, config: DorisConfig):
self.config = config
self.logger = logging.getLogger(__name__)
def setup_logging(self):
"""Setup logging configuration"""
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(getattr(logging, self.config.logging.level.upper()))
# Clear existing handlers
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
# Create formatter
formatter = logging.Formatter(self.config.logging.format)
# Console handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
root_logger.addHandler(console_handler)
# File handler (if configured)
if self.config.logging.file_path:
try:
from logging.handlers import RotatingFileHandler
file_handler = RotatingFileHandler(
self.config.logging.file_path,
maxBytes=self.config.logging.max_file_size,
backupCount=self.config.logging.backup_count,
encoding="utf-8",
)
file_handler.setFormatter(formatter)
root_logger.addHandler(file_handler)
except Exception as e:
self.logger.warning(f"Failed to setup file logging: {e}")
# Audit log handler (if configured)
if self.config.logging.enable_audit and self.config.logging.audit_file_path:
try:
from logging.handlers import RotatingFileHandler
audit_logger = logging.getLogger("audit")
audit_handler = RotatingFileHandler(
self.config.logging.audit_file_path,
maxBytes=self.config.logging.max_file_size,
backupCount=self.config.logging.backup_count,
encoding="utf-8",
)
audit_handler.setFormatter(formatter)
audit_logger.addHandler(audit_handler)
audit_logger.setLevel(logging.INFO)
except Exception as e:
self.logger.warning(f"Failed to setup audit logging: {e}")
def validate_config(self) -> bool:
"""Validate configuration"""
errors = self.config.validate()
if errors:
self.logger.error("Configuration validation failed:")
for error in errors:
self.logger.error(f" - {error}")
return False
self.logger.info("Configuration validation passed")
return True
def log_config_summary(self):
"""Log configuration summary"""
summary = self.config.get_config_summary()
self.logger.info("Configuration Summary:")
self.logger.info(f" Server: {summary['server']}")
self.logger.info(f" Database: {summary['database']}")
self.logger.info(f" Connection Pool: {summary['connection_pool']}")
self.logger.info(f" Security: {summary['security']}")
self.logger.info(f" Performance: {summary['performance']}")
self.logger.info(f" Monitoring: {summary['monitoring']}")
def create_default_config_file(config_path: str):
"""Create default configuration file"""
config = DorisConfig()
config.save_to_file(config_path)
print(f"Default configuration file created: {config_path}")
# Example usage
if __name__ == "__main__":
# Create default configuration
config = DorisConfig()
# Load from environment variables
# config = DorisConfig.from_env()
# Load from file
# config = DorisConfig.from_file("config.json")
# Validate configuration
config_manager = ConfigManager(config)
if config_manager.validate_config():
config_manager.setup_logging()
config_manager.log_config_summary()
# Save configuration
config.save_to_file("example_config.json")
print("Configuration saved to example_config.json")
else:
print("Configuration validation failed")

View File

@@ -1,100 +1,479 @@
import os
import json
import pymysql
import pandas as pd
from typing import Dict, List, Optional, Any
from dotenv import load_dotenv
import re
#!/usr/bin/env python3
"""
Apache Doris Database Connection Management Module
# Load environment variables
load_dotenv(override=True)
Provides high-performance database connection pool management, automatic reconnection mechanism and connection health check functionality
Supports asynchronous operations and concurrent connection management, ensuring stability and performance for enterprise applications
"""
# Database configuration
DB_CONFIG = {
"host": os.getenv("DB_HOST", "localhost"),
"port": int(os.getenv("DB_PORT", "9030")),
"user": os.getenv("DB_USER", "root"),
"password": os.getenv("DB_PASSWORD", ""),
"database": os.getenv("DB_DATABASE", ""),
"charset": "utf8mb4",
"cursorclass": pymysql.cursors.DictCursor
}
import asyncio
import logging
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List
def get_db_connection(db_name: Optional[str] = None):
import aiomysql
from aiomysql import Connection, Pool
@dataclass
class ConnectionMetrics:
"""Connection pool performance metrics"""
total_connections: int = 0
active_connections: int = 0
idle_connections: int = 0
failed_connections: int = 0
connection_errors: int = 0
avg_connection_time: float = 0.0
last_health_check: datetime | None = None
@dataclass
class QueryResult:
"""Query result wrapper"""
data: list[dict[str, Any]]
metadata: dict[str, Any]
execution_time: float
row_count: int
class DorisConnection:
"""Doris database connection wrapper class"""
def __init__(self, connection: Connection, session_id: str, security_manager=None):
self.connection = connection
self.session_id = session_id
self.created_at = datetime.utcnow()
self.last_used = datetime.utcnow()
self.query_count = 0
self.is_healthy = True
self.security_manager = security_manager
async def execute(self, sql: str, params: tuple | None = None, auth_context=None) -> QueryResult:
"""Execute SQL query"""
start_time = time.time()
try:
# If security manager exists, perform SQL security check
security_result = None
if self.security_manager and auth_context:
validation_result = await self.security_manager.validate_sql_security(sql, auth_context)
if not validation_result.is_valid:
raise ValueError(f"SQL security validation failed: {validation_result.error_message}")
security_result = {
"is_valid": validation_result.is_valid,
"risk_level": validation_result.risk_level,
"blocked_operations": validation_result.blocked_operations
}
async with self.connection.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(sql, params)
# Check if it's a query statement (statement that returns result set)
sql_upper = sql.strip().upper()
if (sql_upper.startswith("SELECT") or
sql_upper.startswith("SHOW") or
sql_upper.startswith("DESCRIBE") or
sql_upper.startswith("DESC") or
sql_upper.startswith("EXPLAIN")):
data = await cursor.fetchall()
row_count = len(data)
else:
data = []
row_count = cursor.rowcount
execution_time = time.time() - start_time
self.last_used = datetime.utcnow()
self.query_count += 1
# Get column information
columns = []
if cursor.description:
columns = [desc[0] for desc in cursor.description]
# If security manager exists and has auth context, apply data masking
final_data = list(data) if data else []
if self.security_manager and auth_context and final_data:
final_data = await self.security_manager.apply_data_masking(final_data, auth_context)
metadata = {"columns": columns, "query": sql, "params": params}
if security_result:
metadata["security_check"] = security_result
return QueryResult(
data=final_data,
metadata=metadata,
execution_time=execution_time,
row_count=row_count,
)
except Exception as e:
self.is_healthy = False
logging.error(f"Query execution failed: {e}")
raise
async def ping(self) -> bool:
"""Check connection health status"""
try:
await self.connection.ping()
self.is_healthy = True
return True
except Exception:
self.is_healthy = False
return False
async def close(self):
"""Close connection"""
try:
if self.connection and not self.connection.closed:
await self.connection.ensure_closed()
except Exception as e:
logging.error(f"Error occurred while closing connection: {e}")
class DorisConnectionManager:
"""Doris database connection manager
Provides connection pool management, connection health monitoring, fault recovery and other functions
Supports session-level connection reuse and intelligent load balancing
Integrates security manager to provide unified security validation and data masking
"""
Get database connection
Args:
db_name: Specify the database name to connect to, use default config if None
Returns:
Database connection
"""
if db_name:
# Use default config but override database name
config = DB_CONFIG.copy()
config["database"] = db_name
return pymysql.connect(**config)
else:
# Use default config
return pymysql.connect(**DB_CONFIG)
def get_db_name() -> str:
"""Get the currently configured default database name"""
return DB_CONFIG["database"] or os.getenv("DB_DATABASE", "")
def __init__(self, config, security_manager=None):
self.config = config
self.pool: Pool | None = None
self.session_connections: dict[str, DorisConnection] = {}
self.metrics = ConnectionMetrics()
self.logger = logging.getLogger(__name__)
self.security_manager = security_manager
def execute_query(sql, db_name: Optional[str] = None):
"""
Execute SQL query and return results
Args:
sql: SQL query statement
db_name: Specify the database name to connect to, use default config if None
Returns:
Query results
"""
conn = get_db_connection(db_name)
try:
with conn.cursor() as cursor:
# Set connection character set to utf8 before executing query
cursor.execute("SET NAMES utf8")
# Health check configuration
self.health_check_interval = config.database.health_check_interval or 60
self.max_connection_age = config.database.max_connection_age or 3600
self.connection_timeout = config.database.connection_timeout or 30
# Start background tasks
self._health_check_task = None
self._cleanup_task = None
async def initialize(self):
"""Initialize connection manager"""
try:
# Create connection pool
self.pool = await aiomysql.create_pool(
host=self.config.database.host,
port=self.config.database.port,
user=self.config.database.user,
password=self.config.database.password,
db=self.config.database.database,
charset="utf8",
minsize=self.config.database.min_connections or 5,
maxsize=self.config.database.max_connections or 20,
autocommit=True,
connect_timeout=self.connection_timeout,
)
self.logger.info(
f"Connection pool initialized successfully, min connections: {self.config.database.min_connections}, "
f"max connections: {self.config.database.max_connections}"
)
# Start background monitoring tasks
self._health_check_task = asyncio.create_task(self._health_check_loop())
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
except Exception as e:
self.logger.error(f"Connection pool initialization failed: {e}")
raise
async def get_connection(self, session_id: str) -> DorisConnection:
"""Get database connection
Supports session-level connection reuse to improve performance and consistency
"""
# Check if there's an existing session connection
if session_id in self.session_connections:
conn = self.session_connections[session_id]
# Check connection health
if await conn.ping():
return conn
else:
# Connection is unhealthy, clean up and create new one
await self._cleanup_session_connection(session_id)
# Create new connection
return await self._create_new_connection(session_id)
async def _create_new_connection(self, session_id: str) -> DorisConnection:
"""Create new database connection"""
try:
if not self.pool:
raise RuntimeError("Connection pool not initialized")
# Get connection from pool
raw_connection = await self.pool.acquire()
# Execute the actual query
cursor.execute(sql)
result = cursor.fetchall()
return result
finally:
conn.close()
def execute_query_df(sql, db_name: Optional[str] = None):
"""
Execute SQL query and return pandas DataFrame
Args:
sql: SQL query statement
db_name: Specify the database name to connect to, use default config if None
Returns:
pandas DataFrame
"""
conn = get_db_connection(db_name)
try:
# Use a temporary cursor to execute the query and get results
with conn.cursor() as cursor:
# Set connection character set to utf8 before executing query
cursor.execute("SET NAMES utf8")
# Create wrapped connection
doris_conn = DorisConnection(raw_connection, session_id, self.security_manager)
# Execute the actual query
cursor.execute(sql)
result = cursor.fetchall()
# Store in session connections
self.session_connections[session_id] = doris_conn
self.metrics.total_connections += 1
self.logger.debug(f"Created new connection for session: {session_id}")
return doris_conn
except Exception as e:
self.metrics.connection_errors += 1
self.logger.error(f"Failed to create connection for session {session_id}: {e}")
raise
async def release_connection(self, session_id: str):
"""Release session connection"""
if session_id in self.session_connections:
await self._cleanup_session_connection(session_id)
async def _cleanup_session_connection(self, session_id: str):
"""Clean up session connection"""
if session_id in self.session_connections:
conn = self.session_connections[session_id]
try:
# Return connection to pool
if self.pool and conn.connection and not conn.connection.closed:
self.pool.release(conn.connection)
# Close connection wrapper
await conn.close()
except Exception as e:
self.logger.error(f"Error cleaning up connection for session {session_id}: {e}")
finally:
# Remove from session connections
del self.session_connections[session_id]
self.logger.debug(f"Cleaned up connection for session: {session_id}")
async def _health_check_loop(self):
"""Background health check loop"""
while True:
try:
await asyncio.sleep(self.health_check_interval)
await self._perform_health_check()
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error(f"Health check error: {e}")
async def _perform_health_check(self):
"""Perform health check"""
try:
unhealthy_sessions = []
for session_id, conn in self.session_connections.items():
if not await conn.ping():
unhealthy_sessions.append(session_id)
# Clean up unhealthy connections
for session_id in unhealthy_sessions:
await self._cleanup_session_connection(session_id)
self.metrics.failed_connections += 1
# Update metrics
await self._update_connection_metrics()
self.metrics.last_health_check = datetime.utcnow()
if unhealthy_sessions:
self.logger.warning(f"Cleaned up {len(unhealthy_sessions)} unhealthy connections")
except Exception as e:
self.logger.error(f"Health check failed: {e}")
async def _cleanup_loop(self):
"""Background cleanup loop"""
while True:
try:
await asyncio.sleep(300) # Run every 5 minutes
await self._cleanup_idle_connections()
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error(f"Cleanup loop error: {e}")
async def _cleanup_idle_connections(self):
"""Clean up idle connections"""
current_time = datetime.utcnow()
idle_sessions = []
# If no results, return empty DataFrame
if not result:
return pd.DataFrame()
# Manually convert dict results to DataFrame
df = pd.DataFrame(result)
return df
finally:
conn.close()
for session_id, conn in self.session_connections.items():
# Check if connection has exceeded maximum age
age = (current_time - conn.created_at).total_seconds()
if age > self.max_connection_age:
idle_sessions.append(session_id)
# Clean up idle connections
for session_id in idle_sessions:
await self._cleanup_session_connection(session_id)
if idle_sessions:
self.logger.info(f"Cleaned up {len(idle_sessions)} idle connections")
async def _update_connection_metrics(self):
"""Update connection metrics"""
self.metrics.active_connections = len(self.session_connections)
if self.pool:
self.metrics.idle_connections = self.pool.freesize
async def get_metrics(self) -> ConnectionMetrics:
"""Get connection metrics"""
await self._update_connection_metrics()
return self.metrics
async def execute_query(
self, session_id: str, sql: str, params: tuple | None = None, auth_context=None
) -> QueryResult:
"""Execute query"""
conn = await self.get_connection(session_id)
return await conn.execute(sql, params, auth_context)
@asynccontextmanager
async def get_connection_context(self, session_id: str):
"""Get connection context manager"""
conn = await self.get_connection(session_id)
try:
yield conn
finally:
# Connection will be reused, no need to close here
pass
async def close(self):
"""Close connection manager"""
try:
# Cancel background tasks
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
# Clean up all session connections
for session_id in list(self.session_connections.keys()):
await self._cleanup_session_connection(session_id)
# Close connection pool
if self.pool:
self.pool.close()
await self.pool.wait_closed()
self.logger.info("Connection manager closed successfully")
except Exception as e:
self.logger.error(f"Error closing connection manager: {e}")
async def test_connection(self) -> bool:
"""Test database connection"""
try:
if not self.pool:
return False
async with self.pool.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute("SELECT 1")
result = await cursor.fetchone()
return result is not None
except Exception as e:
self.logger.error(f"Connection test failed: {e}")
return False
class ConnectionPoolMonitor:
"""Connection pool monitor
Provides detailed monitoring and reporting capabilities for connection pool status
"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
self.logger = logging.getLogger(__name__)
async def get_pool_status(self) -> dict[str, Any]:
"""Get connection pool status"""
metrics = await self.connection_manager.get_metrics()
status = {
"pool_size": self.connection_manager.pool.size if self.connection_manager.pool else 0,
"free_connections": self.connection_manager.pool.freesize if self.connection_manager.pool else 0,
"active_sessions": len(self.connection_manager.session_connections),
"total_connections": metrics.total_connections,
"failed_connections": metrics.failed_connections,
"connection_errors": metrics.connection_errors,
"avg_connection_time": metrics.avg_connection_time,
"last_health_check": metrics.last_health_check.isoformat() if metrics.last_health_check else None,
}
return status
async def get_session_details(self) -> list[dict[str, Any]]:
"""Get session connection details"""
sessions = []
for session_id, conn in self.connection_manager.session_connections.items():
session_info = {
"session_id": session_id,
"created_at": conn.created_at.isoformat(),
"last_used": conn.last_used.isoformat(),
"query_count": conn.query_count,
"is_healthy": conn.is_healthy,
"connection_age": (datetime.utcnow() - conn.created_at).total_seconds(),
}
sessions.append(session_info)
return sessions
async def generate_health_report(self) -> dict[str, Any]:
"""Generate connection health report"""
pool_status = await self.get_pool_status()
session_details = await self.get_session_details()
# Calculate health statistics
healthy_sessions = sum(1 for s in session_details if s["is_healthy"])
total_sessions = len(session_details)
health_ratio = healthy_sessions / total_sessions if total_sessions > 0 else 1.0
report = {
"timestamp": datetime.utcnow().isoformat(),
"pool_status": pool_status,
"session_summary": {
"total_sessions": total_sessions,
"healthy_sessions": healthy_sessions,
"health_ratio": health_ratio,
},
"session_details": session_details,
"recommendations": [],
}
# Add recommendations based on health status
if health_ratio < 0.8:
report["recommendations"].append("Consider checking database connectivity and network stability")
if pool_status["connection_errors"] > 10:
report["recommendations"].append("High connection error rate detected, review connection configuration")
if pool_status["active_sessions"] > pool_status["pool_size"] * 0.9:
report["recommendations"].append("Connection pool utilization is high, consider increasing pool size")
return report

View File

@@ -1,226 +1,85 @@
"""
Unified Logging Configuration Module
Provides unified logging configuration, including:
- General logs: Record all program execution information
- Audit logs: Record JSON data for key operations and processing results
- Error logs: Specifically record program exceptions and errors
Logging configuration for Doris MCP Server.
"""
import os
import sys
import logging
import logging.handlers
import logging.config
import sys
from pathlib import Path
from typing import Dict
from datetime import datetime
from dotenv import load_dotenv
from typing import Any
# Load environment variables
load_dotenv(override=True)
# Get project root directory
PROJECT_ROOT = Path(__file__).parents[2].absolute()
def setup_logging(
level: str = "INFO",
log_file: str | None = None,
log_format: str | None = None,
) -> None:
"""
Setup logging configuration.
# Get log configuration from environment variables
LOG_DIR = os.getenv("LOG_DIR", str(PROJECT_ROOT / "logs"))
LOG_PREFIX = os.getenv("LOG_PREFIX", "doris_mcp")
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
LOG_MAX_DAYS = int(os.getenv("LOG_MAX_DAYS", "30"))
# Whether to output logs to the console (should be disabled when running as a service)
CONSOLE_LOGGING = os.getenv("CONSOLE_LOGGING", "false").lower() == "true"
# Whether stdio transport mode is being used
STDIO_MODE = os.getenv("MCP_TRANSPORT_TYPE", "").lower() == "stdio"
Args:
level: Logging level (DEBUG, INFO, WARNING, ERROR)
log_file: Optional log file path
log_format: Optional custom log format
"""
if log_format is None:
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
def purge_old_logs():
"""Clean up expired log files"""
# --- Only perform cleanup in non-Stdio mode ---
if STDIO_MODE:
return
try:
now = datetime.now()
log_dir = Path(LOG_DIR)
# Check if directory exists and is readable/writable
if not log_dir.is_dir() or not os.access(LOG_DIR, os.W_OK):
if not STDIO_MODE: # Avoid printing to stdout in stdio mode
print(f"Warning: Log directory {LOG_DIR} not accessible, skipping log purge.", file=sys.stderr)
return
# Base configuration
config: dict[str, Any] = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {"format": log_format, "datefmt": "%Y-%m-%d %H:%M:%S"}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": level,
"formatter": "default",
"stream": sys.stdout,
}
},
"root": {"level": level, "handlers": ["console"]},
"loggers": {
"doris_mcp_server": {
"level": level,
"handlers": ["console"],
"propagate": False,
}
},
}
for log_file in log_dir.glob(f"{LOG_PREFIX}*.20*"):
# Parse date
file_name = log_file.name
date_str = None
# Try to find the date part
parts = file_name.split('.')
for part in parts:
if part.startswith('20') and len(part) == 8: # 20YYMMDD format
date_str = part
break
if date_str:
try:
file_date = datetime.strptime(date_str, '%Y%m%d')
days_old = (now - file_date).days
if days_old > LOG_MAX_DAYS:
os.remove(log_file)
if not STDIO_MODE:
print(f"Deleted expired log file: {log_file}")
except (ValueError, OSError) as e:
if not STDIO_MODE:
print(f"Error processing log file {file_name}: {e}", file=sys.stderr)
except Exception as e:
if not STDIO_MODE:
print(f"Error cleaning up logs: {e}", file=sys.stderr)
# Add file handler if log_file is specified
if log_file:
# Ensure log directory exists
log_path = Path(log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
# Force disable console log output if in stdio mode
if STDIO_MODE:
CONSOLE_LOGGING = False
config["handlers"]["file"] = {
"class": "logging.handlers.RotatingFileHandler",
"level": level,
"formatter": "default",
"filename": log_file,
"maxBytes": 10485760, # 10MB
"backupCount": 5,
}
# --- Only create log directory and clean old logs in non-Stdio mode ---
if not STDIO_MODE:
try:
os.makedirs(LOG_DIR, exist_ok=True)
# Clean up expired logs on startup (also moved here, as it only handles file logs)
purge_old_logs()
except OSError as e:
# If directory creation fails (e.g., permission issue), print warning but continue to avoid startup failure
print(f"Warning: Failed to create log directory {LOG_DIR} or purge logs: {e}", file=sys.stderr)
# Add file handler to root and package loggers
config["root"]["handlers"].append("file")
config["loggers"]["doris_mcp_server"]["handlers"].append("file")
# Log file paths (definition still needed, but files might not be created/used)
LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.log")
AUDIT_LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.audit")
ERROR_LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.error")
# Log level mapping
LOG_LEVELS = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL
}
# Log format
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
AUDIT_FORMAT = '%(asctime)s - %(name)s - %(message)s'
ERROR_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(pathname)s:%(lineno)d - %(message)s'
# Dedicated audit log level
AUDIT = 25 # Level between INFO and WARNING
logging.addLevelName(AUDIT, "AUDIT")
# Logger object cache
_loggers: Dict[str, logging.Logger] = {}
# Handler type mapping, used to ensure no duplicates are added
_handler_types = {
'console': logging.StreamHandler,
'file': logging.handlers.TimedRotatingFileHandler,
'audit': logging.handlers.TimedRotatingFileHandler,
'error': logging.handlers.TimedRotatingFileHandler
}
logging.config.dictConfig(config)
def get_logger(name: str) -> logging.Logger:
"""
Get a logger with the specified name
Get a logger instance.
Args:
name: Logger name
Returns:
logging.Logger: Configured logger
Logger instance
"""
if name in _loggers:
return _loggers[name]
# Create logger
logger = logging.getLogger(name)
logger.setLevel(LOG_LEVELS.get(LOG_LEVEL, logging.INFO))
# Avoid duplicate logs caused by propagation
logger.propagate = False
# Check if handlers already exist to avoid duplicates
handler_types = set(type(h) for h in logger.handlers)
# Add audit log method
def audit(self, message, *args, **kwargs):
self.log(AUDIT, message, *args, **kwargs)
logger.audit = audit.__get__(logger)
# General log handler - output to console (only if enabled)
if CONSOLE_LOGGING and _handler_types['console'] not in handler_types:
# Use stderr instead of stdout to avoid conflicts with MCP communication
console_handler = logging.StreamHandler(sys.stderr)
console_handler.setFormatter(logging.Formatter(LOG_FORMAT))
logger.addHandler(console_handler)
# --- Only add file handlers in non-Stdio mode ---
if not STDIO_MODE:
# General log handler - daily rotating file
if _handler_types['file'] not in handler_types:
try: # Add try-except block
file_handler = logging.handlers.TimedRotatingFileHandler(
LOG_FILE,
when='midnight',
interval=1,
backupCount=LOG_MAX_DAYS,
encoding='utf-8'
)
file_handler.setFormatter(logging.Formatter(LOG_FORMAT))
file_handler.suffix = "%Y%m%d"
logger.addHandler(file_handler)
except OSError as e:
print(f"Warning: Failed to add file log handler for {LOG_FILE}: {e}", file=sys.stderr)
# Audit log handler - only logs AUDIT level
if _handler_types['audit'] not in handler_types:
try: # Add try-except block
audit_handler = logging.handlers.TimedRotatingFileHandler(
AUDIT_LOG_FILE,
when='midnight',
interval=1,
backupCount=LOG_MAX_DAYS,
encoding='utf-8'
)
audit_handler.setFormatter(logging.Formatter(AUDIT_FORMAT))
audit_handler.suffix = "%Y%m%d"
audit_handler.setLevel(AUDIT)
audit_handler.addFilter(lambda record: record.levelno == AUDIT)
logger.addHandler(audit_handler)
except OSError as e:
print(f"Warning: Failed to add audit log handler for {AUDIT_LOG_FILE}: {e}", file=sys.stderr)
# Error log handler - only logs ERROR level and above
if _handler_types['error'] not in handler_types:
try: # Add try-except block
error_handler = logging.handlers.TimedRotatingFileHandler(
ERROR_LOG_FILE,
when='midnight',
interval=1,
backupCount=LOG_MAX_DAYS,
encoding='utf-8'
)
error_handler.setFormatter(logging.Formatter(ERROR_FORMAT))
error_handler.suffix = "%Y%m%d"
error_handler.setLevel(logging.ERROR)
logger.addHandler(error_handler)
except OSError as e:
print(f"Warning: Failed to add error log handler for {ERROR_LOG_FILE}: {e}", file=sys.stderr)
# Cache logger
_loggers[name] = logger
return logger
# Default logger
logger = get_logger('doris_mcp')
# Audit logger - for recording processing results, business operations, etc.
audit_logger = get_logger('audit')
# Call to clean logs moved after directory creation, and added non-stdio check
return logging.getLogger(name)

View File

@@ -0,0 +1,800 @@
#!/usr/bin/env python3
"""
Doris Query Execution Module
Implements query optimization, cache management and performance monitoring functionality
"""
import asyncio
import hashlib
import json
import logging
import time
import os
import uuid
import traceback
from dataclasses import dataclass
from datetime import datetime, timedelta, date
from typing import Any, Dict
from decimal import Decimal
from .db import DorisConnectionManager, QueryResult
@dataclass
class QueryRequest:
"""Query request wrapper"""
sql: str
session_id: str
user_id: str
parameters: dict[str, Any] | None = None
timeout: int | None = None
cache_enabled: bool = True
@dataclass
class CachedQuery:
"""Cached query result"""
result: QueryResult
created_at: datetime
ttl: int
access_count: int = 0
last_accessed: datetime | None = None
def is_expired(self) -> bool:
"""Check if cache is expired"""
if self.ttl <= 0:
return False
return (datetime.utcnow() - self.created_at).total_seconds() > self.ttl
def access(self):
"""Record access"""
self.access_count += 1
self.last_accessed = datetime.utcnow()
@dataclass
class QueryMetrics:
"""Query performance metrics"""
total_queries: int = 0
successful_queries: int = 0
failed_queries: int = 0
cache_hits: int = 0
cache_misses: int = 0
avg_execution_time: float = 0.0
total_execution_time: float = 0.0
slow_queries: int = 0
concurrent_queries: int = 0
class QueryCache:
"""Query result cache manager"""
def __init__(self, max_size: int = 1000, default_ttl: int = 300):
self.max_size = max_size
self.default_ttl = default_ttl
self.cache: dict[str, CachedQuery] = {}
self.logger = logging.getLogger(__name__)
def _generate_cache_key(
self, sql: str, parameters: dict[str, Any] | None = None
) -> str:
"""Generate cache key"""
cache_data = {"sql": sql.strip().lower(), "parameters": parameters or {}}
cache_string = json.dumps(cache_data, sort_keys=True)
return hashlib.md5(cache_string.encode()).hexdigest()
async def get(
self, sql: str, parameters: dict[str, Any] | None = None
) -> CachedQuery | None:
"""Get cached query result"""
cache_key = self._generate_cache_key(sql, parameters)
if cache_key in self.cache:
cached_query = self.cache[cache_key]
if not cached_query.is_expired():
cached_query.access()
self.logger.debug(f"Cache hit: {cache_key}")
return cached_query
else:
# Clean up expired cache
del self.cache[cache_key]
self.logger.debug(f"Cache expired, cleaned up: {cache_key}")
return None
async def set(
self,
sql: str,
result: QueryResult,
parameters: dict[str, Any] | None = None,
ttl: int | None = None,
) -> str:
"""Set query result cache"""
cache_key = self._generate_cache_key(sql, parameters)
# Check cache size limit
if len(self.cache) >= self.max_size:
await self._evict_oldest()
cached_query = CachedQuery(
result=result, created_at=datetime.utcnow(), ttl=ttl or self.default_ttl
)
self.cache[cache_key] = cached_query
self.logger.debug(f"Cache set: {cache_key}")
return cache_key
async def _evict_oldest(self):
"""Clean up oldest cache item"""
if not self.cache:
return
# Find oldest cache item
oldest_key = min(self.cache.keys(), key=lambda k: self.cache[k].created_at)
del self.cache[oldest_key]
self.logger.debug(f"Cleaned up oldest cache: {oldest_key}")
async def clear_expired(self):
"""Clean up all expired cache"""
expired_keys = [
key for key, cached_query in self.cache.items() if cached_query.is_expired()
]
for key in expired_keys:
del self.cache[key]
if expired_keys:
self.logger.info(f"Cleaned up {len(expired_keys)} expired cache items")
async def clear_all(self):
"""Clean up all cache"""
cache_count = len(self.cache)
self.cache.clear()
self.logger.info(f"Cleaned up all cache, total {cache_count} items")
def get_stats(self) -> dict[str, Any]:
"""Get cache statistics"""
total_access = sum(cached.access_count for cached in self.cache.values())
return {
"cache_size": len(self.cache),
"max_size": self.max_size,
"total_access": total_access,
"hit_rate": 0.0
if total_access == 0
else sum(cached.access_count for cached in self.cache.values())
/ total_access,
}
class QueryOptimizer:
"""Query optimizer"""
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(__name__)
self.optimization_rules = self._load_optimization_rules()
def _load_optimization_rules(self) -> list[dict[str, Any]]:
"""Load query optimization rules"""
return [
{
"name": "add_limit_clause",
"description": "Add default limit for SELECT queries without LIMIT",
"pattern": r"^select\s+.*(?!.*limit\s+\d+)",
"action": "add_limit",
"params": {"default_limit": 1000},
},
{
"name": "optimize_count_query",
"description": "Optimize COUNT queries",
"pattern": r"select\s+count\(\*\)\s+from\s+(\w+)",
"action": "optimize_count",
"params": {},
},
]
async def optimize_query(self, sql: str, context: dict[str, Any]) -> str:
"""Apply query optimization"""
optimized_sql = sql
for rule in self.optimization_rules:
if self._should_apply_rule(rule, optimized_sql, context):
optimized_sql = await self._apply_optimization_rule(
optimized_sql, rule, context
)
self.logger.debug(f"Applied optimization rule: {rule['name']}")
return optimized_sql
def _should_apply_rule(
self, rule: dict[str, Any], sql: str, context: dict[str, Any]
) -> bool:
"""Check if optimization rule should be applied"""
import re
# Check pattern match
if "pattern" in rule:
if not re.search(rule["pattern"], sql, re.IGNORECASE):
return False
# Check conditions
if "conditions" in rule:
for condition in rule["conditions"]:
if not self._check_condition(condition, context):
return False
return True
def _check_condition(
self, condition: dict[str, Any], context: dict[str, Any]
) -> bool:
"""Check optimization condition"""
condition_type = condition.get("type")
if condition_type == "user_role":
required_roles = condition.get("roles", [])
user_roles = context.get("user_roles", [])
return any(role in user_roles for role in required_roles)
elif condition_type == "query_size":
max_size = condition.get("max_size", 1000)
return len(context.get("sql", "")) <= max_size
return True
async def _apply_optimization_rule(
self, sql: str, rule: dict[str, Any], context: dict[str, Any]
) -> str:
"""Apply optimization rule"""
action = rule.get("action")
params = rule.get("params", {})
if action == "add_limit":
return await self._add_limit_clause(sql, params)
elif action == "optimize_count":
return await self._optimize_count_query(sql, params)
elif action == "add_hints":
return await self._add_query_hints(sql, params)
return sql
async def _add_limit_clause(self, sql: str, params: dict[str, Any]) -> str:
"""Add LIMIT clause to query"""
import re
default_limit = params.get("default_limit", 1000)
# Check if LIMIT already exists
if re.search(r"\blimit\s+\d+", sql, re.IGNORECASE):
return sql
# Add LIMIT clause
if sql.strip().endswith(";"):
sql = sql.strip()[:-1]
return f"{sql} LIMIT {default_limit}"
async def _optimize_count_query(self, sql: str, params: dict[str, Any]) -> str:
"""Optimize COUNT query"""
# For COUNT queries, we can add optimization hints
return sql.replace("COUNT(*)", "COUNT(1)")
async def _add_query_hints(self, sql: str, params: dict[str, Any]) -> str:
"""Add query hints"""
hints = params.get("hints", [])
if not hints:
return sql
hint_string = "/*+ " + " ".join(hints) + " */"
return f"{hint_string} {sql}"
class DorisQueryExecutor:
"""Doris query executor with caching and optimization"""
def __init__(self, connection_manager: DorisConnectionManager, config=None):
self.connection_manager = connection_manager
self.config = config or self._create_default_config()
self.logger = logging.getLogger(__name__)
# Initialize components
cache_config = getattr(self.config, 'performance', None)
if cache_config:
cache_size = getattr(cache_config, 'max_cache_size', 1000)
cache_ttl = getattr(cache_config, 'cache_ttl', 300)
else:
cache_size = 1000
cache_ttl = 300
self.query_cache = QueryCache(max_size=cache_size, default_ttl=cache_ttl)
self.query_optimizer = QueryOptimizer(self.config)
self.metrics = QueryMetrics()
# Performance monitoring
self.slow_query_threshold = 5.0 # seconds
self.max_concurrent_queries = getattr(
getattr(self.config, 'performance', None), 'max_concurrent_queries', 50
) if hasattr(self.config, 'performance') else 50
# Background tasks
self._background_tasks = []
self._start_background_tasks()
def _create_default_config(self):
"""Create default configuration"""
class DefaultConfig:
def __init__(self):
self.performance = DefaultPerformanceConfig()
class DefaultPerformanceConfig:
def __init__(self):
self.max_cache_size = 1000
self.cache_ttl = 300
self.max_concurrent_queries = 50
return DefaultConfig()
def _start_background_tasks(self):
"""Start background tasks"""
try:
# Cache cleanup task
cleanup_task = asyncio.create_task(self._cache_cleanup_loop())
self._background_tasks.append(cleanup_task)
except RuntimeError:
# No event loop running (e.g., in tests), skip background tasks
self.logger.debug("No event loop running, skipping background tasks")
async def _cache_cleanup_loop(self):
"""Background cache cleanup loop"""
while True:
try:
await asyncio.sleep(300) # Run every 5 minutes
await self.query_cache.clear_expired()
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error(f"Cache cleanup error: {e}")
async def execute_query(
self, query_request: QueryRequest, auth_context=None
) -> QueryResult:
"""Execute query with caching and optimization"""
start_time = time.time()
self.metrics.total_queries += 1
self.metrics.concurrent_queries += 1
try:
# Check cache first
if query_request.cache_enabled:
cached_result = await self.query_cache.get(
query_request.sql, query_request.parameters
)
if cached_result:
self.metrics.cache_hits += 1
self.logger.debug(f"Cache hit for query: {query_request.sql[:50]}...")
return cached_result.result
self.metrics.cache_misses += 1
# Execute query
result = await self._execute_query_internal(query_request, auth_context)
# Cache result if enabled
if query_request.cache_enabled and result.row_count > 0:
await self.query_cache.set(
query_request.sql, result, query_request.parameters
)
self.metrics.successful_queries += 1
return result
except Exception as e:
self.metrics.failed_queries += 1
self.logger.error(f"Query execution failed: {e}")
raise
finally:
execution_time = time.time() - start_time
self.metrics.concurrent_queries -= 1
self._update_execution_metrics(execution_time)
async def _execute_query_internal(
self, query_request: QueryRequest, auth_context
) -> QueryResult:
"""Internal query execution"""
# Optimize query
optimized_sql = await self.query_optimizer.optimize_query(
query_request.sql, {"user_roles": getattr(auth_context, 'roles', [])}
)
# Execute query
connection = await self.connection_manager.get_connection(
query_request.session_id
)
# Set timeout if specified
if query_request.timeout:
try:
result = await asyncio.wait_for(
connection.execute(optimized_sql, query_request.parameters, auth_context),
timeout=query_request.timeout
)
except asyncio.TimeoutError:
raise Exception(f"Query timeout after {query_request.timeout} seconds")
else:
result = await connection.execute(optimized_sql, query_request.parameters, auth_context)
return result
def _update_execution_metrics(self, execution_time: float):
"""Update execution metrics"""
self.metrics.total_execution_time += execution_time
# Update average execution time
if self.metrics.successful_queries > 0:
self.metrics.avg_execution_time = (
self.metrics.total_execution_time / self.metrics.successful_queries
)
# Check for slow queries
if execution_time > self.slow_query_threshold:
self.metrics.slow_queries += 1
self.logger.warning(
f"Slow query detected: {execution_time:.2f}s (threshold: {self.slow_query_threshold}s)"
)
async def execute_batch_queries(
self, query_requests: list[QueryRequest], auth_context=None
) -> list[QueryResult]:
"""Execute multiple queries in batch"""
results = []
# Check concurrent query limit
if len(query_requests) > self.max_concurrent_queries:
raise Exception(
f"Batch size {len(query_requests)} exceeds maximum concurrent queries {self.max_concurrent_queries}"
)
# Execute queries concurrently
tasks = [
self.execute_query(request, auth_context) for request in query_requests
]
try:
results = await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:
self.logger.error(f"Batch query execution failed: {e}")
raise
return results
async def explain_query(self, sql: str, session_id: str) -> dict[str, Any]:
"""Get query execution plan"""
explain_sql = f"EXPLAIN {sql}"
connection = await self.connection_manager.get_connection(session_id)
result = await connection.execute(explain_sql)
return {
"query": sql,
"execution_plan": result.data,
"estimated_cost": "N/A", # Doris doesn't provide cost estimates
}
async def get_query_stats(self) -> dict[str, Any]:
"""Get query execution statistics"""
cache_stats = self.query_cache.get_stats()
return {
"query_metrics": {
"total_queries": self.metrics.total_queries,
"successful_queries": self.metrics.successful_queries,
"failed_queries": self.metrics.failed_queries,
"success_rate": (
self.metrics.successful_queries / self.metrics.total_queries
if self.metrics.total_queries > 0
else 0.0
),
"avg_execution_time": self.metrics.avg_execution_time,
"slow_queries": self.metrics.slow_queries,
"concurrent_queries": self.metrics.concurrent_queries,
},
"cache_metrics": {
"cache_hits": self.metrics.cache_hits,
"cache_misses": self.metrics.cache_misses,
"hit_rate": (
self.metrics.cache_hits
/ (self.metrics.cache_hits + self.metrics.cache_misses)
if (self.metrics.cache_hits + self.metrics.cache_misses) > 0
else 0.0
),
**cache_stats,
},
}
async def clear_cache(self):
"""Clear query cache"""
await self.query_cache.clear_all()
async def execute_sql_for_mcp(
self,
sql: str,
limit: int = 1000,
timeout: int = 30,
session_id: str = "mcp_session",
user_id: str = "mcp_user"
) -> Dict[str, Any]:
"""Execute SQL query for MCP interface - unified method"""
try:
if not sql:
return {
"success": False,
"error": "SQL query is required",
"data": None
}
# Add LIMIT if not present and it's a SELECT query
if sql.upper().startswith("SELECT") and "LIMIT" not in sql.upper():
if sql.endswith(";"):
sql = sql[:-1]
sql = f"{sql} LIMIT {limit}"
# Create auth context for MCP calls
class MockAuthContext:
def __init__(self):
self.user_id = user_id
self.roles = ["data_analyst"]
self.permissions = ["read_data", "execute_query"]
self.session_id = session_id
self.security_level = "internal"
auth_context = MockAuthContext()
# Create query request
query_request = QueryRequest(
sql=sql,
session_id=session_id,
user_id=user_id,
timeout=timeout,
cache_enabled=True
)
# Execute query
result = await self.execute_query(query_request, auth_context)
# Process results
processed_data = []
if result.data:
for row in result.data:
processed_row = self._serialize_row_data(row)
processed_data.append(processed_row)
return {
"success": True,
"data": processed_data,
"metadata": {
"row_count": result.row_count,
"execution_time": result.execution_time,
"columns": result.metadata.get("columns", []),
"query": sql
},
"error": None
}
except Exception as e:
error_msg = str(e)
self.logger.error(f"SQL execution error: {error_msg}")
# Analyze error for better user feedback
error_analysis = self._analyze_error(error_msg)
return {
"success": False,
"error": error_analysis.get("user_message", error_msg),
"error_type": error_analysis.get("error_type", "execution_error"),
"data": None,
"metadata": {
"query": sql,
"error_details": error_msg
}
}
def _serialize_row_data(self, row_data: Dict[str, Any]) -> Dict[str, Any]:
"""Serialize row data for JSON response"""
serialized = {}
for key, value in row_data.items():
if value is None:
serialized[key] = None
elif isinstance(value, (str, int, float, bool)):
serialized[key] = value
elif isinstance(value, Decimal):
serialized[key] = float(value)
elif isinstance(value, (datetime, date)):
serialized[key] = value.isoformat()
elif isinstance(value, bytes):
try:
serialized[key] = value.decode('utf-8')
except UnicodeDecodeError:
serialized[key] = str(value)
else:
serialized[key] = str(value)
return serialized
def _analyze_error(self, error_message: str) -> Dict[str, str]:
"""Analyze error message and provide user-friendly feedback"""
error_msg_lower = error_message.lower()
if "table" in error_msg_lower and "doesn't exist" in error_msg_lower:
return {
"error_type": "table_not_found",
"user_message": "The specified table does not exist. Please check the table name and database."
}
elif "column" in error_msg_lower and ("unknown" in error_msg_lower or "doesn't exist" in error_msg_lower):
return {
"error_type": "column_not_found",
"user_message": "One or more columns in the query do not exist. Please check column names."
}
elif "syntax error" in error_msg_lower or "sql syntax" in error_msg_lower:
return {
"error_type": "syntax_error",
"user_message": "SQL syntax error. Please check your query syntax."
}
elif "access denied" in error_msg_lower or "permission" in error_msg_lower:
return {
"error_type": "permission_denied",
"user_message": "Access denied. You don't have permission to execute this query."
}
elif "timeout" in error_msg_lower:
return {
"error_type": "timeout",
"user_message": "Query execution timed out. Try simplifying your query or adding more specific filters."
}
else:
return {
"error_type": "general_error",
"user_message": f"Query execution failed: {error_message}"
}
async def close(self):
"""Close query executor and cleanup resources"""
# Cancel background tasks
for task in self._background_tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Clear cache
await self.query_cache.clear_all()
self.logger.info("Query executor closed")
class QueryPerformanceMonitor:
"""Query performance monitor"""
def __init__(self, query_executor: DorisQueryExecutor):
self.query_executor = query_executor
self.logger = logging.getLogger(__name__)
self.performance_records = []
async def record_query_performance(
self, query_request: QueryRequest, result: QueryResult, execution_time: float
):
"""Record query performance"""
record = {
"timestamp": datetime.utcnow(),
"sql": query_request.sql,
"user_id": query_request.user_id,
"session_id": query_request.session_id,
"execution_time": execution_time,
"row_count": result.row_count,
"cache_hit": False, # This would need to be passed from executor
}
self.performance_records.append(record)
# Keep only recent records (last 1000)
if len(self.performance_records) > 1000:
self.performance_records = self.performance_records[-1000:]
async def get_performance_report(
self, time_range_minutes: int = 60
) -> dict[str, Any]:
"""Get performance report"""
cutoff_time = datetime.utcnow() - timedelta(minutes=time_range_minutes)
recent_records = [
record
for record in self.performance_records
if record["timestamp"] >= cutoff_time
]
if not recent_records:
return {"message": "No performance data available for the specified time range"}
# Calculate statistics
execution_times = [record["execution_time"] for record in recent_records]
row_counts = [record["row_count"] for record in recent_records]
return {
"time_range_minutes": time_range_minutes,
"total_queries": len(recent_records),
"avg_execution_time": sum(execution_times) / len(execution_times),
"max_execution_time": max(execution_times),
"min_execution_time": min(execution_times),
"avg_row_count": sum(row_counts) / len(row_counts),
"query_distribution": self._analyze_query_distribution(recent_records),
}
def _analyze_query_distribution(
self, records: list[dict[str, Any]]
) -> dict[str, Any]:
"""Analyze query distribution"""
query_types = {}
user_distribution = {}
for record in records:
# Analyze query type
sql_upper = record["sql"].strip().upper()
if sql_upper.startswith("SELECT"):
query_type = "SELECT"
elif sql_upper.startswith("INSERT"):
query_type = "INSERT"
elif sql_upper.startswith("UPDATE"):
query_type = "UPDATE"
elif sql_upper.startswith("DELETE"):
query_type = "DELETE"
else:
query_type = "OTHER"
query_types[query_type] = query_types.get(query_type, 0) + 1
# Analyze user distribution
user_id = record["user_id"]
user_distribution[user_id] = user_distribution.get(user_id, 0) + 1
return {"query_types": query_types, "user_distribution": user_distribution}
# Unified convenience function for MCP integration
async def execute_sql_query(sql: str, connection_manager: DorisConnectionManager, **kwargs) -> Dict[str, Any]:
"""Execute SQL query - unified convenience function for MCP tools"""
try:
# Create query executor
executor = DorisQueryExecutor(connection_manager)
try:
# Extract parameters from kwargs or use defaults
limit = kwargs.get("limit", 1000)
timeout = kwargs.get("timeout", 30)
session_id = kwargs.get("session_id", "mcp_session")
user_id = kwargs.get("user_id", "mcp_user")
result = await executor.execute_sql_for_mcp(
sql=sql,
limit=limit,
timeout=timeout,
session_id=session_id,
user_id=user_id
)
return result
finally:
await executor.close()
except Exception as e:
return {
"success": False,
"error": f"Query execution failed: {str(e)}",
"data": None
}

View File

@@ -8,6 +8,8 @@ import os
import json
import pandas as pd
import re
import uuid
import time
from typing import Dict, List, Any, Optional, Tuple
from dotenv import load_dotenv
from datetime import datetime, timedelta
@@ -26,23 +28,25 @@ ENABLE_MULTI_DATABASE=os.getenv("ENABLE_MULTI_DATABASE",True)
MULTI_DATABASE_NAMES=os.getenv("MULTI_DATABASE_NAMES","")
# Import local modules
from doris_mcp_server.utils.db import execute_query_df, execute_query
from .db import DorisConnectionManager
class MetadataExtractor:
"""Apache Doris Metadata Extractor"""
def __init__(self, db_name: str = None, catalog_name: str = None):
def __init__(self, db_name: str = None, catalog_name: str = None, connection_manager=None):
"""
Initialize the metadata extractor
Args:
db_name: Default database name, uses the currently connected database if not specified
catalog_name: Default catalog name for federation queries, uses the current catalog if not specified
connection_manager: DorisConnectionManager instance for database operations
"""
# Get configuration from environment variables
self.db_name = db_name or os.getenv("DB_DATABASE", "")
self.catalog_name = catalog_name # Store catalog name for federation support
self.metadata_db = METADATA_DB_NAME # Use constant
self.connection_manager = connection_manager
# Caching system
self.metadata_cache = {}
@@ -65,6 +69,9 @@ class MetadataExtractor:
# List of excluded system databases
self.excluded_databases = self._load_excluded_databases()
# Session ID for database queries
self._session_id = f"metadata_extractor_{uuid.uuid4().hex[:8]}"
def _load_excluded_databases(self) -> List[str]:
"""
Load the list of excluded databases configuration
@@ -482,7 +489,7 @@ class MetadataExtractor:
TABLE_SCHEMA = '{db_name}'
AND TABLE_NAME = '{table_name}'
"""
table_type_result = execute_query(table_type_query)
table_type_result = self._execute_query(table_type_query)
if table_type_result:
schema["table_type"] = table_type_result[0].get("TABLE_TYPE", "")
schema["engine"] = table_type_result[0].get("ENGINE", "")
@@ -633,31 +640,52 @@ class MetadataExtractor:
else:
query = f"SHOW INDEX FROM `{db_name}`.`{table_name}`"
df = execute_query_df(query)
# Process results
indexes = []
current_index = None
for _, row in df.iterrows():
index_name = row['Key_name']
column_name = row['Column_name']
try:
df = self._execute_query(query, return_dataframe=True)
if current_index is None or current_index['name'] != index_name:
# Process results
indexes = []
current_index = None
if not df.empty:
for _, row in df.iterrows():
try:
index_name = row['Key_name']
column_name = row['Column_name']
if current_index is None or current_index['name'] != index_name:
if current_index is not None:
indexes.append(current_index)
current_index = {
'name': index_name,
'columns': [column_name],
'unique': row['Non_unique'] == 0,
'type': row['Index_type']
}
else:
current_index['columns'].append(column_name)
except Exception as row_error:
logger.warning(f"Failed to process index row data: {row_error}")
continue
if current_index is not None:
indexes.append(current_index)
current_index = {
'name': index_name,
'columns': [column_name],
'unique': row['Non_unique'] == 0,
'type': row['Index_type']
}
else:
current_index['columns'].append(column_name)
if current_index is not None:
indexes.append(current_index)
except Exception as df_error:
logger.warning(f"DataFrame processing failed, trying regular query: {df_error}")
# Fall back to regular query
result = self._execute_query(query, return_dataframe=False)
indexes = []
if result:
# Simple processing, no complex index grouping
for row in result:
if isinstance(row, dict):
indexes.append({
'name': row.get('Key_name', ''),
'columns': [row.get('Column_name', '')],
'unique': row.get('Non_unique', 1) == 0,
'type': row.get('Index_type', '')
})
# Update cache
self.metadata_cache[cache_key] = indexes
@@ -748,7 +776,7 @@ class MetadataExtractor:
ORDER BY time DESC
LIMIT {limit}
"""
df = execute_query_df(query)
df = self._execute_query(query, return_dataframe=True)
return df
except Exception as e:
logger.error(f"Error getting audit logs: {str(e)}")
@@ -768,7 +796,7 @@ class MetadataExtractor:
try:
# Use SHOW CATALOGS command to get catalog list
query = "SHOW CATALOGS"
result = execute_query(query)
result = self._execute_query(query)
if not result:
catalogs = []
@@ -1057,7 +1085,7 @@ class MetadataExtractor:
AND TABLE_NAME = '{table_name}'
"""
partitions = execute_query(query)
partitions = self._execute_query(query)
if not partitions:
return {}
@@ -1099,10 +1127,511 @@ class MetadataExtractor:
# Replace 'information_schema' with 'catalog_name.information_schema'
modified_query = query.replace('information_schema', f'{catalog_name}.information_schema')
logger.info(f"Modified query for catalog {catalog_name}: {modified_query}")
return execute_query(modified_query, db_name)
return self._execute_query(modified_query, db_name)
else:
# Execute the original query
return execute_query(query, db_name)
return self._execute_query(query, db_name)
except Exception as e:
logger.error(f"Error executing query with catalog: {str(e)}")
raise
raise
async def _execute_query_async(self, query: str, db_name: str = None, return_dataframe: bool = False):
"""
Execute database query asynchronously
Args:
query: SQL query to execute
db_name: Database name to use (optional)
return_dataframe: Whether to return a pandas DataFrame instead of list
Returns:
Query result data (list of dictionaries or pandas DataFrame)
"""
try:
if self.connection_manager:
# Use the injected connection manager directly (async)
result = await self.connection_manager.execute_query(self._session_id, query, None)
# Extract data from QueryResult
if hasattr(result, 'data'):
data = result.data
else:
data = result
# Convert to DataFrame if requested
if return_dataframe and data:
import pandas as pd
return pd.DataFrame(data)
elif return_dataframe:
import pandas as pd
return pd.DataFrame()
else:
return data
else:
# Fallback: Return empty result
logger.warning("No connection manager provided, returning empty result")
if return_dataframe:
import pandas as pd
return pd.DataFrame()
else:
return []
except Exception as e:
logger.error(f"Error executing query: {str(e)}")
# Return empty result instead of raising exception to prevent cascade failures
if return_dataframe:
import pandas as pd
return pd.DataFrame()
else:
return []
def _execute_query(self, query: str, db_name: str = None, return_dataframe: bool = False):
"""
Execute database query with proper session management (sync wrapper)
Args:
query: SQL query to execute
db_name: Database name to use (optional)
return_dataframe: Whether to return a pandas DataFrame instead of list
Returns:
Query result data (list of dictionaries or pandas DataFrame)
"""
try:
if self.connection_manager:
import asyncio
# Try to run the async query
try:
# Check if there's a running event loop
loop = asyncio.get_running_loop()
# If we're in an async context, we need to run in a separate thread
import concurrent.futures
def run_in_new_loop():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(
self._execute_query_async(query, db_name, return_dataframe)
)
finally:
new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_new_loop)
return future.result(timeout=30)
except RuntimeError:
# No running loop, we can safely create one
return asyncio.run(
self._execute_query_async(query, db_name, return_dataframe)
)
else:
# Fallback: Return empty result
logger.warning("No connection manager provided, returning empty result")
if return_dataframe:
import pandas as pd
return pd.DataFrame()
else:
return []
except Exception as e:
logger.error(f"Error executing query: {str(e)}")
# Return empty result instead of raising exception to prevent cascade failures
if return_dataframe:
import pandas as pd
return pd.DataFrame()
else:
return []
async def get_table_schema_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> List[Dict[str, Any]]:
"""Asynchronously get table schema information"""
try:
# Use async query method
effective_catalog = catalog_name or self.catalog_name
# Build query statement
if effective_catalog and effective_catalog != "internal":
query = f"DESCRIBE `{effective_catalog}`.`{db_name or self.db_name}`.`{table_name}`"
else:
query = f"DESCRIBE `{db_name or self.db_name}`.`{table_name}`"
# Execute async query
result = await self._execute_query_async(query, db_name)
if not result:
return []
# Process results
schema = []
for row in result:
if isinstance(row, dict):
schema.append({
'column_name': row.get('Field', ''),
'data_type': row.get('Type', ''),
'is_nullable': row.get('Null', 'NO') == 'YES',
'default_value': row.get('Default', None),
'comment': row.get('Comment', ''),
'key': row.get('Key', ''),
'extra': row.get('Extra', '')
})
return schema
except Exception as e:
logger.error(f"Failed to get table schema: {e}")
return []
async def get_all_databases_async(self, catalog_name: str = None) -> List[str]:
"""Asynchronously get all database list"""
try:
effective_catalog = catalog_name or self.catalog_name
if effective_catalog and effective_catalog != "internal":
query = f"SHOW DATABASES FROM `{effective_catalog}`"
else:
query = "SHOW DATABASES"
result = await self._execute_query_async(query)
if not result:
return []
# Extract database names
databases = []
for row in result:
if isinstance(row, dict):
# Get the value of the first field (usually Database field)
db_name = list(row.values())[0] if row else None
if db_name:
databases.append(db_name)
return databases
except Exception as e:
logger.error(f"Failed to get database list: {e}")
return []
async def get_database_tables_async(self, db_name: str = None, catalog_name: str = None) -> List[str]:
"""Asynchronously get table list in database"""
try:
effective_catalog = catalog_name or self.catalog_name
effective_db = db_name or self.db_name
if effective_catalog and effective_catalog != "internal":
query = f"SHOW TABLES FROM `{effective_catalog}`.`{effective_db}`"
else:
query = f"SHOW TABLES FROM `{effective_db}`"
result = await self._execute_query_async(query, effective_db)
if not result:
return []
# Extract table names
tables = []
for row in result:
if isinstance(row, dict):
# Get the value of the first field (usually Tables_in_xxx field)
table_name = list(row.values())[0] if row else None
if table_name:
tables.append(table_name)
return tables
except Exception as e:
logger.error(f"Failed to get table list: {e}")
return []
async def get_catalog_list_async(self) -> List[str]:
"""Asynchronously get catalog list"""
try:
query = "SHOW CATALOGS"
result = await self._execute_query_async(query)
if not result:
return []
# Extract catalog names
catalogs = []
for row in result:
if isinstance(row, dict):
# SHOW CATALOGS returns fields including: CatalogId, CatalogName, Type, IsCurrent, CreateTime, LastUpdateTime, Comment
# We need to get the CatalogName field (second field)
if 'CatalogName' in row:
catalog_name = row['CatalogName']
else:
# If no CatalogName field, try to get the second field
values = list(row.values())
catalog_name = values[1] if len(values) > 1 else values[0] if values else None
if catalog_name:
catalogs.append(str(catalog_name))
return catalogs
except Exception as e:
logger.error(f"Failed to get catalog list: {e}")
return []
# ==================== Business layer methods (original metadata_tools.py functionality) ====================
def _format_response(self, success: bool, result: Any = None, error: str = None, message: str = "") -> Dict[str, Any]:
"""Format response result"""
response_data = {
"success": success,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
if success and result is not None:
response_data["result"] = result
response_data["message"] = message or "Operation successful"
elif not success:
response_data["error"] = error or "Unknown error"
response_data["message"] = message or "Operation failed"
return response_data
async def exec_query_for_mcp(
self,
sql: str,
db_name: str = None,
catalog_name: str = None,
max_rows: int = 100,
timeout: int = 30
) -> Dict[str, Any]:
"""
Execute SQL query and return results, supports catalog federation queries
Unified interface for MCP tools
"""
logger.info(f"Executing SQL query: {sql}, DB: {db_name}, Catalog: {catalog_name}, MaxRows: {max_rows}, Timeout: {timeout}")
try:
if not sql:
return self._format_response(success=False, error="No SQL statement provided", message="Please provide SQL statement to execute")
# Import query executor
from .query_executor import execute_sql_query
# Call execute_sql_query to execute query
exec_result = await execute_sql_query(
sql=sql,
connection_manager=self.connection_manager,
limit=max_rows,
timeout=timeout
)
return exec_result
except Exception as e:
logger.error(f"Failed to execute SQL query: {str(e)}", exc_info=True)
return self._format_response(success=False, error=str(e), message="Error occurred while executing SQL query")
async def get_table_schema_for_mcp(
self,
table_name: str,
db_name: str = None,
catalog_name: str = None
) -> Dict[str, Any]:
"""Get detailed schema information for specified table (columns, types, comments, etc.) - MCP interface"""
logger.info(f"Getting table schema: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
if not table_name:
return self._format_response(success=False, error="Missing table_name parameter")
try:
schema = await self.get_table_schema_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
if not schema:
return self._format_response(
success=False,
error="Table does not exist or has no columns",
message=f"Unable to get schema for table {catalog_name or 'default'}.{db_name or self.db_name}.{table_name}"
)
return self._format_response(success=True, result=schema)
except Exception as e:
logger.error(f"Failed to get table schema: {str(e)}", exc_info=True)
return self._format_response(success=False, error=str(e), message="Error occurred while getting table schema")
async def get_db_table_list_for_mcp(
self,
db_name: str = None,
catalog_name: str = None
) -> Dict[str, Any]:
"""Get list of all table names in specified database - MCP interface"""
logger.info(f"Getting database table list: DB: {db_name}, Catalog: {catalog_name}")
try:
tables = await self.get_database_tables_async(db_name=db_name, catalog_name=catalog_name)
return self._format_response(success=True, result=tables)
except Exception as e:
logger.error(f"Failed to get database table list: {str(e)}", exc_info=True)
return self._format_response(success=False, error=str(e), message="Error occurred while getting database table list")
async def get_db_list_for_mcp(self, catalog_name: str = None) -> Dict[str, Any]:
"""Get list of all database names on server - MCP interface"""
logger.info(f"Getting database list: Catalog: {catalog_name}")
try:
databases = await self.get_all_databases_async(catalog_name=catalog_name)
return self._format_response(success=True, result=databases)
except Exception as e:
logger.error(f"Failed to get database list: {str(e)}", exc_info=True)
return self._format_response(success=False, error=str(e), message="Error occurred while getting database list")
async def get_table_comment_for_mcp(
self,
table_name: str,
db_name: str = None,
catalog_name: str = None
) -> Dict[str, Any]:
"""Get comment information for specified table - MCP interface"""
logger.info(f"Getting table comment: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
if not table_name:
return self._format_response(success=False, error="Missing table_name parameter")
try:
comment = self.get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
return self._format_response(success=True, result=comment)
except Exception as e:
logger.error(f"Failed to get table comment: {str(e)}", exc_info=True)
return self._format_response(success=False, error=str(e), message="Error occurred while getting table comment")
async def get_table_column_comments_for_mcp(
self,
table_name: str,
db_name: str = None,
catalog_name: str = None
) -> Dict[str, Any]:
"""Get comment information for all columns in specified table - MCP interface"""
logger.info(f"Getting table column comments: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
if not table_name:
return self._format_response(success=False, error="Missing table_name parameter")
try:
comments = self.get_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
return self._format_response(success=True, result=comments)
except Exception as e:
logger.error(f"Failed to get table column comments: {str(e)}", exc_info=True)
return self._format_response(success=False, error=str(e), message="Error occurred while getting table column comments")
async def get_table_indexes_for_mcp(
self,
table_name: str,
db_name: str = None,
catalog_name: str = None
) -> Dict[str, Any]:
"""Get index information for specified table - MCP interface"""
logger.info(f"Getting table indexes: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
if not table_name:
return self._format_response(success=False, error="Missing table_name parameter")
try:
indexes = self.get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
return self._format_response(success=True, result=indexes)
except Exception as e:
logger.error(f"Failed to get table indexes: {str(e)}", exc_info=True)
return self._format_response(success=False, error=str(e), message="Error occurred while getting table indexes")
def _serialize_datetime_objects(self, data):
"""Serialize datetime objects to JSON compatible format"""
if isinstance(data, list):
return [self._serialize_datetime_objects(item) for item in data]
elif isinstance(data, dict):
return {key: self._serialize_datetime_objects(value) for key, value in data.items()}
elif hasattr(data, 'isoformat'): # datetime, date, time objects
return data.isoformat()
elif hasattr(data, 'strftime'): # pandas Timestamp objects
return data.strftime('%Y-%m-%d %H:%M:%S')
else:
return data
async def get_recent_audit_logs_for_mcp(self, days: int = 7, limit: int = 100) -> Dict[str, Any]:
"""Get recent audit log records - MCP interface"""
logger.info(f"Getting audit logs: Days: {days}, Limit: {limit}")
try:
logs_df = self.get_recent_audit_logs(days=days, limit=limit)
# Convert DataFrame to JSON format
if hasattr(logs_df, 'to_dict'):
try:
logs_data = logs_df.to_dict('records')
except Exception as e:
logger.warning(f"DataFrame.to_dict failed, trying manual conversion: {e}")
# Manually convert DataFrame to records format
logs_data = []
if not logs_df.empty:
for _, row in logs_df.iterrows():
logs_data.append(dict(row))
# Serialize datetime objects
logs_data = self._serialize_datetime_objects(logs_data)
else:
logs_data = self._serialize_datetime_objects(logs_df)
return self._format_response(success=True, result=logs_data)
except Exception as e:
logger.error(f"Failed to get audit logs: {str(e)}", exc_info=True)
return self._format_response(success=False, error=str(e), message="Error occurred while getting audit logs")
async def get_catalog_list_for_mcp(self) -> Dict[str, Any]:
"""Get Doris catalog list - MCP interface"""
logger.info("Getting catalog list")
try:
catalogs = await self.get_catalog_list_async()
return self._format_response(success=True, result=catalogs, message="Successfully retrieved catalog list")
except Exception as e:
logger.error(f"Failed to get catalog list: {str(e)}", exc_info=True)
return self._format_response(success=False, error=str(e), message="Error occurred while getting catalog list")
# ==================== Compatibility aliases ====================
# For backward compatibility, create MetadataManager alias
class MetadataManager:
"""
Metadata manager - backward compatibility class
Actually a wrapper for MetadataExtractor
"""
def __init__(self, connection_manager=None):
self.extractor = MetadataExtractor(connection_manager=connection_manager)
async def exec_query(self, sql: str, db_name: str = None, catalog_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]:
"""Execute SQL query and return results, supports catalog federation queries"""
return await self.extractor.exec_query_for_mcp(sql, db_name, catalog_name, max_rows, timeout)
async def get_table_schema(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
"""Get detailed schema information for specified table (columns, types, comments, etc.)"""
return await self.extractor.get_table_schema_for_mcp(table_name, db_name, catalog_name)
async def get_db_table_list(self, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
"""Get list of all table names in specified database"""
return await self.extractor.get_db_table_list_for_mcp(db_name, catalog_name)
async def get_db_list(self, catalog_name: str = None) -> Dict[str, Any]:
"""Get list of all database names on server"""
return await self.extractor.get_db_list_for_mcp(catalog_name)
async def get_table_comment(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
"""Get comment information for specified table"""
return await self.extractor.get_table_comment_for_mcp(table_name, db_name, catalog_name)
async def get_table_column_comments(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
"""Get comment information for all columns in specified table"""
return await self.extractor.get_table_column_comments_for_mcp(table_name, db_name, catalog_name)
async def get_table_indexes(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
"""Get index information for specified table"""
return await self.extractor.get_table_indexes_for_mcp(table_name, db_name, catalog_name)
async def get_recent_audit_logs(self, days: int = 7, limit: int = 100) -> Dict[str, Any]:
"""Get recent audit log records"""
return await self.extractor.get_recent_audit_logs_for_mcp(days, limit)
async def get_catalog_list(self) -> Dict[str, Any]:
"""Get Doris catalog list"""
return await self.extractor.get_catalog_list_for_mcp()

View File

@@ -0,0 +1,861 @@
#!/usr/bin/env python3
"""
Doris Security Management Module
Implements enterprise-level authentication, authorization, SQL security validation and data masking functionality
"""
import hashlib
import logging
import re
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any
import sqlparse
from sqlparse.sql import Statement
from sqlparse.tokens import Keyword, Name
class SecurityLevel(Enum):
"""Security level enumeration"""
PUBLIC = "public"
INTERNAL = "internal"
CONFIDENTIAL = "confidential"
SECRET = "secret"
@dataclass
class AuthContext:
"""Authentication context"""
user_id: str
roles: list[str]
permissions: list[str]
session_id: str
login_time: datetime | None = None
last_activity: datetime | None = None
security_level: SecurityLevel = SecurityLevel.INTERNAL
@dataclass
class ValidationResult:
"""Validation result"""
is_valid: bool
error_message: str | None = None
risk_level: str = "low"
blocked_operations: list[str] = None
def __post_init__(self):
if self.blocked_operations is None:
self.blocked_operations = []
@dataclass
class MaskingRule:
"""Data masking rule"""
column_pattern: str
algorithm: str
parameters: dict[str, Any]
security_level: SecurityLevel
class DorisSecurityManager:
"""Doris security manager
Provides complete security control functionality, including authentication, authorization, SQL security validation and data masking
"""
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(__name__)
# Initialize security components
self.auth_provider = AuthenticationProvider(config)
self.authz_provider = AuthorizationProvider(config)
self.sql_validator = SQLSecurityValidator(config)
self.masking_processor = DataMaskingProcessor(config)
# Security rule configuration
self.blocked_keywords = self._load_blocked_keywords()
self.sensitive_tables = self._load_sensitive_tables()
self.masking_rules = self._load_masking_rules()
def _load_blocked_keywords(self) -> set[str]:
"""Load blocked SQL keywords"""
default_blocked = {
"DROP",
"DELETE",
"TRUNCATE",
"ALTER",
"CREATE",
"INSERT",
"UPDATE",
"GRANT",
"REVOKE",
"EXEC",
"EXECUTE",
"SHUTDOWN",
"KILL",
}
# Load custom rules from configuration file
if hasattr(self.config, 'get'):
custom_blocked = set(self.config.get("blocked_keywords", []))
else:
custom_blocked = set()
return default_blocked.union(custom_blocked)
def _load_sensitive_tables(self) -> dict[str, SecurityLevel]:
"""Load sensitive table configuration"""
default_tables = {
"user_info": SecurityLevel.CONFIDENTIAL,
"payment_records": SecurityLevel.SECRET,
"employee_data": SecurityLevel.CONFIDENTIAL,
"public_reports": SecurityLevel.PUBLIC,
}
if hasattr(self.config, 'get'):
config_tables = self.config.get("sensitive_tables", {})
# Convert string values to SecurityLevel enum
for table_name, level in config_tables.items():
if isinstance(level, str):
try:
default_tables[table_name] = SecurityLevel(level.lower())
except ValueError:
default_tables[table_name] = SecurityLevel.INTERNAL
else:
default_tables[table_name] = level
return default_tables
else:
return default_tables
def _load_masking_rules(self) -> list[MaskingRule]:
"""Load data masking rules"""
default_rules = [
MaskingRule(
column_pattern=r".*phone.*|.*mobile.*",
algorithm="phone_mask",
parameters={"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4},
security_level=SecurityLevel.INTERNAL,
),
MaskingRule(
column_pattern=r".*email.*",
algorithm="email_mask",
parameters={"mask_char": "*"},
security_level=SecurityLevel.INTERNAL,
),
MaskingRule(
column_pattern=r".*id_card.*|.*identity.*",
algorithm="id_mask",
parameters={"mask_char": "*", "keep_prefix": 6, "keep_suffix": 4},
security_level=SecurityLevel.CONFIDENTIAL,
),
]
# Load custom rules from configuration
custom_rules = []
if hasattr(self.config, 'get'):
custom_rules = self.config.get("masking_rules", [])
elif hasattr(self.config, 'security') and hasattr(self.config.security, 'masking_rules'):
custom_rules = self.config.security.masking_rules
for rule_config in custom_rules:
if isinstance(rule_config, dict):
default_rules.append(MaskingRule(**rule_config))
elif isinstance(rule_config, MaskingRule):
default_rules.append(rule_config)
return default_rules
async def authenticate_request(self, auth_info: dict[str, Any]) -> AuthContext:
"""Validate request authentication information"""
return await self.auth_provider.authenticate(auth_info)
async def authorize_resource_access(
self, auth_context: AuthContext, resource_uri: str
) -> bool:
"""Validate resource access permissions"""
return await self.authz_provider.check_permission(
auth_context, resource_uri, "read"
)
async def validate_sql_security(
self, sql: str, auth_context: AuthContext
) -> ValidationResult:
"""Validate SQL query security"""
return await self.sql_validator.validate(sql, auth_context)
async def apply_data_masking(
self, data: list[dict[str, Any]], auth_context: AuthContext
) -> list[dict[str, Any]]:
"""Apply data masking processing"""
return await self.masking_processor.process(data, auth_context)
class AuthenticationProvider:
"""Authentication provider"""
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(__name__)
self.session_cache = {}
async def authenticate(self, auth_info: dict[str, Any]) -> AuthContext:
"""Perform identity authentication"""
auth_type = auth_info.get("type", "token")
if auth_type == "token":
return await self._authenticate_token(auth_info)
elif auth_type == "basic":
return await self._authenticate_basic(auth_info)
else:
raise ValueError(f"Unsupported authentication type: {auth_type}")
async def _authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext:
"""Token authentication"""
token = auth_info.get("token")
if not token:
raise ValueError("Missing authentication token")
# Validate token (simplified implementation, should validate JWT or query authentication service in practice)
user_info = await self._validate_token(token)
return AuthContext(
user_id=user_info["user_id"],
roles=user_info["roles"],
permissions=user_info["permissions"],
session_id=auth_info.get("session_id", "default"),
login_time=datetime.utcnow(),
security_level=SecurityLevel(user_info.get("security_level", "internal")),
)
async def _authenticate_basic(self, auth_info: dict[str, Any]) -> AuthContext:
"""Basic authentication (username password)"""
username = auth_info.get("username")
password = auth_info.get("password")
if not username or not password:
raise ValueError("Missing username or password")
# Validate username password (simplified implementation)
user_info = await self._validate_credentials(username, password)
return AuthContext(
user_id=user_info["user_id"],
roles=user_info["roles"],
permissions=user_info["permissions"],
session_id=auth_info.get("session_id", "default"),
login_time=datetime.utcnow(),
security_level=SecurityLevel(user_info.get("security_level", "internal")),
)
async def _validate_token(self, token: str) -> dict[str, Any]:
"""Validate token validity"""
# Simplified implementation for testing, should parse JWT or query authentication service in practice
valid_tokens = {
"valid_token_123": {
"user_id": "test_user",
"roles": ["data_analyst"],
"permissions": ["read_data"],
"security_level": SecurityLevel.INTERNAL,
},
"admin_token_456": {
"user_id": "admin_user",
"roles": ["data_admin"],
"permissions": ["admin"],
"security_level": SecurityLevel.SECRET,
}
}
if token in valid_tokens:
return valid_tokens[token]
else:
raise ValueError("Invalid token")
async def _validate_credentials(
self, username: str, password: str
) -> dict[str, Any]:
"""Validate user credentials"""
# Simplified implementation for testing, should query user database in practice
valid_users = {
"admin": {
"password": "admin123",
"user_id": "admin_user",
"roles": ["data_admin"],
"permissions": ["admin", "read_data", "write_data"],
"security_level": SecurityLevel.SECRET,
},
"analyst": {
"password": "analyst123",
"user_id": "analyst_user",
"roles": ["data_analyst"],
"permissions": ["read_data"],
"security_level": SecurityLevel.INTERNAL,
}
}
if username in valid_users and valid_users[username]["password"] == password:
user_info = valid_users[username].copy()
del user_info["password"] # Remove password from returned info
return user_info
else:
raise ValueError("Incorrect username or password")
class AuthorizationProvider:
"""Authorization provider"""
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(__name__)
self.permission_cache = {}
# Load sensitive tables configuration
self.sensitive_tables = self._load_sensitive_tables()
def _load_sensitive_tables(self) -> dict[str, SecurityLevel]:
"""Load sensitive table configuration"""
default_tables = {
"user_info": SecurityLevel.CONFIDENTIAL,
"payment_records": SecurityLevel.SECRET,
"employee_data": SecurityLevel.CONFIDENTIAL,
"public_reports": SecurityLevel.PUBLIC,
}
if hasattr(self.config, 'get'):
config_tables = self.config.get("sensitive_tables", {})
# Convert string values to SecurityLevel enum
for table_name, level in config_tables.items():
if isinstance(level, str):
try:
default_tables[table_name] = SecurityLevel(level.lower())
except ValueError:
default_tables[table_name] = SecurityLevel.INTERNAL
else:
default_tables[table_name] = level
return default_tables
else:
return default_tables
async def check_permission(
self, auth_context: AuthContext, resource_uri: str, action: str
) -> bool:
"""Check permissions"""
# Parse resource information
resource_info = self._parse_resource_uri(resource_uri)
# First check security level - this is mandatory
if not await self._check_security_level_permission(auth_context, resource_info):
return False
# Then check role-based permissions
if await self._check_role_permission(auth_context, resource_info, action):
return True
# Finally check user-based permissions
if await self._check_user_permission(auth_context, resource_info, action):
return True
return False
def _parse_resource_uri(self, uri: str) -> dict[str, str]:
"""Parse resource URI"""
parts = uri.split("/")
if len(parts) >= 3:
return {
"type": parts[2], # table, view, etc.
"name": parts[3] if len(parts) > 3 else "",
"schema": parts[4] if len(parts) > 4 else "default",
}
return {"type": "unknown", "name": "", "schema": "default"}
async def _check_role_permission(
self, auth_context: AuthContext, resource_info: dict[str, str], action: str
) -> bool:
"""Check role-based permissions"""
# Role permission mapping
role_permissions = {
"data_analyst": {"table": ["read"], "view": ["read"]},
"data_admin": {
"table": ["read", "write", "admin"],
"view": ["read", "write", "admin"],
},
}
for role in auth_context.roles:
role_perms = role_permissions.get(role, {})
resource_perms = role_perms.get(resource_info["type"], [])
if action in resource_perms:
return True
return False
async def _check_user_permission(
self, auth_context: AuthContext, resource_info: dict[str, str], action: str
) -> bool:
"""Check user-based permissions"""
# User-specific permission check
if "admin" in auth_context.permissions:
return True
if action == "read" and "read_data" in auth_context.permissions:
return True
return False
async def _check_security_level_permission(
self, auth_context: AuthContext, resource_info: dict[str, str]
) -> bool:
"""Check security level permissions"""
# Get resource security level
resource_security_level = self._get_resource_security_level(resource_info)
# Check if user security level is sufficient
security_hierarchy = {
SecurityLevel.PUBLIC: 0,
SecurityLevel.INTERNAL: 1,
SecurityLevel.CONFIDENTIAL: 2,
SecurityLevel.SECRET: 3,
}
user_level = security_hierarchy.get(auth_context.security_level, 0)
resource_level = security_hierarchy.get(resource_security_level, 0)
# User must have higher or equal security level to access resource
return user_level >= resource_level
def _get_resource_security_level(
self, resource_info: dict[str, str]
) -> SecurityLevel:
"""Get resource security level"""
# Get table security level from configuration
table_name = resource_info.get("name", "")
# Use the loaded sensitive tables
sensitive_tables = self.sensitive_tables
# Convert string values to SecurityLevel enum if needed
security_level = sensitive_tables.get(table_name, SecurityLevel.INTERNAL)
if isinstance(security_level, str):
try:
security_level = SecurityLevel(security_level.lower())
except ValueError:
security_level = SecurityLevel.INTERNAL
return security_level
class SQLSecurityValidator:
"""SQL security validator"""
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(__name__)
# Handle DorisConfig object or dictionary configuration
if hasattr(config, 'get'):
# Dictionary configuration
self.blocked_keywords = set(config.get("blocked_keywords", []))
self.max_query_complexity = config.get("max_query_complexity", 100)
else:
# DorisConfig object, use default values
self.blocked_keywords = set(["DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE"])
self.max_query_complexity = 100
async def validate(self, sql: str, auth_context: AuthContext) -> ValidationResult:
"""Validate SQL query security"""
try:
# Parse SQL statement
parsed = sqlparse.parse(sql)[0]
# Check blocked operations first (more specific)
keyword_result = await self._check_blocked_keywords(parsed)
if not keyword_result.is_valid:
return keyword_result
# Check SQL injection risks
injection_result = await self._check_sql_injection(sql, parsed)
if not injection_result.is_valid:
return injection_result
# Check query complexity
complexity_result = await self._check_query_complexity(parsed)
if not complexity_result.is_valid:
return complexity_result
# Check table access permissions
table_result = await self._check_table_access(parsed, auth_context)
if not table_result.is_valid:
return table_result
return ValidationResult(is_valid=True)
except Exception as e:
self.logger.error(f"SQL security validation failed: {e}")
return ValidationResult(
is_valid=False,
error_message=f"SQL parsing error: {str(e)}",
risk_level="high",
)
async def _check_sql_injection(
self, sql: str, parsed: Statement
) -> ValidationResult:
"""Check SQL injection risks"""
# Check common SQL injection patterns
injection_patterns = [
r"(\s|^)(union|select|insert|update|delete|drop|create|alter)\s+.*\s+(union|select|insert|update|delete|drop|create|alter)",
r"(\s|^)(or|and)\s+\d+\s*=\s*\d+",
r"(\s|^)(or|and)\s+['\"].*['\"]",
r";\s*(drop|delete|truncate|alter|create)",
r"(exec|execute|sp_|xp_)",
r"(script|javascript|vbscript)",
r"(char|ascii|substring|concat)\s*\(",
]
sql_lower = sql.lower()
for pattern in injection_patterns:
if re.search(pattern, sql_lower, re.IGNORECASE):
return ValidationResult(
is_valid=False,
error_message="Potential SQL injection risk detected",
risk_level="high",
)
# Check suspicious quotes and comments
if self._has_suspicious_quotes_or_comments(sql):
return ValidationResult(
is_valid=False,
error_message="Suspicious quote or comment pattern detected",
risk_level="medium",
)
return ValidationResult(is_valid=True)
def _has_suspicious_quotes_or_comments(self, sql: str) -> bool:
"""Check suspicious quote and comment patterns"""
# Check unmatched quotes
single_quotes = sql.count("'")
double_quotes = sql.count('"')
if single_quotes % 2 != 0 or double_quotes % 2 != 0:
return True
# Check SQL comments
if "--" in sql or "/*" in sql:
return True
return False
async def _check_blocked_keywords(self, parsed: Statement) -> ValidationResult:
"""Check blocked keywords"""
blocked_operations = []
# Check all tokens in the parsed statement
for token in parsed.flatten():
# Check if token is a keyword (including DML/DDL) or name that matches blocked operations
if (token.ttype is Keyword or
token.ttype is Name or
(token.ttype and str(token.ttype).startswith('Token.Keyword'))):
token_value = token.value.upper().strip()
if token_value in self.blocked_keywords:
blocked_operations.append(token_value)
# Also check for DDL/DML keywords in token values
elif hasattr(token, 'value') and token.value:
token_value = token.value.upper().strip()
for blocked_keyword in self.blocked_keywords:
if blocked_keyword in token_value:
blocked_operations.append(blocked_keyword)
if blocked_operations:
return ValidationResult(
is_valid=False,
error_message=f"Contains blocked operations: {', '.join(set(blocked_operations))}",
risk_level="high",
blocked_operations=list(set(blocked_operations)),
)
return ValidationResult(is_valid=True)
async def _check_query_complexity(self, parsed: Statement) -> ValidationResult:
"""Check query complexity"""
complexity_score = 0
# Calculate complexity score
for token in parsed.flatten():
if token.ttype is Keyword:
keyword = token.value.upper()
if keyword in ["JOIN", "INNER", "LEFT", "RIGHT", "FULL"]:
complexity_score += 10
elif keyword in ["UNION", "INTERSECT", "EXCEPT"]:
complexity_score += 15
elif keyword in ["GROUP BY", "ORDER BY", "HAVING"]:
complexity_score += 5
elif keyword in ["SUBQUERY", "EXISTS", "IN"]:
complexity_score += 8
if complexity_score > self.max_query_complexity:
return ValidationResult(
is_valid=False,
error_message=f"Query complexity too high (score: {complexity_score}, limit: {self.max_query_complexity})",
risk_level="medium",
)
return ValidationResult(is_valid=True)
async def _check_table_access(
self, parsed: Statement, auth_context: AuthContext
) -> ValidationResult:
"""Check table access permissions"""
# Extract table names from query
tables = self._extract_table_names(parsed)
# Check access permissions for each table
unauthorized_tables = []
for table in tables:
# Should call authorization provider to check permissions
# Simplified implementation, assume some tables require special permissions
if (
table.lower() in ["sensitive_data", "admin_logs"]
and "admin" not in auth_context.roles
):
unauthorized_tables.append(table)
if unauthorized_tables:
return ValidationResult(
is_valid=False,
error_message=f"No access to tables: {', '.join(unauthorized_tables)}",
risk_level="high",
)
return ValidationResult(is_valid=True)
def _extract_table_names(self, parsed: Statement) -> list[str]:
"""Extract table names from SQL statement"""
tables = []
# Simplified table name extraction logic
tokens = list(parsed.flatten())
for i, token in enumerate(tokens):
if token.ttype is Keyword and token.value.upper() == "FROM":
# Find table name after FROM
for j in range(i + 1, len(tokens)):
next_token = tokens[j]
if next_token.ttype is Name:
tables.append(next_token.value)
break
elif next_token.ttype is Keyword:
break
return tables
class DataMaskingProcessor:
"""Data masking processor"""
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(__name__)
self.masking_algorithms = self._init_masking_algorithms()
self.masking_rules = self._load_masking_rules()
def _load_masking_rules(self) -> list[MaskingRule]:
"""Load data masking rules"""
default_rules = [
MaskingRule(
column_pattern=r".*phone.*|.*mobile.*",
algorithm="phone_mask",
parameters={"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4},
security_level=SecurityLevel.INTERNAL,
),
MaskingRule(
column_pattern=r".*email.*",
algorithm="email_mask",
parameters={"mask_char": "*"},
security_level=SecurityLevel.INTERNAL,
),
MaskingRule(
column_pattern=r".*id_card.*|.*identity.*",
algorithm="id_mask",
parameters={"mask_char": "*", "keep_prefix": 6, "keep_suffix": 4},
security_level=SecurityLevel.CONFIDENTIAL,
),
]
# Load custom rules from configuration
if hasattr(self.config, 'get'):
custom_rules = self.config.get("masking_rules", [])
for rule_config in custom_rules:
if isinstance(rule_config, dict):
# Convert string security level to enum
if 'security_level' in rule_config and isinstance(rule_config['security_level'], str):
try:
rule_config['security_level'] = SecurityLevel(rule_config['security_level'].lower())
except ValueError:
rule_config['security_level'] = SecurityLevel.INTERNAL
default_rules.append(MaskingRule(**rule_config))
elif isinstance(rule_config, MaskingRule):
default_rules.append(rule_config)
return default_rules
def _init_masking_algorithms(self) -> dict[str, callable]:
"""Initialize masking algorithms"""
return {
"phone_mask": self._mask_phone,
"email_mask": self._mask_email,
"id_mask": self._mask_id_card,
"name_mask": self._mask_name,
"partial_mask": self._mask_partial,
}
async def process(
self, data: list[dict[str, Any]], auth_context: AuthContext
) -> list[dict[str, Any]]:
"""Process data masking"""
if not data:
return data
# Get applicable masking rules
applicable_rules = self._get_applicable_rules(auth_context)
masked_data = []
for row in data:
masked_row = {}
for column, value in row.items():
masked_value = await self._apply_masking_rules(
column, value, applicable_rules
)
masked_row[column] = masked_value
masked_data.append(masked_row)
return masked_data
def _get_applicable_rules(self, auth_context: AuthContext) -> list[MaskingRule]:
"""Get applicable masking rules"""
applicable_rules = []
for rule in self.masking_rules:
# Decide whether to apply masking rules based on user security level
if self._should_apply_rule(rule, auth_context):
applicable_rules.append(rule)
return applicable_rules
def _should_apply_rule(self, rule: MaskingRule, auth_context: AuthContext) -> bool:
"""Determine whether masking rule should be applied"""
# Admin users can see original data
if "admin" in auth_context.roles:
return False
# Decide based on security level
security_hierarchy = {
SecurityLevel.PUBLIC: 0,
SecurityLevel.INTERNAL: 1,
SecurityLevel.CONFIDENTIAL: 2,
SecurityLevel.SECRET: 3,
}
user_level = security_hierarchy.get(auth_context.security_level, 0)
rule_level = security_hierarchy.get(rule.security_level, 0)
# Apply masking if user level is less than or equal to rule level
return user_level <= rule_level
async def _apply_masking_rules(
self, column: str, value: Any, rules: list[MaskingRule]
) -> Any:
"""Apply masking rules"""
if value is None:
return value
for rule in rules:
if re.match(rule.column_pattern, column, re.IGNORECASE):
algorithm = self.masking_algorithms.get(rule.algorithm)
if algorithm:
return algorithm(str(value), rule.parameters)
return value
def _mask_phone(self, value: str, params: dict[str, Any]) -> str:
"""Phone number masking"""
if len(value) < 7:
return value
mask_char = params.get("mask_char", "*")
keep_prefix = params.get("keep_prefix", 3)
keep_suffix = params.get("keep_suffix", 4)
if len(value) <= keep_prefix + keep_suffix:
return mask_char * len(value)
prefix = value[:keep_prefix]
suffix = value[-keep_suffix:]
middle_length = len(value) - keep_prefix - keep_suffix
return prefix + mask_char * middle_length + suffix
def _mask_email(self, value: str, params: dict[str, Any]) -> str:
"""Email masking"""
if "@" not in value:
return value
mask_char = params.get("mask_char", "*")
local, domain = value.split("@", 1)
if len(local) <= 2:
masked_local = mask_char * len(local)
else:
masked_local = local[0] + mask_char * (len(local) - 2) + local[-1]
return f"{masked_local}@{domain}"
def _mask_id_card(self, value: str, params: dict[str, Any]) -> str:
"""ID card number masking"""
if len(value) < 10:
return value
mask_char = params.get("mask_char", "*")
keep_prefix = params.get("keep_prefix", 6)
keep_suffix = params.get("keep_suffix", 4)
if len(value) <= keep_prefix + keep_suffix:
return mask_char * len(value)
prefix = value[:keep_prefix]
suffix = value[-keep_suffix:]
middle_length = len(value) - keep_prefix - keep_suffix
return prefix + mask_char * middle_length + suffix
def _mask_name(self, value: str, params: dict[str, Any]) -> str:
"""Name masking"""
if len(value) <= 1:
return value
mask_char = params.get("mask_char", "*")
if len(value) == 2:
return value[0] + mask_char
else:
return value[0] + mask_char * (len(value) - 2) + value[-1]
def _mask_partial(self, value: str, params: dict[str, Any]) -> str:
"""Partial masking"""
mask_char = params.get("mask_char", "*")
mask_ratio = params.get("mask_ratio", 0.5)
mask_length = int(len(value) * mask_ratio)
start_pos = (len(value) - mask_length) // 2
result = list(value)
for i in range(start_pos, start_pos + mask_length):
if i < len(result):
result[i] = mask_char
return "".join(result)

View File

@@ -1,352 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
SQL Execution Tool
Responsible for executing SQL queries and handling results
"""
import os
import json
import logging
import traceback
import time
from typing import Dict, Any
import re
import datetime
from decimal import Decimal
# Get logger
logger = logging.getLogger("doris-mcp.sql-executor")
# Add environment variable control for whether to perform SQL security checks
ENABLE_SQL_SECURITY_CHECK = os.environ.get('ENABLE_SQL_SECURITY_CHECK', 'true').lower() == 'true'
async def execute_sql_query(ctx) -> Dict[str, Any]:
"""
Execute SQL query and return results
Args:
ctx: Context object or dictionary containing request parameters
Returns:
Dict[str, Any]: Execution result
"""
try:
# Support the case where the passed argument is a dictionary
if isinstance(ctx, dict) and 'params' in ctx:
params = ctx['params']
else:
params = ctx.params
sql = params.get("sql")
db_name = params.get("db_name", os.getenv("DB_DATABASE", ""))
catalog_name = params.get("catalog_name", None) # Add catalog parameter support
max_rows = params.get("max_rows", 1000) # Maximum number of rows to return
timeout = params.get("timeout", 30) # Timeout in seconds
if not sql:
return {
"content": [
{
"type": "text",
"text": json.dumps({
"success": False,
"error": "Missing SQL parameter",
"message": "Please provide the SQL query to execute"
}, ensure_ascii=False)
}
]
}
# First check SQL security
security_result = await _check_sql_security(sql)
if not security_result.get("is_safe", False):
return {
"content": [
{
"type": "text",
"text": json.dumps({
"success": False,
"error": "SQL security check failed",
"message": "Query contains unsafe operations and cannot be executed",
"security_issues": security_result.get("security_issues", [])
}, ensure_ascii=False)
}
]
}
# Import database connection tool
from doris_mcp_server.utils.db import execute_query
if not sql:
return {
"content": [
{
"type": "text",
"text": json.dumps({
"success": False,
"error": "Missing SQL parameter",
"message": "Please provide the SQL query to execute"
}, ensure_ascii=False)
}
]
}
# Ensure SELECT statements include a LIMIT clause
sql_lower = sql.lower().strip()
if sql_lower.startswith("select") and "limit" not in sql_lower:
sql = sql.rstrip(";") + f" LIMIT {max_rows};"
# Start timer
start_time = time.time()
# Execute query
try:
# For federation queries, SQL must use three-part naming: catalog_name.db_name.table_name
# This is enforced at the tool description level
result = execute_query(sql, db_name)
# Calculate execution time
execution_time = time.time() - start_time
# Build return result
if isinstance(result, list):
# Handle list of query results
row_count = len(result)
# Extract column names
if hasattr(result[0], "_fields"):
# If it's a named tuple
columns = list(result[0]._fields)
else:
# Otherwise, assume it's a dictionary
columns = list(result[0].keys()) if isinstance(result[0], dict) else []
# Convert results to serializable format
data = []
for row in result:
row_dict = {}
if hasattr(row, "_asdict"):
# If it's a named tuple
row_dict = row._asdict()
elif isinstance(row, dict):
# If it's a dictionary
row_dict = row
else:
# If it's a list or tuple
row_dict = dict(zip(columns, row)) if columns else row
# Handle special types to make them JSON serializable
serialized_row = _serialize_row_data(row_dict)
data.append(serialized_row)
return {
"content": [
{
"type": "text",
"text": json.dumps({
"success": True,
"sql": sql,
"row_count": row_count,
"columns": columns,
"data": data[:max_rows], # Limit returned rows
"execution_time": execution_time,
"truncated": row_count > max_rows
}, ensure_ascii=False)
}
]
}
else:
# Handle other types of results
other_response = {
"success": True,
"sql": sql,
"result": str(result),
"execution_time": execution_time
}
other_response = _serialize_row_data(other_response)
return {
"content": [
{
"type": "text",
"text": json.dumps(other_response, ensure_ascii=False)
}
]
}
except Exception as db_error:
error_message = str(db_error)
# Try to get more detailed error information
error_details = {}
if "timeout" in error_message.lower():
error_details["type"] = "timeout"
error_details["suggestion"] = "Query timed out, please optimize SQL or increase timeout"
elif "syntax" in error_message.lower():
error_details["type"] = "syntax"
error_details["suggestion"] = "SQL syntax error, please check syntax"
elif "not found" in error_message.lower() or "doesn't exist" in error_message.lower():
error_details["type"] = "not_found"
error_details["suggestion"] = "Table or column not found, please check table and column names"
else:
error_details["type"] = "unknown"
error_details["suggestion"] = "Please check the SQL statement and try simplifying the query"
# Create error response
error_response = {
"success": False,
"error": error_message,
"error_details": error_details,
"sql": sql,
"db_name": db_name
}
# Ensure error response is also serializable
error_response = _serialize_row_data(error_response)
return {
"content": [
{
"type": "text",
"text": json.dumps(error_response, ensure_ascii=False)
}
]
}
except Exception as e:
logger.error(f"Failed to execute SQL query: {str(e)}")
logger.error(traceback.format_exc())
error_response = {
"success": False,
"error": str(e),
"message": "Error occurred while executing SQL query"
}
# Ensure error response is also serializable
error_response = _serialize_row_data(error_response)
return {
"content": [
{
"type": "text",
"text": json.dumps(error_response, ensure_ascii=False)
}
]
}
# Helper function
async def _check_sql_security(sql: str) -> Dict[str, Any]:
"""Check SQL security"""
# If environment variable is set to disable security check, return safe immediately
if not ENABLE_SQL_SECURITY_CHECK:
return {
"is_safe": True,
"security_issues": []
}
# Check if SQL contains dangerous operations
sql_lower = sql.lower()
# Check if it's a read-only query type
is_read_only = sql_lower.strip().startswith(("select ", "show ", "desc ", "describe ", "explain "))
# Define list of dangerous operations (checked for both read-only and non-read-only queries)
dangerous_operations = [
(r'\bdelete\b', "DELETE operation"),
(r'\bdrop\b', "DROP TABLE/DATABASE operation"),
(r'\btruncate\b', "TRUNCATE TABLE operation"),
(r'\bupdate\b', "UPDATE operation"),
(r'\binsert\b', "INSERT operation"),
(r'\balter\b', "ALTER TABLE structure operation"),
(r'\bcreate\b', "CREATE TABLE/DATABASE operation"),
(r'\bgrant\b', "GRANT operation"),
(r'\brevoke\b', "REVOKE permission operation"),
(r'\bexec\b', "EXECUTE stored procedure"),
(r'\bxp_', "Extended stored procedure, potential security risk"),
(r'\bshutdown\b', "SHUTDOWN database operation"),
(r'\binto\s+outfile\b', "Write to file operation"),
(r'\bload_file\b', "Load file operation")
]
# Dangerous operations checked only for non-read-only queries
non_readonly_operations = []
if not is_read_only:
non_readonly_operations = [
(r'--', "SQL comment, potential SQL injection"),
(r'/\*', "SQL block comment, potential SQL injection")
]
# Check if dangerous operations are included
security_issues = []
# Check dangerous operations applicable to all queries
for operation, description in dangerous_operations:
if re.search(operation, sql_lower):
# For specific keywords in read-only queries, differentiate if used as independent operations
if is_read_only and operation in [r'\bcreate\b', r'\bdrop\b', r'\bdelete\b', r'\binsert\b', r'\bupdate\b', r'\balter\b']:
# Check if used as DDL/DML keyword, e.g., CREATE TABLE, DROP DATABASE
pattern = operation + r'\s+(?:table|database|view|index|procedure|function|trigger|event)'
if re.search(pattern, sql_lower):
security_issues.append({
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
"description": description,
"severity": "High"
})
else:
security_issues.append({
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
"description": description,
"severity": "High"
})
# Check dangerous operations specific to non-read-only queries
for operation, description in non_readonly_operations:
if re.search(operation, sql_lower):
security_issues.append({
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
"description": description,
"severity": "Medium"
})
return {
"is_safe": len(security_issues) == 0,
"security_issues": security_issues
}
def _serialize_row_data(row_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert special types in row data (like date, time, Decimal) to JSON serializable format
Args:
row_data: Row data dictionary
Returns:
Dict[str, Any]: Processed serializable dictionary
"""
serialized_data = {}
for key, value in row_data.items():
if value is None:
serialized_data[key] = None
elif isinstance(value, (datetime.date, datetime.datetime)):
# Convert date and time types to ISO format string
serialized_data[key] = value.isoformat()
elif isinstance(value, Decimal):
# Convert Decimal type to float
serialized_data[key] = float(value)
elif isinstance(value, (list, tuple)):
# Recursively process elements in list or tuple
serialized_data[key] = [
_serialize_row_data(item) if isinstance(item, dict) else item
for item in value
]
elif isinstance(value, dict):
# Recursively process nested dictionaries
serialized_data[key] = _serialize_row_data(value)
else:
serialized_data[key] = value
return serialized_data