init doris mcp 0.2.0

This commit is contained in:
Yijia Su
2025-05-06 12:56:55 +08:00
parent 9dc25be87a
commit c190f19cb5
23 changed files with 6405 additions and 0 deletions

View File

@@ -0,0 +1 @@
# Mark directory as a package

View File

@@ -0,0 +1,100 @@
import os
import json
import pymysql
import pandas as pd
from typing import Dict, List, Optional, Any
from dotenv import load_dotenv
import re
# Load environment variables
load_dotenv(override=True)
# Database configuration
DB_CONFIG = {
"host": os.getenv("DB_HOST", "localhost"),
"port": int(os.getenv("DB_PORT", "9030")),
"user": os.getenv("DB_USER", "root"),
"password": os.getenv("DB_PASSWORD", ""),
"database": os.getenv("DB_DATABASE", ""),
"charset": "utf8mb4",
"cursorclass": pymysql.cursors.DictCursor
}
def get_db_connection(db_name: Optional[str] = None):
"""
Get database connection
Args:
db_name: Specify the database name to connect to, use default config if None
Returns:
Database connection
"""
if db_name:
# Use default config but override database name
config = DB_CONFIG.copy()
config["database"] = db_name
return pymysql.connect(**config)
else:
# Use default config
return pymysql.connect(**DB_CONFIG)
def get_db_name() -> str:
"""Get the currently configured default database name"""
return DB_CONFIG["database"] or os.getenv("DB_DATABASE", "")
def execute_query(sql, db_name: Optional[str] = None):
"""
Execute SQL query and return results
Args:
sql: SQL query statement
db_name: Specify the database name to connect to, use default config if None
Returns:
Query results
"""
conn = get_db_connection(db_name)
try:
with conn.cursor() as cursor:
# Set connection character set to utf8 before executing query
cursor.execute("SET NAMES utf8")
# Execute the actual query
cursor.execute(sql)
result = cursor.fetchall()
return result
finally:
conn.close()
def execute_query_df(sql, db_name: Optional[str] = None):
"""
Execute SQL query and return pandas DataFrame
Args:
sql: SQL query statement
db_name: Specify the database name to connect to, use default config if None
Returns:
pandas DataFrame
"""
conn = get_db_connection(db_name)
try:
# Use a temporary cursor to execute the query and get results
with conn.cursor() as cursor:
# Set connection character set to utf8 before executing query
cursor.execute("SET NAMES utf8")
# Execute the actual query
cursor.execute(sql)
result = cursor.fetchall()
# If no results, return empty DataFrame
if not result:
return pd.DataFrame()
# Manually convert dict results to DataFrame
df = pd.DataFrame(result)
return df
finally:
conn.close()

View File

