0.3.0 Release Version
This commit is contained in:
@@ -1 +1,13 @@
|
||||
# Mark directory as a package
|
||||
"""
|
||||
Doris MCP Server - A Model Context Protocol server for Apache Doris database integration.
|
||||
|
||||
This package provides:
|
||||
- MCP protocol implementation for Apache Doris
|
||||
- Multi-transport support (stdio, SSE, streamable HTTP)
|
||||
- Comprehensive database tools and resources
|
||||
- Enterprise-grade security and monitoring
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "Doris MCP Team"
|
||||
__description__ = "Apache Doris MCP Server Implementation"
|
||||
|
||||
8
doris_mcp_server/__main__.py
Normal file
8
doris_mcp_server/__main__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Entry point for running doris_mcp_server as a module
|
||||
"""
|
||||
|
||||
from .main import main_sync
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_sync()
|
||||
@@ -1,33 +0,0 @@
|
||||
# doris_mcp_server/config.py
|
||||
import os
|
||||
import logging
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Get Log Level from environment variable, default to 'info'
|
||||
LOG_LEVEL_STR = os.getenv('LOG_LEVEL', 'info').upper()
|
||||
|
||||
# Map string level to logging level constant
|
||||
LOG_LEVEL_MAP = {
|
||||
'DEBUG': logging.DEBUG,
|
||||
'INFO': logging.INFO,
|
||||
'WARNING': logging.WARNING,
|
||||
'ERROR': logging.ERROR,
|
||||
'CRITICAL': logging.CRITICAL
|
||||
}
|
||||
LOG_LEVEL = LOG_LEVEL_MAP.get(LOG_LEVEL_STR, logging.INFO)
|
||||
|
||||
# Function to load config (can be expanded later if needed)
|
||||
def load_config():
|
||||
"""Loads configuration settings."""
|
||||
# Currently, configuration is mainly handled by environment variables
|
||||
# and constants defined in this module.
|
||||
# This function can be used to perform additional setup if required.
|
||||
logging.getLogger(__name__).info("Configuration loaded (mainly from environment variables).")
|
||||
|
||||
# You can add other configuration constants here if needed
|
||||
# Example: DB_HOST = os.getenv("DB_HOST", "localhost")
|
||||
# But often it's better to access os.getenv directly where needed
|
||||
# or pass config dictionaries around.
|
||||
@@ -1,196 +1,515 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Apache Doris MCP Server Main Entry - Primarily handles SSE mode
|
||||
Apache Doris MCP Server - Enterprise Database Service Implementation
|
||||
|
||||
Stdio mode is handled by doris_mcp_server.mcp_core:run_stdio.
|
||||
Based on Apache Doris official MCP Server architecture design, providing complete MCP protocol support
|
||||
Supports independent encapsulation implementation of Resources, Tools, and Prompts
|
||||
Supports both stdio and streamable HTTP startup modes
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any
|
||||
import uvicorn
|
||||
from uvicorn import Config, Server
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from dotenv import load_dotenv
|
||||
from typing import Any
|
||||
|
||||
# Add project root to path
|
||||
PROJECT_ROOT = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
from mcp.server import Server
|
||||
from mcp.server.models import InitializationOptions
|
||||
|
||||
# SSE related imports
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from doris_mcp_server.sse_server import DorisMCPSseServer
|
||||
from doris_mcp_server.streamable_server import DorisMCPStreamableServer
|
||||
|
||||
# Stdio related imports (only needed for tools now, maybe move tool init?)
|
||||
# from mcp.server.stdio import stdio_server -> No longer used here
|
||||
|
||||
# Config and Tool Initializer
|
||||
from doris_mcp_server.config import load_config # LOG_LEVEL might not be needed here directly
|
||||
from doris_mcp_server.tools.tool_initializer import register_mcp_tools
|
||||
|
||||
# Load environment variables (load early for all modes)
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger("doris-mcp-main") # Changed logger name slightly
|
||||
|
||||
# --- Configuration Loading and Logging Setup ---
|
||||
load_config() # Loads .env
|
||||
|
||||
# --- Create FastAPI App (Global Scope for SSE Mode) ---
|
||||
# This 'app' object is targeted by 'mcp run doris_mcp_server/main.py:app --transport sse'
|
||||
# And used when running directly with --sse
|
||||
app = FastAPI(
|
||||
title="Doris MCP Server (SSE Mode)",
|
||||
# Lifespan will be added in start_sse_server
|
||||
from mcp.types import (
|
||||
Prompt,
|
||||
Resource,
|
||||
TextContent,
|
||||
Tool,
|
||||
)
|
||||
|
||||
# --- Removed StdioServerWrapper ---
|
||||
from .tools.tools_manager import DorisToolsManager
|
||||
from .tools.prompts_manager import DorisPromptsManager
|
||||
from .tools.resources_manager import DorisResourcesManager
|
||||
from .utils.config import DorisConfig
|
||||
from .utils.db import DorisConnectionManager
|
||||
from .utils.security import DorisSecurityManager
|
||||
|
||||
# --- Command Line Argument Parsing ---
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Apache Doris MCP Server (SSE Mode Entry)")
|
||||
# Only keep SSE related args here
|
||||
parser.add_argument('--sse', action='store_true', help='Start SSE Web server mode (required)')
|
||||
parser.add_argument('--host', type=str, default=os.getenv('SERVER_HOST', '0.0.0.0'), help='Host address')
|
||||
parser.add_argument('--port', type=int, default=int(os.getenv('SERVER_PORT', os.getenv('MCP_PORT', '3000'))), help='Port number')
|
||||
parser.add_argument('--debug', action='store_true', help='Enable debug mode')
|
||||
parser.add_argument('--reload', action='store_true', help='Enable auto-reload')
|
||||
return parser.parse_args()
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- SSE Mode Specific Code ---
|
||||
@dataclass
|
||||
class AppContext:
|
||||
config: Dict[str, Any]
|
||||
|
||||
@asynccontextmanager
|
||||
async def app_lifespan(app_instance: FastAPI) -> AsyncIterator[None]:
|
||||
logger.info("SSE application lifecycle start...")
|
||||
config = {
|
||||
# Simplified config - maybe get from elsewhere?
|
||||
"db_host": os.getenv("DB_HOST", "localhost"),
|
||||
"db_port": int(os.getenv("DB_PORT", "9030")),
|
||||
"db_user": os.getenv("DB_USER", "root"),
|
||||
"db_password": os.getenv("DB_PASSWORD", ""),
|
||||
"db_database": os.getenv("DB_DATABASE", "test"),
|
||||
}
|
||||
app_instance.state.config = config
|
||||
try:
|
||||
# Yield None implicitly or explicitly None
|
||||
yield
|
||||
finally:
|
||||
logger.info("Cleaning up SSE application resources...")
|
||||
class DorisServer:
|
||||
"""Apache Doris MCP Server main class"""
|
||||
|
||||
async def start_sse_server(args):
|
||||
"""Start SSE Web server mode (Configures the global 'app')"""
|
||||
logger.info("Starting SSE Web server mode...")
|
||||
global app
|
||||
def __init__(self, config: DorisConfig):
|
||||
self.config = config
|
||||
self.server = Server("doris-mcp-server")
|
||||
|
||||
# --- Initialize MCP and Tools for SSE ---
|
||||
# Create a *separate* MCP instance for SSE mode
|
||||
sse_mcp = FastMCP(
|
||||
name="doris-mcp-sse",
|
||||
description="Apache Doris MCP Server (SSE)",
|
||||
lifespan=None, # Managed by FastAPI
|
||||
dependencies=["fastapi", "uvicorn", "openai", "sse_starlette"]
|
||||
)
|
||||
logger.info("Registering MCP tools for SSE mode...")
|
||||
await register_mcp_tools(sse_mcp) # Register tools for the SSE instance
|
||||
logger.info("MCP tools registered for SSE.")
|
||||
# Initialize security manager
|
||||
self.security_manager = DorisSecurityManager(config)
|
||||
|
||||
# --- Configure Lifespan and CORS for the global app ---
|
||||
app.router.lifespan_context = app_lifespan
|
||||
origins = os.getenv("ALLOWED_ORIGINS", "*").split(",")
|
||||
allow_credentials = os.getenv("MCP_ALLOW_CREDENTIALS", "false").lower() == "true"
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=allow_credentials,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
expose_headers=["Mcp-Session-Id"],
|
||||
)
|
||||
# Initialize connection manager, pass in security manager
|
||||
self.connection_manager = DorisConnectionManager(config, self.security_manager)
|
||||
|
||||
# --- Initialize Handlers and Register Routes (Pass sse_mcp instance) ---
|
||||
logger.info("Initializing SSE server handlers and registering routes...")
|
||||
sse_server_handler = DorisMCPSseServer(sse_mcp, app)
|
||||
streamable_server_handler = DorisMCPStreamableServer(sse_mcp, app)
|
||||
logger.info("SSE Server handlers initialized and routes registered.")
|
||||
# Initialize independent managers
|
||||
self.resources_manager = DorisResourcesManager(self.connection_manager)
|
||||
self.tools_manager = DorisToolsManager(self.connection_manager)
|
||||
self.prompts_manager = DorisPromptsManager(self.connection_manager)
|
||||
|
||||
# --- Print Configuration and Endpoints ---
|
||||
print("--- SSE Mode Configuration ---")
|
||||
print(f"Server Host: {args.host}")
|
||||
print(f"Server Port: {args.port}")
|
||||
print(f"Allowed Origins: {origins}")
|
||||
print(f"Allow Credentials: {allow_credentials}")
|
||||
print(f"Log Level: {os.getenv('LOG_LEVEL', 'info')}")
|
||||
print(f"Debug Mode: {args.debug}")
|
||||
print(f"Reload Mode: {args.reload}")
|
||||
print(f"DB Host: {os.getenv('DB_HOST')}")
|
||||
print(f"DB Port: {os.getenv('DB_PORT')}")
|
||||
print(f"DB User: {os.getenv('DB_USER')}")
|
||||
print(f"DB Database: {os.getenv('DB_DATABASE')}")
|
||||
print(f"Force Refresh Metadata: {os.getenv('FORCE_REFRESH_METADATA', 'false')}")
|
||||
print("------------------------------")
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
print(f"Service running at: {base_url}")
|
||||
print(f" Health Check: GET {base_url}/health")
|
||||
print(f" Status Check: GET {base_url}/status")
|
||||
print(f" SSE Init: GET {base_url}/sse")
|
||||
print(f" SSE/Legacy Messages: POST {base_url}/mcp/messages")
|
||||
print(f" Streamable HTTP: GET/POST/DELETE/OPTIONS {base_url}/mcp")
|
||||
print("------------------------------")
|
||||
print("Use Ctrl+C to stop the service")
|
||||
self.logger = logging.getLogger(f"{__name__}.DorisServer")
|
||||
self._setup_handlers()
|
||||
|
||||
# --- Start Uvicorn Server ---
|
||||
config = Config(
|
||||
app=app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="debug" if args.debug else "info",
|
||||
reload=args.reload
|
||||
)
|
||||
server = Server(config=config)
|
||||
await server.serve()
|
||||
def _setup_handlers(self):
|
||||
"""Setup MCP protocol handlers"""
|
||||
|
||||
# --- Main Execution Logic (Simplified) ---
|
||||
@self.server.list_resources()
|
||||
async def handle_list_resources() -> list[Resource]:
|
||||
"""Handle resource list request"""
|
||||
try:
|
||||
self.logger.info("Handling resource list request")
|
||||
resources = await self.resources_manager.list_resources()
|
||||
self.logger.info(f"Returning {len(resources)} resources")
|
||||
return resources
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to handle resource list request: {e}")
|
||||
return []
|
||||
|
||||
def run_main_sync():
|
||||
"""Synchronous wrapper, primarily for SSE mode now."""
|
||||
sync_logger = logging.getLogger("run_main_sync")
|
||||
sync_logger.info("Entering run_main_sync (SSE focus)...")
|
||||
print("DEBUG: Entering run_main_sync (SSE focus)...", file=sys.stderr, flush=True)
|
||||
args = parse_args()
|
||||
@self.server.read_resource()
|
||||
async def handle_read_resource(uri: str) -> str:
|
||||
"""Handle resource read request"""
|
||||
try:
|
||||
self.logger.info(f"Handling resource read request: {uri}")
|
||||
content = await self.resources_manager.read_resource(uri)
|
||||
return content
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to handle resource read request: {e}")
|
||||
return json.dumps(
|
||||
{"error": f"Failed to read resource: {str(e)}", "uri": uri},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
@self.server.list_tools()
|
||||
async def handle_list_tools() -> list[Tool]:
|
||||
"""Handle tool list request"""
|
||||
try:
|
||||
self.logger.info("Handling tool list request")
|
||||
tools = await self.tools_manager.list_tools()
|
||||
self.logger.info(f"Returning {len(tools)} tools")
|
||||
return tools
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to handle tool list request: {e}")
|
||||
return []
|
||||
|
||||
@self.server.call_tool()
|
||||
async def handle_call_tool(
|
||||
name: str, arguments: dict[str, Any]
|
||||
) -> list[TextContent]:
|
||||
"""Handle tool call request"""
|
||||
try:
|
||||
self.logger.info(f"Handling tool call request: {name}")
|
||||
result = await self.tools_manager.call_tool(name, arguments)
|
||||
|
||||
return [TextContent(type="text", text=result)]
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to handle tool call request: {e}")
|
||||
error_result = json.dumps(
|
||||
{
|
||||
"error": f"Tool call failed: {str(e)}",
|
||||
"tool_name": name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
return [TextContent(type="text", text=error_result)]
|
||||
|
||||
@self.server.list_prompts()
|
||||
async def handle_list_prompts() -> list[Prompt]:
|
||||
"""Handle prompt list request"""
|
||||
try:
|
||||
self.logger.info("Handling prompt list request")
|
||||
prompts = await self.prompts_manager.list_prompts()
|
||||
self.logger.info(f"Returning {len(prompts)} prompts")
|
||||
return prompts
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to handle prompt list request: {e}")
|
||||
return []
|
||||
|
||||
@self.server.get_prompt()
|
||||
async def handle_get_prompt(name: str, arguments: dict[str, Any]) -> str:
|
||||
"""Handle prompt get request"""
|
||||
try:
|
||||
self.logger.info(f"Handling prompt get request: {name}")
|
||||
result = await self.prompts_manager.get_prompt(name, arguments)
|
||||
return result
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to handle prompt get request: {e}")
|
||||
error_result = json.dumps(
|
||||
{
|
||||
"error": f"Failed to get prompt: {str(e)}",
|
||||
"prompt_name": name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
return error_result
|
||||
|
||||
async def start_stdio(self):
|
||||
"""Start stdio transport mode"""
|
||||
self.logger.info("Starting Doris MCP Server (stdio mode)")
|
||||
|
||||
if args.sse:
|
||||
try:
|
||||
# Run the async SSE server setup and Uvicorn loop
|
||||
asyncio.run(start_sse_server(args))
|
||||
sync_logger.info("asyncio.run(start_sse_server) completed.")
|
||||
print("DEBUG: asyncio.run(start_sse_server) completed.", file=sys.stderr, flush=True)
|
||||
except KeyboardInterrupt:
|
||||
sync_logger.info("SSE server stopped by KeyboardInterrupt.")
|
||||
# Ensure connection manager is initialized
|
||||
await self.connection_manager.initialize()
|
||||
self.logger.info("Connection manager initialization completed")
|
||||
|
||||
# Start stdio server - using simpler approach
|
||||
from mcp.server.stdio import stdio_server
|
||||
|
||||
self.logger.info("Creating stdio_server transport...")
|
||||
|
||||
# Try different startup approaches
|
||||
try:
|
||||
async with stdio_server() as streams:
|
||||
read_stream, write_stream = streams
|
||||
self.logger.info("stdio_server streams created successfully")
|
||||
|
||||
# Create initialization options
|
||||
# MCP 1.8.0 requires parameters for get_capabilities
|
||||
from mcp.server.lowlevel.server import NotificationOptions
|
||||
|
||||
capabilities = self.server.get_capabilities(
|
||||
notification_options=NotificationOptions(
|
||||
prompts_changed=True,
|
||||
resources_changed=True,
|
||||
tools_changed=True
|
||||
),
|
||||
experimental_capabilities={}
|
||||
)
|
||||
|
||||
init_options = InitializationOptions(
|
||||
server_name="doris-mcp-server",
|
||||
server_version="1.0.0",
|
||||
capabilities=capabilities,
|
||||
)
|
||||
self.logger.info("Initialization options created successfully")
|
||||
|
||||
# Run server
|
||||
self.logger.info("Starting to run MCP server...")
|
||||
await self.server.run(read_stream, write_stream, init_options)
|
||||
|
||||
except Exception as inner_e:
|
||||
self.logger.error(f"stdio_server internal error: {inner_e}")
|
||||
self.logger.error(f"Error type: {type(inner_e)}")
|
||||
|
||||
# Try to get more error information
|
||||
import traceback
|
||||
self.logger.error("Complete error stack:")
|
||||
self.logger.error(traceback.format_exc())
|
||||
|
||||
# If it's ExceptionGroup, try to parse
|
||||
if hasattr(inner_e, 'exceptions'):
|
||||
self.logger.error(f"ExceptionGroup contains {len(inner_e.exceptions)} exceptions:")
|
||||
for i, exc in enumerate(inner_e.exceptions):
|
||||
self.logger.error(f" Exception {i+1}: {type(exc).__name__}: {exc}")
|
||||
|
||||
raise inner_e
|
||||
|
||||
except Exception as e:
|
||||
sync_logger.critical(f"Error during asyncio.run(start_sse_server): {e}", exc_info=True)
|
||||
print(f"DEBUG: Error during asyncio.run(start_sse_server): {e}", file=sys.stderr, flush=True)
|
||||
self.logger.error(f"stdio server startup failed: {e}")
|
||||
self.logger.error(f"Error type: {type(e)}")
|
||||
raise
|
||||
else:
|
||||
# If run without --sse, print help/error
|
||||
message = "Error: This entry point requires --sse flag. For stdio mode, use 'uv run mcp-doris' or the appropriate command for your stdio setup."
|
||||
sync_logger.error(message)
|
||||
print(message, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
|
||||
async def start_http(self, host: str = "localhost", port: int = 3000):
|
||||
"""Start Streamable HTTP transport mode"""
|
||||
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}")
|
||||
|
||||
try:
|
||||
# Ensure connection manager is initialized
|
||||
await self.connection_manager.initialize()
|
||||
|
||||
# Use Starlette and StreamableHTTPSessionManager according to official example
|
||||
import uvicorn
|
||||
import contextlib
|
||||
from collections.abc import AsyncIterator
|
||||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Mount, Route
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
# Create session manager
|
||||
session_manager = StreamableHTTPSessionManager(
|
||||
app=self.server,
|
||||
json_response=True, # Enable JSON response
|
||||
stateless=False # Maintain session state
|
||||
)
|
||||
|
||||
self.logger.info(f"StreamableHTTP session manager created, will start at http://{host}:{port}")
|
||||
|
||||
# Health check endpoint
|
||||
async def health_check(request):
|
||||
return JSONResponse({"status": "healthy", "service": "doris-mcp-server"})
|
||||
|
||||
# Lifecycle manager - simplified since we manage session_manager externally
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app: Starlette) -> AsyncIterator[None]:
|
||||
"""Context manager for managing application lifecycle"""
|
||||
self.logger.info("Application started!")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.logger.info("Application is shutting down...")
|
||||
|
||||
# Create ASGI application - use direct session manager as ASGI app
|
||||
starlette_app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route("/health", health_check, methods=["GET"]),
|
||||
],
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Custom ASGI app that handles both /mcp and /mcp/ without redirects
|
||||
async def mcp_app(scope, receive, send):
|
||||
# Handle lifespan events
|
||||
if scope["type"] == "lifespan":
|
||||
await starlette_app(scope, receive, send)
|
||||
return
|
||||
|
||||
# Handle HTTP requests
|
||||
if scope["type"] == "http":
|
||||
path = scope.get("path", "")
|
||||
self.logger.info(f"Received request for path: {path}")
|
||||
|
||||
try:
|
||||
# Handle health check
|
||||
if path.startswith("/health"):
|
||||
await starlette_app(scope, receive, send)
|
||||
return
|
||||
|
||||
# Handle MCP requests - both /mcp and /mcp/ go to session manager
|
||||
if path == "/mcp" or path.startswith("/mcp/"):
|
||||
self.logger.info(f"Handling MCP request for path: {path}")
|
||||
# Log request details for debugging
|
||||
method = scope.get("method", "UNKNOWN")
|
||||
headers = dict(scope.get("headers", []))
|
||||
self.logger.info(f"MCP Request - Method: {method}")
|
||||
self.logger.info(f"MCP Request - Headers: {headers}")
|
||||
|
||||
# Handle Dify compatibility for GET requests
|
||||
if method == "GET":
|
||||
accept_header = headers.get(b'accept', b'').decode('utf-8')
|
||||
user_agent = headers.get(b'user-agent', b'').decode('utf-8')
|
||||
|
||||
|
||||
|
||||
# For other GET requests, try to add application/json to Accept header
|
||||
if 'text/event-stream' in accept_header and 'application/json' not in accept_header:
|
||||
self.logger.info("Adding application/json to Accept header for GET request")
|
||||
# Modify headers to include both content types
|
||||
new_headers = []
|
||||
for name, value in scope.get("headers", []):
|
||||
if name == b'accept':
|
||||
# Add application/json to the accept header
|
||||
new_value = value.decode('utf-8') + ', application/json'
|
||||
new_headers.append((name, new_value.encode('utf-8')))
|
||||
else:
|
||||
new_headers.append((name, value))
|
||||
# Update scope with modified headers
|
||||
scope = dict(scope)
|
||||
scope["headers"] = new_headers
|
||||
self.logger.info(f"Modified Accept header to: {new_value}")
|
||||
|
||||
await session_manager.handle_request(scope, receive, send)
|
||||
return
|
||||
|
||||
# 404 for other paths
|
||||
self.logger.info(f"Path not found: {path}")
|
||||
response = Response("Not Found", status_code=404)
|
||||
await response(scope, receive, send)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error handling request for {path}: {e}")
|
||||
import traceback
|
||||
self.logger.error(traceback.format_exc())
|
||||
response = Response("Internal Server Error", status_code=500)
|
||||
await response(scope, receive, send)
|
||||
else:
|
||||
# For other scope types, just return
|
||||
self.logger.warning(f"Unsupported scope type: {scope['type']}")
|
||||
return
|
||||
|
||||
# Start uvicorn server with session manager lifecycle
|
||||
config = uvicorn.Config(
|
||||
app=mcp_app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="info"
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# Run session manager and server together
|
||||
async with session_manager.run():
|
||||
self.logger.info("Session manager started, now starting HTTP server")
|
||||
await server.serve()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Streamable HTTP server startup failed: {e}")
|
||||
import traceback
|
||||
self.logger.error("Complete error stack:")
|
||||
self.logger.error(traceback.format_exc())
|
||||
|
||||
# If it's ExceptionGroup, try to parse
|
||||
if hasattr(e, 'exceptions'):
|
||||
self.logger.error(f"ExceptionGroup contains {len(e.exceptions)} exceptions:")
|
||||
for i, exc in enumerate(e.exceptions):
|
||||
self.logger.error(f" Exception {i+1}: {type(exc).__name__}: {exc}")
|
||||
raise
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown server"""
|
||||
self.logger.info("Shutting down Doris MCP Server")
|
||||
try:
|
||||
await self.connection_manager.close()
|
||||
self.logger.info("Doris MCP Server has been shut down")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error occurred while shutting down server: {e}")
|
||||
|
||||
|
||||
def create_arg_parser():
|
||||
"""Create command line argument parser"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Apache Doris MCP Server - Enterprise Database Service",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Transport Modes:
|
||||
stdio - Standard input/output (for local process communication)
|
||||
http - Streamable HTTP mode (MCP 2025-03-26 protocol)
|
||||
|
||||
Examples:
|
||||
python -m doris_mcp_server --transport stdio
|
||||
python -m doris_mcp_server --transport http --host 0.0.0.0 --port 3000
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--transport",
|
||||
type=str,
|
||||
choices=["stdio", "http"],
|
||||
default="stdio",
|
||||
help="Transport protocol type: stdio (local), http (Streamable HTTP)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Host address for HTTP mode (default: localhost)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=3000, help="Port number for HTTP mode (default: 3000)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--db-host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Doris database host address (default: localhost)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--db-port", type=int, default=9030, help="Doris database port number (default: 9030)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--db-user", type=str, default="root", help="Doris database username (default: root)"
|
||||
)
|
||||
|
||||
parser.add_argument("--db-password", type=str, default="", help="Doris database password")
|
||||
|
||||
parser.add_argument(
|
||||
"--db-database",
|
||||
type=str,
|
||||
default="information_schema",
|
||||
help="Doris database name (default: information_schema)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
default="INFO",
|
||||
help="Log level (default: INFO)",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function"""
|
||||
parser = create_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set log level
|
||||
logging.getLogger().setLevel(getattr(logging, args.log_level))
|
||||
|
||||
# Create configuration - priority: command line arguments > .env file > default values
|
||||
config = DorisConfig.from_env() # First load from .env file and environment variables
|
||||
|
||||
# Command line arguments override configuration (if provided)
|
||||
if args.db_host != "localhost": # If not default value, use command line argument
|
||||
config.database.host = args.db_host
|
||||
if args.db_port != 9030:
|
||||
config.database.port = args.db_port
|
||||
if args.db_user != "root":
|
||||
config.database.user = args.db_user
|
||||
if args.db_password: # Use password if provided
|
||||
config.database.password = args.db_password
|
||||
if args.db_database != "information_schema":
|
||||
config.database.database = args.db_database
|
||||
if args.log_level != "INFO":
|
||||
config.logging.level = args.log_level
|
||||
|
||||
# Create server instance
|
||||
server = DorisServer(config)
|
||||
|
||||
try:
|
||||
if args.transport == "stdio":
|
||||
await server.start_stdio()
|
||||
elif args.transport == "http":
|
||||
await server.start_http(args.host, args.port)
|
||||
else:
|
||||
logger.error(f"Unsupported transport protocol: {args.transport}")
|
||||
await server.shutdown()
|
||||
return 1
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received interrupt signal, shutting down server...")
|
||||
except Exception as e:
|
||||
logger.error(f"Server runtime error: {e}")
|
||||
# Clean up resources even in case of exception
|
||||
try:
|
||||
await server.shutdown()
|
||||
except Exception as shutdown_error:
|
||||
logger.error(f"Error occurred while shutting down server: {shutdown_error}")
|
||||
return 1
|
||||
finally:
|
||||
# Cleanup in case of normal shutdown
|
||||
try:
|
||||
await server.shutdown()
|
||||
except Exception as shutdown_error:
|
||||
logger.error(f"Error occurred while shutting down server: {shutdown_error}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def main_sync():
|
||||
"""Synchronous main function for entry point"""
|
||||
exit_code = asyncio.run(main())
|
||||
exit(exit_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_main_sync()
|
||||
main_sync()
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Core MCP instance and startup logic for stdio mode.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
|
||||
# Import necessary components from mcp and our project
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
logger = logging.getLogger("doris-mcp-core")
|
||||
|
||||
# --- Global MCP Instance for Stdio ---
|
||||
# Create the instance when the module is imported.
|
||||
# Tools will be registered synchronously(?) before running.
|
||||
stdio_mcp = FastMCP(
|
||||
name="doris-mcp-stdio-core",
|
||||
description="Apache Doris MCP Server (stdio via core)",
|
||||
)
|
||||
|
||||
# --- Removed async setup functions ---
|
||||
def run_stdio():
|
||||
"""
|
||||
Synchronous entry point for running the stdio server.
|
||||
Mimics the mcp-doris example by calling .run() on the instance.
|
||||
Handles tool registration beforehand.
|
||||
"""
|
||||
logger.info("Executing run_stdio (synchronous entry point)...")
|
||||
|
||||
# --- Run the stdio server using the instance's run() method ---
|
||||
logger.info("Calling stdio_mcp.run()...")
|
||||
try:
|
||||
# Assuming stdio_mcp has a synchronous run() method for stdio
|
||||
stdio_mcp.run()
|
||||
logger.info("stdio_mcp.run() completed.")
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Stdio server stopped by KeyboardInterrupt.")
|
||||
except AttributeError:
|
||||
logger.critical("Error: stdio_mcp object does not have a '.run()' method suitable for stdio.", exc_info=False)
|
||||
print("ERROR: stdio_mcp object does not have a '.run()' method.", file=sys.stderr, flush=True)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.critical(f"run_stdio encountered an error during stdio_mcp.run(): {e}", exc_info=True)
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Register Tool: Execute SQL Query
|
||||
@stdio_mcp.tool("exec_query", description="""[Function Description]: Execute SQL query and return result command with catalog federation support.\n
|
||||
[Parameter Content]:\n
|
||||
- sql (string) [Required] - SQL statement to execute. MUST use three-part naming for all table references: 'catalog_name.db_name.table_name'. For internal tables use 'internal.db_name.table_name', for external tables use 'catalog_name.db_name.table_name'\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Reference catalog name for context, defaults to current catalog\n
|
||||
- max_rows (integer) [Optional] - Maximum number of rows to return, default 100\n
|
||||
- timeout (integer) [Optional] - Query timeout in seconds, default 30\n""")
|
||||
async def exec_query_tool(sql: str, db_name: str = None, catalog_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""Wrapper: Execute SQL query and return result command"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_exec_query
|
||||
return await mcp_doris_exec_query(sql=sql, db_name=db_name, catalog_name=catalog_name, max_rows=max_rows, timeout=timeout)
|
||||
|
||||
# Register Tool: Get Table Schema
|
||||
@stdio_mcp.tool("get_table_schema", description="""[Function Description]: Get detailed structure information of the specified table (columns, types, comments, etc.).\n
|
||||
[Parameter Content]:\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_table_schema_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table schema"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_schema
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_schema(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Database Table List
|
||||
@stdio_mcp.tool("get_db_table_list", description="""[Function Description]: Get a list of all table names in the specified database.\n
|
||||
[Parameter Content]:\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_db_table_list_tool(db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get database table list"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_db_table_list
|
||||
return await mcp_doris_get_db_table_list(db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Database List
|
||||
@stdio_mcp.tool("get_db_list", description="""[Function Description]: Get a list of all database names on the server.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_db_list_tool(catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get database list"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_db_list
|
||||
return await mcp_doris_get_db_list(catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Table Comment
|
||||
@stdio_mcp.tool("get_table_comment", description="""[Function Description]: Get the comment information for the specified table.\n
|
||||
[Parameter Content]:\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_table_comment_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table comment"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_comment
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Table Column Comments
|
||||
@stdio_mcp.tool("get_table_column_comments", description="""[Function Description]: Get comment information for all columns in the specified table.\n
|
||||
[Parameter Content]:\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_table_column_comments_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table column comments"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_column_comments
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Table Indexes
|
||||
@stdio_mcp.tool("get_table_indexes", description="""[Function Description]: Get index information for the specified table.
|
||||
[Parameter Content]:\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_table_indexes_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table indexes"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_indexes
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Recent Audit Logs
|
||||
@stdio_mcp.tool("get_recent_audit_logs", description="""[Function Description]: Get audit log records for a recent period.\n
|
||||
[Parameter Content]:\n
|
||||
- days (integer) [Optional] - Number of recent days of logs to retrieve, default is 7\n
|
||||
- limit (integer) [Optional] - Maximum number of records to return, default is 100\n""")
|
||||
async def get_recent_audit_logs_tool(days: int = 7, limit: int = 100) -> Dict[str, Any]:
|
||||
"""Wrapper: Get recent audit logs"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_recent_audit_logs
|
||||
try:
|
||||
days = int(days)
|
||||
limit = int(limit)
|
||||
except (ValueError, TypeError):
|
||||
return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "days and limit parameters must be integers"})}]}
|
||||
return await mcp_doris_get_recent_audit_logs(days=days, limit=limit)
|
||||
|
||||
# Register Tool: Get Catalog List
|
||||
@stdio_mcp.tool("get_catalog_list", description="""[Function Description]: Get a list of all catalog names on the server.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n""")
|
||||
async def get_catalog_list_tool() -> Dict[str, Any]:
|
||||
"""Wrapper: Get catalog list"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_catalog_list
|
||||
return await mcp_doris_get_catalog_list()
|
||||
|
||||
# --- Register Tools ---
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,912 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Doris MCP Streamable HTTP Server Implementation
|
||||
|
||||
Implements the MCP 2025-03-26 Streamable HTTP specification.
|
||||
Uses a unified /mcp endpoint for GET, POST, DELETE, OPTIONS.
|
||||
Manages sessions using Mcp-Session-Id header.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional, Dict, List
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
# Use a distinct logger name
|
||||
logger = logging.getLogger("doris-mcp-streamable")
|
||||
|
||||
# Special marker for closing streams
|
||||
STREAM_END_MARKER = "__MCP_STREAM_END__"
|
||||
|
||||
class DorisMCPStreamableServer:
|
||||
"""Doris MCP Streamable HTTP Server"""
|
||||
|
||||
def __init__(self, mcp_server, app: FastAPI):
|
||||
"""
|
||||
Initializes the Doris MCP Streamable HTTP server.
|
||||
|
||||
Args:
|
||||
mcp_server: The shared FastMCP server instance.
|
||||
app: The main FastAPI application instance.
|
||||
"""
|
||||
self.mcp_server = mcp_server
|
||||
self.app = app # We'll add routes to this app
|
||||
|
||||
# Note: CORS middleware should be added only once in main.py usually.
|
||||
# If added here, ensure it doesn't conflict or duplicate.
|
||||
# For separation, we might let main.py handle CORS entirely.
|
||||
|
||||
# Client session management for Streamable HTTP clients
|
||||
# key: session_id (from Mcp-Session-Id header)
|
||||
# value: {
|
||||
# "created_at": timestamp,
|
||||
# "last_active": timestamp,
|
||||
# "request_queues": { request_id: asyncio.Queue }, # For POST /mcp request streams
|
||||
# "general_sse_queues": List[asyncio.Queue] # For GET /mcp server push streams
|
||||
# }
|
||||
self.client_sessions: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Setup the unified MCP endpoint
|
||||
self._setup_streamable_http_routes()
|
||||
|
||||
# Register session cleanup task if this instance manages lifespan independently
|
||||
# Usually, startup events are tied to the main app lifespan managed in main.py
|
||||
# We might not need @app.on_event("startup") here if main.py handles it.
|
||||
# Let's assume main.py handles the cleanup task initiation.
|
||||
|
||||
def _setup_streamable_http_routes(self):
|
||||
"""Sets up the unified /mcp endpoint for Streamable HTTP.
|
||||
Uses a distinct tag for API docs.
|
||||
"""
|
||||
|
||||
@self.app.api_route("/mcp", methods=["GET", "POST", "DELETE", "OPTIONS"], tags=["Streamable HTTP"])
|
||||
async def mcp_endpoint_handler(request: Request):
|
||||
"""Handles GET, POST, DELETE, OPTIONS for the /mcp endpoint."""
|
||||
|
||||
# 1. Handle OPTIONS (CORS preflight)
|
||||
if request.method == "OPTIONS":
|
||||
# Assuming CORS headers are handled by middleware in main.py
|
||||
# If not, provide necessary headers here.
|
||||
# This minimal response might suffice if middleware handles the rest
|
||||
logger.debug("Handling OPTIONS request for /mcp")
|
||||
# Return basic OK allowing exposed headers if middleware handles the rest
|
||||
return JSONResponse({}, headers={"Access-Control-Expose-Headers": "Mcp-Session-Id"})
|
||||
|
||||
# Session ID from header is required for most methods
|
||||
session_id = request.headers.get("Mcp-Session-Id")
|
||||
|
||||
# 2. Handle DELETE (Terminate Session)
|
||||
if request.method == "DELETE":
|
||||
if not session_id:
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32600, "message": "Mcp-Session-Id header is required for DELETE"}}, status_code=400)
|
||||
|
||||
logger.info(f"Handling DELETE request for session [Session ID: {session_id}]")
|
||||
session_data = self.client_sessions.pop(session_id, None)
|
||||
if session_data:
|
||||
await self._cleanup_session_resources(session_id, session_data)
|
||||
return JSONResponse({}, status_code=204) # No Content
|
||||
else:
|
||||
logger.warning(f"Attempted DELETE on non-existent session: {session_id}")
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32001, "message": "Session not found"}}, status_code=404)
|
||||
|
||||
# 3. Handle GET (Server Push SSE Stream)
|
||||
if request.method == "GET":
|
||||
if not session_id:
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32000, "message": "Mcp-Session-Id header is required for GET streams"}}, status_code=400)
|
||||
if session_id not in self.client_sessions:
|
||||
# Note: Unlike legacy SSE, GET here assumes session exists.
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32001, "message": "Session not found. Initialize first."}}, status_code=404)
|
||||
|
||||
accept_header = request.headers.get("Accept", "")
|
||||
if "text/event-stream" not in accept_header:
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32600, "message": "Accept header must include text/event-stream for GET"}}, status_code=406)
|
||||
|
||||
# TODO: Handle Last-Event-ID for stream recovery?
|
||||
|
||||
logger.info(f"Handling GET request, establishing server push SSE stream [Session ID: {session_id}]")
|
||||
|
||||
push_queue = asyncio.Queue()
|
||||
if self.client_sessions[session_id].get("general_sse_queues") is None:
|
||||
self.client_sessions[session_id]["general_sse_queues"] = []
|
||||
self.client_sessions[session_id]["general_sse_queues"].append(push_queue)
|
||||
self.client_sessions[session_id]["last_active"] = time.time()
|
||||
|
||||
return EventSourceResponse(self._create_general_sse_generator(session_id, push_queue), media_type="text/event-stream")
|
||||
|
||||
# 4. Handle POST (Client Messages & Initialize)
|
||||
if request.method == "POST":
|
||||
accept_header = request.headers.get("Accept", "")
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
body = {}
|
||||
try:
|
||||
if "application/json" not in content_type:
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32700, "message": "Content-Type must be application/json"}}, status_code=415)
|
||||
body = await request.json()
|
||||
if isinstance(body, list): return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32600, "message": "Batch requests not supported"}}, status_code=400)
|
||||
if not isinstance(body, dict): return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32700, "message": "Invalid JSON received"}}, status_code=400)
|
||||
|
||||
method = body.get("method")
|
||||
message_id = body.get("id") # Can be None for notifications
|
||||
|
||||
# Handle Initialize request (does not require Mcp-Session-Id header)
|
||||
if method == "initialize":
|
||||
if "application/json" not in accept_header:
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32600, "message": "Accept header must include application/json for initialize"}}, status_code=406)
|
||||
return await self._handle_initialize(request, body, message_id)
|
||||
|
||||
# Handle other POST requests (require Mcp-Session-Id)
|
||||
else:
|
||||
if not session_id:
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32000, "message": "Mcp-Session-Id header is required for this request"}}, status_code=400)
|
||||
if session_id not in self.client_sessions:
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32001, "message": "Session not found"}}, status_code=404)
|
||||
# Check Accept header for non-initialize POST
|
||||
if not ("application/json" in accept_header and "text/event-stream" in accept_header):
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32600, "message": "Accept header must include application/json and text/event-stream for POST"}}, status_code=406)
|
||||
|
||||
self.client_sessions[session_id]["last_active"] = time.time()
|
||||
return await self._handle_client_post(request, body, session_id, message_id)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error - Invalid JSON received"}}, status_code=400)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error handling POST /mcp: {str(e)}", exc_info=True)
|
||||
error_id = body.get("id") if isinstance(body, dict) else None
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": error_id, "error": {"code": -32000, "message": "Internal server error"}}, status_code=500)
|
||||
|
||||
# Fallback for other methods like PUT, PATCH etc.
|
||||
return JSONResponse({"error": "Method Not Allowed"}, status_code=405)
|
||||
|
||||
async def _handle_initialize(self, request: Request, body: Dict, message_id: Any):
|
||||
"""Handles the 'initialize' method call via POST /mcp."""
|
||||
logger.info("Handling Streamable HTTP initialize request")
|
||||
# Optional: Validate params in body if needed
|
||||
# params = body.get("params", {})
|
||||
|
||||
new_session_id = str(uuid.uuid4())
|
||||
logger.info(f"Created new Streamable HTTP session [Session ID: {new_session_id}]")
|
||||
|
||||
self.client_sessions[new_session_id] = {
|
||||
"created_at": time.time(),
|
||||
"last_active": time.time(),
|
||||
# No transport_type needed here as this class *is* the streamable server
|
||||
"request_queues": {}, # Initialize request queues dict
|
||||
"general_sse_queues": [] # Initialize general queues list
|
||||
}
|
||||
|
||||
# Build InitializeResult based on spec
|
||||
initialize_result = {
|
||||
"protocolVersion": "2025-03-26",
|
||||
"name": self.mcp_server.name,
|
||||
"instructions": "Apache Doris MCP Server (Streamable HTTP Mode)",
|
||||
"serverInfo": { "version": "0.2.0", "name": "Doris MCP Streamable Server" }, # Adjust as needed
|
||||
"capabilities": {
|
||||
"tools": { "supportsStreaming": True, "supportsProgress": True },
|
||||
"resources": { "supportsStreaming": False }, # Example capability
|
||||
"prompts": { "supported": True }, # Example capability
|
||||
"session": { "supported": True }
|
||||
}
|
||||
}
|
||||
response_body = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"result": initialize_result
|
||||
}
|
||||
|
||||
# Return JSON response with Mcp-Session-Id header
|
||||
return JSONResponse(
|
||||
content=response_body,
|
||||
media_type="application/json",
|
||||
headers={"Mcp-Session-Id": new_session_id}
|
||||
)
|
||||
|
||||
async def _handle_client_post(self, request: Request, body: Dict, session_id: str, message_id: Any):
|
||||
"""Handles non-initialize POST requests (notifications, responses, method calls)."""
|
||||
method = body.get("method")
|
||||
|
||||
# Handle Notifications/Responses from client
|
||||
is_notification = "method" in body and "id" not in body
|
||||
is_response = "result" in body or "error" in body
|
||||
if is_notification or is_response:
|
||||
logger.info(f"Received Streamable HTTP notification/response [Session ID: {session_id}] - Processing needed? (Ignoring for now)")
|
||||
# TODO: If the server sends requests that expect responses, process is_response here.
|
||||
# For now, just acknowledge client notifications/responses.
|
||||
return JSONResponse({}, status_code=202) # Accepted
|
||||
|
||||
# Handle Requests from client (method call)
|
||||
if "method" in body and "id" in body:
|
||||
logger.info(f"Received Streamable HTTP request [Session ID: {session_id}, ID: {message_id}, Method: {method}]")
|
||||
params = body.get("params", {})
|
||||
stream_required = params.get("stream", False) if method in ["tools/call", "mcp/callTool"] else False
|
||||
|
||||
if stream_required:
|
||||
# --- Return SSE stream for response parts ---
|
||||
logger.info(f"Using SSE stream for request [Session ID: {session_id}, ID: {message_id}]")
|
||||
response_queue = asyncio.Queue()
|
||||
# Ensure request_queues exists (should have been created during initialize)
|
||||
if self.client_sessions[session_id].get("request_queues") is None:
|
||||
logger.error(f"Session {session_id} is missing 'request_queues' dictionary!")
|
||||
# Handle this inconsistency, maybe return an error
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32000, "message": "Internal server error: Session state inconsistent"}}, status_code=500)
|
||||
self.client_sessions[session_id]["request_queues"][message_id] = response_queue
|
||||
|
||||
# Start background task to process and put results in the queue
|
||||
asyncio.create_task(self._process_request_and_respond(
|
||||
request, body, session_id, message_id, response_queue, is_stream=True
|
||||
))
|
||||
|
||||
# Return EventSourceResponse using the request-specific queue
|
||||
return EventSourceResponse(self._create_request_sse_generator(session_id, message_id, response_queue), media_type="text/event-stream")
|
||||
else:
|
||||
# --- Return single JSON response ---
|
||||
logger.info(f"Using JSON response for request [Session ID: {session_id}, ID: {message_id}]")
|
||||
try:
|
||||
# Process the request directly and get the result/error payload
|
||||
result_or_error_payload = await self._process_request_and_respond(
|
||||
request, body, session_id, message_id, None, is_stream=False
|
||||
)
|
||||
# This function now returns the final JSON body or raises HTTPException
|
||||
return JSONResponse(content=result_or_error_payload, media_type="application/json")
|
||||
except HTTPException as http_exc:
|
||||
# Format HTTPException details into JSON-RPC error
|
||||
return JSONResponse(
|
||||
{"jsonrpc": "2.0", "id": message_id, "error": {"code": -32000, "message": http_exc.detail}},
|
||||
status_code=http_exc.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
# Catch unexpected errors during synchronous processing
|
||||
logger.error(f"Error processing non-stream request [Session ID: {session_id}, ID: {message_id}]: {str(e)}", exc_info=True)
|
||||
error_response = {"jsonrpc": "2.0", "id": message_id, "error": {"code": -32000, "message": f"Internal server error: {str(e)}"}}
|
||||
return JSONResponse(content=error_response, status_code=500)
|
||||
else:
|
||||
# Invalid JSON-RPC format (e.g., missing method or id for a request)
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32600, "message": "Invalid JSON-RPC request format"}}, status_code=400)
|
||||
|
||||
# === Generator Functions for SSE Streams ===
|
||||
|
||||
async def _create_general_sse_generator(self, session_id: str, queue: asyncio.Queue):
|
||||
"""Generator for GET /mcp server push streams."""
|
||||
queue_removed = False
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
if session_id not in self.client_sessions:
|
||||
logger.warning(f"General SSE stream generator: Session {session_id} closed.")
|
||||
break
|
||||
|
||||
message = await asyncio.wait_for(queue.get(), timeout=60.0)
|
||||
|
||||
if message == STREAM_END_MARKER:
|
||||
logger.debug(f"General SSE stream received end marker [Session ID: {session_id}]")
|
||||
break
|
||||
|
||||
if isinstance(message, dict) and ("result" in message or "error" in message) and "id" in message:
|
||||
logger.warning(f"Attempted to send response on GET stream, blocked [Session ID: {session_id}]: {message}")
|
||||
queue.task_done()
|
||||
continue
|
||||
|
||||
# TODO: Event ID for recovery?
|
||||
yield {"event": "message", "data": json.dumps(message)}
|
||||
queue.task_done()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
if session_id not in self.client_sessions:
|
||||
logger.warning(f"General SSE stream generator (timeout): Session {session_id} closed.")
|
||||
break
|
||||
yield {"event": "ping", "data": "keepalive"}
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"General SSE stream cancelled [Session ID: {session_id}]")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"General SSE stream error [Session ID: {session_id}]: {str(e)}", exc_info=True)
|
||||
break
|
||||
finally:
|
||||
logger.info(f"General SSE stream ended [Session ID: {session_id}]")
|
||||
if not queue_removed and session_id in self.client_sessions:
|
||||
session = self.client_sessions[session_id]
|
||||
if session.get("general_sse_queues") is not None:
|
||||
try:
|
||||
session["general_sse_queues"].remove(queue)
|
||||
queue_removed = True
|
||||
logger.debug(f"General SSE queue removed from session [Session ID: {session_id}]")
|
||||
except ValueError:
|
||||
logger.warning(f"Failed to remove general SSE queue (not found) [Session ID: {session_id}]")
|
||||
except Exception as ce:
|
||||
logger.error(f"Error removing general SSE queue [Session ID: {session_id}]: {ce}")
|
||||
while not queue.empty():
|
||||
try: queue.get_nowait(); queue.task_done()
|
||||
except asyncio.QueueEmpty: break
|
||||
|
||||
async def _create_request_sse_generator(self, session_id: str, request_id: Any, queue: asyncio.Queue):
|
||||
"""Generator for POST /mcp request-response streams."""
|
||||
queue_removed = False
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
if session_id not in self.client_sessions or \
|
||||
request_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
|
||||
logger.warning(f"Request SSE stream generator: Session/Request queue closed [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
break
|
||||
|
||||
message = await asyncio.wait_for(queue.get(), timeout=120.0) # Longer timeout for requests?
|
||||
|
||||
if message == STREAM_END_MARKER:
|
||||
logger.debug(f"Request SSE stream received end marker [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
break
|
||||
|
||||
# TODO: Event ID for parts?
|
||||
yield {"event": "message", "data": json.dumps(message)}
|
||||
queue.task_done()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
if session_id not in self.client_sessions or \
|
||||
request_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
|
||||
logger.warning(f"Request SSE stream generator (timeout): Session/Request queue closed [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
break
|
||||
logger.debug(f"Request SSE stream timed out waiting for message/end [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
# Unlike general stream, timeout here might indicate an issue or just long processing.
|
||||
# Continue waiting for the STREAM_END_MARKER.
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Request SSE stream cancelled [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Request SSE stream error [Session ID: {session_id}, Request ID: {request_id}]: {str(e)}", exc_info=True)
|
||||
break
|
||||
finally:
|
||||
logger.info(f"Request SSE stream ended [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
if not queue_removed and session_id in self.client_sessions:
|
||||
session = self.client_sessions[session_id]
|
||||
if session.get("request_queues") is not None:
|
||||
if session["request_queues"].pop(request_id, None):
|
||||
queue_removed = True
|
||||
logger.debug(f"Request SSE queue removed from session [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
else:
|
||||
logger.warning(f"Failed to remove request SSE queue (not found) [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
while not queue.empty():
|
||||
try: queue.get_nowait(); queue.task_done()
|
||||
except asyncio.QueueEmpty: break
|
||||
|
||||
# === Core Request Processing Logic ===
|
||||
|
||||
async def _process_request_and_respond(
|
||||
self, request: Request, body: Dict, session_id: str, message_id: Any,
|
||||
response_queue: Optional[asyncio.Queue], # Queue ONLY for streaming responses
|
||||
is_stream: bool # True if response should go via SSE queue
|
||||
):
|
||||
"""Processes client method calls and prepares response/error payload or sends to queue.
|
||||
Returns payload for non-streaming, returns None for streaming (uses queue).
|
||||
Raises HTTPException for non-streaming errors that need specific status codes.
|
||||
"""
|
||||
logger.info(f"Entering _process_request_and_respond for method '{body.get('method')}'...")
|
||||
method = body.get("method")
|
||||
params = body.get("params", {})
|
||||
response_payload = None # Holds the 'result' or 'error' part of JSON-RPC
|
||||
|
||||
try:
|
||||
# --- Handle Method Calls ---
|
||||
if method == "mcp/listOfferings":
|
||||
tools = await self.mcp_server.list_tools()
|
||||
tools_json = self._format_tools(tools)
|
||||
resources = await self.mcp_server.list_resources()
|
||||
resources_json = self._format_resources(resources)
|
||||
prompts = await self.mcp_server.list_prompts()
|
||||
prompts_json = self._format_prompts(prompts)
|
||||
response_payload = {"tools": tools_json, "resources": resources_json, "prompts": prompts_json}
|
||||
|
||||
elif method == "mcp/listTools" or method == "tools/list":
|
||||
tools = await self.mcp_server.list_tools()
|
||||
response_payload = {"tools": self._format_tools(tools)}
|
||||
|
||||
elif method == "mcp/listResources":
|
||||
resources = await self.mcp_server.list_resources()
|
||||
response_payload = {"resources": self._format_resources(resources)}
|
||||
|
||||
elif method == "mcp/listPrompts":
|
||||
prompts = await self.mcp_server.list_prompts()
|
||||
response_payload = {"prompts": self._format_prompts(prompts)}
|
||||
|
||||
elif method == "mcp/callTool" or method == "tools/call":
|
||||
tool_name = params.get("name")
|
||||
arguments = params.get("arguments", {})
|
||||
if not tool_name:
|
||||
# For non-streaming, raise HTTPException; for streaming, send error via queue
|
||||
error_detail = "Invalid params: tool name ('name') is required"
|
||||
if is_stream and response_queue:
|
||||
error_resp = {"jsonrpc": "2.0", "id": message_id, "error": {"code": -32602, "message": error_detail}}
|
||||
await response_queue.put(error_resp)
|
||||
# No return here for stream, let finally handle end marker
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=error_detail)
|
||||
return # Exit after handling error
|
||||
|
||||
# --- Tool Calling ---
|
||||
if is_stream and response_queue:
|
||||
# Background task handles putting results/errors in queue
|
||||
logger.info(f"Launching stream tool task [Session: {session_id}, Req: {message_id}, Tool: {tool_name}]")
|
||||
asyncio.create_task(self._execute_stream_tool_wrapper(
|
||||
tool_name, arguments, message_id, session_id, request, response_queue
|
||||
))
|
||||
# Returns None, caller (_handle_client_post) returns EventSourceResponse
|
||||
return
|
||||
else:
|
||||
# Execute tool directly for non-streaming response
|
||||
logger.info(f"Executing non-stream tool [Session: {session_id}, Req: {message_id}, Tool: {tool_name}]")
|
||||
# Note: call_tool now raises ValueError on internal errors
|
||||
result = await self.call_tool(tool_name, arguments, request, None) # No callback needed
|
||||
logger.debug(f"Raw result from non-stream call_tool: {result}")
|
||||
response_payload = self._format_tool_call_result(result)
|
||||
else:
|
||||
# Method not found
|
||||
error_detail = f"Method not found: {method}"
|
||||
if is_stream and response_queue:
|
||||
error_resp = {"jsonrpc": "2.0", "id": message_id, "error": {"code": -32601, "message": error_detail}}
|
||||
await response_queue.put(error_resp)
|
||||
else:
|
||||
raise HTTPException(status_code=405, detail=error_detail)
|
||||
return # Exit after handling error
|
||||
|
||||
# --- Prepare final response payload (only if not streaming and successful) ---
|
||||
if response_payload is not None:
|
||||
final_response = {"jsonrpc": "2.0", "id": message_id, "result": response_payload}
|
||||
if is_stream and response_queue: # Should not happen if response_payload is set
|
||||
logger.error("Logic error: response_payload set for streaming call?")
|
||||
await response_queue.put(final_response) # Send anyway?
|
||||
elif not is_stream:
|
||||
logger.debug(f"Returning successful non-stream payload for {method}")
|
||||
return final_response # Return dict for JSONResponse
|
||||
|
||||
except Exception as e:
|
||||
# Handles errors raised by call_tool (ValueError) or other unexpected issues
|
||||
logger.error(f"Error processing request [Session: {session_id}, Req: {message_id}, Method: {method}]: {str(e)}", exc_info=True)
|
||||
error_code = -32000
|
||||
error_message = f"Internal server error: {str(e)}"
|
||||
status_code = 500 # Default for unexpected errors
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
# If it was an HTTPException raised earlier (e.g., 400, 405)
|
||||
error_message = e.detail
|
||||
status_code = e.status_code
|
||||
error_code = -32000 # Keep generic JSON-RPC code for now
|
||||
elif isinstance(e, ValueError):
|
||||
# Errors from call_tool (tool not found, execution error)
|
||||
error_message = str(e)
|
||||
status_code = 500 # Treat tool execution errors as internal server errors
|
||||
error_code = -32000 # Or a custom tool error code?
|
||||
|
||||
error_response_payload = {"code": error_code, "message": error_message}
|
||||
|
||||
if is_stream and response_queue:
|
||||
# Send error via queue for streaming calls
|
||||
final_error_response = {"jsonrpc": "2.0", "id": message_id, "error": error_response_payload}
|
||||
logger.debug(f"Putting error response into stream queue [Session: {session_id}, Req: {message_id}]")
|
||||
await response_queue.put(final_error_response)
|
||||
# Returns None, let finally send end marker
|
||||
return
|
||||
else:
|
||||
# For non-streaming, raise HTTPException to set status code
|
||||
logger.debug(f"Raising HTTPException for non-stream error (Status: {status_code})")
|
||||
raise HTTPException(status_code=status_code, detail=error_message)
|
||||
|
||||
finally:
|
||||
# If this was a streaming call, ensure the end marker is sent.
|
||||
# This runs even if the processing returns early (e.g., after launching task or handling error).
|
||||
if is_stream and response_queue:
|
||||
logger.debug(f"Putting stream end marker [Session: {session_id}, Req: {message_id}]")
|
||||
await response_queue.put(STREAM_END_MARKER)
|
||||
|
||||
|
||||
async def _execute_stream_tool_wrapper(
|
||||
self, tool_name: str, arguments: Dict, message_id: Any, session_id: str,
|
||||
request: Request, response_queue: asyncio.Queue
|
||||
):
|
||||
"""Wraps stream-capable tool calls, handles callback, puts results/errors into queue."""
|
||||
logger.info(f"Entering _execute_stream_tool_wrapper for tool '{tool_name}'...")
|
||||
try:
|
||||
logger.debug(f"Executing stream tool wrapper [Session: {session_id}, Req: {message_id}, Tool: {tool_name}]")
|
||||
|
||||
async def stream_callback(content, metadata=None):
|
||||
logger.debug(f"Stream callback received content [Session: {session_id}, Req: {message_id}]")
|
||||
partial_result_formatted = self._format_tool_call_result(content)
|
||||
|
||||
# Check session/queue validity before putting
|
||||
if session_id not in self.client_sessions or \
|
||||
message_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
|
||||
logger.warning(f"Stream callback: Session/Queue closed, cannot send partial result [Session: {session_id}, Req: {message_id}]")
|
||||
return
|
||||
|
||||
# Send progress notification
|
||||
progress_notification = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/progress",
|
||||
"params": {
|
||||
"requestId": message_id,
|
||||
"toolName": tool_name,
|
||||
"progress": partial_result_formatted,
|
||||
}
|
||||
}
|
||||
try:
|
||||
await response_queue.put(progress_notification)
|
||||
except Exception as e:
|
||||
logger.error(f"Stream callback failed to send progress: {str(e)}")
|
||||
|
||||
# Handle visualization data
|
||||
if metadata and "visualization" in metadata:
|
||||
await self.send_visualization_data(session_id, message_id, metadata["visualization"])
|
||||
|
||||
# --- Call Tool ---
|
||||
kwargs = dict(arguments)
|
||||
# Simplification: Assume tool supports callback if streaming requested
|
||||
kwargs['callback'] = stream_callback
|
||||
|
||||
# call_tool handles its own internal errors and raises ValueError
|
||||
result = await self.call_tool(tool_name, kwargs, request, stream_callback)
|
||||
logger.debug(f"Stream wrapper received final result from call_tool: {result}")
|
||||
|
||||
# --- Send Final Result ---
|
||||
if session_id not in self.client_sessions or \
|
||||
message_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
|
||||
logger.warning(f"Stream tool finished but Session/Queue closed [Session: {session_id}, Req: {message_id}]")
|
||||
return # Cannot send final result
|
||||
|
||||
final_result_formatted = self._format_tool_call_result(result)
|
||||
final_message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"result": final_result_formatted
|
||||
}
|
||||
logger.debug(f"Putting final stream result into queue [Session: {session_id}, Req: {message_id}]")
|
||||
await response_queue.put(final_message)
|
||||
logger.info(f"Stream tool execution successful [Session: {session_id}, Req: {message_id}]")
|
||||
|
||||
except Exception as e:
|
||||
# Catches errors from call_tool (ValueError) or other wrapper issues
|
||||
logger.error(f"Error during stream tool execution wrapper [Session: {session_id}, Req: {message_id}]: {str(e)}", exc_info=True)
|
||||
# Check session/queue validity before sending error
|
||||
if session_id not in self.client_sessions or \
|
||||
message_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
|
||||
logger.warning(f"Stream tool failed but Session/Queue closed [Session: {session_id}, Req: {message_id}]")
|
||||
return # Cannot send error
|
||||
|
||||
error_code = -32000
|
||||
error_message = f"Tool execution error: {str(e)}"
|
||||
if isinstance(e, ValueError):
|
||||
error_code = -32602 # Or -32000?
|
||||
error_message = str(e)
|
||||
|
||||
error_response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"error": { "code": error_code, "message": error_message }
|
||||
}
|
||||
try:
|
||||
await response_queue.put(error_response)
|
||||
except Exception as qe:
|
||||
logger.error(f"Failed to put error response into stream queue: {qe}")
|
||||
# No finally block needed here, handled by _process_request_and_respond
|
||||
|
||||
|
||||
async def call_tool(self, tool_name, arguments, request, callback: Optional[callable] = None):
|
||||
"""Finds and executes the target tool function/method.
|
||||
Raises ValueError on tool not found or execution error.
|
||||
"""
|
||||
logger.info(f"Entering call_tool for tool '{tool_name}'...")
|
||||
# Log args excluding callback
|
||||
log_args = {k: v for k, v in arguments.items() if k != 'callback'}
|
||||
logger.info(f"Executing tool: {tool_name}, Args: {json.dumps(log_args, ensure_ascii=False, default=str)}")
|
||||
|
||||
recent_query = self._extract_recent_query(request)
|
||||
# Tool mapping might be needed if client uses different names
|
||||
tool_mapping = {
|
||||
# Example: "clientFacingName": "internalFunctionName"
|
||||
"status": "mcp_doris_status",
|
||||
"health": "mcp_doris_health",
|
||||
# Add other mappings if needed, ensure consistency with tool_initializer
|
||||
"nl2sql_query": "mcp_doris_nl2sql_query",
|
||||
"nl2sql_query_stream": "mcp_doris_nl2sql_query_stream",
|
||||
"list_database_tables": "mcp_doris_list_database_tables",
|
||||
"explain_table": "mcp_doris_explain_table",
|
||||
"get_nl2sql_status": "mcp_doris_get_nl2sql_status",
|
||||
"refresh_metadata": "mcp_doris_refresh_metadata",
|
||||
"sql_optimize": "mcp_doris_sql_optimize",
|
||||
"fix_sql": "mcp_doris_fix_sql",
|
||||
"count_chars": "mcp_doris_count_chars",
|
||||
"exec_query": "mcp_doris_exec_query",
|
||||
"get_schema_list": "mcp_doris_get_schema_list", # Deprecated?
|
||||
"save_metadata": "mcp_doris_save_metadata", # Likely internal
|
||||
"get_metadata": "mcp_doris_get_metadata", # Likely internal
|
||||
"analyze_query_result": "mcp_doris_analyze_query_result", # Internal?
|
||||
"generate_sql": "mcp_doris_generate_sql", # Likely internal
|
||||
"explain_sql": "mcp_doris_explain_sql", # Internal?
|
||||
"modify_sql": "mcp_doris_modify_sql", # Internal?
|
||||
"parse_query": "mcp_doris_parse_query", # Internal?
|
||||
"identify_query_type": "mcp_doris_identify_query_type", # Internal?
|
||||
"validate_sql_syntax": "mcp_doris_validate_sql_syntax", # Internal?
|
||||
"check_sql_security": "mcp_doris_check_sql_security", # Internal?
|
||||
"find_similar_examples": "mcp_doris_find_similar_examples", # Internal?
|
||||
"find_similar_history": "mcp_doris_find_similar_history", # Internal?
|
||||
"calculate_query_similarity": "mcp_doris_calculate_query_similarity", # Internal?
|
||||
"adapt_similar_query": "mcp_doris_adapt_similar_query", # Internal?
|
||||
"get_nl2sql_prompt": "mcp_doris_get_nl2sql_prompt" # Internal?
|
||||
}
|
||||
mapped_tool_name = tool_mapping.get(tool_name, tool_name)
|
||||
|
||||
try:
|
||||
# 1. Find the registered tool instance/function from FastMCP
|
||||
tool_instance = None
|
||||
mcp = self.app.state.mcp if hasattr(self.app.state, 'mcp') else self.mcp_server
|
||||
registered_tools = await mcp.list_tools()
|
||||
for tool in registered_tools:
|
||||
# The tool object returned by list_tools might be the wrapper function
|
||||
# defined in tool_initializer. We need its name.
|
||||
tool_registered_name = getattr(tool, 'name', getattr(tool, '__name__', None))
|
||||
if tool_registered_name == tool_name: # Match against the name used in @mcp.tool
|
||||
tool_instance = tool # This is likely the wrapper function itself
|
||||
logger.debug(f"Found registered tool wrapper: {tool_registered_name}")
|
||||
break
|
||||
|
||||
if not tool_instance:
|
||||
# Fallback: Try importing directly (less ideal as it bypasses registration)
|
||||
logger.warning(f"Tool '{tool_name}' not found in registered tools, trying direct import of {mapped_tool_name}")
|
||||
try:
|
||||
import doris_mcp_server.tools.mcp_doris_tools as mcp_tools
|
||||
tool_instance = getattr(mcp_tools, mapped_tool_name, None)
|
||||
if not tool_instance or not callable(tool_instance):
|
||||
raise ValueError(f"Tool function {mapped_tool_name} not found or not callable in mcp_doris_tools.")
|
||||
logger.debug(f"Using directly imported tool function: {mapped_tool_name}")
|
||||
# If using direct import, FastMCP context (ctx) is not available
|
||||
# We need to pass args directly
|
||||
processed_args = self._process_tool_arguments(mapped_tool_name, arguments, recent_query)
|
||||
# Inject callback if provided and applicable
|
||||
if callback and mapped_tool_name.endswith("_stream"):
|
||||
processed_args['callback'] = callback
|
||||
elif callback:
|
||||
processed_args.pop('callback', None)
|
||||
result = await tool_instance(**processed_args)
|
||||
logger.debug(f"Raw result from directly imported tool '{mapped_tool_name}': {result}")
|
||||
return result
|
||||
|
||||
except (ImportError, AttributeError, ValueError) as import_err:
|
||||
logger.error(f"Failed to find or import tool: {tool_name} / {mapped_tool_name}. Error: {import_err}")
|
||||
raise ValueError(f"Tool '{tool_name}' not found or failed to import.") from import_err
|
||||
|
||||
# 2. If found via registration, execute using FastMCP's mechanism (if possible)
|
||||
# or simulate the context passing if tool_instance is the wrapper.
|
||||
# The wrapper expects a Context object.
|
||||
logger.debug(f"Executing registered tool wrapper '{tool_name}'")
|
||||
# We need to manually create a mock or simplified Context if FastMCP doesn't handle this automatically
|
||||
# For simplicity, let's try passing parameters directly if the wrapper handles it.
|
||||
# Ideally, FastMCP would handle the execution via mcp.call_tool(tool_name, params=...) if available.
|
||||
# Let's assume the wrapper function handles **kwargs or a Context object.
|
||||
|
||||
# Create a pseudo-context or just pass params
|
||||
# Method 1: Pass params directly (assuming wrapper handles it)
|
||||
# processed_args = self._process_tool_arguments(mapped_tool_name, arguments, recent_query)
|
||||
# if callback:
|
||||
# processed_args['callback'] = callback
|
||||
# result = await tool_instance(**processed_args) # This likely won't work if it expects Context
|
||||
|
||||
# Method 2: Create a Context-like object (Requires Context class import)
|
||||
# from mcp.server.fastmcp import Context # Make sure imported
|
||||
# pseudo_ctx = Context(mcp=mcp, request=request, params=arguments, tool=tool_instance)
|
||||
# result = await tool_instance(pseudo_ctx)
|
||||
|
||||
# Method 3: Use mcp.call_tool internal method if accessible and appropriate
|
||||
# This is speculative based on potential FastMCP internals
|
||||
if hasattr(mcp, 'call_tool_by_name'): # Hypothetical method
|
||||
logger.debug("Attempting execution via mcp.call_tool_by_name")
|
||||
pseudo_ctx_params = arguments # Pass client args
|
||||
# pseudo_ctx_params['_request'] = request # Maybe pass request?
|
||||
if callback: pseudo_ctx_params['callback'] = callback # Pass callback?
|
||||
result = await mcp.call_tool_by_name(tool_name, params=pseudo_ctx_params)
|
||||
logger.debug(f"Result from mcp.call_tool_by_name: {result}")
|
||||
else:
|
||||
# Fallback to manual context simulation if no direct call method exists
|
||||
logger.debug("Falling back to manual context simulation for tool wrapper execution")
|
||||
from mcp.server.fastmcp import Context # Ensure imported
|
||||
# Prepare params for context, including potentially callback
|
||||
context_params = dict(arguments)
|
||||
if callback: context_params['callback'] = callback
|
||||
pseudo_ctx = Context(mcp=mcp, request=request, params=context_params, tool=tool_instance)
|
||||
result = await tool_instance(pseudo_ctx) # Call the wrapper with simulated context
|
||||
logger.debug(f"Result from manual context simulation: {result}")
|
||||
|
||||
logger.debug(f"Raw result received in call_tool from registered tool '{tool_name}': {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during call_tool for '{tool_name}': {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Error executing tool '{tool_name}': {str(e)}") from e
|
||||
|
||||
|
||||
# === Helper Methods (Formatting, Session Cleanup, etc.) ===
|
||||
|
||||
def _format_tools(self, tools):
|
||||
# Helper to format tool list for responses
|
||||
# Based on mcp/listTools structure
|
||||
tools_json = []
|
||||
for tool in tools:
|
||||
# Assuming tools from list_tools are the wrapper functions
|
||||
tool_registered_name = getattr(tool, 'name', getattr(tool, '__name__', None))
|
||||
if not tool_registered_name:
|
||||
logger.warning(f"Could not determine name for tool object: {tool}")
|
||||
continue
|
||||
|
||||
# Need a way to get description and schema associated with the wrapper
|
||||
# This might require inspecting the mcp instance's internal storage
|
||||
mcp = self.app.state.mcp if hasattr(self.app.state, 'mcp') else self.mcp_server
|
||||
# Hypothetical internal access - THIS IS FRAGILE
|
||||
tool_spec = mcp.tools.get(tool_registered_name) if hasattr(mcp, 'tools') else None
|
||||
|
||||
description = ""
|
||||
input_schema = {"type": "object", "properties": {}, "required": []}
|
||||
if tool_spec and hasattr(tool_spec, 'description'):
|
||||
description = tool_spec.description
|
||||
if tool_spec and hasattr(tool_spec, 'parameters'): # Assuming parameters holds the JSON schema
|
||||
input_schema = tool_spec.parameters
|
||||
|
||||
tools_json.append({
|
||||
"name": tool_registered_name,
|
||||
"description": description,
|
||||
"inputSchema": input_schema
|
||||
})
|
||||
return tools_json
|
||||
|
||||
def _format_resources(self, resources):
|
||||
# Helper to format resource list
|
||||
return [res.model_dump() if hasattr(res, "model_dump") else res for res in resources]
|
||||
|
||||
def _format_prompts(self, prompts):
|
||||
# Helper to format prompt list
|
||||
return [prompt.model_dump() if hasattr(prompt, "model_dump") else prompt for prompt in prompts]
|
||||
|
||||
def _format_tool_call_result(self, result: Any) -> Dict[str, Any]:
|
||||
# Helper to format tool results into MCP Content format
|
||||
content_list = []
|
||||
if isinstance(result, str):
|
||||
try:
|
||||
# If it looks like the tool already returned the full JSON RPC like structure
|
||||
parsed_json = json.loads(result)
|
||||
if isinstance(parsed_json, dict) and 'content' in parsed_json and isinstance(parsed_json['content'], list):
|
||||
logger.debug("Tool result already seems formatted with 'content', using as is.")
|
||||
return parsed_json # Use the structure directly
|
||||
else:
|
||||
# Assume it's JSON content, wrap it
|
||||
content_list.append({"type": "json", "json": parsed_json})
|
||||
except json.JSONDecodeError:
|
||||
# Not JSON, treat as text
|
||||
content_list.append({"type": "text", "text": result})
|
||||
elif isinstance(result, (dict, list)):
|
||||
# If result is already a dict with a 'content' list, use it directly
|
||||
if isinstance(result, dict) and 'content' in result and isinstance(result['content'], list):
|
||||
logger.debug("Tool result dictionary has 'content', using as is.")
|
||||
return result # Use the structure directly
|
||||
else:
|
||||
# Otherwise, assume it's JSON content to be wrapped
|
||||
content_list.append({"type": "json", "json": result})
|
||||
elif result is None:
|
||||
# Handle None result, maybe return empty content or specific type?
|
||||
logger.warning("_format_tool_call_result received None result")
|
||||
content_list.append({"type": "text", "text": ""}) # Example: empty text
|
||||
else:
|
||||
# Other types, convert to string and wrap as text
|
||||
content_list.append({"type": "text", "text": str(result)})
|
||||
# Always return a dict with a 'content' key containing a list
|
||||
return {"content": content_list}
|
||||
|
||||
def _process_tool_arguments(self, tool_name, arguments, recent_query):
|
||||
# Helper to process tool arguments, including random_string fallback
|
||||
# Note: Ensure callback is NOT passed here
|
||||
processed_args = dict(arguments)
|
||||
processed_args.pop('callback', None) # Explicitly remove callback
|
||||
|
||||
if "random_string" in arguments and tool_name.startswith("mcp_doris_"):
|
||||
random_string = processed_args.pop("random_string", "") # Remove from processed too
|
||||
logger.debug(f"Processing random_string '{random_string}' for tool {tool_name}")
|
||||
|
||||
# ... (rest of random_string logic as before) ...
|
||||
# Example for exec_query:
|
||||
if tool_name == "mcp_doris_exec_query" and not processed_args.get("sql"):
|
||||
sql_fallback = random_string or recent_query
|
||||
# ... (logic to extract SQL from fallback) ...
|
||||
if sql_extracted:
|
||||
processed_args["sql"] = sql_extracted
|
||||
else:
|
||||
logger.warning(f"Missing sql for {tool_name}, and fallback failed.")
|
||||
# ... (logic for table_name fallback) ...
|
||||
|
||||
return processed_args
|
||||
|
||||
def _extract_recent_query(self, request: Request) -> Optional[str]:
|
||||
# Helper to extract recent user query from request
|
||||
# (Implementation as provided previously)
|
||||
try:
|
||||
# Try to extract message history from request body
|
||||
body = None
|
||||
body_bytes = getattr(request, "_body", None)
|
||||
if body_bytes:
|
||||
try:
|
||||
body = json.loads(body_bytes)
|
||||
except: pass
|
||||
if not body: body = getattr(request, "_json", {})
|
||||
|
||||
messages = body.get("params", {}).get("messages", [])
|
||||
if messages:
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user": return msg.get("content", "")
|
||||
|
||||
message = body.get("params", {}).get("message", {})
|
||||
if message and message.get("role") == "user": return message.get("content", "")
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting recent query: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _cleanup_session_resources(self, session_id: str, session_data: Dict):
|
||||
# Helper to clean up queues when session is deleted
|
||||
logger.info(f"Cleaning up resources for session [Session ID: {session_id}]")
|
||||
# Close general SSE queues
|
||||
general_queues = session_data.get("general_sse_queues", [])
|
||||
for queue in general_queues:
|
||||
try:
|
||||
await queue.put(STREAM_END_MARKER)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error putting end marker in general queue for session {session_id}: {e}")
|
||||
# Close request-specific SSE queues
|
||||
request_queues = session_data.get("request_queues", {})
|
||||
for req_id, queue in request_queues.items():
|
||||
try:
|
||||
await queue.put(STREAM_END_MARKER)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error putting end marker in request queue {req_id} for session {session_id}: {e}")
|
||||
logger.info(f"Finished cleaning resources for session {session_id}")
|
||||
|
||||
# This method might belong in the main app or a shared utility if needed by both servers
|
||||
# async def cleanup_idle_sessions(self):
|
||||
# # ... (implementation - needs access to self.client_sessions) ...
|
||||
# pass
|
||||
|
||||
# This method might belong in the main app or a shared utility
|
||||
# async def broadcast_message(self, message):
|
||||
# # ... (implementation - needs access to self.client_sessions of BOTH servers?) ...
|
||||
# pass
|
||||
|
||||
# This method is specific to streamable http tool calls
|
||||
async def send_visualization_data(self, session_id: str, request_id: Any, visualization_data: Any):
|
||||
"""Sends visualization data as a notification on the request stream."""
|
||||
if session_id not in self.client_sessions:
|
||||
logger.warning(f"Cannot send visualization: Session {session_id} not found.")
|
||||
return
|
||||
queue = self.client_sessions.get(session_id, {}).get("request_queues", {}).get(request_id)
|
||||
if not queue:
|
||||
logger.warning(f"Cannot send visualization: Request queue {request_id} not found for session {session_id}.")
|
||||
return
|
||||
|
||||
notification = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/visualization",
|
||||
"params": visualization_data
|
||||
}
|
||||
try:
|
||||
await queue.put(notification)
|
||||
logger.info(f"Sent visualization notification [Session: {session_id}, Req: {request_id}]")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending visualization notification [Session: {session_id}, Req: {request_id}]: {e}")
|
||||
|
||||
# This might belong in main app or shared utility
|
||||
# async def send_periodic_updates(self):
|
||||
# # ... (implementation) ...
|
||||
# pass
|
||||
|
||||
# End of class DorisMCPStreamableServer
|
||||
@@ -1,25 +1,9 @@
|
||||
from .mcp_doris_tools import (
|
||||
mcp_doris_exec_query,
|
||||
mcp_doris_get_table_schema,
|
||||
mcp_doris_get_db_table_list,
|
||||
mcp_doris_get_db_list,
|
||||
mcp_doris_get_table_comment,
|
||||
mcp_doris_get_table_column_comments,
|
||||
mcp_doris_get_table_indexes,
|
||||
mcp_doris_get_recent_audit_logs,
|
||||
mcp_doris_get_catalog_list
|
||||
)
|
||||
"""
|
||||
MCP Tools Package - Contains all MCP tool implementations.
|
||||
|
||||
# The __all__ list should reflect the registered tool names,
|
||||
# even though the implementation functions have the prefix.
|
||||
__all__ = [
|
||||
"exec_query",
|
||||
"get_table_schema",
|
||||
"get_db_table_list",
|
||||
"get_db_list",
|
||||
"get_table_comment",
|
||||
"get_table_column_comments",
|
||||
"get_table_indexes",
|
||||
"get_recent_audit_logs",
|
||||
"get_catalog_list"
|
||||
]
|
||||
This package includes:
|
||||
- Doris database tools
|
||||
- Resource managers
|
||||
- Prompt managers
|
||||
- Tool registration and initialization
|
||||
"""
|
||||
|
||||
@@ -1,230 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Doris MCP Tool Implementations
|
||||
|
||||
Includes exec_query and new tools based on schema_extractor.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
import pandas as pd
|
||||
|
||||
# --- Use absolute imports ---
|
||||
from doris_mcp_server.utils.schema_extractor import MetadataExtractor
|
||||
from doris_mcp_server.utils.sql_executor_tools import execute_sql_query
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger("doris-mcp-tools")
|
||||
|
||||
# --- Helper Function to format response ---
|
||||
def _format_response(success: bool, result: Any = None, error: str = None, message: str = "") -> Dict[str, Any]:
|
||||
response_data = {
|
||||
"success": success,
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
if success and result is not None:
|
||||
# Handle DataFrame serialization
|
||||
if isinstance(result, pd.DataFrame):
|
||||
try:
|
||||
# Convert DataFrame to JSON records format
|
||||
response_data["result"] = json.loads(result.to_json(orient='records', date_format='iso'))
|
||||
except Exception as df_err:
|
||||
logger.error(f"DataFrame to JSON conversion failed: {df_err}")
|
||||
# Fallback or specific error handling for DataFrame
|
||||
response_data["result"] = {"error": "Failed to serialize DataFrame result"}
|
||||
response_data["success"] = False # Mark as failed if serialization fails
|
||||
response_data["error"] = f"DataFrame serialization error: {str(df_err)}"
|
||||
else:
|
||||
response_data["result"] = result
|
||||
response_data["message"] = message or "Operation successful" # Translated: Operation successful
|
||||
elif not success:
|
||||
response_data["error"] = error or "Unknown error" # Translated: Unknown error
|
||||
response_data["message"] = message or "Operation failed" # Translated: Operation failed
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(response_data, ensure_ascii=False, default=str) # Use default=str for non-serializable types
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
async def mcp_doris_exec_query(sql: str = None, db_name: str = None, catalog_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""
|
||||
Executes an SQL query and returns the result with catalog federation support.
|
||||
|
||||
Args:
|
||||
sql (str): The SQL query to execute. MUST use three-part naming for table references:
|
||||
- Internal tables: internal.db_name.table_name (e.g., "SELECT * FROM internal.ssb.customer")
|
||||
- External tables: catalog_name.db_name.table_name (e.g., "SELECT * FROM mysql.ssb.customer")
|
||||
- Cross-catalog queries: "SELECT * FROM mysql.ssb.customer m JOIN internal.ssb.orders o ON m.id = o.customer_id"
|
||||
|
||||
Examples:
|
||||
- Query internal catalog: "SELECT COUNT(*) FROM internal.ssb.customer"
|
||||
- Query MySQL catalog: "SELECT COUNT(*) FROM mysql.ssb.customer"
|
||||
- Cross-catalog join: "SELECT * FROM internal.ssb.customer c JOIN mysql.test.user_info u ON c.id = u.customer_id"
|
||||
|
||||
db_name (str, optional): Target database name. Only used for connection context, table names in SQL must be fully qualified.
|
||||
catalog_name (str, optional): Reference catalog name for context. Does not affect SQL execution - table names in SQL must be fully qualified.
|
||||
Available catalogs can be found using get_catalog_list tool.
|
||||
max_rows (int, optional): Maximum number of rows to return. Defaults to 100.
|
||||
timeout (int, optional): Query timeout in seconds. Defaults to 30.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing the query result or an error.
|
||||
"""
|
||||
logger.info(f"MCP Tool Call: mcp_doris_exec_query, SQL: {sql}, DB: {db_name}, Catalog: {catalog_name}, MaxRows: {max_rows}, Timeout: {timeout}")
|
||||
try:
|
||||
if not sql:
|
||||
return _format_response(success=False, error="SQL statement not provided", message="Please provide the SQL statement to execute")
|
||||
|
||||
# Build parameters to pass to execute_sql_query
|
||||
exec_ctx = {
|
||||
"params": {
|
||||
"sql": sql,
|
||||
"db_name": db_name,
|
||||
"catalog_name": catalog_name,
|
||||
"max_rows": max_rows,
|
||||
"timeout": timeout
|
||||
}
|
||||
}
|
||||
|
||||
# Directly call execute_sql_query to execute the query
|
||||
exec_result = await execute_sql_query(exec_ctx)
|
||||
|
||||
# The format returned by execute_sql_query is {'content': [{'type': 'text', 'text': json_string}]}
|
||||
# Need to parse the internal JSON string
|
||||
if exec_result and 'content' in exec_result and len(exec_result['content']) > 0 and 'text' in exec_result['content'][0]:
|
||||
try:
|
||||
# Parse JSON string
|
||||
result_data = json.loads(exec_result['content'][0]['text'])
|
||||
|
||||
# Directly return the parsed result obtained from execute_sql_query
|
||||
# This result is already in the format {"success": ..., "data": ..., "columns": ...} or {"success": false, "error": ...}
|
||||
# _format_response would wrap it again, but here we directly use the parsed data
|
||||
# Note: This changes the original return structure of this function; it now directly returns the output of sql_executor
|
||||
# If the _format_response wrapper needs to be maintained, the code below needs adjustment
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(result_data, ensure_ascii=False, default=str)
|
||||
}
|
||||
]
|
||||
}
|
||||
except json.JSONDecodeError as json_err:
|
||||
logger.error(f"Failed to parse execute_sql_query result: {json_err}")
|
||||
return _format_response(success=False, error=str(json_err), message="Error parsing SQL execution result")
|
||||
except Exception as parse_err:
|
||||
logger.error(f"Unexpected error occurred while processing execute_sql_query result: {parse_err}", exc_info=True)
|
||||
return _format_response(success=False, error=str(parse_err), message="Unknown error occurred while processing SQL execution result")
|
||||
else:
|
||||
logger.error(f"execute_sql_query returned an unexpected format: {exec_result}")
|
||||
return _format_response(success=False, error="SQL executor returned invalid format", message="Internal error executing SQL query")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_exec_query: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error executing SQL query")
|
||||
|
||||
|
||||
async def mcp_doris_get_table_schema(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_table_schema, Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
|
||||
if not table_name:
|
||||
return _format_response(success=False, error="Missing table_name parameter")
|
||||
try:
|
||||
extractor = MetadataExtractor(db_name=db_name, catalog_name=catalog_name)
|
||||
schema = extractor.get_table_schema(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
if not schema:
|
||||
return _format_response(success=False, error="Table not found or has no columns", message=f"Could not get schema for table {catalog_name or 'default'}.{db_name or extractor.db_name}.{table_name}")
|
||||
return _format_response(success=True, result=schema)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_table_schema: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting table schema")
|
||||
|
||||
async def mcp_doris_get_db_table_list(db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_db_table_list, DB: {db_name}, Catalog: {catalog_name}")
|
||||
try:
|
||||
extractor = MetadataExtractor(db_name=db_name, catalog_name=catalog_name)
|
||||
tables = extractor.get_database_tables(db_name=db_name, catalog_name=catalog_name)
|
||||
return _format_response(success=True, result=tables)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_db_table_list: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting database table list")
|
||||
|
||||
async def mcp_doris_get_db_list(catalog_name: str = None) -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_db_list, Catalog: {catalog_name}")
|
||||
try:
|
||||
extractor = MetadataExtractor(catalog_name=catalog_name)
|
||||
databases = extractor.get_all_databases(catalog_name=catalog_name)
|
||||
return _format_response(success=True, result=databases)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_db_list: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting database list")
|
||||
|
||||
async def mcp_doris_get_table_comment(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_table_comment, Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
|
||||
if not table_name:
|
||||
return _format_response(success=False, error="Missing table_name parameter")
|
||||
try:
|
||||
extractor = MetadataExtractor(db_name=db_name, catalog_name=catalog_name)
|
||||
comment = extractor.get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return _format_response(success=True, result=comment)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_table_comment: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting table comment")
|
||||
|
||||
async def mcp_doris_get_table_column_comments(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_table_column_comments, Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
|
||||
if not table_name:
|
||||
return _format_response(success=False, error="Missing table_name parameter")
|
||||
try:
|
||||
extractor = MetadataExtractor(db_name=db_name, catalog_name=catalog_name)
|
||||
comments = extractor.get_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return _format_response(success=True, result=comments)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_table_column_comments: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting column comments")
|
||||
|
||||
async def mcp_doris_get_table_indexes(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_table_indexes, Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
|
||||
if not table_name:
|
||||
return _format_response(success=False, error="Missing table_name parameter")
|
||||
try:
|
||||
extractor = MetadataExtractor(db_name=db_name, catalog_name=catalog_name)
|
||||
indexes = extractor.get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return _format_response(success=True, result=indexes)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_table_indexes: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting table indexes")
|
||||
|
||||
async def mcp_doris_get_recent_audit_logs(days: int = 7, limit: int = 100) -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_recent_audit_logs, Days: {days}, Limit: {limit}")
|
||||
try:
|
||||
extractor = MetadataExtractor()
|
||||
logs_df = extractor.get_recent_audit_logs(days=days, limit=limit)
|
||||
return _format_response(success=True, result=logs_df)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_recent_audit_logs: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting audit logs")
|
||||
|
||||
async def mcp_doris_get_catalog_list() -> Dict[str, Any]:
|
||||
"""
|
||||
Get Doris catalog list
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing catalog list or error information
|
||||
"""
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_catalog_list")
|
||||
try:
|
||||
extractor = MetadataExtractor()
|
||||
catalogs = extractor.get_catalog_list()
|
||||
return _format_response(success=True, result=catalogs, message="Successfully retrieved catalog list")
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_catalog_list: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting catalog list")
|
||||
455
doris_mcp_server/tools/prompts_manager.py
Normal file
455
doris_mcp_server/tools/prompts_manager.py
Normal file
@@ -0,0 +1,455 @@
|
||||
"""
|
||||
Apache Doris MCP Prompts Manager
|
||||
Provides standardized management of query templates and intelligent prompts
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from mcp.types import (
|
||||
GetPromptResult,
|
||||
Prompt,
|
||||
PromptArgument,
|
||||
PromptMessage,
|
||||
TextContent,
|
||||
)
|
||||
|
||||
from ..utils.db import DorisConnectionManager
|
||||
|
||||
|
||||
class PromptTemplate:
|
||||
"""Prompt template"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
template: str,
|
||||
arguments: list[PromptArgument] = None,
|
||||
category: str = "general",
|
||||
):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.template = template
|
||||
self.arguments = arguments or []
|
||||
self.category = category
|
||||
self.created_at = datetime.now()
|
||||
|
||||
def render(self, arguments: dict[str, Any]) -> str:
|
||||
"""Render template content"""
|
||||
content = self.template
|
||||
for key, value in arguments.items():
|
||||
placeholder = f"{{{key}}}"
|
||||
content = content.replace(placeholder, str(value))
|
||||
return content
|
||||
|
||||
|
||||
class DorisPromptsManager:
|
||||
"""Apache Doris Prompts Manager"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
self.templates = self._init_prompt_templates()
|
||||
|
||||
def _init_prompt_templates(self) -> dict[str, PromptTemplate]:
|
||||
"""Initialize prompt templates"""
|
||||
templates = {}
|
||||
|
||||
# Sales data analysis template
|
||||
templates["sales_analysis"] = PromptTemplate(
|
||||
name="sales_analysis",
|
||||
description="Sales data analysis query template for generating sales statistics and trend analysis queries",
|
||||
template="""Please help me analyze sales data with the following requirements:
|
||||
|
||||
Analysis time range: {date_range}
|
||||
{product_filter}
|
||||
{region_filter}
|
||||
|
||||
Please generate SQL queries to analyze the following dimensions:
|
||||
1. Total sales amount and order quantity
|
||||
2. Sales trends by time dimension
|
||||
3. Top-selling product rankings
|
||||
4. Sales personnel performance statistics
|
||||
|
||||
Data table structure reference:
|
||||
- Order table: Contains order ID, customer ID, salesperson ID, order amount, order time and other fields
|
||||
- Product table: Contains product ID, product name, product category, price and other fields
|
||||
- Customer table: Contains customer ID, customer name, region and other fields
|
||||
|
||||
Please ensure query results are easy to understand and analyze.""",
|
||||
arguments=[
|
||||
PromptArgument(
|
||||
name="date_range",
|
||||
description="Date range for analysis, such as 'Q1 2024' or 'last 30 days'",
|
||||
required=True,
|
||||
),
|
||||
PromptArgument(
|
||||
name="product_category",
|
||||
description="Product category filter condition, such as 'electronics'",
|
||||
required=False,
|
||||
),
|
||||
PromptArgument(
|
||||
name="region",
|
||||
description="Sales region filter condition, such as 'East China'",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
category="business_analysis",
|
||||
)
|
||||
|
||||
# User behavior analysis template
|
||||
templates["user_behavior_analysis"] = PromptTemplate(
|
||||
name="user_behavior_analysis",
|
||||
description="User behavior analysis query template for analyzing user activity patterns and preferences",
|
||||
template="""Please help me analyze user behavior data, analysis objectives:
|
||||
|
||||
User segment: {user_segment}
|
||||
{behavior_filter}
|
||||
Analysis period: {time_period}
|
||||
|
||||
Please generate SQL queries to analyze the following aspects:
|
||||
1. User activity statistics (DAU, MAU)
|
||||
2. User behavior path analysis
|
||||
3. Feature usage preference statistics
|
||||
4. User retention rate analysis
|
||||
|
||||
Data table structure reference:
|
||||
- User table: Contains user ID, registration time, user type, region and other fields
|
||||
- Behavior log table: Contains user ID, behavior type, behavior time, page path and other fields
|
||||
- Session table: Contains session ID, user ID, session start time, session duration and other fields
|
||||
|
||||
Please provide easy-to-understand statistical results and visualization suggestions.""",
|
||||
arguments=[
|
||||
PromptArgument(
|
||||
name="user_segment",
|
||||
description="User segment conditions, such as 'new users', 'active users'",
|
||||
required=True,
|
||||
),
|
||||
PromptArgument(
|
||||
name="behavior_type",
|
||||
description="Behavior type filter, such as 'login', 'purchase', 'browse'",
|
||||
required=False,
|
||||
),
|
||||
PromptArgument(
|
||||
name="time_period",
|
||||
description="Analysis time period, such as 'last 7 days', 'this month'",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
category="user_analysis",
|
||||
)
|
||||
|
||||
# Performance optimization analysis template
|
||||
templates["performance_optimization"] = PromptTemplate(
|
||||
name="performance_optimization",
|
||||
description="Database performance optimization analysis template for identifying performance bottlenecks and optimization opportunities",
|
||||
template="""Please help me with database performance analysis and optimization recommendations:
|
||||
|
||||
Focus area: {focus_area}
|
||||
{table_scope}
|
||||
Performance metrics: {metrics}
|
||||
|
||||
Please generate SQL queries to analyze the following content:
|
||||
1. Table and query performance statistics
|
||||
2. Index usage efficiency analysis
|
||||
3. Slow query identification and analysis
|
||||
4. Storage space usage
|
||||
|
||||
Analysis objectives:
|
||||
- Identify performance bottlenecks
|
||||
- Provide optimization recommendations
|
||||
- Evaluate optimization effects
|
||||
|
||||
Please provide specific optimization recommendations and implementation steps.""",
|
||||
arguments=[
|
||||
PromptArgument(
|
||||
name="focus_area",
|
||||
description="Performance area of focus, such as 'query performance', 'storage optimization'",
|
||||
required=True,
|
||||
),
|
||||
PromptArgument(
|
||||
name="table_name",
|
||||
description="Specific table name (optional), if analyzing specific table performance",
|
||||
required=False,
|
||||
),
|
||||
PromptArgument(
|
||||
name="metrics",
|
||||
description="Performance metrics of interest, such as 'response time', 'throughput'",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
category="performance",
|
||||
)
|
||||
|
||||
# Data quality check template
|
||||
templates["data_quality_check"] = PromptTemplate(
|
||||
name="data_quality_check",
|
||||
description="Data quality check template for detecting data integrity and consistency issues",
|
||||
template="""Please help me perform data quality checks:
|
||||
|
||||
Check target: {target_table}
|
||||
{quality_dimensions}
|
||||
Check level: {check_level}
|
||||
|
||||
Please generate SQL queries to check the following data quality issues:
|
||||
1. Data integrity (null values, duplicate values)
|
||||
2. Data consistency (format, range)
|
||||
3. Data accuracy (business rule validation)
|
||||
4. Data timeliness (update frequency)
|
||||
|
||||
Check items:
|
||||
- Required field null value checks
|
||||
- Primary key and unique constraint validation
|
||||
- Data format and type checks
|
||||
- Business logic consistency validation
|
||||
- Data distribution anomaly detection
|
||||
|
||||
Please provide detailed problem reports and fix recommendations.""",
|
||||
arguments=[
|
||||
PromptArgument(
|
||||
name="target_table", description="Target table name to check", required=True
|
||||
),
|
||||
PromptArgument(
|
||||
name="quality_dimensions",
|
||||
description="Quality check dimensions, such as 'integrity', 'consistency', 'accuracy'",
|
||||
required=False,
|
||||
),
|
||||
PromptArgument(
|
||||
name="check_level",
|
||||
description="Check level, such as 'basic check', 'deep check'",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
category="data_quality",
|
||||
)
|
||||
|
||||
# Report generation template
|
||||
templates["report_generation"] = PromptTemplate(
|
||||
name="report_generation",
|
||||
description="Business report generation template for creating standardized business reports",
|
||||
template="""Please help me generate business reports:
|
||||
|
||||
Report type: {report_type}
|
||||
Report period: {report_period}
|
||||
{business_scope}
|
||||
|
||||
Please generate SQL queries to build the following report content:
|
||||
1. Key business indicator summary
|
||||
2. Trend analysis and year-over-year/month-over-month comparison
|
||||
3. Anomaly data identification and explanation
|
||||
4. Business insights and recommendations
|
||||
|
||||
Report requirements:
|
||||
- Data accuracy and timeliness
|
||||
- Clear hierarchical structure
|
||||
- Easy-to-understand data presentation
|
||||
- Decision-supporting analytical perspective
|
||||
|
||||
Please provide complete report structure and data acquisition logic.""",
|
||||
arguments=[
|
||||
PromptArgument(
|
||||
name="report_type",
|
||||
description="Report type, such as 'sales report', 'operations report', 'financial report'",
|
||||
required=True,
|
||||
),
|
||||
PromptArgument(
|
||||
name="report_period",
|
||||
description="Report period, such as 'daily report', 'weekly report', 'monthly report'",
|
||||
required=True,
|
||||
),
|
||||
PromptArgument(
|
||||
name="business_unit",
|
||||
description="Business unit scope, such as 'East China region', 'Product line A'",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
category="reporting",
|
||||
)
|
||||
|
||||
# Real-time monitoring template
|
||||
templates["real_time_monitoring"] = PromptTemplate(
|
||||
name="real_time_monitoring",
|
||||
description="Real-time monitoring query template for building real-time data monitoring and alerting",
|
||||
template="""Please help me design real-time monitoring queries:
|
||||
|
||||
Monitoring target: {monitoring_target}
|
||||
Alert threshold: {alert_threshold}
|
||||
Monitoring frequency: {monitoring_frequency}
|
||||
|
||||
Please generate SQL queries to implement the following monitoring functions:
|
||||
1. Real-time statistics of key indicators
|
||||
2. Anomaly detection and alerting
|
||||
3. Trend change monitoring
|
||||
4. System health status checks
|
||||
|
||||
Monitoring dimensions:
|
||||
- Business indicator monitoring (transaction volume, user activity, etc.)
|
||||
- Technical indicator monitoring (performance, error rate, etc.)
|
||||
- Data quality monitoring (integrity, consistency, etc.)
|
||||
|
||||
Please provide complete monitoring solution and implementation recommendations.""",
|
||||
arguments=[
|
||||
PromptArgument(
|
||||
name="monitoring_target",
|
||||
description="Monitoring target, such as 'transaction system', 'user activity'",
|
||||
required=True,
|
||||
),
|
||||
PromptArgument(
|
||||
name="alert_threshold",
|
||||
description="Alert threshold setting, such as 'error rate > 5%'",
|
||||
required=False,
|
||||
),
|
||||
PromptArgument(
|
||||
name="monitoring_frequency",
|
||||
description="Monitoring frequency, such as 'real-time', 'every minute', 'every 5 minutes'",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
category="monitoring",
|
||||
)
|
||||
|
||||
return templates
|
||||
|
||||
async def list_prompts(self) -> list[Prompt]:
|
||||
"""List all available prompt templates"""
|
||||
prompts = []
|
||||
|
||||
for template in self.templates.values():
|
||||
prompt = Prompt(
|
||||
name=template.name,
|
||||
description=template.description,
|
||||
arguments=template.arguments,
|
||||
)
|
||||
prompts.append(prompt)
|
||||
|
||||
return prompts
|
||||
|
||||
async def get_prompt(self, name: str, arguments: dict[str, Any]) -> GetPromptResult:
|
||||
"""Get content of specific prompt template"""
|
||||
if name not in self.templates:
|
||||
raise ValueError(f"Prompt template named '{name}' not found")
|
||||
|
||||
template = self.templates[name]
|
||||
|
||||
# Process optional arguments
|
||||
processed_args = await self._process_arguments(template, arguments)
|
||||
|
||||
# Render template content
|
||||
rendered_content = template.render(processed_args)
|
||||
|
||||
# Add database context information
|
||||
context_info = await self._get_database_context()
|
||||
|
||||
full_content = f"""{rendered_content}
|
||||
|
||||
Database context information:
|
||||
{context_info}
|
||||
|
||||
Please generate accurate and efficient SQL queries based on the above requirements and database structure."""
|
||||
|
||||
return GetPromptResult(
|
||||
description=template.description,
|
||||
messages=[
|
||||
PromptMessage(
|
||||
role="user", content=TextContent(type="text", text=full_content)
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def _process_arguments(
|
||||
self, template: PromptTemplate, arguments: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Process template arguments"""
|
||||
processed = {}
|
||||
|
||||
for arg in template.arguments:
|
||||
if arg.name in arguments:
|
||||
processed[arg.name] = arguments[arg.name]
|
||||
elif arg.required:
|
||||
raise ValueError(f"Missing required parameter: {arg.name}")
|
||||
else:
|
||||
# Provide default handling for optional parameters
|
||||
processed[arg.name] = self._get_default_argument_text(arg.name)
|
||||
|
||||
return processed
|
||||
|
||||
def _get_default_argument_text(self, arg_name: str) -> str:
|
||||
"""Get default text for optional parameters"""
|
||||
defaults = {
|
||||
"product_category": "",
|
||||
"region": "",
|
||||
"behavior_type": "",
|
||||
"time_period": "No time range restriction",
|
||||
"table_name": "",
|
||||
"metrics": "All performance metrics",
|
||||
"quality_dimensions": "All quality dimensions",
|
||||
"check_level": "Standard check",
|
||||
"business_unit": "Full business scope",
|
||||
"alert_threshold": "Use default threshold",
|
||||
"monitoring_frequency": "Real-time monitoring",
|
||||
}
|
||||
|
||||
return defaults.get(arg_name, "")
|
||||
|
||||
async def _get_database_context(self) -> str:
|
||||
"""Get database context information"""
|
||||
try:
|
||||
connection = await self.connection_manager.get_connection("system")
|
||||
|
||||
# Get basic database information
|
||||
db_info_sql = """
|
||||
SELECT
|
||||
COUNT(*) as table_count,
|
||||
SUM(table_rows) as total_rows
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_type = 'BASE TABLE'
|
||||
"""
|
||||
|
||||
db_result = await connection.execute(db_info_sql)
|
||||
db_info = db_result.data[0] if db_result.data else {}
|
||||
|
||||
# Get main table list
|
||||
tables_sql = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment,
|
||||
table_rows
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_type = 'BASE TABLE'
|
||||
ORDER BY table_rows DESC
|
||||
LIMIT 10
|
||||
"""
|
||||
|
||||
tables_result = await connection.execute(tables_sql)
|
||||
|
||||
context = f"""Current database statistics:
|
||||
- Total number of tables: {db_info.get("table_count", 0)}
|
||||
- Total data rows: {db_info.get("total_rows", 0):,}
|
||||
|
||||
Main data tables:"""
|
||||
|
||||
for table in tables_result.data:
|
||||
context += f"\n- {table['table_name']}"
|
||||
if table.get("table_comment"):
|
||||
context += f": {table['table_comment']}"
|
||||
context += f" ({table.get('table_rows', 0):,} rows)"
|
||||
|
||||
return context
|
||||
|
||||
except Exception as e:
|
||||
return f"Unable to get database context information: {str(e)}"
|
||||
|
||||
def get_templates_by_category(self, category: str) -> list[PromptTemplate]:
|
||||
"""Get templates by category"""
|
||||
return [
|
||||
template
|
||||
for template in self.templates.values()
|
||||
if template.category == category
|
||||
]
|
||||
|
||||
def get_all_categories(self) -> list[str]:
|
||||
"""Get all template categories"""
|
||||
categories = {template.category for template in self.templates.values()}
|
||||
return sorted(categories)
|
||||
361
doris_mcp_server/tools/resources_manager.py
Normal file
361
doris_mcp_server/tools/resources_manager.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
Apache Doris MCP Resources Manager
|
||||
Provides standardized abstraction and access interface for database metadata
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from mcp.types import Resource
|
||||
|
||||
from ..utils.db import DorisConnectionManager
|
||||
|
||||
|
||||
class TableMetadata:
|
||||
"""Data table metadata"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
comment: str = None,
|
||||
row_count: int = 0,
|
||||
columns: list[dict] = None,
|
||||
create_time: datetime = None,
|
||||
):
|
||||
self.name = name
|
||||
self.comment = comment
|
||||
self.row_count = row_count
|
||||
self.columns = columns or []
|
||||
self.create_time = create_time
|
||||
|
||||
|
||||
class ViewMetadata:
|
||||
"""Data view metadata"""
|
||||
|
||||
def __init__(self, name: str, comment: str = None, definition: str = None):
|
||||
self.name = name
|
||||
self.comment = comment
|
||||
self.definition = definition
|
||||
|
||||
|
||||
class MetadataCache:
|
||||
"""Metadata cache manager"""
|
||||
|
||||
def __init__(self, ttl_seconds: int = 300):
|
||||
self.cache = {}
|
||||
self.ttl = ttl_seconds
|
||||
|
||||
async def get(self, key: str) -> Any | None:
|
||||
if key in self.cache:
|
||||
data, timestamp = self.cache[key]
|
||||
if datetime.now().timestamp() - timestamp < self.ttl:
|
||||
return data
|
||||
else:
|
||||
del self.cache[key]
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: Any):
|
||||
self.cache[key] = (value, datetime.now().timestamp())
|
||||
|
||||
|
||||
class DorisResourcesManager:
|
||||
"""Apache Doris Resources Manager"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
self.metadata_cache = MetadataCache()
|
||||
|
||||
async def list_resources(self) -> list[Resource]:
|
||||
"""List all available database resources"""
|
||||
resources = []
|
||||
|
||||
try:
|
||||
# Get metadata for all tables
|
||||
tables = await self._get_table_metadata()
|
||||
for table in tables:
|
||||
resources.append(
|
||||
Resource(
|
||||
uri=f"doris://table/{table.name}",
|
||||
name=f"Data Table: {table.name}",
|
||||
description=f"{table.comment or 'Data table'} (rows: {table.row_count:,})",
|
||||
mimeType="application/json",
|
||||
)
|
||||
)
|
||||
|
||||
# Get metadata for all views
|
||||
views = await self._get_view_metadata()
|
||||
for view in views:
|
||||
resources.append(
|
||||
Resource(
|
||||
uri=f"doris://view/{view.name}",
|
||||
name=f"Data View: {view.name}",
|
||||
description=f"{view.comment or 'Data view'}",
|
||||
mimeType="application/json",
|
||||
)
|
||||
)
|
||||
|
||||
# Add database statistics resource
|
||||
resources.append(
|
||||
Resource(
|
||||
uri="doris://stats/database",
|
||||
name="Database Statistics",
|
||||
description="Overall database statistics and performance metrics",
|
||||
mimeType="application/json",
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to get resource list: {e}")
|
||||
|
||||
return resources
|
||||
|
||||
async def read_resource(self, uri: str) -> str:
|
||||
"""Read detailed information of specific resource"""
|
||||
try:
|
||||
resource_type, resource_name = self._parse_resource_uri(uri)
|
||||
|
||||
if resource_type == "table":
|
||||
return await self._get_table_schema(resource_name)
|
||||
elif resource_type == "view":
|
||||
return await self._get_view_definition(resource_name)
|
||||
elif resource_type == "stats" and resource_name == "database":
|
||||
return await self._get_database_stats()
|
||||
else:
|
||||
raise ValueError(f"Unsupported resource type: {resource_type}")
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps(
|
||||
{"error": f"Failed to read resource: {str(e)}", "uri": uri},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
async def _get_table_metadata(self) -> list[TableMetadata]:
|
||||
"""Get metadata for all tables"""
|
||||
cache_key = "table_metadata"
|
||||
cached = await self.metadata_cache.get(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
connection = await self.connection_manager.get_connection("system")
|
||||
|
||||
# Query basic table information
|
||||
tables_query = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment,
|
||||
table_rows as row_count,
|
||||
create_time
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_type = 'BASE TABLE'
|
||||
ORDER BY table_name
|
||||
"""
|
||||
|
||||
result = await connection.execute(tables_query)
|
||||
tables = []
|
||||
|
||||
for row in result.data:
|
||||
# Get column information for the table
|
||||
columns = await self._get_table_columns(connection, row["table_name"])
|
||||
|
||||
table = TableMetadata(
|
||||
name=row["table_name"],
|
||||
comment=row.get("table_comment"),
|
||||
row_count=row.get("row_count", 0),
|
||||
columns=columns,
|
||||
create_time=row.get("create_time"),
|
||||
)
|
||||
tables.append(table)
|
||||
|
||||
await self.metadata_cache.set(cache_key, tables)
|
||||
return tables
|
||||
|
||||
async def _get_table_columns(self, connection, table_name: str) -> list[dict]:
|
||||
"""Get column information for table"""
|
||||
columns_query = """
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
is_nullable,
|
||||
column_default,
|
||||
column_comment,
|
||||
column_key
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = %s
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
result = await connection.execute(columns_query, (table_name,))
|
||||
return [dict(row) for row in result.data]
|
||||
|
||||
async def _get_view_metadata(self) -> list[ViewMetadata]:
|
||||
"""Get metadata for all views"""
|
||||
cache_key = "view_metadata"
|
||||
cached = await self.metadata_cache.get(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
connection = await self.connection_manager.get_connection("system")
|
||||
|
||||
views_query = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment,
|
||||
view_definition
|
||||
FROM information_schema.views
|
||||
WHERE table_schema = DATABASE()
|
||||
ORDER BY table_name
|
||||
"""
|
||||
|
||||
result = await connection.execute(views_query)
|
||||
views = []
|
||||
|
||||
for row in result.data:
|
||||
view = ViewMetadata(
|
||||
name=row["table_name"],
|
||||
comment=row.get("table_comment"),
|
||||
definition=row.get("view_definition"),
|
||||
)
|
||||
views.append(view)
|
||||
|
||||
await self.metadata_cache.set(cache_key, views)
|
||||
return views
|
||||
|
||||
async def _get_table_schema(self, table_name: str) -> str:
|
||||
"""Get detailed structure information of table"""
|
||||
connection = await self.connection_manager.get_connection("system")
|
||||
|
||||
# Get basic table information
|
||||
table_info_query = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment,
|
||||
table_rows,
|
||||
create_time,
|
||||
engine
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = %s
|
||||
"""
|
||||
|
||||
table_result = await connection.execute(table_info_query, (table_name,))
|
||||
if not table_result.data:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
|
||||
table_info = table_result.data[0]
|
||||
|
||||
# Get column information
|
||||
columns = await self._get_table_columns(connection, table_name)
|
||||
|
||||
# Get index information
|
||||
indexes = await self._get_table_indexes(connection, table_name)
|
||||
|
||||
schema_info = {
|
||||
"table_name": table_info["table_name"],
|
||||
"comment": table_info.get("table_comment"),
|
||||
"row_count": table_info.get("table_rows", 0),
|
||||
"create_time": str(table_info.get("create_time")),
|
||||
"engine": table_info.get("engine"),
|
||||
"columns": columns,
|
||||
"indexes": indexes,
|
||||
}
|
||||
|
||||
return json.dumps(schema_info, ensure_ascii=False, indent=2)
|
||||
|
||||
async def _get_table_indexes(self, connection, table_name: str) -> list[dict]:
|
||||
"""Get index information for table"""
|
||||
indexes_query = """
|
||||
SELECT
|
||||
index_name,
|
||||
column_name,
|
||||
index_type,
|
||||
non_unique
|
||||
FROM information_schema.statistics
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = %s
|
||||
ORDER BY index_name, seq_in_index
|
||||
"""
|
||||
|
||||
result = await connection.execute(indexes_query, (table_name,))
|
||||
return [dict(row) for row in result.data]
|
||||
|
||||
async def _get_view_definition(self, view_name: str) -> str:
|
||||
"""Get definition information of view"""
|
||||
connection = await self.connection_manager.get_connection("system")
|
||||
|
||||
view_query = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment,
|
||||
view_definition
|
||||
FROM information_schema.views
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = %s
|
||||
"""
|
||||
|
||||
result = await connection.execute(view_query, (view_name,))
|
||||
if not result.data:
|
||||
raise ValueError(f"View {view_name} does not exist")
|
||||
|
||||
view_info = result.data[0]
|
||||
|
||||
schema_info = {
|
||||
"view_name": view_info["table_name"],
|
||||
"comment": view_info.get("table_comment"),
|
||||
"definition": view_info.get("view_definition"),
|
||||
}
|
||||
|
||||
return json.dumps(schema_info, ensure_ascii=False, indent=2)
|
||||
|
||||
async def _get_database_stats(self) -> str:
|
||||
"""Get database statistics"""
|
||||
connection = await self.connection_manager.get_connection("system")
|
||||
|
||||
# Get table statistics
|
||||
table_stats_query = """
|
||||
SELECT
|
||||
COUNT(*) as table_count,
|
||||
SUM(table_rows) as total_rows
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_type = 'BASE TABLE'
|
||||
"""
|
||||
|
||||
table_result = await connection.execute(table_stats_query)
|
||||
table_stats = table_result.data[0] if table_result.data else {}
|
||||
|
||||
# Get view statistics
|
||||
view_stats_query = """
|
||||
SELECT COUNT(*) as view_count
|
||||
FROM information_schema.views
|
||||
WHERE table_schema = DATABASE()
|
||||
"""
|
||||
|
||||
view_result = await connection.execute(view_stats_query)
|
||||
view_stats = view_result.data[0] if view_result.data else {}
|
||||
|
||||
stats_info = {
|
||||
"database_name": "current_database",
|
||||
"table_count": table_stats.get("table_count", 0),
|
||||
"view_count": view_stats.get("view_count", 0),
|
||||
"total_rows": table_stats.get("total_rows", 0),
|
||||
"last_updated": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
return json.dumps(stats_info, ensure_ascii=False, indent=2)
|
||||
|
||||
def _parse_resource_uri(self, uri: str) -> tuple:
|
||||
"""Parse resource URI"""
|
||||
if not uri.startswith("doris://"):
|
||||
raise ValueError("Invalid resource URI format")
|
||||
|
||||
path = uri[8:] # Remove "doris://" prefix
|
||||
parts = path.split("/")
|
||||
|
||||
if len(parts) < 2:
|
||||
raise ValueError("Incomplete resource URI format")
|
||||
|
||||
return parts[0], parts[1]
|
||||
@@ -1,157 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Tool Initialization Module
|
||||
|
||||
Centralized initialization of all tools, ensuring they are correctly registered with MCP
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional
|
||||
import json
|
||||
from datetime import datetime
|
||||
import traceback
|
||||
|
||||
# Import Context
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
# Import doris mcp tools
|
||||
from doris_mcp_server.tools.mcp_doris_tools import (
|
||||
mcp_doris_exec_query,
|
||||
mcp_doris_get_table_schema,
|
||||
mcp_doris_get_db_table_list,
|
||||
mcp_doris_get_db_list,
|
||||
mcp_doris_get_table_comment,
|
||||
mcp_doris_get_table_column_comments,
|
||||
mcp_doris_get_table_indexes,
|
||||
mcp_doris_get_recent_audit_logs,
|
||||
mcp_doris_get_catalog_list
|
||||
)
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger("doris-mcp-tools-initializer")
|
||||
|
||||
async def register_mcp_tools(mcp):
|
||||
"""Register MCP tool functions
|
||||
|
||||
Args:
|
||||
mcp: FastMCP instance
|
||||
"""
|
||||
logger.info("Starting to register MCP tools...")
|
||||
|
||||
try:
|
||||
# Register Tool: Execute SQL Query (Using long description string including parameters)
|
||||
@mcp.tool("exec_query", description="""[Function Description]: Execute SQL query and return result command with catalog federation support.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- sql (string) [Required] - SQL statement to execute. MUST use three-part naming for all table references: 'catalog_name.db_name.table_name'. For internal tables use 'internal.db_name.table_name', for external tables use 'catalog_name.db_name.table_name'\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Reference catalog name for context, defaults to current catalog\n
|
||||
- max_rows (integer) [Optional] - Maximum number of rows to return, default 100
|
||||
- timeout (integer) [Optional] - Query timeout in seconds, default 30""")
|
||||
async def exec_query_tool(sql: str, db_name: str = None, catalog_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""Wrapper: Execute SQL query and return result command"""
|
||||
# Note: ctx parameter is no longer needed here as we receive named parameters directly
|
||||
return await mcp_doris_exec_query(sql=sql, db_name=db_name, catalog_name=catalog_name, max_rows=max_rows, timeout=timeout)
|
||||
|
||||
# Register Tool: Get Table Schema (Keep long description string including parameters)
|
||||
@mcp.tool("get_table_schema", description="""[Function Description]: Get detailed structure information of the specified table (columns, types, comments, etc.).\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_table_schema_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table schema"""
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_schema(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Database Table List (Keep long description string including parameters)
|
||||
@mcp.tool("get_db_table_list", description="""[Function Description]: Get a list of all table names in the specified database.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_db_table_list_tool(db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get database table list"""
|
||||
return await mcp_doris_get_db_table_list(db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Database List (Keep long description string including parameters)
|
||||
# Note: Although the description mentions random_string, the wrapper function signature does not. See how mcp handles this.
|
||||
@mcp.tool("get_db_list", description="""[Function Description]: Get a list of all database names on the server.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_db_list_tool(catalog_name: str = None) -> Dict[str, Any]: # Function signature has no parameters
|
||||
"""Wrapper: Get database list"""
|
||||
return await mcp_doris_get_db_list(catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Table Comment (Keep long description string including parameters)
|
||||
@mcp.tool("get_table_comment", description="""[Function Description]: Get the comment information for the specified table.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_table_comment_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table comment"""
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Table Column Comments (Keep long description string including parameters)
|
||||
@mcp.tool("get_table_column_comments", description="""[Function Description]: Get comment information for all columns in the specified table.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_table_column_comments_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table column comments"""
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Table Indexes (Keep long description string including parameters)
|
||||
@mcp.tool("get_table_indexes", description="""[Function Description]: Get index information for the specified table.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_table_indexes_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table indexes"""
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Recent Audit Logs (Keep long description string including parameters)
|
||||
@mcp.tool("get_recent_audit_logs", description="""[Function Description]: Get audit log records for a recent period.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- days (integer) [Optional] - Number of recent days of logs to retrieve, default is 7\n
|
||||
- limit (integer) [Optional] - Maximum number of records to return, default is 100\n""")
|
||||
async def get_recent_audit_logs_tool(days: int = 7, limit: int = 100) -> Dict[str, Any]:
|
||||
"""Wrapper: Get recent audit logs"""
|
||||
try:
|
||||
days = int(days)
|
||||
limit = int(limit)
|
||||
except (ValueError, TypeError):
|
||||
return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "days and limit parameters must be integers"})}]}
|
||||
return await mcp_doris_get_recent_audit_logs(days=days, limit=limit)
|
||||
|
||||
# Register Tool: Get Catalog List (Keep long description string including parameters)
|
||||
@mcp.tool("get_catalog_list", description="""[Function Description]: Get a list of all catalog names on the server.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n""")
|
||||
async def get_catalog_list_tool() -> Dict[str, Any]:
|
||||
"""Wrapper: Get catalog list"""
|
||||
return await mcp_doris_get_catalog_list()
|
||||
|
||||
# Get tool count
|
||||
tools_count = len(await mcp.list_tools())
|
||||
logger.info(f"Registered all MCP tools, total {tools_count} tools")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering MCP tools: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
766
doris_mcp_server/tools/tools_manager.py
Normal file
766
doris_mcp_server/tools/tools_manager.py
Normal file
@@ -0,0 +1,766 @@
|
||||
"""
|
||||
Apache Doris MCP Tools Manager
|
||||
Responsible for tool registration, management, scheduling and routing, does not contain specific business logic implementation
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from mcp.types import Tool
|
||||
|
||||
from ..utils.db import DorisConnectionManager
|
||||
from ..utils.query_executor import DorisQueryExecutor
|
||||
from ..utils.analysis_tools import TableAnalyzer, PerformanceMonitor
|
||||
from ..utils.schema_extractor import MetadataExtractor
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
|
||||
class DorisToolsManager:
|
||||
"""Apache Doris Tools Manager"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
|
||||
# Initialize business logic processors
|
||||
self.query_executor = DorisQueryExecutor(connection_manager)
|
||||
self.table_analyzer = TableAnalyzer(connection_manager)
|
||||
self.performance_monitor = PerformanceMonitor(connection_manager)
|
||||
self.metadata_extractor = MetadataExtractor(connection_manager=connection_manager)
|
||||
|
||||
logger.info("DorisToolsManager initialized with business logic processors")
|
||||
|
||||
async def register_tools_with_mcp(self, mcp):
|
||||
"""Register all tools to MCP server"""
|
||||
logger.info("Starting to register MCP tools")
|
||||
|
||||
# Column statistical analysis tool
|
||||
@mcp.tool(
|
||||
"column_analysis",
|
||||
description="""[Function Description]: Analyze statistical information and data distribution of the specified column.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to analyze
|
||||
|
||||
- column_name (string) [Required] - Name of the column to analyze
|
||||
|
||||
- analysis_type (string) [Optional] - Type of analysis to perform, default is "basic"
|
||||
* "basic": Basic statistics (count, null values, distinct values)
|
||||
* "distribution": Data distribution analysis (frequency, percentiles)
|
||||
* "detailed": Comprehensive analysis including all above plus patterns and outliers
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"table_name": {"type": "string", "description": "Table name"},
|
||||
"column_name": {
|
||||
"type": "string",
|
||||
"description": "Column name to analyze",
|
||||
},
|
||||
"analysis_type": {
|
||||
"type": "string",
|
||||
"enum": ["basic", "distribution", "detailed"],
|
||||
"description": "Analysis type",
|
||||
"default": "basic",
|
||||
},
|
||||
},
|
||||
"required": ["table_name", "column_name"],
|
||||
}
|
||||
)
|
||||
async def column_analysis_tool(
|
||||
table_name: str,
|
||||
column_name: str,
|
||||
analysis_type: str = "basic"
|
||||
) -> str:
|
||||
"""Column statistical analysis tool"""
|
||||
return await self.call_tool("column_analysis", {
|
||||
"table_name": table_name,
|
||||
"column_name": column_name,
|
||||
"analysis_type": analysis_type
|
||||
})
|
||||
|
||||
# Database performance monitoring tool
|
||||
@mcp.tool(
|
||||
"performance_stats[Experimental]",
|
||||
description="""[Important]: This tool is experimental and may not be fully functional!
|
||||
[Function Description]: Get database performance statistics information.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- metric_type (string) [Optional] - Type of performance metrics to retrieve, default is "queries"
|
||||
* "queries": Query performance metrics (execution time, frequency, etc.)
|
||||
* "connections": Connection statistics (active connections, connection pool status)
|
||||
* "tables": Table-level statistics (size, row count, access patterns)
|
||||
* "system": System-level metrics (CPU, memory, disk usage)
|
||||
|
||||
- time_range (string) [Optional] - Time range for statistics, default is "1h"
|
||||
* "1h": Last 1 hour
|
||||
* "6h": Last 6 hours
|
||||
* "24h": Last 24 hours
|
||||
* "7d": Last 7 days
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"metric_type": {
|
||||
"type": "string",
|
||||
"enum": ["queries", "connections", "tables", "system"],
|
||||
"description": "Performance metric type",
|
||||
"default": "queries",
|
||||
},
|
||||
"time_range": {
|
||||
"type": "string",
|
||||
"enum": ["1h", "6h", "24h", "7d"],
|
||||
"description": "Time range",
|
||||
"default": "1h",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
async def performance_stats_tool(
|
||||
metric_type: str = "queries",
|
||||
time_range: str = "1h"
|
||||
) -> str:
|
||||
"""Database performance monitoring tool"""
|
||||
return await self.call_tool("performance_stats", {
|
||||
"metric_type": metric_type,
|
||||
"time_range": time_range
|
||||
})
|
||||
|
||||
# SQL query execution tool (supports catalog federation queries)
|
||||
@mcp.tool(
|
||||
"exec_query",
|
||||
description="""[Function Description]: Execute SQL query and return result command with catalog federation support.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- sql (string) [Required] - SQL statement to execute. MUST use three-part naming for all table references: 'catalog_name.db_name.table_name'. For internal tables use 'internal.db_name.table_name', for external tables use 'catalog_name.db_name.table_name'
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Reference catalog name for context, defaults to current catalog
|
||||
|
||||
- max_rows (integer) [Optional] - Maximum number of rows to return, default 100
|
||||
|
||||
- timeout (integer) [Optional] - Query timeout in seconds, default 30
|
||||
""",
|
||||
)
|
||||
async def exec_query_tool(
|
||||
sql: str,
|
||||
db_name: str = None,
|
||||
catalog_name: str = None,
|
||||
max_rows: int = 100,
|
||||
timeout: int = 30,
|
||||
) -> str:
|
||||
"""Execute SQL query (supports federation queries)"""
|
||||
return await self.call_tool("exec_query", {
|
||||
"sql": sql,
|
||||
"db_name": db_name,
|
||||
"catalog_name": catalog_name,
|
||||
"max_rows": max_rows,
|
||||
"timeout": timeout
|
||||
})
|
||||
|
||||
# Get table schema tool
|
||||
@mcp.tool(
|
||||
"get_table_schema",
|
||||
description="""[Function Description]: Get detailed structure information of the specified table (columns, types, comments, etc.).
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to query
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
)
|
||||
async def get_table_schema_tool(
|
||||
table_name: str, db_name: str = None, catalog_name: str = None
|
||||
) -> str:
|
||||
"""Get table schema information"""
|
||||
return await self.call_tool("get_table_schema", {
|
||||
"table_name": table_name,
|
||||
"db_name": db_name,
|
||||
"catalog_name": catalog_name
|
||||
})
|
||||
|
||||
# Get database table list tool
|
||||
@mcp.tool(
|
||||
"get_db_table_list",
|
||||
description="""[Function Description]: Get a list of all table names in the specified database.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
)
|
||||
async def get_db_table_list_tool(
|
||||
db_name: str = None, catalog_name: str = None
|
||||
) -> str:
|
||||
"""Get database table list"""
|
||||
return await self.call_tool("get_db_table_list", {
|
||||
"db_name": db_name,
|
||||
"catalog_name": catalog_name
|
||||
})
|
||||
|
||||
# Get database list tool
|
||||
@mcp.tool(
|
||||
"get_db_list",
|
||||
description="""[Function Description]: Get a list of all database names on the server.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
)
|
||||
async def get_db_list_tool(catalog_name: str = None) -> str:
|
||||
"""Get database list"""
|
||||
return await self.call_tool("get_db_list", {
|
||||
"catalog_name": catalog_name
|
||||
})
|
||||
|
||||
# Get table comment tool
|
||||
@mcp.tool(
|
||||
"get_table_comment",
|
||||
description="""[Function Description]: Get the comment information for the specified table.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to query
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
)
|
||||
async def get_table_comment_tool(
|
||||
table_name: str, db_name: str = None, catalog_name: str = None
|
||||
) -> str:
|
||||
"""Get table comment"""
|
||||
return await self.call_tool("get_table_comment", {
|
||||
"table_name": table_name,
|
||||
"db_name": db_name,
|
||||
"catalog_name": catalog_name
|
||||
})
|
||||
|
||||
# Get table column comments tool
|
||||
@mcp.tool(
|
||||
"get_table_column_comments",
|
||||
description="""[Function Description]: Get comment information for all columns in the specified table.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to query
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
)
|
||||
async def get_table_column_comments_tool(
|
||||
table_name: str, db_name: str = None, catalog_name: str = None
|
||||
) -> str:
|
||||
"""Get table column comments"""
|
||||
return await self.call_tool("get_table_column_comments", {
|
||||
"table_name": table_name,
|
||||
"db_name": db_name,
|
||||
"catalog_name": catalog_name
|
||||
})
|
||||
|
||||
# Get table indexes tool
|
||||
@mcp.tool(
|
||||
"get_table_indexes",
|
||||
description="""[Function Description]: Get index information for the specified table.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to query
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
)
|
||||
async def get_table_indexes_tool(
|
||||
table_name: str, db_name: str = None, catalog_name: str = None
|
||||
) -> str:
|
||||
"""Get table indexes"""
|
||||
return await self.call_tool("get_table_indexes", {
|
||||
"table_name": table_name,
|
||||
"db_name": db_name,
|
||||
"catalog_name": catalog_name
|
||||
})
|
||||
|
||||
# Get audit logs tool
|
||||
@mcp.tool(
|
||||
"get_recent_audit_logs",
|
||||
description="""[Function Description]: Get audit log records for a recent period.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- days (integer) [Optional] - Number of recent days of logs to retrieve, default is 7
|
||||
|
||||
- limit (integer) [Optional] - Maximum number of records to return, default is 100
|
||||
""",
|
||||
)
|
||||
async def get_recent_audit_logs_tool(
|
||||
days: int = 7, limit: int = 100
|
||||
) -> str:
|
||||
"""Get audit logs"""
|
||||
return await self.call_tool("get_recent_audit_logs", {
|
||||
"days": days,
|
||||
"limit": limit
|
||||
})
|
||||
|
||||
# Get catalog list tool
|
||||
@mcp.tool(
|
||||
"get_catalog_list",
|
||||
description="""[Function Description]: Get a list of all catalog names on the server.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- random_string (string) [Required] - Unique identifier for the tool call
|
||||
""",
|
||||
)
|
||||
async def get_catalog_list_tool(random_string: str) -> str:
|
||||
"""Get catalog list"""
|
||||
return await self.call_tool("get_catalog_list", {
|
||||
"random_string": random_string
|
||||
})
|
||||
|
||||
logger.info("Successfully registered 11 tools to MCP server (2 core tools + 9 migrated tools)")
|
||||
|
||||
async def list_tools(self) -> List[Tool]:
|
||||
"""List all available query tools (for stdio mode)"""
|
||||
tools = [
|
||||
Tool(
|
||||
name="column_analysis[Experimental]",
|
||||
description="""[Important]: This tool is experimental and may not be fully functional!
|
||||
[Function Description]: Analyze statistical information and data distribution of the specified column.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to analyze
|
||||
|
||||
- column_name (string) [Required] - Name of the column to analyze
|
||||
|
||||
- analysis_type (string) [Optional] - Type of analysis to perform, default is "basic"
|
||||
* "basic": Basic statistics (count, null values, distinct values)
|
||||
* "distribution": Data distribution analysis (frequency, percentiles)
|
||||
* "detailed": Comprehensive analysis including all above plus patterns and outliers
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"table_name": {"type": "string", "description": "Table name"},
|
||||
"column_name": {
|
||||
"type": "string",
|
||||
"description": "Column name to analyze",
|
||||
},
|
||||
"analysis_type": {
|
||||
"type": "string",
|
||||
"enum": ["basic", "distribution", "detailed"],
|
||||
"description": "Analysis type",
|
||||
"default": "basic",
|
||||
},
|
||||
},
|
||||
"required": ["table_name", "column_name"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="performance_stats",
|
||||
description="""[Function Description]: Get database performance statistics information.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- metric_type (string) [Optional] - Type of performance metrics to retrieve, default is "queries"
|
||||
* "queries": Query performance metrics (execution time, frequency, etc.)
|
||||
* "connections": Connection statistics (active connections, connection pool status)
|
||||
* "tables": Table-level statistics (size, row count, access patterns)
|
||||
* "system": System-level metrics (CPU, memory, disk usage)
|
||||
|
||||
- time_range (string) [Optional] - Time range for statistics, default is "1h"
|
||||
* "1h": Last 1 hour
|
||||
* "6h": Last 6 hours
|
||||
* "24h": Last 24 hours
|
||||
* "7d": Last 7 days
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"metric_type": {
|
||||
"type": "string",
|
||||
"enum": ["queries", "connections", "tables", "system"],
|
||||
"description": "Performance metric type",
|
||||
"default": "queries",
|
||||
},
|
||||
"time_range": {
|
||||
"type": "string",
|
||||
"enum": ["1h", "6h", "24h", "7d"],
|
||||
"description": "Time range",
|
||||
"default": "1h",
|
||||
},
|
||||
},
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="exec_query",
|
||||
description="""[Function Description]: Execute SQL query and return result command with catalog federation support.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- sql (string) [Required] - SQL statement to execute. MUST use three-part naming for all table references: 'catalog_name.db_name.table_name'. For internal tables use 'internal.db_name.table_name', for external tables use 'catalog_name.db_name.table_name'
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Reference catalog name for context, defaults to current catalog
|
||||
|
||||
- max_rows (integer) [Optional] - Maximum number of rows to return, default 100
|
||||
|
||||
- timeout (integer) [Optional] - Query timeout in seconds, default 30
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sql": {"type": "string", "description": "SQL statement to execute, must use three-part naming"},
|
||||
"db_name": {"type": "string", "description": "Target database name"},
|
||||
"catalog_name": {"type": "string", "description": "Catalog name"},
|
||||
"max_rows": {"type": "integer", "description": "Maximum number of rows to return", "default": 100},
|
||||
"timeout": {"type": "integer", "description": "Timeout in seconds", "default": 30},
|
||||
},
|
||||
"required": ["sql"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="get_table_schema",
|
||||
description="""[Function Description]: Get detailed structure information of the specified table (columns, types, comments, etc.).
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to query
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"table_name": {"type": "string", "description": "Table name"},
|
||||
"db_name": {"type": "string", "description": "Database name"},
|
||||
"catalog_name": {"type": "string", "description": "Catalog name"},
|
||||
},
|
||||
"required": ["table_name"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="get_db_table_list",
|
||||
description="""[Function Description]: Get a list of all table names in the specified database.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"db_name": {"type": "string", "description": "Database name"},
|
||||
"catalog_name": {"type": "string", "description": "Catalog name"},
|
||||
},
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="get_db_list",
|
||||
description="""[Function Description]: Get a list of all database names on the server.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"catalog_name": {"type": "string", "description": "Catalog name"},
|
||||
},
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="get_table_comment",
|
||||
description="""[Function Description]: Get the comment information for the specified table.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to query
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"table_name": {"type": "string", "description": "Table name"},
|
||||
"db_name": {"type": "string", "description": "Database name"},
|
||||
"catalog_name": {"type": "string", "description": "Catalog name"},
|
||||
},
|
||||
"required": ["table_name"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="get_table_column_comments",
|
||||
description="""[Function Description]: Get comment information for all columns in the specified table.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to query
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"table_name": {"type": "string", "description": "Table name"},
|
||||
"db_name": {"type": "string", "description": "Database name"},
|
||||
"catalog_name": {"type": "string", "description": "Catalog name"},
|
||||
},
|
||||
"required": ["table_name"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="get_table_indexes",
|
||||
description="""[Function Description]: Get index information for the specified table.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to query
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"table_name": {"type": "string", "description": "Table name"},
|
||||
"db_name": {"type": "string", "description": "Database name"},
|
||||
"catalog_name": {"type": "string", "description": "Catalog name"},
|
||||
},
|
||||
"required": ["table_name"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="get_recent_audit_logs",
|
||||
description="""[Function Description]: Get audit log records for a recent period.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- days (integer) [Optional] - Number of recent days of logs to retrieve, default is 7
|
||||
|
||||
- limit (integer) [Optional] - Maximum number of records to return, default is 100
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"days": {"type": "integer", "description": "Number of recent days", "default": 7},
|
||||
"limit": {"type": "integer", "description": "Maximum number of records", "default": 100},
|
||||
},
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="get_catalog_list",
|
||||
description="""[Function Description]: Get a list of all catalog names on the server.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- random_string (string) [Required] - Unique identifier for the tool call
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"random_string": {"type": "string", "description": "Unique identifier"},
|
||||
},
|
||||
"required": ["random_string"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
return tools
|
||||
|
||||
async def call_tool(self, name: str, arguments: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Call the specified query tool (tool routing and scheduling center)
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Tool routing - dispatch requests to corresponding business logic processors
|
||||
if name == "column_analysis":
|
||||
result = await self._column_analysis_tool(arguments)
|
||||
elif name == "performance_stats":
|
||||
result = await self._performance_stats_tool(arguments)
|
||||
# ===== 9 tool routes migrated from source project =====
|
||||
elif name == "exec_query":
|
||||
result = await self._exec_query_tool(arguments)
|
||||
elif name == "get_table_schema":
|
||||
result = await self._get_table_schema_tool(arguments)
|
||||
elif name == "get_db_table_list":
|
||||
result = await self._get_db_table_list_tool(arguments)
|
||||
elif name == "get_db_list":
|
||||
result = await self._get_db_list_tool(arguments)
|
||||
elif name == "get_table_comment":
|
||||
result = await self._get_table_comment_tool(arguments)
|
||||
elif name == "get_table_column_comments":
|
||||
result = await self._get_table_column_comments_tool(arguments)
|
||||
elif name == "get_table_indexes":
|
||||
result = await self._get_table_indexes_tool(arguments)
|
||||
elif name == "get_recent_audit_logs":
|
||||
result = await self._get_recent_audit_logs_tool(arguments)
|
||||
elif name == "get_catalog_list":
|
||||
result = await self._get_catalog_list_tool(arguments)
|
||||
else:
|
||||
raise ValueError(f"Unknown tool: {name}")
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# Add execution information
|
||||
if isinstance(result, dict):
|
||||
result["_execution_info"] = {
|
||||
"tool_name": name,
|
||||
"execution_time": round(execution_time, 3),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool call failed {name}: {str(e)}")
|
||||
error_result = {
|
||||
"error": str(e),
|
||||
"tool_name": name,
|
||||
"arguments": arguments,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
return json.dumps(error_result, ensure_ascii=False, indent=2)
|
||||
|
||||
# The following are tool routing methods, responsible for calling corresponding business logic processors
|
||||
|
||||
async def _column_analysis_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Column statistical analysis tool routing"""
|
||||
table_name = arguments.get("table_name")
|
||||
column_name = arguments.get("column_name")
|
||||
analysis_type = arguments.get("analysis_type", "basic")
|
||||
|
||||
# Delegate to table analyzer for processing
|
||||
return await self.table_analyzer.analyze_column(
|
||||
table_name, column_name, analysis_type
|
||||
)
|
||||
|
||||
async def _performance_stats_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Database performance statistics tool routing"""
|
||||
metric_type = arguments.get("metric_type", "queries")
|
||||
time_range = arguments.get("time_range", "1h")
|
||||
|
||||
# Delegate to performance monitor for processing
|
||||
return await self.performance_monitor.get_performance_stats(
|
||||
metric_type, time_range
|
||||
)
|
||||
|
||||
async def _exec_query_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""SQL query execution tool routing (supports federation queries)"""
|
||||
sql = arguments.get("sql")
|
||||
db_name = arguments.get("db_name")
|
||||
catalog_name = arguments.get("catalog_name")
|
||||
max_rows = arguments.get("max_rows", 100)
|
||||
timeout = arguments.get("timeout", 30)
|
||||
|
||||
# Delegate to metadata extractor for processing
|
||||
return await self.metadata_extractor.exec_query_for_mcp(
|
||||
sql, db_name, catalog_name, max_rows, timeout
|
||||
)
|
||||
|
||||
async def _get_table_schema_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get table schema tool routing"""
|
||||
table_name = arguments.get("table_name")
|
||||
db_name = arguments.get("db_name")
|
||||
catalog_name = arguments.get("catalog_name")
|
||||
|
||||
# Delegate to metadata extractor for processing
|
||||
return await self.metadata_extractor.get_table_schema_for_mcp(
|
||||
table_name, db_name, catalog_name
|
||||
)
|
||||
|
||||
async def _get_db_table_list_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get database table list tool routing"""
|
||||
db_name = arguments.get("db_name")
|
||||
catalog_name = arguments.get("catalog_name")
|
||||
|
||||
# Delegate to metadata extractor for processing
|
||||
return await self.metadata_extractor.get_db_table_list_for_mcp(db_name, catalog_name)
|
||||
|
||||
async def _get_db_list_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get database list tool routing"""
|
||||
catalog_name = arguments.get("catalog_name")
|
||||
|
||||
# Delegate to metadata extractor for processing
|
||||
return await self.metadata_extractor.get_db_list_for_mcp(catalog_name)
|
||||
|
||||
async def _get_table_comment_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get table comment tool routing"""
|
||||
table_name = arguments.get("table_name")
|
||||
db_name = arguments.get("db_name")
|
||||
catalog_name = arguments.get("catalog_name")
|
||||
|
||||
# Delegate to metadata extractor for processing
|
||||
return await self.metadata_extractor.get_table_comment_for_mcp(
|
||||
table_name, db_name, catalog_name
|
||||
)
|
||||
|
||||
async def _get_table_column_comments_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get table column comments tool routing"""
|
||||
table_name = arguments.get("table_name")
|
||||
db_name = arguments.get("db_name")
|
||||
catalog_name = arguments.get("catalog_name")
|
||||
|
||||
# Delegate to metadata extractor for processing
|
||||
return await self.metadata_extractor.get_table_column_comments_for_mcp(
|
||||
table_name, db_name, catalog_name
|
||||
)
|
||||
|
||||
async def _get_table_indexes_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get table indexes tool routing"""
|
||||
table_name = arguments.get("table_name")
|
||||
db_name = arguments.get("db_name")
|
||||
catalog_name = arguments.get("catalog_name")
|
||||
|
||||
# Delegate to metadata extractor for processing
|
||||
return await self.metadata_extractor.get_table_indexes_for_mcp(
|
||||
table_name, db_name, catalog_name
|
||||
)
|
||||
|
||||
async def _get_recent_audit_logs_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get audit logs tool routing"""
|
||||
days = arguments.get("days", 7)
|
||||
limit = arguments.get("limit", 100)
|
||||
|
||||
# Delegate to metadata extractor for processing
|
||||
return await self.metadata_extractor.get_recent_audit_logs_for_mcp(days, limit)
|
||||
|
||||
async def _get_catalog_list_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get catalog list tool routing"""
|
||||
# random_string parameter is required in the source project, but not actually used in business logic
|
||||
# Here we ignore it and directly call business logic
|
||||
|
||||
# Delegate to metadata extractor for processing
|
||||
return await self.metadata_extractor.get_catalog_list_for_mcp()
|
||||
@@ -1 +1,10 @@
|
||||
# Mark directory as a package
|
||||
"""
|
||||
Utilities Package - Contains utility classes and helper functions.
|
||||
|
||||
This package includes:
|
||||
- Database connection and operations
|
||||
- Configuration management
|
||||
- Security utilities
|
||||
- Query execution helpers
|
||||
- Logging configuration
|
||||
"""
|
||||
|
||||
318
doris_mcp_server/utils/analysis_tools.py
Normal file
318
doris_mcp_server/utils/analysis_tools.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
Data Analysis Tools Module
|
||||
Provides data analysis functions including table analysis, column statistics, performance monitoring, etc.
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TableAnalyzer:
|
||||
"""Table analyzer"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
|
||||
async def get_table_summary(
|
||||
self,
|
||||
table_name: str,
|
||||
include_sample: bool = True,
|
||||
sample_size: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""Get table summary information"""
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
|
||||
# Get table basic information
|
||||
table_info_sql = f"""
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment,
|
||||
table_rows,
|
||||
create_time,
|
||||
engine
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '{table_name}'
|
||||
"""
|
||||
|
||||
table_info_result = await connection.execute(table_info_sql)
|
||||
if not table_info_result.data:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
|
||||
table_info = table_info_result.data[0]
|
||||
|
||||
# Get column information
|
||||
columns_sql = f"""
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
is_nullable,
|
||||
column_comment
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '{table_name}'
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
columns_result = await connection.execute(columns_sql)
|
||||
|
||||
summary = {
|
||||
"table_name": table_info["table_name"],
|
||||
"comment": table_info.get("table_comment"),
|
||||
"row_count": table_info.get("table_rows", 0),
|
||||
"create_time": str(table_info.get("create_time")),
|
||||
"engine": table_info.get("engine"),
|
||||
"column_count": len(columns_result.data),
|
||||
"columns": columns_result.data,
|
||||
}
|
||||
|
||||
# Get sample data
|
||||
if include_sample and sample_size > 0:
|
||||
sample_sql = f"SELECT * FROM {table_name} LIMIT {sample_size}"
|
||||
sample_result = await connection.execute(sample_sql)
|
||||
summary["sample_data"] = sample_result.data
|
||||
|
||||
return summary
|
||||
|
||||
async def analyze_column(
|
||||
self,
|
||||
table_name: str,
|
||||
column_name: str,
|
||||
analysis_type: str = "basic"
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze column statistics"""
|
||||
try:
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
|
||||
# Basic statistics
|
||||
basic_stats_sql = f"""
|
||||
SELECT
|
||||
'{column_name}' as column_name,
|
||||
COUNT(*) as total_count,
|
||||
COUNT({column_name}) as non_null_count,
|
||||
COUNT(DISTINCT {column_name}) as distinct_count
|
||||
FROM {table_name}
|
||||
"""
|
||||
|
||||
basic_result = await connection.execute(basic_stats_sql)
|
||||
if not basic_result.data:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Unable to get statistics for table {table_name} column {column_name}"
|
||||
}
|
||||
|
||||
analysis = basic_result.data[0].copy()
|
||||
analysis["success"] = True
|
||||
analysis["analysis_type"] = analysis_type
|
||||
|
||||
if analysis_type in ["distribution", "detailed"]:
|
||||
# Data distribution analysis
|
||||
distribution_sql = f"""
|
||||
SELECT
|
||||
{column_name} as value,
|
||||
COUNT(*) as frequency
|
||||
FROM {table_name}
|
||||
WHERE {column_name} IS NOT NULL
|
||||
GROUP BY {column_name}
|
||||
ORDER BY frequency DESC
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
distribution_result = await connection.execute(distribution_sql)
|
||||
analysis["value_distribution"] = distribution_result.data
|
||||
|
||||
if analysis_type == "detailed":
|
||||
# Detailed statistics (for numeric types)
|
||||
try:
|
||||
numeric_stats_sql = f"""
|
||||
SELECT
|
||||
MIN({column_name}) as min_value,
|
||||
MAX({column_name}) as max_value,
|
||||
AVG({column_name}) as avg_value
|
||||
FROM {table_name}
|
||||
WHERE {column_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
numeric_result = await connection.execute(numeric_stats_sql)
|
||||
if numeric_result.data:
|
||||
analysis.update(numeric_result.data[0])
|
||||
except Exception:
|
||||
# Non-numeric columns don't support numeric statistics
|
||||
pass
|
||||
|
||||
return analysis
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Column analysis failed: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"column_name": column_name,
|
||||
"table_name": table_name
|
||||
}
|
||||
|
||||
async def analyze_table_relationships(
|
||||
self,
|
||||
table_name: str,
|
||||
depth: int = 2
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze table relationships"""
|
||||
connection = await self.connection_manager.get_connection("system")
|
||||
|
||||
# Get table basic information
|
||||
table_info_sql = f"""
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment,
|
||||
table_rows
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '{table_name}'
|
||||
"""
|
||||
|
||||
table_result = await connection.execute(table_info_sql)
|
||||
if not table_result.data:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
|
||||
# Get all tables list (for analyzing potential relationships)
|
||||
all_tables_sql = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_type = 'BASE TABLE'
|
||||
AND table_name != %s
|
||||
"""
|
||||
|
||||
all_tables_result = await connection.execute(all_tables_sql, (table_name,))
|
||||
|
||||
return {
|
||||
"center_table": table_result.data[0],
|
||||
"related_tables": all_tables_result.data,
|
||||
"depth": depth,
|
||||
"note": "Table relationship analysis based on column name similarity and business logic inference",
|
||||
}
|
||||
|
||||
|
||||
class PerformanceMonitor:
|
||||
"""Performance monitor"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
|
||||
async def get_performance_stats(
|
||||
self,
|
||||
metric_type: str = "queries",
|
||||
time_range: str = "1h"
|
||||
) -> Dict[str, Any]:
|
||||
"""Get performance statistics"""
|
||||
connection = await self.connection_manager.get_connection("system")
|
||||
|
||||
# Convert time range to seconds
|
||||
time_mapping = {
|
||||
"1h": 3600,
|
||||
"6h": 21600,
|
||||
"24h": 86400,
|
||||
"7d": 604800
|
||||
}
|
||||
|
||||
seconds = time_mapping.get(time_range, 3600)
|
||||
|
||||
if metric_type == "queries":
|
||||
# Query performance metrics
|
||||
stats = {
|
||||
"metric_type": "queries",
|
||||
"time_range": time_range,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"total_queries": 0,
|
||||
"avg_execution_time": 0.0,
|
||||
"slow_queries": 0,
|
||||
"error_queries": 0,
|
||||
"note": "Query performance statistics (simulated data)"
|
||||
}
|
||||
|
||||
elif metric_type == "connections":
|
||||
# Connection statistics
|
||||
connection_metrics = await self.connection_manager.get_metrics()
|
||||
stats = {
|
||||
"metric_type": "connections",
|
||||
"time_range": time_range,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"total_connections": connection_metrics.total_connections,
|
||||
"active_connections": connection_metrics.active_connections,
|
||||
"idle_connections": connection_metrics.idle_connections,
|
||||
"failed_connections": connection_metrics.failed_connections,
|
||||
"connection_errors": connection_metrics.connection_errors,
|
||||
"avg_connection_time": connection_metrics.avg_connection_time,
|
||||
"last_health_check": connection_metrics.last_health_check.isoformat() if connection_metrics.last_health_check else None
|
||||
}
|
||||
|
||||
elif metric_type == "tables":
|
||||
# Table-level statistics
|
||||
tables_sql = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_rows,
|
||||
data_length,
|
||||
index_length,
|
||||
create_time,
|
||||
update_time
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_type = 'BASE TABLE'
|
||||
ORDER BY table_rows DESC
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
tables_result = await connection.execute(tables_sql)
|
||||
stats = {
|
||||
"metric_type": "tables",
|
||||
"time_range": time_range,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"table_count": len(tables_result.data),
|
||||
"tables": tables_result.data
|
||||
}
|
||||
|
||||
elif metric_type == "system":
|
||||
# System-level metrics (simulated)
|
||||
stats = {
|
||||
"metric_type": "system",
|
||||
"time_range": time_range,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"cpu_usage": 45.2,
|
||||
"memory_usage": 68.5,
|
||||
"disk_usage": 72.1,
|
||||
"network_io": {
|
||||
"bytes_sent": 1024000,
|
||||
"bytes_received": 2048000
|
||||
},
|
||||
"note": "System metrics (simulated data)"
|
||||
}
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported metric type: {metric_type}")
|
||||
|
||||
return stats
|
||||
|
||||
async def get_query_history(
|
||||
self,
|
||||
limit: int = 50,
|
||||
order_by: str = "time"
|
||||
) -> Dict[str, Any]:
|
||||
"""Get query history"""
|
||||
# Since Doris doesn't have a built-in query history table,
|
||||
# we return simulated data
|
||||
return {
|
||||
"total_queries": 0,
|
||||
"queries": [],
|
||||
"limit": limit,
|
||||
"order_by": order_by,
|
||||
"note": "Query history feature requires audit log configuration"
|
||||
}
|
||||
608
doris_mcp_server/utils/config.py
Normal file
608
doris_mcp_server/utils/config.py
Normal file
@@ -0,0 +1,608 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Doris Configuration Management Module
|
||||
Implements configuration loading, validation and management functionality
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except ImportError:
|
||||
load_dotenv = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""Database connection configuration"""
|
||||
|
||||
host: str = "localhost"
|
||||
port: int = 9030
|
||||
user: str = "root"
|
||||
password: str = ""
|
||||
database: str = "test"
|
||||
charset: str = "utf8mb4"
|
||||
|
||||
# Connection pool configuration
|
||||
min_connections: int = 5
|
||||
max_connections: int = 20
|
||||
connection_timeout: int = 30
|
||||
health_check_interval: int = 60
|
||||
max_connection_age: int = 3600
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityConfig:
|
||||
"""Security configuration"""
|
||||
|
||||
# Authentication configuration
|
||||
auth_type: str = "token" # token, basic, oauth
|
||||
token_secret: str = "default_secret"
|
||||
token_expiry: int = 3600
|
||||
|
||||
# SQL security configuration
|
||||
blocked_keywords: list[str] = field(
|
||||
default_factory=lambda: [
|
||||
"DROP",
|
||||
"DELETE",
|
||||
"TRUNCATE",
|
||||
"ALTER",
|
||||
"CREATE",
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"GRANT",
|
||||
"REVOKE",
|
||||
]
|
||||
)
|
||||
max_query_complexity: int = 100
|
||||
max_result_rows: int = 10000
|
||||
|
||||
# Sensitive table configuration
|
||||
sensitive_tables: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# Data masking configuration
|
||||
enable_masking: bool = True
|
||||
masking_rules: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerformanceConfig:
|
||||
"""Performance configuration"""
|
||||
|
||||
# Query cache configuration
|
||||
enable_query_cache: bool = True
|
||||
cache_ttl: int = 300
|
||||
max_cache_size: int = 1000
|
||||
|
||||
# Concurrency control configuration
|
||||
max_concurrent_queries: int = 50
|
||||
query_timeout: int = 300
|
||||
|
||||
# Connection pool optimization configuration
|
||||
connection_pool_size: int = 20
|
||||
idle_timeout: int = 1800
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoggingConfig:
|
||||
"""Logging configuration"""
|
||||
|
||||
level: str = "INFO"
|
||||
format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
file_path: str | None = None
|
||||
max_file_size: int = 10 * 1024 * 1024 # 10MB
|
||||
backup_count: int = 5
|
||||
|
||||
# Audit log configuration
|
||||
enable_audit: bool = True
|
||||
audit_file_path: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MonitoringConfig:
|
||||
"""Monitoring configuration"""
|
||||
|
||||
# Metrics collection configuration
|
||||
enable_metrics: bool = True
|
||||
metrics_port: int = 8081
|
||||
metrics_path: str = "/metrics"
|
||||
|
||||
# Health check configuration
|
||||
health_check_port: int = 8082
|
||||
health_check_path: str = "/health"
|
||||
|
||||
# Alert configuration
|
||||
enable_alerts: bool = False
|
||||
alert_webhook_url: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DorisConfig:
|
||||
"""Doris MCP Server complete configuration"""
|
||||
|
||||
# Basic configuration
|
||||
server_name: str = "doris-mcp-server"
|
||||
server_version: str = "1.0.0"
|
||||
server_port: int = 8080
|
||||
|
||||
# Sub-configuration modules
|
||||
database: DatabaseConfig = field(default_factory=DatabaseConfig)
|
||||
security: SecurityConfig = field(default_factory=SecurityConfig)
|
||||
performance: PerformanceConfig = field(default_factory=PerformanceConfig)
|
||||
logging: LoggingConfig = field(default_factory=LoggingConfig)
|
||||
monitoring: MonitoringConfig = field(default_factory=MonitoringConfig)
|
||||
|
||||
# Custom configuration
|
||||
custom_config: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_path: str) -> "DorisConfig":
|
||||
"""Load configuration from file"""
|
||||
config_file = Path(config_path)
|
||||
|
||||
if not config_file.exists():
|
||||
raise FileNotFoundError(f"Configuration file does not exist: {config_path}")
|
||||
|
||||
try:
|
||||
with open(config_file, encoding="utf-8") as f:
|
||||
if config_file.suffix.lower() == ".json":
|
||||
config_data = json.load(f)
|
||||
else:
|
||||
# Support other formats (like YAML)
|
||||
raise ValueError(f"Unsupported configuration file format: {config_file.suffix}")
|
||||
|
||||
return cls._from_dict(config_data)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load configuration file: {e}")
|
||||
|
||||
@classmethod
|
||||
def from_env(cls, env_file: str | None = None) -> "DorisConfig":
|
||||
"""Load configuration from environment variables
|
||||
|
||||
Args:
|
||||
env_file: .env file path, if None, search in the following order:
|
||||
.env, .env.local, .env.production, .env.development
|
||||
"""
|
||||
# Load .env file
|
||||
if load_dotenv is not None:
|
||||
if env_file:
|
||||
# Load specified .env file
|
||||
if Path(env_file).exists():
|
||||
load_dotenv(env_file)
|
||||
logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_file}")
|
||||
else:
|
||||
logging.getLogger(__name__).warning(f"Environment configuration file does not exist: {env_file}")
|
||||
else:
|
||||
# Load .env files in priority order
|
||||
env_files = [".env", ".env.local", ".env.production", ".env.development"]
|
||||
for env_path in env_files:
|
||||
if Path(env_path).exists():
|
||||
load_dotenv(env_path)
|
||||
logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_path}")
|
||||
break
|
||||
else:
|
||||
logging.getLogger(__name__).info("No .env configuration file found, using system environment variables")
|
||||
else:
|
||||
logging.getLogger(__name__).warning("python-dotenv not installed, cannot load .env files")
|
||||
|
||||
config = cls()
|
||||
|
||||
# Database configuration
|
||||
config.database.host = os.getenv("DORIS_HOST", config.database.host)
|
||||
config.database.port = int(os.getenv("DORIS_PORT", str(config.database.port)))
|
||||
config.database.user = os.getenv("DORIS_USER", config.database.user)
|
||||
config.database.password = os.getenv("DORIS_PASSWORD", config.database.password)
|
||||
config.database.database = os.getenv("DORIS_DATABASE", config.database.database)
|
||||
|
||||
# Connection pool configuration
|
||||
config.database.min_connections = int(
|
||||
os.getenv("DORIS_MIN_CONNECTIONS", str(config.database.min_connections))
|
||||
)
|
||||
config.database.max_connections = int(
|
||||
os.getenv("DORIS_MAX_CONNECTIONS", str(config.database.max_connections))
|
||||
)
|
||||
config.database.connection_timeout = int(
|
||||
os.getenv("DORIS_CONNECTION_TIMEOUT", str(config.database.connection_timeout))
|
||||
)
|
||||
config.database.health_check_interval = int(
|
||||
os.getenv("DORIS_HEALTH_CHECK_INTERVAL", str(config.database.health_check_interval))
|
||||
)
|
||||
config.database.max_connection_age = int(
|
||||
os.getenv("DORIS_MAX_CONNECTION_AGE", str(config.database.max_connection_age))
|
||||
)
|
||||
|
||||
# Security configuration
|
||||
config.security.auth_type = os.getenv("AUTH_TYPE", config.security.auth_type)
|
||||
config.security.token_secret = os.getenv("TOKEN_SECRET", config.security.token_secret)
|
||||
config.security.token_expiry = int(
|
||||
os.getenv("TOKEN_EXPIRY", str(config.security.token_expiry))
|
||||
)
|
||||
config.security.max_result_rows = int(
|
||||
os.getenv("MAX_RESULT_ROWS", str(config.security.max_result_rows))
|
||||
)
|
||||
config.security.max_query_complexity = int(
|
||||
os.getenv("MAX_QUERY_COMPLEXITY", str(config.security.max_query_complexity))
|
||||
)
|
||||
config.security.enable_masking = (
|
||||
os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true"
|
||||
)
|
||||
|
||||
# Performance configuration
|
||||
config.performance.enable_query_cache = (
|
||||
os.getenv("ENABLE_QUERY_CACHE", "true").lower() == "true"
|
||||
)
|
||||
config.performance.cache_ttl = int(
|
||||
os.getenv("CACHE_TTL", str(config.performance.cache_ttl))
|
||||
)
|
||||
config.performance.max_cache_size = int(
|
||||
os.getenv("MAX_CACHE_SIZE", str(config.performance.max_cache_size))
|
||||
)
|
||||
config.performance.max_concurrent_queries = int(
|
||||
os.getenv("MAX_CONCURRENT_QUERIES", str(config.performance.max_concurrent_queries))
|
||||
)
|
||||
config.performance.query_timeout = int(
|
||||
os.getenv("QUERY_TIMEOUT", str(config.performance.query_timeout))
|
||||
)
|
||||
|
||||
# Logging configuration
|
||||
config.logging.level = os.getenv("LOG_LEVEL", config.logging.level)
|
||||
config.logging.file_path = os.getenv("LOG_FILE_PATH", config.logging.file_path)
|
||||
config.logging.enable_audit = (
|
||||
os.getenv("ENABLE_AUDIT", str(config.logging.enable_audit).lower()).lower() == "true"
|
||||
)
|
||||
config.logging.audit_file_path = os.getenv("AUDIT_FILE_PATH", config.logging.audit_file_path)
|
||||
|
||||
# Monitoring configuration
|
||||
config.monitoring.enable_metrics = (
|
||||
os.getenv("ENABLE_METRICS", "true").lower() == "true"
|
||||
)
|
||||
config.monitoring.metrics_port = int(
|
||||
os.getenv("METRICS_PORT", str(config.monitoring.metrics_port))
|
||||
)
|
||||
config.monitoring.health_check_port = int(
|
||||
os.getenv("HEALTH_CHECK_PORT", str(config.monitoring.health_check_port))
|
||||
)
|
||||
config.monitoring.enable_alerts = (
|
||||
os.getenv("ENABLE_ALERTS", str(config.monitoring.enable_alerts).lower()).lower() == "true"
|
||||
)
|
||||
config.monitoring.alert_webhook_url = os.getenv("ALERT_WEBHOOK_URL", config.monitoring.alert_webhook_url)
|
||||
|
||||
# Server configuration
|
||||
config.server_name = os.getenv("SERVER_NAME", config.server_name)
|
||||
config.server_version = os.getenv("SERVER_VERSION", config.server_version)
|
||||
config.server_port = int(os.getenv("SERVER_PORT", str(config.server_port)))
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _from_dict(cls, config_data: dict[str, Any]) -> "DorisConfig":
|
||||
"""Create configuration object from dictionary"""
|
||||
config = cls()
|
||||
|
||||
# Update basic configuration
|
||||
for key in ["server_name", "server_version", "server_port"]:
|
||||
if key in config_data:
|
||||
setattr(config, key, config_data[key])
|
||||
|
||||
# Update database configuration
|
||||
if "database" in config_data:
|
||||
db_config = config_data["database"]
|
||||
for key, value in db_config.items():
|
||||
if hasattr(config.database, key):
|
||||
setattr(config.database, key, value)
|
||||
|
||||
# Update security configuration
|
||||
if "security" in config_data:
|
||||
sec_config = config_data["security"]
|
||||
for key, value in sec_config.items():
|
||||
if hasattr(config.security, key):
|
||||
setattr(config.security, key, value)
|
||||
|
||||
# Update performance configuration
|
||||
if "performance" in config_data:
|
||||
perf_config = config_data["performance"]
|
||||
for key, value in perf_config.items():
|
||||
if hasattr(config.performance, key):
|
||||
setattr(config.performance, key, value)
|
||||
|
||||
# Update logging configuration
|
||||
if "logging" in config_data:
|
||||
log_config = config_data["logging"]
|
||||
for key, value in log_config.items():
|
||||
if hasattr(config.logging, key):
|
||||
setattr(config.logging, key, value)
|
||||
|
||||
# Update monitoring configuration
|
||||
if "monitoring" in config_data:
|
||||
mon_config = config_data["monitoring"]
|
||||
for key, value in mon_config.items():
|
||||
if hasattr(config.monitoring, key):
|
||||
setattr(config.monitoring, key, value)
|
||||
|
||||
# Custom configuration
|
||||
config.custom_config = config_data.get("custom", {})
|
||||
|
||||
return config
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary format"""
|
||||
return {
|
||||
"server_name": self.server_name,
|
||||
"server_version": self.server_version,
|
||||
"server_port": self.server_port,
|
||||
"database": {
|
||||
"host": self.database.host,
|
||||
"port": self.database.port,
|
||||
"user": self.database.user,
|
||||
"password": "***", # Hide password
|
||||
"database": self.database.database,
|
||||
"charset": self.database.charset,
|
||||
"min_connections": self.database.min_connections,
|
||||
"max_connections": self.database.max_connections,
|
||||
"connection_timeout": self.database.connection_timeout,
|
||||
"health_check_interval": self.database.health_check_interval,
|
||||
"max_connection_age": self.database.max_connection_age,
|
||||
},
|
||||
"security": {
|
||||
"auth_type": self.security.auth_type,
|
||||
"token_secret": "***", # Hide secret key
|
||||
"token_expiry": self.security.token_expiry,
|
||||
"blocked_keywords": self.security.blocked_keywords,
|
||||
"max_query_complexity": self.security.max_query_complexity,
|
||||
"max_result_rows": self.security.max_result_rows,
|
||||
"sensitive_tables": self.security.sensitive_tables,
|
||||
"enable_masking": self.security.enable_masking,
|
||||
"masking_rules": len(self.security.masking_rules),
|
||||
},
|
||||
"performance": {
|
||||
"enable_query_cache": self.performance.enable_query_cache,
|
||||
"cache_ttl": self.performance.cache_ttl,
|
||||
"max_cache_size": self.performance.max_cache_size,
|
||||
"max_concurrent_queries": self.performance.max_concurrent_queries,
|
||||
"query_timeout": self.performance.query_timeout,
|
||||
"connection_pool_size": self.performance.connection_pool_size,
|
||||
"idle_timeout": self.performance.idle_timeout,
|
||||
},
|
||||
"logging": {
|
||||
"level": self.logging.level,
|
||||
"format": self.logging.format,
|
||||
"file_path": self.logging.file_path,
|
||||
"max_file_size": self.logging.max_file_size,
|
||||
"backup_count": self.logging.backup_count,
|
||||
"enable_audit": self.logging.enable_audit,
|
||||
"audit_file_path": self.logging.audit_file_path,
|
||||
},
|
||||
"monitoring": {
|
||||
"enable_metrics": self.monitoring.enable_metrics,
|
||||
"metrics_port": self.monitoring.metrics_port,
|
||||
"metrics_path": self.monitoring.metrics_path,
|
||||
"health_check_port": self.monitoring.health_check_port,
|
||||
"health_check_path": self.monitoring.health_check_path,
|
||||
"enable_alerts": self.monitoring.enable_alerts,
|
||||
"alert_webhook_url": self.monitoring.alert_webhook_url,
|
||||
},
|
||||
"custom": self.custom_config,
|
||||
}
|
||||
|
||||
def save_to_file(self, config_path: str):
|
||||
"""Save configuration to file"""
|
||||
config_file = Path(config_path)
|
||||
config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
with open(config_file, "w", encoding="utf-8") as f:
|
||||
if config_file.suffix.lower() == ".json":
|
||||
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
else:
|
||||
raise ValueError(f"Unsupported configuration file format: {config_file.suffix}")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to save configuration file: {e}")
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
"""Validate configuration validity"""
|
||||
errors = []
|
||||
|
||||
# Validate database configuration
|
||||
if not self.database.host:
|
||||
errors.append("Database host address cannot be empty")
|
||||
|
||||
if not (1 <= self.database.port <= 65535):
|
||||
errors.append("Database port must be in the range 1-65535")
|
||||
|
||||
if not self.database.user:
|
||||
errors.append("Database username cannot be empty")
|
||||
|
||||
if self.database.min_connections <= 0:
|
||||
errors.append("Minimum connections must be greater than 0")
|
||||
|
||||
if self.database.max_connections <= self.database.min_connections:
|
||||
errors.append("Maximum connections must be greater than minimum connections")
|
||||
|
||||
# Validate security configuration
|
||||
if self.security.auth_type not in ["token", "basic", "oauth"]:
|
||||
errors.append("Authentication type must be one of token, basic, or oauth")
|
||||
|
||||
if self.security.token_expiry <= 0:
|
||||
errors.append("Token expiry time must be greater than 0")
|
||||
|
||||
if self.security.max_query_complexity <= 0:
|
||||
errors.append("Maximum query complexity must be greater than 0")
|
||||
|
||||
if self.security.max_result_rows <= 0:
|
||||
errors.append("Maximum result rows must be greater than 0")
|
||||
|
||||
# Validate performance configuration
|
||||
if self.performance.cache_ttl <= 0:
|
||||
errors.append("Cache TTL must be greater than 0")
|
||||
|
||||
if self.performance.max_concurrent_queries <= 0:
|
||||
errors.append("Maximum concurrent queries must be greater than 0")
|
||||
|
||||
if self.performance.query_timeout <= 0:
|
||||
errors.append("Query timeout must be greater than 0")
|
||||
|
||||
# Validate logging configuration
|
||||
if self.logging.level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
|
||||
errors.append("Log level must be one of DEBUG, INFO, WARNING, ERROR, or CRITICAL")
|
||||
|
||||
if self.logging.max_file_size <= 0:
|
||||
errors.append("Maximum log file size must be greater than 0")
|
||||
|
||||
if self.logging.backup_count < 0:
|
||||
errors.append("Log backup count cannot be negative")
|
||||
|
||||
# Validate monitoring configuration
|
||||
if not (1 <= self.monitoring.metrics_port <= 65535):
|
||||
errors.append("Monitoring port must be in the range 1-65535")
|
||||
|
||||
if not (1 <= self.monitoring.health_check_port <= 65535):
|
||||
errors.append("Health check port must be in the range 1-65535")
|
||||
|
||||
return errors
|
||||
|
||||
def get_connection_string(self) -> str:
|
||||
"""Get database connection string (hide password)"""
|
||||
return f"mysql://{self.database.user}:***@{self.database.host}:{self.database.port}/{self.database.database}"
|
||||
|
||||
def get_config_summary(self) -> dict[str, Any]:
|
||||
"""Get configuration summary information"""
|
||||
return {
|
||||
"server": f"{self.server_name} v{self.server_version}",
|
||||
"database": f"{self.database.host}:{self.database.port}/{self.database.database}",
|
||||
"connection_pool": f"{self.database.min_connections}-{self.database.max_connections}",
|
||||
"security": {
|
||||
"auth_type": self.security.auth_type,
|
||||
"masking_enabled": self.security.enable_masking,
|
||||
"blocked_keywords_count": len(self.security.blocked_keywords),
|
||||
},
|
||||
"performance": {
|
||||
"cache_enabled": self.performance.enable_query_cache,
|
||||
"max_concurrent": self.performance.max_concurrent_queries,
|
||||
"query_timeout": self.performance.query_timeout,
|
||||
},
|
||||
"monitoring": {
|
||||
"metrics_enabled": self.monitoring.enable_metrics,
|
||||
"alerts_enabled": self.monitoring.enable_alerts,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""Configuration manager class"""
|
||||
|
||||
def __init__(self, config: DorisConfig):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def setup_logging(self):
|
||||
"""Setup logging configuration"""
|
||||
# Configure root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(getattr(logging, self.config.logging.level.upper()))
|
||||
|
||||
# Clear existing handlers
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter(self.config.logging.format)
|
||||
|
||||
# Console handler
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# File handler (if configured)
|
||||
if self.config.logging.file_path:
|
||||
try:
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
file_handler = RotatingFileHandler(
|
||||
self.config.logging.file_path,
|
||||
maxBytes=self.config.logging.max_file_size,
|
||||
backupCount=self.config.logging.backup_count,
|
||||
encoding="utf-8",
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(file_handler)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to setup file logging: {e}")
|
||||
|
||||
# Audit log handler (if configured)
|
||||
if self.config.logging.enable_audit and self.config.logging.audit_file_path:
|
||||
try:
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
audit_logger = logging.getLogger("audit")
|
||||
audit_handler = RotatingFileHandler(
|
||||
self.config.logging.audit_file_path,
|
||||
maxBytes=self.config.logging.max_file_size,
|
||||
backupCount=self.config.logging.backup_count,
|
||||
encoding="utf-8",
|
||||
)
|
||||
audit_handler.setFormatter(formatter)
|
||||
audit_logger.addHandler(audit_handler)
|
||||
audit_logger.setLevel(logging.INFO)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to setup audit logging: {e}")
|
||||
|
||||
def validate_config(self) -> bool:
|
||||
"""Validate configuration"""
|
||||
errors = self.config.validate()
|
||||
if errors:
|
||||
self.logger.error("Configuration validation failed:")
|
||||
for error in errors:
|
||||
self.logger.error(f" - {error}")
|
||||
return False
|
||||
|
||||
self.logger.info("Configuration validation passed")
|
||||
return True
|
||||
|
||||
def log_config_summary(self):
|
||||
"""Log configuration summary"""
|
||||
summary = self.config.get_config_summary()
|
||||
self.logger.info("Configuration Summary:")
|
||||
self.logger.info(f" Server: {summary['server']}")
|
||||
self.logger.info(f" Database: {summary['database']}")
|
||||
self.logger.info(f" Connection Pool: {summary['connection_pool']}")
|
||||
self.logger.info(f" Security: {summary['security']}")
|
||||
self.logger.info(f" Performance: {summary['performance']}")
|
||||
self.logger.info(f" Monitoring: {summary['monitoring']}")
|
||||
|
||||
|
||||
def create_default_config_file(config_path: str):
|
||||
"""Create default configuration file"""
|
||||
config = DorisConfig()
|
||||
config.save_to_file(config_path)
|
||||
print(f"Default configuration file created: {config_path}")
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Create default configuration
|
||||
config = DorisConfig()
|
||||
|
||||
# Load from environment variables
|
||||
# config = DorisConfig.from_env()
|
||||
|
||||
# Load from file
|
||||
# config = DorisConfig.from_file("config.json")
|
||||
|
||||
# Validate configuration
|
||||
config_manager = ConfigManager(config)
|
||||
if config_manager.validate_config():
|
||||
config_manager.setup_logging()
|
||||
config_manager.log_config_summary()
|
||||
|
||||
# Save configuration
|
||||
config.save_to_file("example_config.json")
|
||||
print("Configuration saved to example_config.json")
|
||||
else:
|
||||
print("Configuration validation failed")
|
||||
@@ -1,100 +1,479 @@
|
||||
import os
|
||||
import json
|
||||
import pymysql
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dotenv import load_dotenv
|
||||
import re
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Apache Doris Database Connection Management Module
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
Provides high-performance database connection pool management, automatic reconnection mechanism and connection health check functionality
|
||||
Supports asynchronous operations and concurrent connection management, ensuring stability and performance for enterprise applications
|
||||
"""
|
||||
|
||||
# Database configuration
|
||||
DB_CONFIG = {
|
||||
"host": os.getenv("DB_HOST", "localhost"),
|
||||
"port": int(os.getenv("DB_PORT", "9030")),
|
||||
"user": os.getenv("DB_USER", "root"),
|
||||
"password": os.getenv("DB_PASSWORD", ""),
|
||||
"database": os.getenv("DB_DATABASE", ""),
|
||||
"charset": "utf8mb4",
|
||||
"cursorclass": pymysql.cursors.DictCursor
|
||||
}
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
def get_db_connection(db_name: Optional[str] = None):
|
||||
import aiomysql
|
||||
from aiomysql import Connection, Pool
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionMetrics:
|
||||
"""Connection pool performance metrics"""
|
||||
|
||||
total_connections: int = 0
|
||||
active_connections: int = 0
|
||||
idle_connections: int = 0
|
||||
failed_connections: int = 0
|
||||
connection_errors: int = 0
|
||||
avg_connection_time: float = 0.0
|
||||
last_health_check: datetime | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryResult:
|
||||
"""Query result wrapper"""
|
||||
|
||||
data: list[dict[str, Any]]
|
||||
metadata: dict[str, Any]
|
||||
execution_time: float
|
||||
row_count: int
|
||||
|
||||
|
||||
class DorisConnection:
|
||||
"""Doris database connection wrapper class"""
|
||||
|
||||
def __init__(self, connection: Connection, session_id: str, security_manager=None):
|
||||
self.connection = connection
|
||||
self.session_id = session_id
|
||||
self.created_at = datetime.utcnow()
|
||||
self.last_used = datetime.utcnow()
|
||||
self.query_count = 0
|
||||
self.is_healthy = True
|
||||
self.security_manager = security_manager
|
||||
|
||||
async def execute(self, sql: str, params: tuple | None = None, auth_context=None) -> QueryResult:
|
||||
"""Execute SQL query"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# If security manager exists, perform SQL security check
|
||||
security_result = None
|
||||
if self.security_manager and auth_context:
|
||||
validation_result = await self.security_manager.validate_sql_security(sql, auth_context)
|
||||
if not validation_result.is_valid:
|
||||
raise ValueError(f"SQL security validation failed: {validation_result.error_message}")
|
||||
security_result = {
|
||||
"is_valid": validation_result.is_valid,
|
||||
"risk_level": validation_result.risk_level,
|
||||
"blocked_operations": validation_result.blocked_operations
|
||||
}
|
||||
|
||||
async with self.connection.cursor(aiomysql.DictCursor) as cursor:
|
||||
await cursor.execute(sql, params)
|
||||
|
||||
# Check if it's a query statement (statement that returns result set)
|
||||
sql_upper = sql.strip().upper()
|
||||
if (sql_upper.startswith("SELECT") or
|
||||
sql_upper.startswith("SHOW") or
|
||||
sql_upper.startswith("DESCRIBE") or
|
||||
sql_upper.startswith("DESC") or
|
||||
sql_upper.startswith("EXPLAIN")):
|
||||
data = await cursor.fetchall()
|
||||
row_count = len(data)
|
||||
else:
|
||||
data = []
|
||||
row_count = cursor.rowcount
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
self.last_used = datetime.utcnow()
|
||||
self.query_count += 1
|
||||
|
||||
# Get column information
|
||||
columns = []
|
||||
if cursor.description:
|
||||
columns = [desc[0] for desc in cursor.description]
|
||||
|
||||
# If security manager exists and has auth context, apply data masking
|
||||
final_data = list(data) if data else []
|
||||
if self.security_manager and auth_context and final_data:
|
||||
final_data = await self.security_manager.apply_data_masking(final_data, auth_context)
|
||||
|
||||
metadata = {"columns": columns, "query": sql, "params": params}
|
||||
if security_result:
|
||||
metadata["security_check"] = security_result
|
||||
|
||||
return QueryResult(
|
||||
data=final_data,
|
||||
metadata=metadata,
|
||||
execution_time=execution_time,
|
||||
row_count=row_count,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.is_healthy = False
|
||||
logging.error(f"Query execution failed: {e}")
|
||||
raise
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""Check connection health status"""
|
||||
try:
|
||||
await self.connection.ping()
|
||||
self.is_healthy = True
|
||||
return True
|
||||
except Exception:
|
||||
self.is_healthy = False
|
||||
return False
|
||||
|
||||
async def close(self):
|
||||
"""Close connection"""
|
||||
try:
|
||||
if self.connection and not self.connection.closed:
|
||||
await self.connection.ensure_closed()
|
||||
except Exception as e:
|
||||
logging.error(f"Error occurred while closing connection: {e}")
|
||||
|
||||
|
||||
class DorisConnectionManager:
|
||||
"""Doris database connection manager
|
||||
|
||||
Provides connection pool management, connection health monitoring, fault recovery and other functions
|
||||
Supports session-level connection reuse and intelligent load balancing
|
||||
Integrates security manager to provide unified security validation and data masking
|
||||
"""
|
||||
Get database connection
|
||||
|
||||
Args:
|
||||
db_name: Specify the database name to connect to, use default config if None
|
||||
|
||||
Returns:
|
||||
Database connection
|
||||
"""
|
||||
if db_name:
|
||||
# Use default config but override database name
|
||||
config = DB_CONFIG.copy()
|
||||
config["database"] = db_name
|
||||
return pymysql.connect(**config)
|
||||
else:
|
||||
# Use default config
|
||||
return pymysql.connect(**DB_CONFIG)
|
||||
|
||||
def get_db_name() -> str:
|
||||
"""Get the currently configured default database name"""
|
||||
return DB_CONFIG["database"] or os.getenv("DB_DATABASE", "")
|
||||
def __init__(self, config, security_manager=None):
|
||||
self.config = config
|
||||
self.pool: Pool | None = None
|
||||
self.session_connections: dict[str, DorisConnection] = {}
|
||||
self.metrics = ConnectionMetrics()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.security_manager = security_manager
|
||||
|
||||
def execute_query(sql, db_name: Optional[str] = None):
|
||||
"""
|
||||
Execute SQL query and return results
|
||||
|
||||
Args:
|
||||
sql: SQL query statement
|
||||
db_name: Specify the database name to connect to, use default config if None
|
||||
|
||||
Returns:
|
||||
Query results
|
||||
"""
|
||||
conn = get_db_connection(db_name)
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
# Set connection character set to utf8 before executing query
|
||||
cursor.execute("SET NAMES utf8")
|
||||
# Health check configuration
|
||||
self.health_check_interval = config.database.health_check_interval or 60
|
||||
self.max_connection_age = config.database.max_connection_age or 3600
|
||||
self.connection_timeout = config.database.connection_timeout or 30
|
||||
|
||||
# Start background tasks
|
||||
self._health_check_task = None
|
||||
self._cleanup_task = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize connection manager"""
|
||||
try:
|
||||
# Create connection pool
|
||||
self.pool = await aiomysql.create_pool(
|
||||
host=self.config.database.host,
|
||||
port=self.config.database.port,
|
||||
user=self.config.database.user,
|
||||
password=self.config.database.password,
|
||||
db=self.config.database.database,
|
||||
charset="utf8",
|
||||
minsize=self.config.database.min_connections or 5,
|
||||
maxsize=self.config.database.max_connections or 20,
|
||||
autocommit=True,
|
||||
connect_timeout=self.connection_timeout,
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
f"Connection pool initialized successfully, min connections: {self.config.database.min_connections}, "
|
||||
f"max connections: {self.config.database.max_connections}"
|
||||
)
|
||||
|
||||
# Start background monitoring tasks
|
||||
self._health_check_task = asyncio.create_task(self._health_check_loop())
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Connection pool initialization failed: {e}")
|
||||
raise
|
||||
|
||||
async def get_connection(self, session_id: str) -> DorisConnection:
|
||||
"""Get database connection
|
||||
|
||||
Supports session-level connection reuse to improve performance and consistency
|
||||
"""
|
||||
# Check if there's an existing session connection
|
||||
if session_id in self.session_connections:
|
||||
conn = self.session_connections[session_id]
|
||||
# Check connection health
|
||||
if await conn.ping():
|
||||
return conn
|
||||
else:
|
||||
# Connection is unhealthy, clean up and create new one
|
||||
await self._cleanup_session_connection(session_id)
|
||||
|
||||
# Create new connection
|
||||
return await self._create_new_connection(session_id)
|
||||
|
||||
async def _create_new_connection(self, session_id: str) -> DorisConnection:
|
||||
"""Create new database connection"""
|
||||
try:
|
||||
if not self.pool:
|
||||
raise RuntimeError("Connection pool not initialized")
|
||||
|
||||
# Get connection from pool
|
||||
raw_connection = await self.pool.acquire()
|
||||
|
||||
# Execute the actual query
|
||||
cursor.execute(sql)
|
||||
result = cursor.fetchall()
|
||||
return result
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def execute_query_df(sql, db_name: Optional[str] = None):
|
||||
"""
|
||||
Execute SQL query and return pandas DataFrame
|
||||
|
||||
Args:
|
||||
sql: SQL query statement
|
||||
db_name: Specify the database name to connect to, use default config if None
|
||||
|
||||
Returns:
|
||||
pandas DataFrame
|
||||
"""
|
||||
conn = get_db_connection(db_name)
|
||||
try:
|
||||
# Use a temporary cursor to execute the query and get results
|
||||
with conn.cursor() as cursor:
|
||||
# Set connection character set to utf8 before executing query
|
||||
cursor.execute("SET NAMES utf8")
|
||||
# Create wrapped connection
|
||||
doris_conn = DorisConnection(raw_connection, session_id, self.security_manager)
|
||||
|
||||
# Execute the actual query
|
||||
cursor.execute(sql)
|
||||
result = cursor.fetchall()
|
||||
# Store in session connections
|
||||
self.session_connections[session_id] = doris_conn
|
||||
|
||||
self.metrics.total_connections += 1
|
||||
self.logger.debug(f"Created new connection for session: {session_id}")
|
||||
|
||||
return doris_conn
|
||||
|
||||
except Exception as e:
|
||||
self.metrics.connection_errors += 1
|
||||
self.logger.error(f"Failed to create connection for session {session_id}: {e}")
|
||||
raise
|
||||
|
||||
async def release_connection(self, session_id: str):
|
||||
"""Release session connection"""
|
||||
if session_id in self.session_connections:
|
||||
await self._cleanup_session_connection(session_id)
|
||||
|
||||
async def _cleanup_session_connection(self, session_id: str):
|
||||
"""Clean up session connection"""
|
||||
if session_id in self.session_connections:
|
||||
conn = self.session_connections[session_id]
|
||||
try:
|
||||
# Return connection to pool
|
||||
if self.pool and conn.connection and not conn.connection.closed:
|
||||
self.pool.release(conn.connection)
|
||||
|
||||
# Close connection wrapper
|
||||
await conn.close()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cleaning up connection for session {session_id}: {e}")
|
||||
finally:
|
||||
# Remove from session connections
|
||||
del self.session_connections[session_id]
|
||||
self.logger.debug(f"Cleaned up connection for session: {session_id}")
|
||||
|
||||
async def _health_check_loop(self):
|
||||
"""Background health check loop"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.health_check_interval)
|
||||
await self._perform_health_check()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Health check error: {e}")
|
||||
|
||||
async def _perform_health_check(self):
|
||||
"""Perform health check"""
|
||||
try:
|
||||
unhealthy_sessions = []
|
||||
|
||||
for session_id, conn in self.session_connections.items():
|
||||
if not await conn.ping():
|
||||
unhealthy_sessions.append(session_id)
|
||||
|
||||
# Clean up unhealthy connections
|
||||
for session_id in unhealthy_sessions:
|
||||
await self._cleanup_session_connection(session_id)
|
||||
self.metrics.failed_connections += 1
|
||||
|
||||
# Update metrics
|
||||
await self._update_connection_metrics()
|
||||
self.metrics.last_health_check = datetime.utcnow()
|
||||
|
||||
if unhealthy_sessions:
|
||||
self.logger.warning(f"Cleaned up {len(unhealthy_sessions)} unhealthy connections")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Health check failed: {e}")
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""Background cleanup loop"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(300) # Run every 5 minutes
|
||||
await self._cleanup_idle_connections()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Cleanup loop error: {e}")
|
||||
|
||||
async def _cleanup_idle_connections(self):
|
||||
"""Clean up idle connections"""
|
||||
current_time = datetime.utcnow()
|
||||
idle_sessions = []
|
||||
|
||||
# If no results, return empty DataFrame
|
||||
if not result:
|
||||
return pd.DataFrame()
|
||||
|
||||
# Manually convert dict results to DataFrame
|
||||
df = pd.DataFrame(result)
|
||||
return df
|
||||
finally:
|
||||
conn.close()
|
||||
for session_id, conn in self.session_connections.items():
|
||||
# Check if connection has exceeded maximum age
|
||||
age = (current_time - conn.created_at).total_seconds()
|
||||
if age > self.max_connection_age:
|
||||
idle_sessions.append(session_id)
|
||||
|
||||
# Clean up idle connections
|
||||
for session_id in idle_sessions:
|
||||
await self._cleanup_session_connection(session_id)
|
||||
|
||||
if idle_sessions:
|
||||
self.logger.info(f"Cleaned up {len(idle_sessions)} idle connections")
|
||||
|
||||
async def _update_connection_metrics(self):
|
||||
"""Update connection metrics"""
|
||||
self.metrics.active_connections = len(self.session_connections)
|
||||
if self.pool:
|
||||
self.metrics.idle_connections = self.pool.freesize
|
||||
|
||||
async def get_metrics(self) -> ConnectionMetrics:
|
||||
"""Get connection metrics"""
|
||||
await self._update_connection_metrics()
|
||||
return self.metrics
|
||||
|
||||
async def execute_query(
|
||||
self, session_id: str, sql: str, params: tuple | None = None, auth_context=None
|
||||
) -> QueryResult:
|
||||
"""Execute query"""
|
||||
conn = await self.get_connection(session_id)
|
||||
return await conn.execute(sql, params, auth_context)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_connection_context(self, session_id: str):
|
||||
"""Get connection context manager"""
|
||||
conn = await self.get_connection(session_id)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
# Connection will be reused, no need to close here
|
||||
pass
|
||||
|
||||
async def close(self):
|
||||
"""Close connection manager"""
|
||||
try:
|
||||
# Cancel background tasks
|
||||
if self._health_check_task:
|
||||
self._health_check_task.cancel()
|
||||
try:
|
||||
await self._health_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Clean up all session connections
|
||||
for session_id in list(self.session_connections.keys()):
|
||||
await self._cleanup_session_connection(session_id)
|
||||
|
||||
# Close connection pool
|
||||
if self.pool:
|
||||
self.pool.close()
|
||||
await self.pool.wait_closed()
|
||||
|
||||
self.logger.info("Connection manager closed successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error closing connection manager: {e}")
|
||||
|
||||
async def test_connection(self) -> bool:
|
||||
"""Test database connection"""
|
||||
try:
|
||||
if not self.pool:
|
||||
return False
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("SELECT 1")
|
||||
result = await cursor.fetchone()
|
||||
return result is not None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Connection test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class ConnectionPoolMonitor:
|
||||
"""Connection pool monitor
|
||||
|
||||
Provides detailed monitoring and reporting capabilities for connection pool status
|
||||
"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
async def get_pool_status(self) -> dict[str, Any]:
|
||||
"""Get connection pool status"""
|
||||
metrics = await self.connection_manager.get_metrics()
|
||||
|
||||
status = {
|
||||
"pool_size": self.connection_manager.pool.size if self.connection_manager.pool else 0,
|
||||
"free_connections": self.connection_manager.pool.freesize if self.connection_manager.pool else 0,
|
||||
"active_sessions": len(self.connection_manager.session_connections),
|
||||
"total_connections": metrics.total_connections,
|
||||
"failed_connections": metrics.failed_connections,
|
||||
"connection_errors": metrics.connection_errors,
|
||||
"avg_connection_time": metrics.avg_connection_time,
|
||||
"last_health_check": metrics.last_health_check.isoformat() if metrics.last_health_check else None,
|
||||
}
|
||||
|
||||
return status
|
||||
|
||||
async def get_session_details(self) -> list[dict[str, Any]]:
|
||||
"""Get session connection details"""
|
||||
sessions = []
|
||||
|
||||
for session_id, conn in self.connection_manager.session_connections.items():
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
"created_at": conn.created_at.isoformat(),
|
||||
"last_used": conn.last_used.isoformat(),
|
||||
"query_count": conn.query_count,
|
||||
"is_healthy": conn.is_healthy,
|
||||
"connection_age": (datetime.utcnow() - conn.created_at).total_seconds(),
|
||||
}
|
||||
sessions.append(session_info)
|
||||
|
||||
return sessions
|
||||
|
||||
async def generate_health_report(self) -> dict[str, Any]:
|
||||
"""Generate connection health report"""
|
||||
pool_status = await self.get_pool_status()
|
||||
session_details = await self.get_session_details()
|
||||
|
||||
# Calculate health statistics
|
||||
healthy_sessions = sum(1 for s in session_details if s["is_healthy"])
|
||||
total_sessions = len(session_details)
|
||||
health_ratio = healthy_sessions / total_sessions if total_sessions > 0 else 1.0
|
||||
|
||||
report = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"pool_status": pool_status,
|
||||
"session_summary": {
|
||||
"total_sessions": total_sessions,
|
||||
"healthy_sessions": healthy_sessions,
|
||||
"health_ratio": health_ratio,
|
||||
},
|
||||
"session_details": session_details,
|
||||
"recommendations": [],
|
||||
}
|
||||
|
||||
# Add recommendations based on health status
|
||||
if health_ratio < 0.8:
|
||||
report["recommendations"].append("Consider checking database connectivity and network stability")
|
||||
|
||||
if pool_status["connection_errors"] > 10:
|
||||
report["recommendations"].append("High connection error rate detected, review connection configuration")
|
||||
|
||||
if pool_status["active_sessions"] > pool_status["pool_size"] * 0.9:
|
||||
report["recommendations"].append("Connection pool utilization is high, consider increasing pool size")
|
||||
|
||||
return report
|
||||
|
||||
@@ -1,226 +1,85 @@
|
||||
"""
|
||||
Unified Logging Configuration Module
|
||||
|
||||
Provides unified logging configuration, including:
|
||||
- General logs: Record all program execution information
|
||||
- Audit logs: Record JSON data for key operations and processing results
|
||||
- Error logs: Specifically record program exceptions and errors
|
||||
Logging configuration for Doris MCP Server.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import logging.handlers
|
||||
import logging.config
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from datetime import datetime
|
||||
from dotenv import load_dotenv
|
||||
from typing import Any
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Get project root directory
|
||||
PROJECT_ROOT = Path(__file__).parents[2].absolute()
|
||||
def setup_logging(
|
||||
level: str = "INFO",
|
||||
log_file: str | None = None,
|
||||
log_format: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Setup logging configuration.
|
||||
|
||||
# Get log configuration from environment variables
|
||||
LOG_DIR = os.getenv("LOG_DIR", str(PROJECT_ROOT / "logs"))
|
||||
LOG_PREFIX = os.getenv("LOG_PREFIX", "doris_mcp")
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
LOG_MAX_DAYS = int(os.getenv("LOG_MAX_DAYS", "30"))
|
||||
# Whether to output logs to the console (should be disabled when running as a service)
|
||||
CONSOLE_LOGGING = os.getenv("CONSOLE_LOGGING", "false").lower() == "true"
|
||||
# Whether stdio transport mode is being used
|
||||
STDIO_MODE = os.getenv("MCP_TRANSPORT_TYPE", "").lower() == "stdio"
|
||||
Args:
|
||||
level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
log_file: Optional log file path
|
||||
log_format: Optional custom log format
|
||||
"""
|
||||
if log_format is None:
|
||||
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
|
||||
def purge_old_logs():
|
||||
"""Clean up expired log files"""
|
||||
# --- Only perform cleanup in non-Stdio mode ---
|
||||
if STDIO_MODE:
|
||||
return
|
||||
try:
|
||||
now = datetime.now()
|
||||
log_dir = Path(LOG_DIR)
|
||||
# Check if directory exists and is readable/writable
|
||||
if not log_dir.is_dir() or not os.access(LOG_DIR, os.W_OK):
|
||||
if not STDIO_MODE: # Avoid printing to stdout in stdio mode
|
||||
print(f"Warning: Log directory {LOG_DIR} not accessible, skipping log purge.", file=sys.stderr)
|
||||
return
|
||||
# Base configuration
|
||||
config: dict[str, Any] = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"default": {"format": log_format, "datefmt": "%Y-%m-%d %H:%M:%S"}
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": level,
|
||||
"formatter": "default",
|
||||
"stream": sys.stdout,
|
||||
}
|
||||
},
|
||||
"root": {"level": level, "handlers": ["console"]},
|
||||
"loggers": {
|
||||
"doris_mcp_server": {
|
||||
"level": level,
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
for log_file in log_dir.glob(f"{LOG_PREFIX}*.20*"):
|
||||
# Parse date
|
||||
file_name = log_file.name
|
||||
date_str = None
|
||||
|
||||
# Try to find the date part
|
||||
parts = file_name.split('.')
|
||||
for part in parts:
|
||||
if part.startswith('20') and len(part) == 8: # 20YYMMDD format
|
||||
date_str = part
|
||||
break
|
||||
|
||||
if date_str:
|
||||
try:
|
||||
file_date = datetime.strptime(date_str, '%Y%m%d')
|
||||
days_old = (now - file_date).days
|
||||
|
||||
if days_old > LOG_MAX_DAYS:
|
||||
os.remove(log_file)
|
||||
if not STDIO_MODE:
|
||||
print(f"Deleted expired log file: {log_file}")
|
||||
except (ValueError, OSError) as e:
|
||||
if not STDIO_MODE:
|
||||
print(f"Error processing log file {file_name}: {e}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
if not STDIO_MODE:
|
||||
print(f"Error cleaning up logs: {e}", file=sys.stderr)
|
||||
# Add file handler if log_file is specified
|
||||
if log_file:
|
||||
# Ensure log directory exists
|
||||
log_path = Path(log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Force disable console log output if in stdio mode
|
||||
if STDIO_MODE:
|
||||
CONSOLE_LOGGING = False
|
||||
config["handlers"]["file"] = {
|
||||
"class": "logging.handlers.RotatingFileHandler",
|
||||
"level": level,
|
||||
"formatter": "default",
|
||||
"filename": log_file,
|
||||
"maxBytes": 10485760, # 10MB
|
||||
"backupCount": 5,
|
||||
}
|
||||
|
||||
# --- Only create log directory and clean old logs in non-Stdio mode ---
|
||||
if not STDIO_MODE:
|
||||
try:
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
# Clean up expired logs on startup (also moved here, as it only handles file logs)
|
||||
purge_old_logs()
|
||||
except OSError as e:
|
||||
# If directory creation fails (e.g., permission issue), print warning but continue to avoid startup failure
|
||||
print(f"Warning: Failed to create log directory {LOG_DIR} or purge logs: {e}", file=sys.stderr)
|
||||
# Add file handler to root and package loggers
|
||||
config["root"]["handlers"].append("file")
|
||||
config["loggers"]["doris_mcp_server"]["handlers"].append("file")
|
||||
|
||||
# Log file paths (definition still needed, but files might not be created/used)
|
||||
LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.log")
|
||||
AUDIT_LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.audit")
|
||||
ERROR_LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.error")
|
||||
|
||||
# Log level mapping
|
||||
LOG_LEVELS = {
|
||||
"DEBUG": logging.DEBUG,
|
||||
"INFO": logging.INFO,
|
||||
"WARNING": logging.WARNING,
|
||||
"ERROR": logging.ERROR,
|
||||
"CRITICAL": logging.CRITICAL
|
||||
}
|
||||
|
||||
# Log format
|
||||
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
AUDIT_FORMAT = '%(asctime)s - %(name)s - %(message)s'
|
||||
ERROR_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(pathname)s:%(lineno)d - %(message)s'
|
||||
|
||||
# Dedicated audit log level
|
||||
AUDIT = 25 # Level between INFO and WARNING
|
||||
logging.addLevelName(AUDIT, "AUDIT")
|
||||
|
||||
# Logger object cache
|
||||
_loggers: Dict[str, logging.Logger] = {}
|
||||
|
||||
# Handler type mapping, used to ensure no duplicates are added
|
||||
_handler_types = {
|
||||
'console': logging.StreamHandler,
|
||||
'file': logging.handlers.TimedRotatingFileHandler,
|
||||
'audit': logging.handlers.TimedRotatingFileHandler,
|
||||
'error': logging.handlers.TimedRotatingFileHandler
|
||||
}
|
||||
logging.config.dictConfig(config)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""
|
||||
Get a logger with the specified name
|
||||
|
||||
Get a logger instance.
|
||||
|
||||
Args:
|
||||
name: Logger name
|
||||
|
||||
|
||||
Returns:
|
||||
logging.Logger: Configured logger
|
||||
Logger instance
|
||||
"""
|
||||
if name in _loggers:
|
||||
return _loggers[name]
|
||||
|
||||
# Create logger
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(LOG_LEVELS.get(LOG_LEVEL, logging.INFO))
|
||||
|
||||
# Avoid duplicate logs caused by propagation
|
||||
logger.propagate = False
|
||||
|
||||
# Check if handlers already exist to avoid duplicates
|
||||
handler_types = set(type(h) for h in logger.handlers)
|
||||
|
||||
# Add audit log method
|
||||
def audit(self, message, *args, **kwargs):
|
||||
self.log(AUDIT, message, *args, **kwargs)
|
||||
|
||||
logger.audit = audit.__get__(logger)
|
||||
|
||||
# General log handler - output to console (only if enabled)
|
||||
if CONSOLE_LOGGING and _handler_types['console'] not in handler_types:
|
||||
# Use stderr instead of stdout to avoid conflicts with MCP communication
|
||||
console_handler = logging.StreamHandler(sys.stderr)
|
||||
console_handler.setFormatter(logging.Formatter(LOG_FORMAT))
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# --- Only add file handlers in non-Stdio mode ---
|
||||
if not STDIO_MODE:
|
||||
# General log handler - daily rotating file
|
||||
if _handler_types['file'] not in handler_types:
|
||||
try: # Add try-except block
|
||||
file_handler = logging.handlers.TimedRotatingFileHandler(
|
||||
LOG_FILE,
|
||||
when='midnight',
|
||||
interval=1,
|
||||
backupCount=LOG_MAX_DAYS,
|
||||
encoding='utf-8'
|
||||
)
|
||||
file_handler.setFormatter(logging.Formatter(LOG_FORMAT))
|
||||
file_handler.suffix = "%Y%m%d"
|
||||
logger.addHandler(file_handler)
|
||||
except OSError as e:
|
||||
print(f"Warning: Failed to add file log handler for {LOG_FILE}: {e}", file=sys.stderr)
|
||||
|
||||
# Audit log handler - only logs AUDIT level
|
||||
if _handler_types['audit'] not in handler_types:
|
||||
try: # Add try-except block
|
||||
audit_handler = logging.handlers.TimedRotatingFileHandler(
|
||||
AUDIT_LOG_FILE,
|
||||
when='midnight',
|
||||
interval=1,
|
||||
backupCount=LOG_MAX_DAYS,
|
||||
encoding='utf-8'
|
||||
)
|
||||
audit_handler.setFormatter(logging.Formatter(AUDIT_FORMAT))
|
||||
audit_handler.suffix = "%Y%m%d"
|
||||
audit_handler.setLevel(AUDIT)
|
||||
audit_handler.addFilter(lambda record: record.levelno == AUDIT)
|
||||
logger.addHandler(audit_handler)
|
||||
except OSError as e:
|
||||
print(f"Warning: Failed to add audit log handler for {AUDIT_LOG_FILE}: {e}", file=sys.stderr)
|
||||
|
||||
# Error log handler - only logs ERROR level and above
|
||||
if _handler_types['error'] not in handler_types:
|
||||
try: # Add try-except block
|
||||
error_handler = logging.handlers.TimedRotatingFileHandler(
|
||||
ERROR_LOG_FILE,
|
||||
when='midnight',
|
||||
interval=1,
|
||||
backupCount=LOG_MAX_DAYS,
|
||||
encoding='utf-8'
|
||||
)
|
||||
error_handler.setFormatter(logging.Formatter(ERROR_FORMAT))
|
||||
error_handler.suffix = "%Y%m%d"
|
||||
error_handler.setLevel(logging.ERROR)
|
||||
logger.addHandler(error_handler)
|
||||
except OSError as e:
|
||||
print(f"Warning: Failed to add error log handler for {ERROR_LOG_FILE}: {e}", file=sys.stderr)
|
||||
|
||||
# Cache logger
|
||||
_loggers[name] = logger
|
||||
|
||||
return logger
|
||||
|
||||
# Default logger
|
||||
logger = get_logger('doris_mcp')
|
||||
|
||||
# Audit logger - for recording processing results, business operations, etc.
|
||||
audit_logger = get_logger('audit')
|
||||
|
||||
# Call to clean logs moved after directory creation, and added non-stdio check
|
||||
return logging.getLogger(name)
|
||||
|
||||
800
doris_mcp_server/utils/query_executor.py
Normal file
800
doris_mcp_server/utils/query_executor.py
Normal file
@@ -0,0 +1,800 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Doris Query Execution Module
|
||||
Implements query optimization, cache management and performance monitoring functionality
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import os
|
||||
import uuid
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, date
|
||||
from typing import Any, Dict
|
||||
from decimal import Decimal
|
||||
|
||||
from .db import DorisConnectionManager, QueryResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryRequest:
|
||||
"""Query request wrapper"""
|
||||
|
||||
sql: str
|
||||
session_id: str
|
||||
user_id: str
|
||||
parameters: dict[str, Any] | None = None
|
||||
timeout: int | None = None
|
||||
cache_enabled: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedQuery:
|
||||
"""Cached query result"""
|
||||
|
||||
result: QueryResult
|
||||
created_at: datetime
|
||||
ttl: int
|
||||
access_count: int = 0
|
||||
last_accessed: datetime | None = None
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if cache is expired"""
|
||||
if self.ttl <= 0:
|
||||
return False
|
||||
return (datetime.utcnow() - self.created_at).total_seconds() > self.ttl
|
||||
|
||||
def access(self):
|
||||
"""Record access"""
|
||||
self.access_count += 1
|
||||
self.last_accessed = datetime.utcnow()
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryMetrics:
|
||||
"""Query performance metrics"""
|
||||
|
||||
total_queries: int = 0
|
||||
successful_queries: int = 0
|
||||
failed_queries: int = 0
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
avg_execution_time: float = 0.0
|
||||
total_execution_time: float = 0.0
|
||||
slow_queries: int = 0
|
||||
concurrent_queries: int = 0
|
||||
|
||||
|
||||
class QueryCache:
|
||||
"""Query result cache manager"""
|
||||
|
||||
def __init__(self, max_size: int = 1000, default_ttl: int = 300):
|
||||
self.max_size = max_size
|
||||
self.default_ttl = default_ttl
|
||||
self.cache: dict[str, CachedQuery] = {}
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def _generate_cache_key(
|
||||
self, sql: str, parameters: dict[str, Any] | None = None
|
||||
) -> str:
|
||||
"""Generate cache key"""
|
||||
cache_data = {"sql": sql.strip().lower(), "parameters": parameters or {}}
|
||||
cache_string = json.dumps(cache_data, sort_keys=True)
|
||||
return hashlib.md5(cache_string.encode()).hexdigest()
|
||||
|
||||
async def get(
|
||||
self, sql: str, parameters: dict[str, Any] | None = None
|
||||
) -> CachedQuery | None:
|
||||
"""Get cached query result"""
|
||||
cache_key = self._generate_cache_key(sql, parameters)
|
||||
|
||||
if cache_key in self.cache:
|
||||
cached_query = self.cache[cache_key]
|
||||
|
||||
if not cached_query.is_expired():
|
||||
cached_query.access()
|
||||
self.logger.debug(f"Cache hit: {cache_key}")
|
||||
return cached_query
|
||||
else:
|
||||
# Clean up expired cache
|
||||
del self.cache[cache_key]
|
||||
self.logger.debug(f"Cache expired, cleaned up: {cache_key}")
|
||||
|
||||
return None
|
||||
|
||||
async def set(
|
||||
self,
|
||||
sql: str,
|
||||
result: QueryResult,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
ttl: int | None = None,
|
||||
) -> str:
|
||||
"""Set query result cache"""
|
||||
cache_key = self._generate_cache_key(sql, parameters)
|
||||
|
||||
# Check cache size limit
|
||||
if len(self.cache) >= self.max_size:
|
||||
await self._evict_oldest()
|
||||
|
||||
cached_query = CachedQuery(
|
||||
result=result, created_at=datetime.utcnow(), ttl=ttl or self.default_ttl
|
||||
)
|
||||
|
||||
self.cache[cache_key] = cached_query
|
||||
self.logger.debug(f"Cache set: {cache_key}")
|
||||
|
||||
return cache_key
|
||||
|
||||
async def _evict_oldest(self):
|
||||
"""Clean up oldest cache item"""
|
||||
if not self.cache:
|
||||
return
|
||||
|
||||
# Find oldest cache item
|
||||
oldest_key = min(self.cache.keys(), key=lambda k: self.cache[k].created_at)
|
||||
|
||||
del self.cache[oldest_key]
|
||||
self.logger.debug(f"Cleaned up oldest cache: {oldest_key}")
|
||||
|
||||
async def clear_expired(self):
|
||||
"""Clean up all expired cache"""
|
||||
expired_keys = [
|
||||
key for key, cached_query in self.cache.items() if cached_query.is_expired()
|
||||
]
|
||||
|
||||
for key in expired_keys:
|
||||
del self.cache[key]
|
||||
|
||||
if expired_keys:
|
||||
self.logger.info(f"Cleaned up {len(expired_keys)} expired cache items")
|
||||
|
||||
async def clear_all(self):
|
||||
"""Clean up all cache"""
|
||||
cache_count = len(self.cache)
|
||||
self.cache.clear()
|
||||
self.logger.info(f"Cleaned up all cache, total {cache_count} items")
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics"""
|
||||
total_access = sum(cached.access_count for cached in self.cache.values())
|
||||
|
||||
return {
|
||||
"cache_size": len(self.cache),
|
||||
"max_size": self.max_size,
|
||||
"total_access": total_access,
|
||||
"hit_rate": 0.0
|
||||
if total_access == 0
|
||||
else sum(cached.access_count for cached in self.cache.values())
|
||||
/ total_access,
|
||||
}
|
||||
|
||||
|
||||
class QueryOptimizer:
|
||||
"""Query optimizer"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.optimization_rules = self._load_optimization_rules()
|
||||
|
||||
def _load_optimization_rules(self) -> list[dict[str, Any]]:
|
||||
"""Load query optimization rules"""
|
||||
return [
|
||||
{
|
||||
"name": "add_limit_clause",
|
||||
"description": "Add default limit for SELECT queries without LIMIT",
|
||||
"pattern": r"^select\s+.*(?!.*limit\s+\d+)",
|
||||
"action": "add_limit",
|
||||
"params": {"default_limit": 1000},
|
||||
},
|
||||
{
|
||||
"name": "optimize_count_query",
|
||||
"description": "Optimize COUNT queries",
|
||||
"pattern": r"select\s+count\(\*\)\s+from\s+(\w+)",
|
||||
"action": "optimize_count",
|
||||
"params": {},
|
||||
},
|
||||
]
|
||||
|
||||
async def optimize_query(self, sql: str, context: dict[str, Any]) -> str:
|
||||
"""Apply query optimization"""
|
||||
optimized_sql = sql
|
||||
|
||||
for rule in self.optimization_rules:
|
||||
if self._should_apply_rule(rule, optimized_sql, context):
|
||||
optimized_sql = await self._apply_optimization_rule(
|
||||
optimized_sql, rule, context
|
||||
)
|
||||
self.logger.debug(f"Applied optimization rule: {rule['name']}")
|
||||
|
||||
return optimized_sql
|
||||
|
||||
def _should_apply_rule(
|
||||
self, rule: dict[str, Any], sql: str, context: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Check if optimization rule should be applied"""
|
||||
import re
|
||||
|
||||
# Check pattern match
|
||||
if "pattern" in rule:
|
||||
if not re.search(rule["pattern"], sql, re.IGNORECASE):
|
||||
return False
|
||||
|
||||
# Check conditions
|
||||
if "conditions" in rule:
|
||||
for condition in rule["conditions"]:
|
||||
if not self._check_condition(condition, context):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _check_condition(
|
||||
self, condition: dict[str, Any], context: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Check optimization condition"""
|
||||
condition_type = condition.get("type")
|
||||
|
||||
if condition_type == "user_role":
|
||||
required_roles = condition.get("roles", [])
|
||||
user_roles = context.get("user_roles", [])
|
||||
return any(role in user_roles for role in required_roles)
|
||||
|
||||
elif condition_type == "query_size":
|
||||
max_size = condition.get("max_size", 1000)
|
||||
return len(context.get("sql", "")) <= max_size
|
||||
|
||||
return True
|
||||
|
||||
async def _apply_optimization_rule(
|
||||
self, sql: str, rule: dict[str, Any], context: dict[str, Any]
|
||||
) -> str:
|
||||
"""Apply optimization rule"""
|
||||
action = rule.get("action")
|
||||
params = rule.get("params", {})
|
||||
|
||||
if action == "add_limit":
|
||||
return await self._add_limit_clause(sql, params)
|
||||
elif action == "optimize_count":
|
||||
return await self._optimize_count_query(sql, params)
|
||||
elif action == "add_hints":
|
||||
return await self._add_query_hints(sql, params)
|
||||
|
||||
return sql
|
||||
|
||||
async def _add_limit_clause(self, sql: str, params: dict[str, Any]) -> str:
|
||||
"""Add LIMIT clause to query"""
|
||||
import re
|
||||
|
||||
default_limit = params.get("default_limit", 1000)
|
||||
|
||||
# Check if LIMIT already exists
|
||||
if re.search(r"\blimit\s+\d+", sql, re.IGNORECASE):
|
||||
return sql
|
||||
|
||||
# Add LIMIT clause
|
||||
if sql.strip().endswith(";"):
|
||||
sql = sql.strip()[:-1]
|
||||
|
||||
return f"{sql} LIMIT {default_limit}"
|
||||
|
||||
async def _optimize_count_query(self, sql: str, params: dict[str, Any]) -> str:
|
||||
"""Optimize COUNT query"""
|
||||
# For COUNT queries, we can add optimization hints
|
||||
return sql.replace("COUNT(*)", "COUNT(1)")
|
||||
|
||||
async def _add_query_hints(self, sql: str, params: dict[str, Any]) -> str:
|
||||
"""Add query hints"""
|
||||
hints = params.get("hints", [])
|
||||
if not hints:
|
||||
return sql
|
||||
|
||||
hint_string = "/*+ " + " ".join(hints) + " */"
|
||||
return f"{hint_string} {sql}"
|
||||
|
||||
|
||||
class DorisQueryExecutor:
|
||||
"""Doris query executor with caching and optimization"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager, config=None):
|
||||
self.connection_manager = connection_manager
|
||||
self.config = config or self._create_default_config()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize components
|
||||
cache_config = getattr(self.config, 'performance', None)
|
||||
if cache_config:
|
||||
cache_size = getattr(cache_config, 'max_cache_size', 1000)
|
||||
cache_ttl = getattr(cache_config, 'cache_ttl', 300)
|
||||
else:
|
||||
cache_size = 1000
|
||||
cache_ttl = 300
|
||||
|
||||
self.query_cache = QueryCache(max_size=cache_size, default_ttl=cache_ttl)
|
||||
self.query_optimizer = QueryOptimizer(self.config)
|
||||
self.metrics = QueryMetrics()
|
||||
|
||||
# Performance monitoring
|
||||
self.slow_query_threshold = 5.0 # seconds
|
||||
self.max_concurrent_queries = getattr(
|
||||
getattr(self.config, 'performance', None), 'max_concurrent_queries', 50
|
||||
) if hasattr(self.config, 'performance') else 50
|
||||
|
||||
# Background tasks
|
||||
self._background_tasks = []
|
||||
self._start_background_tasks()
|
||||
|
||||
def _create_default_config(self):
|
||||
"""Create default configuration"""
|
||||
class DefaultConfig:
|
||||
def __init__(self):
|
||||
self.performance = DefaultPerformanceConfig()
|
||||
|
||||
class DefaultPerformanceConfig:
|
||||
def __init__(self):
|
||||
self.max_cache_size = 1000
|
||||
self.cache_ttl = 300
|
||||
self.max_concurrent_queries = 50
|
||||
|
||||
return DefaultConfig()
|
||||
|
||||
def _start_background_tasks(self):
|
||||
"""Start background tasks"""
|
||||
try:
|
||||
# Cache cleanup task
|
||||
cleanup_task = asyncio.create_task(self._cache_cleanup_loop())
|
||||
self._background_tasks.append(cleanup_task)
|
||||
except RuntimeError:
|
||||
# No event loop running (e.g., in tests), skip background tasks
|
||||
self.logger.debug("No event loop running, skipping background tasks")
|
||||
|
||||
async def _cache_cleanup_loop(self):
|
||||
"""Background cache cleanup loop"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(300) # Run every 5 minutes
|
||||
await self.query_cache.clear_expired()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Cache cleanup error: {e}")
|
||||
|
||||
async def execute_query(
|
||||
self, query_request: QueryRequest, auth_context=None
|
||||
) -> QueryResult:
|
||||
"""Execute query with caching and optimization"""
|
||||
start_time = time.time()
|
||||
self.metrics.total_queries += 1
|
||||
self.metrics.concurrent_queries += 1
|
||||
|
||||
try:
|
||||
# Check cache first
|
||||
if query_request.cache_enabled:
|
||||
cached_result = await self.query_cache.get(
|
||||
query_request.sql, query_request.parameters
|
||||
)
|
||||
if cached_result:
|
||||
self.metrics.cache_hits += 1
|
||||
self.logger.debug(f"Cache hit for query: {query_request.sql[:50]}...")
|
||||
return cached_result.result
|
||||
|
||||
self.metrics.cache_misses += 1
|
||||
|
||||
# Execute query
|
||||
result = await self._execute_query_internal(query_request, auth_context)
|
||||
|
||||
# Cache result if enabled
|
||||
if query_request.cache_enabled and result.row_count > 0:
|
||||
await self.query_cache.set(
|
||||
query_request.sql, result, query_request.parameters
|
||||
)
|
||||
|
||||
self.metrics.successful_queries += 1
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.metrics.failed_queries += 1
|
||||
self.logger.error(f"Query execution failed: {e}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
execution_time = time.time() - start_time
|
||||
self.metrics.concurrent_queries -= 1
|
||||
self._update_execution_metrics(execution_time)
|
||||
|
||||
async def _execute_query_internal(
|
||||
self, query_request: QueryRequest, auth_context
|
||||
) -> QueryResult:
|
||||
"""Internal query execution"""
|
||||
# Optimize query
|
||||
optimized_sql = await self.query_optimizer.optimize_query(
|
||||
query_request.sql, {"user_roles": getattr(auth_context, 'roles', [])}
|
||||
)
|
||||
|
||||
# Execute query
|
||||
connection = await self.connection_manager.get_connection(
|
||||
query_request.session_id
|
||||
)
|
||||
|
||||
# Set timeout if specified
|
||||
if query_request.timeout:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
connection.execute(optimized_sql, query_request.parameters, auth_context),
|
||||
timeout=query_request.timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(f"Query timeout after {query_request.timeout} seconds")
|
||||
else:
|
||||
result = await connection.execute(optimized_sql, query_request.parameters, auth_context)
|
||||
|
||||
return result
|
||||
|
||||
def _update_execution_metrics(self, execution_time: float):
|
||||
"""Update execution metrics"""
|
||||
self.metrics.total_execution_time += execution_time
|
||||
|
||||
# Update average execution time
|
||||
if self.metrics.successful_queries > 0:
|
||||
self.metrics.avg_execution_time = (
|
||||
self.metrics.total_execution_time / self.metrics.successful_queries
|
||||
)
|
||||
|
||||
# Check for slow queries
|
||||
if execution_time > self.slow_query_threshold:
|
||||
self.metrics.slow_queries += 1
|
||||
self.logger.warning(
|
||||
f"Slow query detected: {execution_time:.2f}s (threshold: {self.slow_query_threshold}s)"
|
||||
)
|
||||
|
||||
async def execute_batch_queries(
|
||||
self, query_requests: list[QueryRequest], auth_context=None
|
||||
) -> list[QueryResult]:
|
||||
"""Execute multiple queries in batch"""
|
||||
results = []
|
||||
|
||||
# Check concurrent query limit
|
||||
if len(query_requests) > self.max_concurrent_queries:
|
||||
raise Exception(
|
||||
f"Batch size {len(query_requests)} exceeds maximum concurrent queries {self.max_concurrent_queries}"
|
||||
)
|
||||
|
||||
# Execute queries concurrently
|
||||
tasks = [
|
||||
self.execute_query(request, auth_context) for request in query_requests
|
||||
]
|
||||
|
||||
try:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Batch query execution failed: {e}")
|
||||
raise
|
||||
|
||||
return results
|
||||
|
||||
async def explain_query(self, sql: str, session_id: str) -> dict[str, Any]:
|
||||
"""Get query execution plan"""
|
||||
explain_sql = f"EXPLAIN {sql}"
|
||||
|
||||
connection = await self.connection_manager.get_connection(session_id)
|
||||
result = await connection.execute(explain_sql)
|
||||
|
||||
return {
|
||||
"query": sql,
|
||||
"execution_plan": result.data,
|
||||
"estimated_cost": "N/A", # Doris doesn't provide cost estimates
|
||||
}
|
||||
|
||||
async def get_query_stats(self) -> dict[str, Any]:
|
||||
"""Get query execution statistics"""
|
||||
cache_stats = self.query_cache.get_stats()
|
||||
|
||||
return {
|
||||
"query_metrics": {
|
||||
"total_queries": self.metrics.total_queries,
|
||||
"successful_queries": self.metrics.successful_queries,
|
||||
"failed_queries": self.metrics.failed_queries,
|
||||
"success_rate": (
|
||||
self.metrics.successful_queries / self.metrics.total_queries
|
||||
if self.metrics.total_queries > 0
|
||||
else 0.0
|
||||
),
|
||||
"avg_execution_time": self.metrics.avg_execution_time,
|
||||
"slow_queries": self.metrics.slow_queries,
|
||||
"concurrent_queries": self.metrics.concurrent_queries,
|
||||
},
|
||||
"cache_metrics": {
|
||||
"cache_hits": self.metrics.cache_hits,
|
||||
"cache_misses": self.metrics.cache_misses,
|
||||
"hit_rate": (
|
||||
self.metrics.cache_hits
|
||||
/ (self.metrics.cache_hits + self.metrics.cache_misses)
|
||||
if (self.metrics.cache_hits + self.metrics.cache_misses) > 0
|
||||
else 0.0
|
||||
),
|
||||
**cache_stats,
|
||||
},
|
||||
}
|
||||
|
||||
async def clear_cache(self):
|
||||
"""Clear query cache"""
|
||||
await self.query_cache.clear_all()
|
||||
|
||||
async def execute_sql_for_mcp(
|
||||
self,
|
||||
sql: str,
|
||||
limit: int = 1000,
|
||||
timeout: int = 30,
|
||||
session_id: str = "mcp_session",
|
||||
user_id: str = "mcp_user"
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute SQL query for MCP interface - unified method"""
|
||||
try:
|
||||
if not sql:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "SQL query is required",
|
||||
"data": None
|
||||
}
|
||||
|
||||
# Add LIMIT if not present and it's a SELECT query
|
||||
if sql.upper().startswith("SELECT") and "LIMIT" not in sql.upper():
|
||||
if sql.endswith(";"):
|
||||
sql = sql[:-1]
|
||||
sql = f"{sql} LIMIT {limit}"
|
||||
|
||||
# Create auth context for MCP calls
|
||||
class MockAuthContext:
|
||||
def __init__(self):
|
||||
self.user_id = user_id
|
||||
self.roles = ["data_analyst"]
|
||||
self.permissions = ["read_data", "execute_query"]
|
||||
self.session_id = session_id
|
||||
self.security_level = "internal"
|
||||
|
||||
auth_context = MockAuthContext()
|
||||
|
||||
# Create query request
|
||||
query_request = QueryRequest(
|
||||
sql=sql,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
timeout=timeout,
|
||||
cache_enabled=True
|
||||
)
|
||||
|
||||
# Execute query
|
||||
result = await self.execute_query(query_request, auth_context)
|
||||
|
||||
# Process results
|
||||
processed_data = []
|
||||
if result.data:
|
||||
for row in result.data:
|
||||
processed_row = self._serialize_row_data(row)
|
||||
processed_data.append(processed_row)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": processed_data,
|
||||
"metadata": {
|
||||
"row_count": result.row_count,
|
||||
"execution_time": result.execution_time,
|
||||
"columns": result.metadata.get("columns", []),
|
||||
"query": sql
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
self.logger.error(f"SQL execution error: {error_msg}")
|
||||
|
||||
# Analyze error for better user feedback
|
||||
error_analysis = self._analyze_error(error_msg)
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_analysis.get("user_message", error_msg),
|
||||
"error_type": error_analysis.get("error_type", "execution_error"),
|
||||
"data": None,
|
||||
"metadata": {
|
||||
"query": sql,
|
||||
"error_details": error_msg
|
||||
}
|
||||
}
|
||||
|
||||
def _serialize_row_data(self, row_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Serialize row data for JSON response"""
|
||||
serialized = {}
|
||||
|
||||
for key, value in row_data.items():
|
||||
if value is None:
|
||||
serialized[key] = None
|
||||
elif isinstance(value, (str, int, float, bool)):
|
||||
serialized[key] = value
|
||||
elif isinstance(value, Decimal):
|
||||
serialized[key] = float(value)
|
||||
elif isinstance(value, (datetime, date)):
|
||||
serialized[key] = value.isoformat()
|
||||
elif isinstance(value, bytes):
|
||||
try:
|
||||
serialized[key] = value.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
serialized[key] = str(value)
|
||||
else:
|
||||
serialized[key] = str(value)
|
||||
|
||||
return serialized
|
||||
|
||||
def _analyze_error(self, error_message: str) -> Dict[str, str]:
|
||||
"""Analyze error message and provide user-friendly feedback"""
|
||||
error_msg_lower = error_message.lower()
|
||||
|
||||
if "table" in error_msg_lower and "doesn't exist" in error_msg_lower:
|
||||
return {
|
||||
"error_type": "table_not_found",
|
||||
"user_message": "The specified table does not exist. Please check the table name and database."
|
||||
}
|
||||
elif "column" in error_msg_lower and ("unknown" in error_msg_lower or "doesn't exist" in error_msg_lower):
|
||||
return {
|
||||
"error_type": "column_not_found",
|
||||
"user_message": "One or more columns in the query do not exist. Please check column names."
|
||||
}
|
||||
elif "syntax error" in error_msg_lower or "sql syntax" in error_msg_lower:
|
||||
return {
|
||||
"error_type": "syntax_error",
|
||||
"user_message": "SQL syntax error. Please check your query syntax."
|
||||
}
|
||||
elif "access denied" in error_msg_lower or "permission" in error_msg_lower:
|
||||
return {
|
||||
"error_type": "permission_denied",
|
||||
"user_message": "Access denied. You don't have permission to execute this query."
|
||||
}
|
||||
elif "timeout" in error_msg_lower:
|
||||
return {
|
||||
"error_type": "timeout",
|
||||
"user_message": "Query execution timed out. Try simplifying your query or adding more specific filters."
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"error_type": "general_error",
|
||||
"user_message": f"Query execution failed: {error_message}"
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
"""Close query executor and cleanup resources"""
|
||||
# Cancel background tasks
|
||||
for task in self._background_tasks:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Clear cache
|
||||
await self.query_cache.clear_all()
|
||||
|
||||
self.logger.info("Query executor closed")
|
||||
|
||||
|
||||
class QueryPerformanceMonitor:
|
||||
"""Query performance monitor"""
|
||||
|
||||
def __init__(self, query_executor: DorisQueryExecutor):
|
||||
self.query_executor = query_executor
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.performance_records = []
|
||||
|
||||
async def record_query_performance(
|
||||
self, query_request: QueryRequest, result: QueryResult, execution_time: float
|
||||
):
|
||||
"""Record query performance"""
|
||||
record = {
|
||||
"timestamp": datetime.utcnow(),
|
||||
"sql": query_request.sql,
|
||||
"user_id": query_request.user_id,
|
||||
"session_id": query_request.session_id,
|
||||
"execution_time": execution_time,
|
||||
"row_count": result.row_count,
|
||||
"cache_hit": False, # This would need to be passed from executor
|
||||
}
|
||||
|
||||
self.performance_records.append(record)
|
||||
|
||||
# Keep only recent records (last 1000)
|
||||
if len(self.performance_records) > 1000:
|
||||
self.performance_records = self.performance_records[-1000:]
|
||||
|
||||
async def get_performance_report(
|
||||
self, time_range_minutes: int = 60
|
||||
) -> dict[str, Any]:
|
||||
"""Get performance report"""
|
||||
cutoff_time = datetime.utcnow() - timedelta(minutes=time_range_minutes)
|
||||
|
||||
recent_records = [
|
||||
record
|
||||
for record in self.performance_records
|
||||
if record["timestamp"] >= cutoff_time
|
||||
]
|
||||
|
||||
if not recent_records:
|
||||
return {"message": "No performance data available for the specified time range"}
|
||||
|
||||
# Calculate statistics
|
||||
execution_times = [record["execution_time"] for record in recent_records]
|
||||
row_counts = [record["row_count"] for record in recent_records]
|
||||
|
||||
return {
|
||||
"time_range_minutes": time_range_minutes,
|
||||
"total_queries": len(recent_records),
|
||||
"avg_execution_time": sum(execution_times) / len(execution_times),
|
||||
"max_execution_time": max(execution_times),
|
||||
"min_execution_time": min(execution_times),
|
||||
"avg_row_count": sum(row_counts) / len(row_counts),
|
||||
"query_distribution": self._analyze_query_distribution(recent_records),
|
||||
}
|
||||
|
||||
def _analyze_query_distribution(
|
||||
self, records: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Analyze query distribution"""
|
||||
query_types = {}
|
||||
user_distribution = {}
|
||||
|
||||
for record in records:
|
||||
# Analyze query type
|
||||
sql_upper = record["sql"].strip().upper()
|
||||
if sql_upper.startswith("SELECT"):
|
||||
query_type = "SELECT"
|
||||
elif sql_upper.startswith("INSERT"):
|
||||
query_type = "INSERT"
|
||||
elif sql_upper.startswith("UPDATE"):
|
||||
query_type = "UPDATE"
|
||||
elif sql_upper.startswith("DELETE"):
|
||||
query_type = "DELETE"
|
||||
else:
|
||||
query_type = "OTHER"
|
||||
|
||||
query_types[query_type] = query_types.get(query_type, 0) + 1
|
||||
|
||||
# Analyze user distribution
|
||||
user_id = record["user_id"]
|
||||
user_distribution[user_id] = user_distribution.get(user_id, 0) + 1
|
||||
|
||||
return {"query_types": query_types, "user_distribution": user_distribution}
|
||||
|
||||
|
||||
# Unified convenience function for MCP integration
|
||||
async def execute_sql_query(sql: str, connection_manager: DorisConnectionManager, **kwargs) -> Dict[str, Any]:
|
||||
"""Execute SQL query - unified convenience function for MCP tools"""
|
||||
try:
|
||||
# Create query executor
|
||||
executor = DorisQueryExecutor(connection_manager)
|
||||
|
||||
try:
|
||||
# Extract parameters from kwargs or use defaults
|
||||
limit = kwargs.get("limit", 1000)
|
||||
timeout = kwargs.get("timeout", 30)
|
||||
session_id = kwargs.get("session_id", "mcp_session")
|
||||
user_id = kwargs.get("user_id", "mcp_user")
|
||||
|
||||
result = await executor.execute_sql_for_mcp(
|
||||
sql=sql,
|
||||
limit=limit,
|
||||
timeout=timeout,
|
||||
session_id=session_id,
|
||||
user_id=user_id
|
||||
)
|
||||
return result
|
||||
finally:
|
||||
await executor.close()
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Query execution failed: {str(e)}",
|
||||
"data": None
|
||||
}
|
||||
@@ -8,6 +8,8 @@ import os
|
||||
import json
|
||||
import pandas as pd
|
||||
import re
|
||||
import uuid
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from dotenv import load_dotenv
|
||||
from datetime import datetime, timedelta
|
||||
@@ -26,23 +28,25 @@ ENABLE_MULTI_DATABASE=os.getenv("ENABLE_MULTI_DATABASE",True)
|
||||
MULTI_DATABASE_NAMES=os.getenv("MULTI_DATABASE_NAMES","")
|
||||
|
||||
# Import local modules
|
||||
from doris_mcp_server.utils.db import execute_query_df, execute_query
|
||||
from .db import DorisConnectionManager
|
||||
|
||||
class MetadataExtractor:
|
||||
"""Apache Doris Metadata Extractor"""
|
||||
|
||||
def __init__(self, db_name: str = None, catalog_name: str = None):
|
||||
def __init__(self, db_name: str = None, catalog_name: str = None, connection_manager=None):
|
||||
"""
|
||||
Initialize the metadata extractor
|
||||
|
||||
Args:
|
||||
db_name: Default database name, uses the currently connected database if not specified
|
||||
catalog_name: Default catalog name for federation queries, uses the current catalog if not specified
|
||||
connection_manager: DorisConnectionManager instance for database operations
|
||||
"""
|
||||
# Get configuration from environment variables
|
||||
self.db_name = db_name or os.getenv("DB_DATABASE", "")
|
||||
self.catalog_name = catalog_name # Store catalog name for federation support
|
||||
self.metadata_db = METADATA_DB_NAME # Use constant
|
||||
self.connection_manager = connection_manager
|
||||
|
||||
# Caching system
|
||||
self.metadata_cache = {}
|
||||
@@ -65,6 +69,9 @@ class MetadataExtractor:
|
||||
# List of excluded system databases
|
||||
self.excluded_databases = self._load_excluded_databases()
|
||||
|
||||
# Session ID for database queries
|
||||
self._session_id = f"metadata_extractor_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
def _load_excluded_databases(self) -> List[str]:
|
||||
"""
|
||||
Load the list of excluded databases configuration
|
||||
@@ -482,7 +489,7 @@ class MetadataExtractor:
|
||||
TABLE_SCHEMA = '{db_name}'
|
||||
AND TABLE_NAME = '{table_name}'
|
||||
"""
|
||||
table_type_result = execute_query(table_type_query)
|
||||
table_type_result = self._execute_query(table_type_query)
|
||||
if table_type_result:
|
||||
schema["table_type"] = table_type_result[0].get("TABLE_TYPE", "")
|
||||
schema["engine"] = table_type_result[0].get("ENGINE", "")
|
||||
@@ -633,31 +640,52 @@ class MetadataExtractor:
|
||||
else:
|
||||
query = f"SHOW INDEX FROM `{db_name}`.`{table_name}`"
|
||||
|
||||
df = execute_query_df(query)
|
||||
|
||||
# Process results
|
||||
indexes = []
|
||||
current_index = None
|
||||
|
||||
for _, row in df.iterrows():
|
||||
index_name = row['Key_name']
|
||||
column_name = row['Column_name']
|
||||
try:
|
||||
df = self._execute_query(query, return_dataframe=True)
|
||||
|
||||
if current_index is None or current_index['name'] != index_name:
|
||||
# Process results
|
||||
indexes = []
|
||||
current_index = None
|
||||
|
||||
if not df.empty:
|
||||
for _, row in df.iterrows():
|
||||
try:
|
||||
index_name = row['Key_name']
|
||||
column_name = row['Column_name']
|
||||
|
||||
if current_index is None or current_index['name'] != index_name:
|
||||
if current_index is not None:
|
||||
indexes.append(current_index)
|
||||
|
||||
current_index = {
|
||||
'name': index_name,
|
||||
'columns': [column_name],
|
||||
'unique': row['Non_unique'] == 0,
|
||||
'type': row['Index_type']
|
||||
}
|
||||
else:
|
||||
current_index['columns'].append(column_name)
|
||||
except Exception as row_error:
|
||||
logger.warning(f"Failed to process index row data: {row_error}")
|
||||
continue
|
||||
|
||||
if current_index is not None:
|
||||
indexes.append(current_index)
|
||||
|
||||
current_index = {
|
||||
'name': index_name,
|
||||
'columns': [column_name],
|
||||
'unique': row['Non_unique'] == 0,
|
||||
'type': row['Index_type']
|
||||
}
|
||||
else:
|
||||
current_index['columns'].append(column_name)
|
||||
|
||||
if current_index is not None:
|
||||
indexes.append(current_index)
|
||||
except Exception as df_error:
|
||||
logger.warning(f"DataFrame processing failed, trying regular query: {df_error}")
|
||||
# Fall back to regular query
|
||||
result = self._execute_query(query, return_dataframe=False)
|
||||
indexes = []
|
||||
if result:
|
||||
# Simple processing, no complex index grouping
|
||||
for row in result:
|
||||
if isinstance(row, dict):
|
||||
indexes.append({
|
||||
'name': row.get('Key_name', ''),
|
||||
'columns': [row.get('Column_name', '')],
|
||||
'unique': row.get('Non_unique', 1) == 0,
|
||||
'type': row.get('Index_type', '')
|
||||
})
|
||||
|
||||
# Update cache
|
||||
self.metadata_cache[cache_key] = indexes
|
||||
@@ -748,7 +776,7 @@ class MetadataExtractor:
|
||||
ORDER BY time DESC
|
||||
LIMIT {limit}
|
||||
"""
|
||||
df = execute_query_df(query)
|
||||
df = self._execute_query(query, return_dataframe=True)
|
||||
return df
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting audit logs: {str(e)}")
|
||||
@@ -768,7 +796,7 @@ class MetadataExtractor:
|
||||
try:
|
||||
# Use SHOW CATALOGS command to get catalog list
|
||||
query = "SHOW CATALOGS"
|
||||
result = execute_query(query)
|
||||
result = self._execute_query(query)
|
||||
|
||||
if not result:
|
||||
catalogs = []
|
||||
@@ -1057,7 +1085,7 @@ class MetadataExtractor:
|
||||
AND TABLE_NAME = '{table_name}'
|
||||
"""
|
||||
|
||||
partitions = execute_query(query)
|
||||
partitions = self._execute_query(query)
|
||||
|
||||
if not partitions:
|
||||
return {}
|
||||
@@ -1099,10 +1127,511 @@ class MetadataExtractor:
|
||||
# Replace 'information_schema' with 'catalog_name.information_schema'
|
||||
modified_query = query.replace('information_schema', f'{catalog_name}.information_schema')
|
||||
logger.info(f"Modified query for catalog {catalog_name}: {modified_query}")
|
||||
return execute_query(modified_query, db_name)
|
||||
return self._execute_query(modified_query, db_name)
|
||||
else:
|
||||
# Execute the original query
|
||||
return execute_query(query, db_name)
|
||||
return self._execute_query(query, db_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing query with catalog: {str(e)}")
|
||||
raise
|
||||
raise
|
||||
|
||||
async def _execute_query_async(self, query: str, db_name: str = None, return_dataframe: bool = False):
|
||||
"""
|
||||
Execute database query asynchronously
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
db_name: Database name to use (optional)
|
||||
return_dataframe: Whether to return a pandas DataFrame instead of list
|
||||
|
||||
Returns:
|
||||
Query result data (list of dictionaries or pandas DataFrame)
|
||||
"""
|
||||
try:
|
||||
if self.connection_manager:
|
||||
# Use the injected connection manager directly (async)
|
||||
result = await self.connection_manager.execute_query(self._session_id, query, None)
|
||||
|
||||
# Extract data from QueryResult
|
||||
if hasattr(result, 'data'):
|
||||
data = result.data
|
||||
else:
|
||||
data = result
|
||||
|
||||
# Convert to DataFrame if requested
|
||||
if return_dataframe and data:
|
||||
import pandas as pd
|
||||
return pd.DataFrame(data)
|
||||
elif return_dataframe:
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
return data
|
||||
else:
|
||||
# Fallback: Return empty result
|
||||
logger.warning("No connection manager provided, returning empty result")
|
||||
if return_dataframe:
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing query: {str(e)}")
|
||||
# Return empty result instead of raising exception to prevent cascade failures
|
||||
if return_dataframe:
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
return []
|
||||
|
||||
def _execute_query(self, query: str, db_name: str = None, return_dataframe: bool = False):
|
||||
"""
|
||||
Execute database query with proper session management (sync wrapper)
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
db_name: Database name to use (optional)
|
||||
return_dataframe: Whether to return a pandas DataFrame instead of list
|
||||
|
||||
Returns:
|
||||
Query result data (list of dictionaries or pandas DataFrame)
|
||||
"""
|
||||
try:
|
||||
if self.connection_manager:
|
||||
import asyncio
|
||||
|
||||
# Try to run the async query
|
||||
try:
|
||||
# Check if there's a running event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we're in an async context, we need to run in a separate thread
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_new_loop():
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(
|
||||
self._execute_query_async(query, db_name, return_dataframe)
|
||||
)
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
return future.result(timeout=30)
|
||||
|
||||
except RuntimeError:
|
||||
# No running loop, we can safely create one
|
||||
return asyncio.run(
|
||||
self._execute_query_async(query, db_name, return_dataframe)
|
||||
)
|
||||
else:
|
||||
# Fallback: Return empty result
|
||||
logger.warning("No connection manager provided, returning empty result")
|
||||
if return_dataframe:
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing query: {str(e)}")
|
||||
# Return empty result instead of raising exception to prevent cascade failures
|
||||
if return_dataframe:
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
return []
|
||||
|
||||
async def get_table_schema_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> List[Dict[str, Any]]:
|
||||
"""Asynchronously get table schema information"""
|
||||
try:
|
||||
# Use async query method
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
# Build query statement
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
query = f"DESCRIBE `{effective_catalog}`.`{db_name or self.db_name}`.`{table_name}`"
|
||||
else:
|
||||
query = f"DESCRIBE `{db_name or self.db_name}`.`{table_name}`"
|
||||
|
||||
# Execute async query
|
||||
result = await self._execute_query_async(query, db_name)
|
||||
|
||||
if not result:
|
||||
return []
|
||||
|
||||
# Process results
|
||||
schema = []
|
||||
for row in result:
|
||||
if isinstance(row, dict):
|
||||
schema.append({
|
||||
'column_name': row.get('Field', ''),
|
||||
'data_type': row.get('Type', ''),
|
||||
'is_nullable': row.get('Null', 'NO') == 'YES',
|
||||
'default_value': row.get('Default', None),
|
||||
'comment': row.get('Comment', ''),
|
||||
'key': row.get('Key', ''),
|
||||
'extra': row.get('Extra', '')
|
||||
})
|
||||
|
||||
return schema
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table schema: {e}")
|
||||
return []
|
||||
|
||||
async def get_all_databases_async(self, catalog_name: str = None) -> List[str]:
|
||||
"""Asynchronously get all database list"""
|
||||
try:
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
query = f"SHOW DATABASES FROM `{effective_catalog}`"
|
||||
else:
|
||||
query = "SHOW DATABASES"
|
||||
|
||||
result = await self._execute_query_async(query)
|
||||
|
||||
if not result:
|
||||
return []
|
||||
|
||||
# Extract database names
|
||||
databases = []
|
||||
for row in result:
|
||||
if isinstance(row, dict):
|
||||
# Get the value of the first field (usually Database field)
|
||||
db_name = list(row.values())[0] if row else None
|
||||
if db_name:
|
||||
databases.append(db_name)
|
||||
|
||||
return databases
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get database list: {e}")
|
||||
return []
|
||||
|
||||
async def get_database_tables_async(self, db_name: str = None, catalog_name: str = None) -> List[str]:
|
||||
"""Asynchronously get table list in database"""
|
||||
try:
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
effective_db = db_name or self.db_name
|
||||
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
query = f"SHOW TABLES FROM `{effective_catalog}`.`{effective_db}`"
|
||||
else:
|
||||
query = f"SHOW TABLES FROM `{effective_db}`"
|
||||
|
||||
result = await self._execute_query_async(query, effective_db)
|
||||
|
||||
if not result:
|
||||
return []
|
||||
|
||||
# Extract table names
|
||||
tables = []
|
||||
for row in result:
|
||||
if isinstance(row, dict):
|
||||
# Get the value of the first field (usually Tables_in_xxx field)
|
||||
table_name = list(row.values())[0] if row else None
|
||||
if table_name:
|
||||
tables.append(table_name)
|
||||
|
||||
return tables
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table list: {e}")
|
||||
return []
|
||||
|
||||
async def get_catalog_list_async(self) -> List[str]:
|
||||
"""Asynchronously get catalog list"""
|
||||
try:
|
||||
query = "SHOW CATALOGS"
|
||||
result = await self._execute_query_async(query)
|
||||
|
||||
if not result:
|
||||
return []
|
||||
|
||||
# Extract catalog names
|
||||
catalogs = []
|
||||
for row in result:
|
||||
if isinstance(row, dict):
|
||||
# SHOW CATALOGS returns fields including: CatalogId, CatalogName, Type, IsCurrent, CreateTime, LastUpdateTime, Comment
|
||||
# We need to get the CatalogName field (second field)
|
||||
if 'CatalogName' in row:
|
||||
catalog_name = row['CatalogName']
|
||||
else:
|
||||
# If no CatalogName field, try to get the second field
|
||||
values = list(row.values())
|
||||
catalog_name = values[1] if len(values) > 1 else values[0] if values else None
|
||||
|
||||
if catalog_name:
|
||||
catalogs.append(str(catalog_name))
|
||||
|
||||
return catalogs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get catalog list: {e}")
|
||||
return []
|
||||
|
||||
# ==================== Business layer methods (original metadata_tools.py functionality) ====================
|
||||
|
||||
def _format_response(self, success: bool, result: Any = None, error: str = None, message: str = "") -> Dict[str, Any]:
|
||||
"""Format response result"""
|
||||
response_data = {
|
||||
"success": success,
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
if success and result is not None:
|
||||
response_data["result"] = result
|
||||
response_data["message"] = message or "Operation successful"
|
||||
elif not success:
|
||||
response_data["error"] = error or "Unknown error"
|
||||
response_data["message"] = message or "Operation failed"
|
||||
|
||||
return response_data
|
||||
|
||||
async def exec_query_for_mcp(
|
||||
self,
|
||||
sql: str,
|
||||
db_name: str = None,
|
||||
catalog_name: str = None,
|
||||
max_rows: int = 100,
|
||||
timeout: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute SQL query and return results, supports catalog federation queries
|
||||
Unified interface for MCP tools
|
||||
"""
|
||||
logger.info(f"Executing SQL query: {sql}, DB: {db_name}, Catalog: {catalog_name}, MaxRows: {max_rows}, Timeout: {timeout}")
|
||||
|
||||
try:
|
||||
if not sql:
|
||||
return self._format_response(success=False, error="No SQL statement provided", message="Please provide SQL statement to execute")
|
||||
|
||||
# Import query executor
|
||||
from .query_executor import execute_sql_query
|
||||
|
||||
# Call execute_sql_query to execute query
|
||||
exec_result = await execute_sql_query(
|
||||
sql=sql,
|
||||
connection_manager=self.connection_manager,
|
||||
limit=max_rows,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
return exec_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute SQL query: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while executing SQL query")
|
||||
|
||||
async def get_table_schema_for_mcp(
|
||||
self,
|
||||
table_name: str,
|
||||
db_name: str = None,
|
||||
catalog_name: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get detailed schema information for specified table (columns, types, comments, etc.) - MCP interface"""
|
||||
logger.info(f"Getting table schema: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
|
||||
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
try:
|
||||
schema = await self.get_table_schema_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
if not schema:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error="Table does not exist or has no columns",
|
||||
message=f"Unable to get schema for table {catalog_name or 'default'}.{db_name or self.db_name}.{table_name}"
|
||||
)
|
||||
|
||||
return self._format_response(success=True, result=schema)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table schema: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting table schema")
|
||||
|
||||
async def get_db_table_list_for_mcp(
|
||||
self,
|
||||
db_name: str = None,
|
||||
catalog_name: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get list of all table names in specified database - MCP interface"""
|
||||
logger.info(f"Getting database table list: DB: {db_name}, Catalog: {catalog_name}")
|
||||
|
||||
try:
|
||||
tables = await self.get_database_tables_async(db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=tables)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get database table list: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting database table list")
|
||||
|
||||
async def get_db_list_for_mcp(self, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get list of all database names on server - MCP interface"""
|
||||
logger.info(f"Getting database list: Catalog: {catalog_name}")
|
||||
|
||||
try:
|
||||
databases = await self.get_all_databases_async(catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=databases)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get database list: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting database list")
|
||||
|
||||
async def get_table_comment_for_mcp(
|
||||
self,
|
||||
table_name: str,
|
||||
db_name: str = None,
|
||||
catalog_name: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get comment information for specified table - MCP interface"""
|
||||
logger.info(f"Getting table comment: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
|
||||
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
try:
|
||||
comment = self.get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=comment)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table comment: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting table comment")
|
||||
|
||||
async def get_table_column_comments_for_mcp(
|
||||
self,
|
||||
table_name: str,
|
||||
db_name: str = None,
|
||||
catalog_name: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get comment information for all columns in specified table - MCP interface"""
|
||||
logger.info(f"Getting table column comments: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
|
||||
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
try:
|
||||
comments = self.get_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=comments)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table column comments: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting table column comments")
|
||||
|
||||
async def get_table_indexes_for_mcp(
|
||||
self,
|
||||
table_name: str,
|
||||
db_name: str = None,
|
||||
catalog_name: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get index information for specified table - MCP interface"""
|
||||
logger.info(f"Getting table indexes: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
|
||||
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
try:
|
||||
indexes = self.get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=indexes)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table indexes: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting table indexes")
|
||||
|
||||
def _serialize_datetime_objects(self, data):
|
||||
"""Serialize datetime objects to JSON compatible format"""
|
||||
if isinstance(data, list):
|
||||
return [self._serialize_datetime_objects(item) for item in data]
|
||||
elif isinstance(data, dict):
|
||||
return {key: self._serialize_datetime_objects(value) for key, value in data.items()}
|
||||
elif hasattr(data, 'isoformat'): # datetime, date, time objects
|
||||
return data.isoformat()
|
||||
elif hasattr(data, 'strftime'): # pandas Timestamp objects
|
||||
return data.strftime('%Y-%m-%d %H:%M:%S')
|
||||
else:
|
||||
return data
|
||||
|
||||
async def get_recent_audit_logs_for_mcp(self, days: int = 7, limit: int = 100) -> Dict[str, Any]:
|
||||
"""Get recent audit log records - MCP interface"""
|
||||
logger.info(f"Getting audit logs: Days: {days}, Limit: {limit}")
|
||||
|
||||
try:
|
||||
logs_df = self.get_recent_audit_logs(days=days, limit=limit)
|
||||
|
||||
# Convert DataFrame to JSON format
|
||||
if hasattr(logs_df, 'to_dict'):
|
||||
try:
|
||||
logs_data = logs_df.to_dict('records')
|
||||
except Exception as e:
|
||||
logger.warning(f"DataFrame.to_dict failed, trying manual conversion: {e}")
|
||||
# Manually convert DataFrame to records format
|
||||
logs_data = []
|
||||
if not logs_df.empty:
|
||||
for _, row in logs_df.iterrows():
|
||||
logs_data.append(dict(row))
|
||||
# Serialize datetime objects
|
||||
logs_data = self._serialize_datetime_objects(logs_data)
|
||||
else:
|
||||
logs_data = self._serialize_datetime_objects(logs_df)
|
||||
|
||||
return self._format_response(success=True, result=logs_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get audit logs: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting audit logs")
|
||||
|
||||
async def get_catalog_list_for_mcp(self) -> Dict[str, Any]:
|
||||
"""Get Doris catalog list - MCP interface"""
|
||||
logger.info("Getting catalog list")
|
||||
|
||||
try:
|
||||
catalogs = await self.get_catalog_list_async()
|
||||
return self._format_response(success=True, result=catalogs, message="Successfully retrieved catalog list")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get catalog list: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting catalog list")
|
||||
|
||||
|
||||
# ==================== Compatibility aliases ====================
|
||||
|
||||
# For backward compatibility, create MetadataManager alias
|
||||
class MetadataManager:
|
||||
"""
|
||||
Metadata manager - backward compatibility class
|
||||
Actually a wrapper for MetadataExtractor
|
||||
"""
|
||||
|
||||
def __init__(self, connection_manager=None):
|
||||
self.extractor = MetadataExtractor(connection_manager=connection_manager)
|
||||
|
||||
async def exec_query(self, sql: str, db_name: str = None, catalog_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""Execute SQL query and return results, supports catalog federation queries"""
|
||||
return await self.extractor.exec_query_for_mcp(sql, db_name, catalog_name, max_rows, timeout)
|
||||
|
||||
async def get_table_schema(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get detailed schema information for specified table (columns, types, comments, etc.)"""
|
||||
return await self.extractor.get_table_schema_for_mcp(table_name, db_name, catalog_name)
|
||||
|
||||
async def get_db_table_list(self, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get list of all table names in specified database"""
|
||||
return await self.extractor.get_db_table_list_for_mcp(db_name, catalog_name)
|
||||
|
||||
async def get_db_list(self, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get list of all database names on server"""
|
||||
return await self.extractor.get_db_list_for_mcp(catalog_name)
|
||||
|
||||
async def get_table_comment(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get comment information for specified table"""
|
||||
return await self.extractor.get_table_comment_for_mcp(table_name, db_name, catalog_name)
|
||||
|
||||
async def get_table_column_comments(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get comment information for all columns in specified table"""
|
||||
return await self.extractor.get_table_column_comments_for_mcp(table_name, db_name, catalog_name)
|
||||
|
||||
async def get_table_indexes(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get index information for specified table"""
|
||||
return await self.extractor.get_table_indexes_for_mcp(table_name, db_name, catalog_name)
|
||||
|
||||
async def get_recent_audit_logs(self, days: int = 7, limit: int = 100) -> Dict[str, Any]:
|
||||
"""Get recent audit log records"""
|
||||
return await self.extractor.get_recent_audit_logs_for_mcp(days, limit)
|
||||
|
||||
async def get_catalog_list(self) -> Dict[str, Any]:
|
||||
"""Get Doris catalog list"""
|
||||
return await self.extractor.get_catalog_list_for_mcp()
|
||||
861
doris_mcp_server/utils/security.py
Normal file
861
doris_mcp_server/utils/security.py
Normal file
@@ -0,0 +1,861 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Doris Security Management Module
|
||||
Implements enterprise-level authentication, authorization, SQL security validation and data masking functionality
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import sqlparse
|
||||
from sqlparse.sql import Statement
|
||||
from sqlparse.tokens import Keyword, Name
|
||||
|
||||
|
||||
class SecurityLevel(Enum):
|
||||
"""Security level enumeration"""
|
||||
|
||||
PUBLIC = "public"
|
||||
INTERNAL = "internal"
|
||||
CONFIDENTIAL = "confidential"
|
||||
SECRET = "secret"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthContext:
|
||||
"""Authentication context"""
|
||||
|
||||
user_id: str
|
||||
roles: list[str]
|
||||
permissions: list[str]
|
||||
session_id: str
|
||||
login_time: datetime | None = None
|
||||
last_activity: datetime | None = None
|
||||
security_level: SecurityLevel = SecurityLevel.INTERNAL
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Validation result"""
|
||||
|
||||
is_valid: bool
|
||||
error_message: str | None = None
|
||||
risk_level: str = "low"
|
||||
blocked_operations: list[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.blocked_operations is None:
|
||||
self.blocked_operations = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskingRule:
|
||||
"""Data masking rule"""
|
||||
|
||||
column_pattern: str
|
||||
algorithm: str
|
||||
parameters: dict[str, Any]
|
||||
security_level: SecurityLevel
|
||||
|
||||
|
||||
class DorisSecurityManager:
|
||||
"""Doris security manager
|
||||
|
||||
Provides complete security control functionality, including authentication, authorization, SQL security validation and data masking
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize security components
|
||||
self.auth_provider = AuthenticationProvider(config)
|
||||
self.authz_provider = AuthorizationProvider(config)
|
||||
self.sql_validator = SQLSecurityValidator(config)
|
||||
self.masking_processor = DataMaskingProcessor(config)
|
||||
|
||||
# Security rule configuration
|
||||
self.blocked_keywords = self._load_blocked_keywords()
|
||||
self.sensitive_tables = self._load_sensitive_tables()
|
||||
self.masking_rules = self._load_masking_rules()
|
||||
|
||||
def _load_blocked_keywords(self) -> set[str]:
|
||||
"""Load blocked SQL keywords"""
|
||||
default_blocked = {
|
||||
"DROP",
|
||||
"DELETE",
|
||||
"TRUNCATE",
|
||||
"ALTER",
|
||||
"CREATE",
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"GRANT",
|
||||
"REVOKE",
|
||||
"EXEC",
|
||||
"EXECUTE",
|
||||
"SHUTDOWN",
|
||||
"KILL",
|
||||
}
|
||||
|
||||
# Load custom rules from configuration file
|
||||
if hasattr(self.config, 'get'):
|
||||
custom_blocked = set(self.config.get("blocked_keywords", []))
|
||||
else:
|
||||
custom_blocked = set()
|
||||
|
||||
return default_blocked.union(custom_blocked)
|
||||
|
||||
def _load_sensitive_tables(self) -> dict[str, SecurityLevel]:
|
||||
"""Load sensitive table configuration"""
|
||||
default_tables = {
|
||||
"user_info": SecurityLevel.CONFIDENTIAL,
|
||||
"payment_records": SecurityLevel.SECRET,
|
||||
"employee_data": SecurityLevel.CONFIDENTIAL,
|
||||
"public_reports": SecurityLevel.PUBLIC,
|
||||
}
|
||||
|
||||
if hasattr(self.config, 'get'):
|
||||
config_tables = self.config.get("sensitive_tables", {})
|
||||
# Convert string values to SecurityLevel enum
|
||||
for table_name, level in config_tables.items():
|
||||
if isinstance(level, str):
|
||||
try:
|
||||
default_tables[table_name] = SecurityLevel(level.lower())
|
||||
except ValueError:
|
||||
default_tables[table_name] = SecurityLevel.INTERNAL
|
||||
else:
|
||||
default_tables[table_name] = level
|
||||
return default_tables
|
||||
else:
|
||||
return default_tables
|
||||
|
||||
def _load_masking_rules(self) -> list[MaskingRule]:
|
||||
"""Load data masking rules"""
|
||||
default_rules = [
|
||||
MaskingRule(
|
||||
column_pattern=r".*phone.*|.*mobile.*",
|
||||
algorithm="phone_mask",
|
||||
parameters={"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4},
|
||||
security_level=SecurityLevel.INTERNAL,
|
||||
),
|
||||
MaskingRule(
|
||||
column_pattern=r".*email.*",
|
||||
algorithm="email_mask",
|
||||
parameters={"mask_char": "*"},
|
||||
security_level=SecurityLevel.INTERNAL,
|
||||
),
|
||||
MaskingRule(
|
||||
column_pattern=r".*id_card.*|.*identity.*",
|
||||
algorithm="id_mask",
|
||||
parameters={"mask_char": "*", "keep_prefix": 6, "keep_suffix": 4},
|
||||
security_level=SecurityLevel.CONFIDENTIAL,
|
||||
),
|
||||
]
|
||||
|
||||
# Load custom rules from configuration
|
||||
custom_rules = []
|
||||
if hasattr(self.config, 'get'):
|
||||
custom_rules = self.config.get("masking_rules", [])
|
||||
elif hasattr(self.config, 'security') and hasattr(self.config.security, 'masking_rules'):
|
||||
custom_rules = self.config.security.masking_rules
|
||||
|
||||
for rule_config in custom_rules:
|
||||
if isinstance(rule_config, dict):
|
||||
default_rules.append(MaskingRule(**rule_config))
|
||||
elif isinstance(rule_config, MaskingRule):
|
||||
default_rules.append(rule_config)
|
||||
|
||||
return default_rules
|
||||
|
||||
async def authenticate_request(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Validate request authentication information"""
|
||||
return await self.auth_provider.authenticate(auth_info)
|
||||
|
||||
async def authorize_resource_access(
|
||||
self, auth_context: AuthContext, resource_uri: str
|
||||
) -> bool:
|
||||
"""Validate resource access permissions"""
|
||||
return await self.authz_provider.check_permission(
|
||||
auth_context, resource_uri, "read"
|
||||
)
|
||||
|
||||
async def validate_sql_security(
|
||||
self, sql: str, auth_context: AuthContext
|
||||
) -> ValidationResult:
|
||||
"""Validate SQL query security"""
|
||||
return await self.sql_validator.validate(sql, auth_context)
|
||||
|
||||
async def apply_data_masking(
|
||||
self, data: list[dict[str, Any]], auth_context: AuthContext
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Apply data masking processing"""
|
||||
return await self.masking_processor.process(data, auth_context)
|
||||
|
||||
|
||||
class AuthenticationProvider:
|
||||
"""Authentication provider"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.session_cache = {}
|
||||
|
||||
async def authenticate(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Perform identity authentication"""
|
||||
auth_type = auth_info.get("type", "token")
|
||||
|
||||
if auth_type == "token":
|
||||
return await self._authenticate_token(auth_info)
|
||||
elif auth_type == "basic":
|
||||
return await self._authenticate_basic(auth_info)
|
||||
else:
|
||||
raise ValueError(f"Unsupported authentication type: {auth_type}")
|
||||
|
||||
async def _authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Token authentication"""
|
||||
token = auth_info.get("token")
|
||||
if not token:
|
||||
raise ValueError("Missing authentication token")
|
||||
|
||||
# Validate token (simplified implementation, should validate JWT or query authentication service in practice)
|
||||
user_info = await self._validate_token(token)
|
||||
|
||||
return AuthContext(
|
||||
user_id=user_info["user_id"],
|
||||
roles=user_info["roles"],
|
||||
permissions=user_info["permissions"],
|
||||
session_id=auth_info.get("session_id", "default"),
|
||||
login_time=datetime.utcnow(),
|
||||
security_level=SecurityLevel(user_info.get("security_level", "internal")),
|
||||
)
|
||||
|
||||
async def _authenticate_basic(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Basic authentication (username password)"""
|
||||
username = auth_info.get("username")
|
||||
password = auth_info.get("password")
|
||||
|
||||
if not username or not password:
|
||||
raise ValueError("Missing username or password")
|
||||
|
||||
# Validate username password (simplified implementation)
|
||||
user_info = await self._validate_credentials(username, password)
|
||||
|
||||
return AuthContext(
|
||||
user_id=user_info["user_id"],
|
||||
roles=user_info["roles"],
|
||||
permissions=user_info["permissions"],
|
||||
session_id=auth_info.get("session_id", "default"),
|
||||
login_time=datetime.utcnow(),
|
||||
security_level=SecurityLevel(user_info.get("security_level", "internal")),
|
||||
)
|
||||
|
||||
async def _validate_token(self, token: str) -> dict[str, Any]:
|
||||
"""Validate token validity"""
|
||||
# Simplified implementation for testing, should parse JWT or query authentication service in practice
|
||||
valid_tokens = {
|
||||
"valid_token_123": {
|
||||
"user_id": "test_user",
|
||||
"roles": ["data_analyst"],
|
||||
"permissions": ["read_data"],
|
||||
"security_level": SecurityLevel.INTERNAL,
|
||||
},
|
||||
"admin_token_456": {
|
||||
"user_id": "admin_user",
|
||||
"roles": ["data_admin"],
|
||||
"permissions": ["admin"],
|
||||
"security_level": SecurityLevel.SECRET,
|
||||
}
|
||||
}
|
||||
|
||||
if token in valid_tokens:
|
||||
return valid_tokens[token]
|
||||
else:
|
||||
raise ValueError("Invalid token")
|
||||
|
||||
async def _validate_credentials(
|
||||
self, username: str, password: str
|
||||
) -> dict[str, Any]:
|
||||
"""Validate user credentials"""
|
||||
# Simplified implementation for testing, should query user database in practice
|
||||
valid_users = {
|
||||
"admin": {
|
||||
"password": "admin123",
|
||||
"user_id": "admin_user",
|
||||
"roles": ["data_admin"],
|
||||
"permissions": ["admin", "read_data", "write_data"],
|
||||
"security_level": SecurityLevel.SECRET,
|
||||
},
|
||||
"analyst": {
|
||||
"password": "analyst123",
|
||||
"user_id": "analyst_user",
|
||||
"roles": ["data_analyst"],
|
||||
"permissions": ["read_data"],
|
||||
"security_level": SecurityLevel.INTERNAL,
|
||||
}
|
||||
}
|
||||
|
||||
if username in valid_users and valid_users[username]["password"] == password:
|
||||
user_info = valid_users[username].copy()
|
||||
del user_info["password"] # Remove password from returned info
|
||||
return user_info
|
||||
else:
|
||||
raise ValueError("Incorrect username or password")
|
||||
|
||||
|
||||
class AuthorizationProvider:
|
||||
"""Authorization provider"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.permission_cache = {}
|
||||
|
||||
# Load sensitive tables configuration
|
||||
self.sensitive_tables = self._load_sensitive_tables()
|
||||
|
||||
def _load_sensitive_tables(self) -> dict[str, SecurityLevel]:
|
||||
"""Load sensitive table configuration"""
|
||||
default_tables = {
|
||||
"user_info": SecurityLevel.CONFIDENTIAL,
|
||||
"payment_records": SecurityLevel.SECRET,
|
||||
"employee_data": SecurityLevel.CONFIDENTIAL,
|
||||
"public_reports": SecurityLevel.PUBLIC,
|
||||
}
|
||||
|
||||
if hasattr(self.config, 'get'):
|
||||
config_tables = self.config.get("sensitive_tables", {})
|
||||
# Convert string values to SecurityLevel enum
|
||||
for table_name, level in config_tables.items():
|
||||
if isinstance(level, str):
|
||||
try:
|
||||
default_tables[table_name] = SecurityLevel(level.lower())
|
||||
except ValueError:
|
||||
default_tables[table_name] = SecurityLevel.INTERNAL
|
||||
else:
|
||||
default_tables[table_name] = level
|
||||
return default_tables
|
||||
else:
|
||||
return default_tables
|
||||
|
||||
async def check_permission(
|
||||
self, auth_context: AuthContext, resource_uri: str, action: str
|
||||
) -> bool:
|
||||
"""Check permissions"""
|
||||
# Parse resource information
|
||||
resource_info = self._parse_resource_uri(resource_uri)
|
||||
|
||||
# First check security level - this is mandatory
|
||||
if not await self._check_security_level_permission(auth_context, resource_info):
|
||||
return False
|
||||
|
||||
# Then check role-based permissions
|
||||
if await self._check_role_permission(auth_context, resource_info, action):
|
||||
return True
|
||||
|
||||
# Finally check user-based permissions
|
||||
if await self._check_user_permission(auth_context, resource_info, action):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _parse_resource_uri(self, uri: str) -> dict[str, str]:
|
||||
"""Parse resource URI"""
|
||||
parts = uri.split("/")
|
||||
if len(parts) >= 3:
|
||||
return {
|
||||
"type": parts[2], # table, view, etc.
|
||||
"name": parts[3] if len(parts) > 3 else "",
|
||||
"schema": parts[4] if len(parts) > 4 else "default",
|
||||
}
|
||||
return {"type": "unknown", "name": "", "schema": "default"}
|
||||
|
||||
async def _check_role_permission(
|
||||
self, auth_context: AuthContext, resource_info: dict[str, str], action: str
|
||||
) -> bool:
|
||||
"""Check role-based permissions"""
|
||||
# Role permission mapping
|
||||
role_permissions = {
|
||||
"data_analyst": {"table": ["read"], "view": ["read"]},
|
||||
"data_admin": {
|
||||
"table": ["read", "write", "admin"],
|
||||
"view": ["read", "write", "admin"],
|
||||
},
|
||||
}
|
||||
|
||||
for role in auth_context.roles:
|
||||
role_perms = role_permissions.get(role, {})
|
||||
resource_perms = role_perms.get(resource_info["type"], [])
|
||||
if action in resource_perms:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _check_user_permission(
|
||||
self, auth_context: AuthContext, resource_info: dict[str, str], action: str
|
||||
) -> bool:
|
||||
"""Check user-based permissions"""
|
||||
# User-specific permission check
|
||||
if "admin" in auth_context.permissions:
|
||||
return True
|
||||
|
||||
if action == "read" and "read_data" in auth_context.permissions:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _check_security_level_permission(
|
||||
self, auth_context: AuthContext, resource_info: dict[str, str]
|
||||
) -> bool:
|
||||
"""Check security level permissions"""
|
||||
# Get resource security level
|
||||
resource_security_level = self._get_resource_security_level(resource_info)
|
||||
|
||||
# Check if user security level is sufficient
|
||||
security_hierarchy = {
|
||||
SecurityLevel.PUBLIC: 0,
|
||||
SecurityLevel.INTERNAL: 1,
|
||||
SecurityLevel.CONFIDENTIAL: 2,
|
||||
SecurityLevel.SECRET: 3,
|
||||
}
|
||||
|
||||
user_level = security_hierarchy.get(auth_context.security_level, 0)
|
||||
resource_level = security_hierarchy.get(resource_security_level, 0)
|
||||
|
||||
# User must have higher or equal security level to access resource
|
||||
return user_level >= resource_level
|
||||
|
||||
def _get_resource_security_level(
|
||||
self, resource_info: dict[str, str]
|
||||
) -> SecurityLevel:
|
||||
"""Get resource security level"""
|
||||
# Get table security level from configuration
|
||||
table_name = resource_info.get("name", "")
|
||||
|
||||
# Use the loaded sensitive tables
|
||||
sensitive_tables = self.sensitive_tables
|
||||
|
||||
# Convert string values to SecurityLevel enum if needed
|
||||
security_level = sensitive_tables.get(table_name, SecurityLevel.INTERNAL)
|
||||
if isinstance(security_level, str):
|
||||
try:
|
||||
security_level = SecurityLevel(security_level.lower())
|
||||
except ValueError:
|
||||
security_level = SecurityLevel.INTERNAL
|
||||
|
||||
return security_level
|
||||
|
||||
|
||||
class SQLSecurityValidator:
|
||||
"""SQL security validator"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Handle DorisConfig object or dictionary configuration
|
||||
if hasattr(config, 'get'):
|
||||
# Dictionary configuration
|
||||
self.blocked_keywords = set(config.get("blocked_keywords", []))
|
||||
self.max_query_complexity = config.get("max_query_complexity", 100)
|
||||
else:
|
||||
# DorisConfig object, use default values
|
||||
self.blocked_keywords = set(["DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE"])
|
||||
self.max_query_complexity = 100
|
||||
|
||||
async def validate(self, sql: str, auth_context: AuthContext) -> ValidationResult:
|
||||
"""Validate SQL query security"""
|
||||
try:
|
||||
# Parse SQL statement
|
||||
parsed = sqlparse.parse(sql)[0]
|
||||
|
||||
# Check blocked operations first (more specific)
|
||||
keyword_result = await self._check_blocked_keywords(parsed)
|
||||
if not keyword_result.is_valid:
|
||||
return keyword_result
|
||||
|
||||
# Check SQL injection risks
|
||||
injection_result = await self._check_sql_injection(sql, parsed)
|
||||
if not injection_result.is_valid:
|
||||
return injection_result
|
||||
|
||||
# Check query complexity
|
||||
complexity_result = await self._check_query_complexity(parsed)
|
||||
if not complexity_result.is_valid:
|
||||
return complexity_result
|
||||
|
||||
# Check table access permissions
|
||||
table_result = await self._check_table_access(parsed, auth_context)
|
||||
if not table_result.is_valid:
|
||||
return table_result
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"SQL security validation failed: {e}")
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"SQL parsing error: {str(e)}",
|
||||
risk_level="high",
|
||||
)
|
||||
|
||||
async def _check_sql_injection(
|
||||
self, sql: str, parsed: Statement
|
||||
) -> ValidationResult:
|
||||
"""Check SQL injection risks"""
|
||||
# Check common SQL injection patterns
|
||||
injection_patterns = [
|
||||
r"(\s|^)(union|select|insert|update|delete|drop|create|alter)\s+.*\s+(union|select|insert|update|delete|drop|create|alter)",
|
||||
r"(\s|^)(or|and)\s+\d+\s*=\s*\d+",
|
||||
r"(\s|^)(or|and)\s+['\"].*['\"]",
|
||||
r";\s*(drop|delete|truncate|alter|create)",
|
||||
r"(exec|execute|sp_|xp_)",
|
||||
r"(script|javascript|vbscript)",
|
||||
r"(char|ascii|substring|concat)\s*\(",
|
||||
]
|
||||
|
||||
sql_lower = sql.lower()
|
||||
for pattern in injection_patterns:
|
||||
if re.search(pattern, sql_lower, re.IGNORECASE):
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Potential SQL injection risk detected",
|
||||
risk_level="high",
|
||||
)
|
||||
|
||||
# Check suspicious quotes and comments
|
||||
if self._has_suspicious_quotes_or_comments(sql):
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Suspicious quote or comment pattern detected",
|
||||
risk_level="medium",
|
||||
)
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
def _has_suspicious_quotes_or_comments(self, sql: str) -> bool:
|
||||
"""Check suspicious quote and comment patterns"""
|
||||
# Check unmatched quotes
|
||||
single_quotes = sql.count("'")
|
||||
double_quotes = sql.count('"')
|
||||
|
||||
if single_quotes % 2 != 0 or double_quotes % 2 != 0:
|
||||
return True
|
||||
|
||||
# Check SQL comments
|
||||
if "--" in sql or "/*" in sql:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _check_blocked_keywords(self, parsed: Statement) -> ValidationResult:
|
||||
"""Check blocked keywords"""
|
||||
blocked_operations = []
|
||||
|
||||
# Check all tokens in the parsed statement
|
||||
for token in parsed.flatten():
|
||||
# Check if token is a keyword (including DML/DDL) or name that matches blocked operations
|
||||
if (token.ttype is Keyword or
|
||||
token.ttype is Name or
|
||||
(token.ttype and str(token.ttype).startswith('Token.Keyword'))):
|
||||
token_value = token.value.upper().strip()
|
||||
if token_value in self.blocked_keywords:
|
||||
blocked_operations.append(token_value)
|
||||
# Also check for DDL/DML keywords in token values
|
||||
elif hasattr(token, 'value') and token.value:
|
||||
token_value = token.value.upper().strip()
|
||||
for blocked_keyword in self.blocked_keywords:
|
||||
if blocked_keyword in token_value:
|
||||
blocked_operations.append(blocked_keyword)
|
||||
|
||||
if blocked_operations:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"Contains blocked operations: {', '.join(set(blocked_operations))}",
|
||||
risk_level="high",
|
||||
blocked_operations=list(set(blocked_operations)),
|
||||
)
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
async def _check_query_complexity(self, parsed: Statement) -> ValidationResult:
|
||||
"""Check query complexity"""
|
||||
complexity_score = 0
|
||||
|
||||
# Calculate complexity score
|
||||
for token in parsed.flatten():
|
||||
if token.ttype is Keyword:
|
||||
keyword = token.value.upper()
|
||||
if keyword in ["JOIN", "INNER", "LEFT", "RIGHT", "FULL"]:
|
||||
complexity_score += 10
|
||||
elif keyword in ["UNION", "INTERSECT", "EXCEPT"]:
|
||||
complexity_score += 15
|
||||
elif keyword in ["GROUP BY", "ORDER BY", "HAVING"]:
|
||||
complexity_score += 5
|
||||
elif keyword in ["SUBQUERY", "EXISTS", "IN"]:
|
||||
complexity_score += 8
|
||||
|
||||
if complexity_score > self.max_query_complexity:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"Query complexity too high (score: {complexity_score}, limit: {self.max_query_complexity})",
|
||||
risk_level="medium",
|
||||
)
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
async def _check_table_access(
|
||||
self, parsed: Statement, auth_context: AuthContext
|
||||
) -> ValidationResult:
|
||||
"""Check table access permissions"""
|
||||
# Extract table names from query
|
||||
tables = self._extract_table_names(parsed)
|
||||
|
||||
# Check access permissions for each table
|
||||
unauthorized_tables = []
|
||||
for table in tables:
|
||||
# Should call authorization provider to check permissions
|
||||
# Simplified implementation, assume some tables require special permissions
|
||||
if (
|
||||
table.lower() in ["sensitive_data", "admin_logs"]
|
||||
and "admin" not in auth_context.roles
|
||||
):
|
||||
unauthorized_tables.append(table)
|
||||
|
||||
if unauthorized_tables:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"No access to tables: {', '.join(unauthorized_tables)}",
|
||||
risk_level="high",
|
||||
)
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
def _extract_table_names(self, parsed: Statement) -> list[str]:
|
||||
"""Extract table names from SQL statement"""
|
||||
tables = []
|
||||
|
||||
# Simplified table name extraction logic
|
||||
tokens = list(parsed.flatten())
|
||||
for i, token in enumerate(tokens):
|
||||
if token.ttype is Keyword and token.value.upper() == "FROM":
|
||||
# Find table name after FROM
|
||||
for j in range(i + 1, len(tokens)):
|
||||
next_token = tokens[j]
|
||||
if next_token.ttype is Name:
|
||||
tables.append(next_token.value)
|
||||
break
|
||||
elif next_token.ttype is Keyword:
|
||||
break
|
||||
|
||||
return tables
|
||||
|
||||
|
||||
class DataMaskingProcessor:
|
||||
"""Data masking processor"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.masking_algorithms = self._init_masking_algorithms()
|
||||
self.masking_rules = self._load_masking_rules()
|
||||
|
||||
def _load_masking_rules(self) -> list[MaskingRule]:
|
||||
"""Load data masking rules"""
|
||||
default_rules = [
|
||||
MaskingRule(
|
||||
column_pattern=r".*phone.*|.*mobile.*",
|
||||
algorithm="phone_mask",
|
||||
parameters={"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4},
|
||||
security_level=SecurityLevel.INTERNAL,
|
||||
),
|
||||
MaskingRule(
|
||||
column_pattern=r".*email.*",
|
||||
algorithm="email_mask",
|
||||
parameters={"mask_char": "*"},
|
||||
security_level=SecurityLevel.INTERNAL,
|
||||
),
|
||||
MaskingRule(
|
||||
column_pattern=r".*id_card.*|.*identity.*",
|
||||
algorithm="id_mask",
|
||||
parameters={"mask_char": "*", "keep_prefix": 6, "keep_suffix": 4},
|
||||
security_level=SecurityLevel.CONFIDENTIAL,
|
||||
),
|
||||
]
|
||||
|
||||
# Load custom rules from configuration
|
||||
if hasattr(self.config, 'get'):
|
||||
custom_rules = self.config.get("masking_rules", [])
|
||||
for rule_config in custom_rules:
|
||||
if isinstance(rule_config, dict):
|
||||
# Convert string security level to enum
|
||||
if 'security_level' in rule_config and isinstance(rule_config['security_level'], str):
|
||||
try:
|
||||
rule_config['security_level'] = SecurityLevel(rule_config['security_level'].lower())
|
||||
except ValueError:
|
||||
rule_config['security_level'] = SecurityLevel.INTERNAL
|
||||
default_rules.append(MaskingRule(**rule_config))
|
||||
elif isinstance(rule_config, MaskingRule):
|
||||
default_rules.append(rule_config)
|
||||
|
||||
return default_rules
|
||||
|
||||
def _init_masking_algorithms(self) -> dict[str, callable]:
|
||||
"""Initialize masking algorithms"""
|
||||
return {
|
||||
"phone_mask": self._mask_phone,
|
||||
"email_mask": self._mask_email,
|
||||
"id_mask": self._mask_id_card,
|
||||
"name_mask": self._mask_name,
|
||||
"partial_mask": self._mask_partial,
|
||||
}
|
||||
|
||||
async def process(
|
||||
self, data: list[dict[str, Any]], auth_context: AuthContext
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Process data masking"""
|
||||
if not data:
|
||||
return data
|
||||
|
||||
# Get applicable masking rules
|
||||
applicable_rules = self._get_applicable_rules(auth_context)
|
||||
|
||||
masked_data = []
|
||||
for row in data:
|
||||
masked_row = {}
|
||||
for column, value in row.items():
|
||||
masked_value = await self._apply_masking_rules(
|
||||
column, value, applicable_rules
|
||||
)
|
||||
masked_row[column] = masked_value
|
||||
masked_data.append(masked_row)
|
||||
|
||||
return masked_data
|
||||
|
||||
def _get_applicable_rules(self, auth_context: AuthContext) -> list[MaskingRule]:
|
||||
"""Get applicable masking rules"""
|
||||
applicable_rules = []
|
||||
|
||||
for rule in self.masking_rules:
|
||||
# Decide whether to apply masking rules based on user security level
|
||||
if self._should_apply_rule(rule, auth_context):
|
||||
applicable_rules.append(rule)
|
||||
|
||||
return applicable_rules
|
||||
|
||||
def _should_apply_rule(self, rule: MaskingRule, auth_context: AuthContext) -> bool:
|
||||
"""Determine whether masking rule should be applied"""
|
||||
# Admin users can see original data
|
||||
if "admin" in auth_context.roles:
|
||||
return False
|
||||
|
||||
# Decide based on security level
|
||||
security_hierarchy = {
|
||||
SecurityLevel.PUBLIC: 0,
|
||||
SecurityLevel.INTERNAL: 1,
|
||||
SecurityLevel.CONFIDENTIAL: 2,
|
||||
SecurityLevel.SECRET: 3,
|
||||
}
|
||||
|
||||
user_level = security_hierarchy.get(auth_context.security_level, 0)
|
||||
rule_level = security_hierarchy.get(rule.security_level, 0)
|
||||
|
||||
# Apply masking if user level is less than or equal to rule level
|
||||
return user_level <= rule_level
|
||||
|
||||
async def _apply_masking_rules(
|
||||
self, column: str, value: Any, rules: list[MaskingRule]
|
||||
) -> Any:
|
||||
"""Apply masking rules"""
|
||||
if value is None:
|
||||
return value
|
||||
|
||||
for rule in rules:
|
||||
if re.match(rule.column_pattern, column, re.IGNORECASE):
|
||||
algorithm = self.masking_algorithms.get(rule.algorithm)
|
||||
if algorithm:
|
||||
return algorithm(str(value), rule.parameters)
|
||||
|
||||
return value
|
||||
|
||||
def _mask_phone(self, value: str, params: dict[str, Any]) -> str:
|
||||
"""Phone number masking"""
|
||||
if len(value) < 7:
|
||||
return value
|
||||
|
||||
mask_char = params.get("mask_char", "*")
|
||||
keep_prefix = params.get("keep_prefix", 3)
|
||||
keep_suffix = params.get("keep_suffix", 4)
|
||||
|
||||
if len(value) <= keep_prefix + keep_suffix:
|
||||
return mask_char * len(value)
|
||||
|
||||
prefix = value[:keep_prefix]
|
||||
suffix = value[-keep_suffix:]
|
||||
middle_length = len(value) - keep_prefix - keep_suffix
|
||||
|
||||
return prefix + mask_char * middle_length + suffix
|
||||
|
||||
def _mask_email(self, value: str, params: dict[str, Any]) -> str:
|
||||
"""Email masking"""
|
||||
if "@" not in value:
|
||||
return value
|
||||
|
||||
mask_char = params.get("mask_char", "*")
|
||||
local, domain = value.split("@", 1)
|
||||
|
||||
if len(local) <= 2:
|
||||
masked_local = mask_char * len(local)
|
||||
else:
|
||||
masked_local = local[0] + mask_char * (len(local) - 2) + local[-1]
|
||||
|
||||
return f"{masked_local}@{domain}"
|
||||
|
||||
def _mask_id_card(self, value: str, params: dict[str, Any]) -> str:
|
||||
"""ID card number masking"""
|
||||
if len(value) < 10:
|
||||
return value
|
||||
|
||||
mask_char = params.get("mask_char", "*")
|
||||
keep_prefix = params.get("keep_prefix", 6)
|
||||
keep_suffix = params.get("keep_suffix", 4)
|
||||
|
||||
if len(value) <= keep_prefix + keep_suffix:
|
||||
return mask_char * len(value)
|
||||
|
||||
prefix = value[:keep_prefix]
|
||||
suffix = value[-keep_suffix:]
|
||||
middle_length = len(value) - keep_prefix - keep_suffix
|
||||
|
||||
return prefix + mask_char * middle_length + suffix
|
||||
|
||||
def _mask_name(self, value: str, params: dict[str, Any]) -> str:
|
||||
"""Name masking"""
|
||||
if len(value) <= 1:
|
||||
return value
|
||||
|
||||
mask_char = params.get("mask_char", "*")
|
||||
|
||||
if len(value) == 2:
|
||||
return value[0] + mask_char
|
||||
else:
|
||||
return value[0] + mask_char * (len(value) - 2) + value[-1]
|
||||
|
||||
def _mask_partial(self, value: str, params: dict[str, Any]) -> str:
|
||||
"""Partial masking"""
|
||||
mask_char = params.get("mask_char", "*")
|
||||
mask_ratio = params.get("mask_ratio", 0.5)
|
||||
|
||||
mask_length = int(len(value) * mask_ratio)
|
||||
start_pos = (len(value) - mask_length) // 2
|
||||
|
||||
result = list(value)
|
||||
for i in range(start_pos, start_pos + mask_length):
|
||||
if i < len(result):
|
||||
result[i] = mask_char
|
||||
|
||||
return "".join(result)
|
||||
@@ -1,352 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
SQL Execution Tool
|
||||
|
||||
Responsible for executing SQL queries and handling results
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
import time
|
||||
from typing import Dict, Any
|
||||
import re
|
||||
import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger("doris-mcp.sql-executor")
|
||||
|
||||
# Add environment variable control for whether to perform SQL security checks
|
||||
ENABLE_SQL_SECURITY_CHECK = os.environ.get('ENABLE_SQL_SECURITY_CHECK', 'true').lower() == 'true'
|
||||
|
||||
async def execute_sql_query(ctx) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute SQL query and return results
|
||||
|
||||
Args:
|
||||
ctx: Context object or dictionary containing request parameters
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Execution result
|
||||
"""
|
||||
try:
|
||||
# Support the case where the passed argument is a dictionary
|
||||
if isinstance(ctx, dict) and 'params' in ctx:
|
||||
params = ctx['params']
|
||||
else:
|
||||
params = ctx.params
|
||||
|
||||
sql = params.get("sql")
|
||||
db_name = params.get("db_name", os.getenv("DB_DATABASE", ""))
|
||||
catalog_name = params.get("catalog_name", None) # Add catalog parameter support
|
||||
max_rows = params.get("max_rows", 1000) # Maximum number of rows to return
|
||||
timeout = params.get("timeout", 30) # Timeout in seconds
|
||||
|
||||
if not sql:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"success": False,
|
||||
"error": "Missing SQL parameter",
|
||||
"message": "Please provide the SQL query to execute"
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# First check SQL security
|
||||
security_result = await _check_sql_security(sql)
|
||||
if not security_result.get("is_safe", False):
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"success": False,
|
||||
"error": "SQL security check failed",
|
||||
"message": "Query contains unsafe operations and cannot be executed",
|
||||
"security_issues": security_result.get("security_issues", [])
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Import database connection tool
|
||||
from doris_mcp_server.utils.db import execute_query
|
||||
|
||||
if not sql:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"success": False,
|
||||
"error": "Missing SQL parameter",
|
||||
"message": "Please provide the SQL query to execute"
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Ensure SELECT statements include a LIMIT clause
|
||||
sql_lower = sql.lower().strip()
|
||||
if sql_lower.startswith("select") and "limit" not in sql_lower:
|
||||
sql = sql.rstrip(";") + f" LIMIT {max_rows};"
|
||||
|
||||
# Start timer
|
||||
start_time = time.time()
|
||||
|
||||
# Execute query
|
||||
try:
|
||||
# For federation queries, SQL must use three-part naming: catalog_name.db_name.table_name
|
||||
# This is enforced at the tool description level
|
||||
|
||||
result = execute_query(sql, db_name)
|
||||
|
||||
# Calculate execution time
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# Build return result
|
||||
if isinstance(result, list):
|
||||
# Handle list of query results
|
||||
row_count = len(result)
|
||||
|
||||
# Extract column names
|
||||
if hasattr(result[0], "_fields"):
|
||||
# If it's a named tuple
|
||||
columns = list(result[0]._fields)
|
||||
else:
|
||||
# Otherwise, assume it's a dictionary
|
||||
columns = list(result[0].keys()) if isinstance(result[0], dict) else []
|
||||
|
||||
# Convert results to serializable format
|
||||
data = []
|
||||
for row in result:
|
||||
row_dict = {}
|
||||
if hasattr(row, "_asdict"):
|
||||
# If it's a named tuple
|
||||
row_dict = row._asdict()
|
||||
elif isinstance(row, dict):
|
||||
# If it's a dictionary
|
||||
row_dict = row
|
||||
else:
|
||||
# If it's a list or tuple
|
||||
row_dict = dict(zip(columns, row)) if columns else row
|
||||
|
||||
# Handle special types to make them JSON serializable
|
||||
serialized_row = _serialize_row_data(row_dict)
|
||||
data.append(serialized_row)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"success": True,
|
||||
"sql": sql,
|
||||
"row_count": row_count,
|
||||
"columns": columns,
|
||||
"data": data[:max_rows], # Limit returned rows
|
||||
"execution_time": execution_time,
|
||||
"truncated": row_count > max_rows
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
# Handle other types of results
|
||||
other_response = {
|
||||
"success": True,
|
||||
"sql": sql,
|
||||
"result": str(result),
|
||||
"execution_time": execution_time
|
||||
}
|
||||
other_response = _serialize_row_data(other_response)
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(other_response, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as db_error:
|
||||
error_message = str(db_error)
|
||||
|
||||
# Try to get more detailed error information
|
||||
error_details = {}
|
||||
if "timeout" in error_message.lower():
|
||||
error_details["type"] = "timeout"
|
||||
error_details["suggestion"] = "Query timed out, please optimize SQL or increase timeout"
|
||||
elif "syntax" in error_message.lower():
|
||||
error_details["type"] = "syntax"
|
||||
error_details["suggestion"] = "SQL syntax error, please check syntax"
|
||||
elif "not found" in error_message.lower() or "doesn't exist" in error_message.lower():
|
||||
error_details["type"] = "not_found"
|
||||
error_details["suggestion"] = "Table or column not found, please check table and column names"
|
||||
else:
|
||||
error_details["type"] = "unknown"
|
||||
error_details["suggestion"] = "Please check the SQL statement and try simplifying the query"
|
||||
|
||||
# Create error response
|
||||
error_response = {
|
||||
"success": False,
|
||||
"error": error_message,
|
||||
"error_details": error_details,
|
||||
"sql": sql,
|
||||
"db_name": db_name
|
||||
}
|
||||
|
||||
# Ensure error response is also serializable
|
||||
error_response = _serialize_row_data(error_response)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(error_response, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute SQL query: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
error_response = {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": "Error occurred while executing SQL query"
|
||||
}
|
||||
|
||||
# Ensure error response is also serializable
|
||||
error_response = _serialize_row_data(error_response)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(error_response, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Helper function
|
||||
async def _check_sql_security(sql: str) -> Dict[str, Any]:
|
||||
"""Check SQL security"""
|
||||
# If environment variable is set to disable security check, return safe immediately
|
||||
if not ENABLE_SQL_SECURITY_CHECK:
|
||||
return {
|
||||
"is_safe": True,
|
||||
"security_issues": []
|
||||
}
|
||||
|
||||
# Check if SQL contains dangerous operations
|
||||
sql_lower = sql.lower()
|
||||
|
||||
# Check if it's a read-only query type
|
||||
is_read_only = sql_lower.strip().startswith(("select ", "show ", "desc ", "describe ", "explain "))
|
||||
|
||||
# Define list of dangerous operations (checked for both read-only and non-read-only queries)
|
||||
dangerous_operations = [
|
||||
(r'\bdelete\b', "DELETE operation"),
|
||||
(r'\bdrop\b', "DROP TABLE/DATABASE operation"),
|
||||
(r'\btruncate\b', "TRUNCATE TABLE operation"),
|
||||
(r'\bupdate\b', "UPDATE operation"),
|
||||
(r'\binsert\b', "INSERT operation"),
|
||||
(r'\balter\b', "ALTER TABLE structure operation"),
|
||||
(r'\bcreate\b', "CREATE TABLE/DATABASE operation"),
|
||||
(r'\bgrant\b', "GRANT operation"),
|
||||
(r'\brevoke\b', "REVOKE permission operation"),
|
||||
(r'\bexec\b', "EXECUTE stored procedure"),
|
||||
(r'\bxp_', "Extended stored procedure, potential security risk"),
|
||||
(r'\bshutdown\b', "SHUTDOWN database operation"),
|
||||
(r'\binto\s+outfile\b', "Write to file operation"),
|
||||
(r'\bload_file\b', "Load file operation")
|
||||
]
|
||||
|
||||
# Dangerous operations checked only for non-read-only queries
|
||||
non_readonly_operations = []
|
||||
if not is_read_only:
|
||||
non_readonly_operations = [
|
||||
(r'--', "SQL comment, potential SQL injection"),
|
||||
(r'/\*', "SQL block comment, potential SQL injection")
|
||||
]
|
||||
|
||||
# Check if dangerous operations are included
|
||||
security_issues = []
|
||||
|
||||
# Check dangerous operations applicable to all queries
|
||||
for operation, description in dangerous_operations:
|
||||
if re.search(operation, sql_lower):
|
||||
# For specific keywords in read-only queries, differentiate if used as independent operations
|
||||
if is_read_only and operation in [r'\bcreate\b', r'\bdrop\b', r'\bdelete\b', r'\binsert\b', r'\bupdate\b', r'\balter\b']:
|
||||
# Check if used as DDL/DML keyword, e.g., CREATE TABLE, DROP DATABASE
|
||||
pattern = operation + r'\s+(?:table|database|view|index|procedure|function|trigger|event)'
|
||||
if re.search(pattern, sql_lower):
|
||||
security_issues.append({
|
||||
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
|
||||
"description": description,
|
||||
"severity": "High"
|
||||
})
|
||||
else:
|
||||
security_issues.append({
|
||||
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
|
||||
"description": description,
|
||||
"severity": "High"
|
||||
})
|
||||
|
||||
# Check dangerous operations specific to non-read-only queries
|
||||
for operation, description in non_readonly_operations:
|
||||
if re.search(operation, sql_lower):
|
||||
security_issues.append({
|
||||
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
|
||||
"description": description,
|
||||
"severity": "Medium"
|
||||
})
|
||||
|
||||
return {
|
||||
"is_safe": len(security_issues) == 0,
|
||||
"security_issues": security_issues
|
||||
}
|
||||
|
||||
|
||||
def _serialize_row_data(row_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert special types in row data (like date, time, Decimal) to JSON serializable format
|
||||
|
||||
Args:
|
||||
row_data: Row data dictionary
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Processed serializable dictionary
|
||||
"""
|
||||
serialized_data = {}
|
||||
for key, value in row_data.items():
|
||||
if value is None:
|
||||
serialized_data[key] = None
|
||||
elif isinstance(value, (datetime.date, datetime.datetime)):
|
||||
# Convert date and time types to ISO format string
|
||||
serialized_data[key] = value.isoformat()
|
||||
elif isinstance(value, Decimal):
|
||||
# Convert Decimal type to float
|
||||
serialized_data[key] = float(value)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
# Recursively process elements in list or tuple
|
||||
serialized_data[key] = [
|
||||
_serialize_row_data(item) if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
elif isinstance(value, dict):
|
||||
# Recursively process nested dictionaries
|
||||
serialized_data[key] = _serialize_row_data(value)
|
||||
else:
|
||||
serialized_data[key] = value
|
||||
return serialized_data
|
||||
Reference in New Issue
Block a user