@@ -0,0 +1,226 @@
"""
Unified Logging Configuration Module
Provides unified logging configuration, including:
- General logs: Record all program execution information
- Audit logs: Record JSON data for key operations and processing results
- Error logs: Specifically record program exceptions and errors
"""
import os
import sys
import logging
import logging.handlers
from pathlib import Path
from typing import Dict
from datetime import datetime
from dotenv import load_dotenv
# Load environment variables
load_dotenv(override=True)
# Get project root directory
PROJECT_ROOT = Path(__file__).parents[2].absolute()
# Get log configuration from environment variables
LOG_DIR = os.getenv("LOG_DIR", str(PROJECT_ROOT / "logs"))
LOG_PREFIX = os.getenv("LOG_PREFIX", "doris_mcp")
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
LOG_MAX_DAYS = int(os.getenv("LOG_MAX_DAYS", "30"))
# Whether to output logs to the console (should be disabled when running as a service)
CONSOLE_LOGGING = os.getenv("CONSOLE_LOGGING", "false").lower() == "true"
# Whether stdio transport mode is being used
STDIO_MODE = os.getenv("MCP_TRANSPORT_TYPE", "").lower() == "stdio"
def purge_old_logs():
"""Clean up expired log files"""
# --- Only perform cleanup in non-Stdio mode ---
if STDIO_MODE:
return
try:
now = datetime.now()
log_dir = Path(LOG_DIR)
# Check if directory exists and is readable/writable
if not log_dir.is_dir() or not os.access(LOG_DIR, os.W_OK):
if not STDIO_MODE: # Avoid printing to stdout in stdio mode
print(f"Warning: Log directory {LOG_DIR} not accessible, skipping log purge.", file=sys.stderr)
return
for log_file in log_dir.glob(f"{LOG_PREFIX}*.20*"):
# Parse date
file_name = log_file.name
date_str = None
# Try to find the date part
parts = file_name.split('.')
for part in parts:
if part.startswith('20') and len(part) == 8: # 20YYMMDD format
date_str = part
break
if date_str:
try:
file_date = datetime.strptime(date_str, '%Y%m%d')
days_old = (now - file_date).days
if days_old > LOG_MAX_DAYS:
os.remove(log_file)
if not STDIO_MODE:
print(f"Deleted expired log file: {log_file}")
except (ValueError, OSError) as e:
if not STDIO_MODE:
print(f"Error processing log file {file_name}: {e}", file=sys.stderr)
except Exception as e:
if not STDIO_MODE:
print(f"Error cleaning up logs: {e}", file=sys.stderr)
# Force disable console log output if in stdio mode
if STDIO_MODE:
CONSOLE_LOGGING = False
# --- Only create log directory and clean old logs in non-Stdio mode ---
if not STDIO_MODE:
try:
os.makedirs(LOG_DIR, exist_ok=True)
# Clean up expired logs on startup (also moved here, as it only handles file logs)
purge_old_logs()
except OSError as e:
# If directory creation fails (e.g., permission issue), print warning but continue to avoid startup failure
print(f"Warning: Failed to create log directory {LOG_DIR} or purge logs: {e}", file=sys.stderr)
# Log file paths (definition still needed, but files might not be created/used)
LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.log")
AUDIT_LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.audit")
ERROR_LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.error")
# Log level mapping
LOG_LEVELS = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL
}
# Log format
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
AUDIT_FORMAT = '%(asctime)s - %(name)s - %(message)s'
ERROR_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(pathname)s:%(lineno)d - %(message)s'
# Dedicated audit log level
AUDIT = 25 # Level between INFO and WARNING
logging.addLevelName(AUDIT, "AUDIT")
# Logger object cache
_loggers: Dict[str, logging.Logger] = {}
# Handler type mapping, used to ensure no duplicates are added
_handler_types = {
'console': logging.StreamHandler,
'file': logging.handlers.TimedRotatingFileHandler,
'audit': logging.handlers.TimedRotatingFileHandler,
'error': logging.handlers.TimedRotatingFileHandler
}
def get_logger(name: str) -> logging.Logger:
"""
Get a logger with the specified name
Args:
name: Logger name
Returns:
logging.Logger: Configured logger
"""
if name in _loggers:
return _loggers[name]
# Create logger
logger = logging.getLogger(name)
logger.setLevel(LOG_LEVELS.get(LOG_LEVEL, logging.INFO))
# Avoid duplicate logs caused by propagation
logger.propagate = False
# Check if handlers already exist to avoid duplicates
handler_types = set(type(h) for h in logger.handlers)
# Add audit log method
def audit(self, message, *args, **kwargs):
self.log(AUDIT, message, *args, **kwargs)
logger.audit = audit.__get__(logger)
# General log handler - output to console (only if enabled)
if CONSOLE_LOGGING and _handler_types['console'] not in handler_types:
# Use stderr instead of stdout to avoid conflicts with MCP communication
console_handler = logging.StreamHandler(sys.stderr)
console_handler.setFormatter(logging.Formatter(LOG_FORMAT))
logger.addHandler(console_handler)
# --- Only add file handlers in non-Stdio mode ---
if not STDIO_MODE:
# General log handler - daily rotating file
if _handler_types['file'] not in handler_types:
try: # Add try-except block
file_handler = logging.handlers.TimedRotatingFileHandler(
LOG_FILE,
when='midnight',
interval=1,
backupCount=LOG_MAX_DAYS,
encoding='utf-8'
)
file_handler.setFormatter(logging.Formatter(LOG_FORMAT))
file_handler.suffix = "%Y%m%d"
logger.addHandler(file_handler)
except OSError as e:
print(f"Warning: Failed to add file log handler for {LOG_FILE}: {e}", file=sys.stderr)
# Audit log handler - only logs AUDIT level
if _handler_types['audit'] not in handler_types:
try: # Add try-except block
audit_handler = logging.handlers.TimedRotatingFileHandler(
AUDIT_LOG_FILE,
when='midnight',
interval=1,
backupCount=LOG_MAX_DAYS,
encoding='utf-8'
)
audit_handler.setFormatter(logging.Formatter(AUDIT_FORMAT))
audit_handler.suffix = "%Y%m%d"
audit_handler.setLevel(AUDIT)
audit_handler.addFilter(lambda record: record.levelno == AUDIT)
logger.addHandler(audit_handler)
except OSError as e:
print(f"Warning: Failed to add audit log handler for {AUDIT_LOG_FILE}: {e}", file=sys.stderr)
# Error log handler - only logs ERROR level and above
if _handler_types['error'] not in handler_types:
try: # Add try-except block
error_handler = logging.handlers.TimedRotatingFileHandler(
ERROR_LOG_FILE,
when='midnight',
interval=1,
backupCount=LOG_MAX_DAYS,
encoding='utf-8'
)
error_handler.setFormatter(logging.Formatter(ERROR_FORMAT))
error_handler.suffix = "%Y%m%d"
error_handler.setLevel(logging.ERROR)
logger.addHandler(error_handler)
except OSError as e:
print(f"Warning: Failed to add error log handler for {ERROR_LOG_FILE}: {e}", file=sys.stderr)
# Cache logger
_loggers[name] = logger
return logger
# Default logger
logger = get_logger('doris_mcp')
# Audit logger - for recording processing results, business operations, etc.
audit_logger = get_logger('audit')
# Call to clean logs moved after directory creation, and added non-stdio check

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,349 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
SQL Execution Tool
Responsible for executing SQL queries and handling results
"""
import os
import json
import logging
import traceback
import time
from typing import Dict, Any
import re
import datetime
from decimal import Decimal
# Get logger
logger = logging.getLogger("doris-mcp.sql-executor")
# Add environment variable control for whether to perform SQL security checks
ENABLE_SQL_SECURITY_CHECK = os.environ.get('ENABLE_SQL_SECURITY_CHECK', 'true').lower() == 'true'
async def execute_sql_query(ctx) -> Dict[str, Any]:
"""
Execute SQL query and return results
Args:
ctx: Context object or dictionary containing request parameters
Returns:
Dict[str, Any]: Execution result
"""
try:
# Support the case where the passed argument is a dictionary
if isinstance(ctx, dict) and 'params' in ctx:
params = ctx['params']
else:
params = ctx.params
sql = params.get("sql")
db_name = params.get("db_name", os.getenv("DB_DATABASE", ""))
max_rows = params.get("max_rows", 1000) # Maximum number of rows to return
timeout = params.get("timeout", 30) # Timeout in seconds
if not sql:
return {
"content": [
{
"type": "text",
"text": json.dumps({
"success": False,
"error": "Missing SQL parameter",
"message": "Please provide the SQL query to execute"
}, ensure_ascii=False)
}
]
}
# First check SQL security
security_result = await _check_sql_security(sql)
if not security_result.get("is_safe", False):
return {
"content": [
{
"type": "text",
"text": json.dumps({
"success": False,
"error": "SQL security check failed",
"message": "Query contains unsafe operations and cannot be executed",
"security_issues": security_result.get("security_issues", [])
}, ensure_ascii=False)
}
]
}
# Import database connection tool
from doris_mcp_server.utils.db import execute_query
if not sql:
return {
"content": [
{
"type": "text",
"text": json.dumps({
"success": False,
"error": "Missing SQL parameter",
"message": "Please provide the SQL query to execute"
}, ensure_ascii=False)
}
]
}
# Ensure SELECT statements include a LIMIT clause
sql_lower = sql.lower().strip()
if sql_lower.startswith("select") and "limit" not in sql_lower:
sql = sql.rstrip(";") + f" LIMIT {max_rows};"
# Start timer
start_time = time.time()
# Execute query
try:
result = execute_query(sql, db_name)
# Calculate execution time
execution_time = time.time() - start_time
# Build return result
if isinstance(result, list):
# Handle list of query results
row_count = len(result)
# Extract column names
if hasattr(result[0], "_fields"):
# If it's a named tuple
columns = list(result[0]._fields)
else:
# Otherwise, assume it's a dictionary
columns = list(result[0].keys()) if isinstance(result[0], dict) else []
# Convert results to serializable format
data = []
for row in result:
row_dict = {}
if hasattr(row, "_asdict"):
# If it's a named tuple
row_dict = row._asdict()
elif isinstance(row, dict):
# If it's a dictionary
row_dict = row
else:
# If it's a list or tuple
row_dict = dict(zip(columns, row)) if columns else row
# Handle special types to make them JSON serializable
serialized_row = _serialize_row_data(row_dict)
data.append(serialized_row)
return {
"content": [
{
"type": "text",
"text": json.dumps({
"success": True,
"sql": sql,
"row_count": row_count,
"columns": columns,
"data": data[:max_rows], # Limit returned rows
"execution_time": execution_time,
"truncated": row_count > max_rows
}, ensure_ascii=False)
}
]
}
else:
# Handle other types of results
other_response = {
"success": True,
"sql": sql,
"result": str(result),
"execution_time": execution_time
}
other_response = _serialize_row_data(other_response)
return {
"content": [
{
"type": "text",
"text": json.dumps(other_response, ensure_ascii=False)
}
]
}
except Exception as db_error:
error_message = str(db_error)
# Try to get more detailed error information
error_details = {}
if "timeout" in error_message.lower():
error_details["type"] = "timeout"
error_details["suggestion"] = "Query timed out, please optimize SQL or increase timeout"
elif "syntax" in error_message.lower():
error_details["type"] = "syntax"
error_details["suggestion"] = "SQL syntax error, please check syntax"
elif "not found" in error_message.lower() or "doesn't exist" in error_message.lower():
error_details["type"] = "not_found"
error_details["suggestion"] = "Table or column not found, please check table and column names"
else:
error_details["type"] = "unknown"
error_details["suggestion"] = "Please check the SQL statement and try simplifying the query"
# Create error response
error_response = {
"success": False,
"error": error_message,
"error_details": error_details,
"sql": sql,
"db_name": db_name
}
# Ensure error response is also serializable
error_response = _serialize_row_data(error_response)
return {
"content": [
{
"type": "text",
"text": json.dumps(error_response, ensure_ascii=False)
}
]
}
except Exception as e:
logger.error(f"Failed to execute SQL query: {str(e)}")
logger.error(traceback.format_exc())
error_response = {
"success": False,
"error": str(e),
"message": "Error occurred while executing SQL query"
}
# Ensure error response is also serializable
error_response = _serialize_row_data(error_response)
return {
"content": [
{
"type": "text",
"text": json.dumps(error_response, ensure_ascii=False)
}
]
}
# Helper function
async def _check_sql_security(sql: str) -> Dict[str, Any]:
"""Check SQL security"""
# If environment variable is set to disable security check, return safe immediately
if not ENABLE_SQL_SECURITY_CHECK:
return {
"is_safe": True,
"security_issues": []
}
# Check if SQL contains dangerous operations
sql_lower = sql.lower()
# Check if it's a read-only query type
is_read_only = sql_lower.strip().startswith(("select ", "show ", "desc ", "describe ", "explain "))
# Define list of dangerous operations (checked for both read-only and non-read-only queries)
dangerous_operations = [
(r'\bdelete\b', "DELETE operation"),
(r'\bdrop\b', "DROP TABLE/DATABASE operation"),
(r'\btruncate\b', "TRUNCATE TABLE operation"),
(r'\bupdate\b', "UPDATE operation"),
(r'\binsert\b', "INSERT operation"),
(r'\balter\b', "ALTER TABLE structure operation"),
(r'\bcreate\b', "CREATE TABLE/DATABASE operation"),
(r'\bgrant\b', "GRANT operation"),
(r'\brevoke\b', "REVOKE permission operation"),
(r'\bexec\b', "EXECUTE stored procedure"),
(r'\bxp_', "Extended stored procedure, potential security risk"),
(r'\bshutdown\b', "SHUTDOWN database operation"),
(r'\bunion\s+all\s+select\b', "UNION statement, potential SQL injection"),
(r'\bunion\s+select\b', "UNION statement, potential SQL injection"),
(r'\binto\s+outfile\b', "Write to file operation"),
(r'\bload_file\b', "Load file operation")
]
# Dangerous operations checked only for non-read-only queries
non_readonly_operations = []
if not is_read_only:
non_readonly_operations = [
(r'--', "SQL comment, potential SQL injection"),
(r'/\*', "SQL block comment, potential SQL injection")
]
# Check if dangerous operations are included
security_issues = []
# Check dangerous operations applicable to all queries
for operation, description in dangerous_operations:
if re.search(operation, sql_lower):
# For specific keywords in read-only queries, differentiate if used as independent operations
if is_read_only and operation in [r'\bcreate\b', r'\bdrop\b', r'\bdelete\b', r'\binsert\b', r'\bupdate\b', r'\balter\b']:
# Check if used as DDL/DML keyword, e.g., CREATE TABLE, DROP DATABASE
pattern = operation + r'\s+(?:table|database|view|index|procedure|function|trigger|event)'
if re.search(pattern, sql_lower):
security_issues.append({
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
"description": description,
"severity": "High"
})
else:
security_issues.append({
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
"description": description,
"severity": "High"
})
# Check dangerous operations specific to non-read-only queries
for operation, description in non_readonly_operations:
if re.search(operation, sql_lower):
security_issues.append({
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
"description": description,
"severity": "Medium"
})
return {
"is_safe": len(security_issues) == 0,
"security_issues": security_issues
}
def _serialize_row_data(row_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert special types in row data (like date, time, Decimal) to JSON serializable format
Args:
row_data: Row data dictionary
Returns:
Dict[str, Any]: Processed serializable dictionary
"""
serialized_data = {}
for key, value in row_data.items():
if value is None:
serialized_data[key] = None
elif isinstance(value, (datetime.date, datetime.datetime)):
# Convert date and time types to ISO format string
serialized_data[key] = value.isoformat()
elif isinstance(value, Decimal):
# Convert Decimal type to float
serialized_data[key] = float(value)
elif isinstance(value, (list, tuple)):
# Recursively process elements in list or tuple
serialized_data[key] = [
_serialize_row_data(item) if isinstance(item, dict) else item
for item in value
]
elif isinstance(value, dict):
# Recursively process nested dictionaries
serialized_data[key] = _serialize_row_data(value)
else:
serialized_data[key] = value
return serialized_data