0.3.0 Release Version

This commit is contained in:
FreeOnePlus
2025-06-08 18:44:40 +08:00
parent d9fed06c92
commit 4c913743c7
54 changed files with 12649 additions and 4667 deletions

View File

@@ -1 +1,10 @@
# Mark directory as a package
"""
Utilities Package - Contains utility classes and helper functions.
This package includes:
- Database connection and operations
- Configuration management
- Security utilities
- Query execution helpers
- Logging configuration
"""

View File

@@ -0,0 +1,318 @@
"""
Data Analysis Tools Module
Provides data analysis functions including table analysis, column statistics, performance monitoring, etc.
"""
import time
from datetime import datetime
from typing import Any, Dict, List
from .db import DorisConnectionManager
from .logger import get_logger
logger = get_logger(__name__)
class TableAnalyzer:
"""Table analyzer"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
async def get_table_summary(
self,
table_name: str,
include_sample: bool = True,
sample_size: int = 10
) -> Dict[str, Any]:
"""Get table summary information"""
connection = await self.connection_manager.get_connection("query")
# Get table basic information
table_info_sql = f"""
SELECT
table_name,
table_comment,
table_rows,
create_time,
engine
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = '{table_name}'
"""
table_info_result = await connection.execute(table_info_sql)
if not table_info_result.data:
raise ValueError(f"Table {table_name} does not exist")
table_info = table_info_result.data[0]
# Get column information
columns_sql = f"""
SELECT
column_name,
data_type,
is_nullable,
column_comment
FROM information_schema.columns
WHERE table_schema = DATABASE()
AND table_name = '{table_name}'
ORDER BY ordinal_position
"""
columns_result = await connection.execute(columns_sql)
summary = {
"table_name": table_info["table_name"],
"comment": table_info.get("table_comment"),
"row_count": table_info.get("table_rows", 0),
"create_time": str(table_info.get("create_time")),
"engine": table_info.get("engine"),
"column_count": len(columns_result.data),
"columns": columns_result.data,
}
# Get sample data
if include_sample and sample_size > 0:
sample_sql = f"SELECT * FROM {table_name} LIMIT {sample_size}"
sample_result = await connection.execute(sample_sql)
summary["sample_data"] = sample_result.data
return summary
async def analyze_column(
self,
table_name: str,
column_name: str,
analysis_type: str = "basic"
) -> Dict[str, Any]:
"""Analyze column statistics"""
try:
connection = await self.connection_manager.get_connection("query")
# Basic statistics
basic_stats_sql = f"""
SELECT
'{column_name}' as column_name,
COUNT(*) as total_count,
COUNT({column_name}) as non_null_count,
COUNT(DISTINCT {column_name}) as distinct_count
FROM {table_name}
"""
basic_result = await connection.execute(basic_stats_sql)
if not basic_result.data:
return {
"success": False,
"error": f"Unable to get statistics for table {table_name} column {column_name}"
}
analysis = basic_result.data[0].copy()
analysis["success"] = True
analysis["analysis_type"] = analysis_type
if analysis_type in ["distribution", "detailed"]:
# Data distribution analysis
distribution_sql = f"""
SELECT
{column_name} as value,
COUNT(*) as frequency
FROM {table_name}
WHERE {column_name} IS NOT NULL
GROUP BY {column_name}
ORDER BY frequency DESC
LIMIT 20
"""
distribution_result = await connection.execute(distribution_sql)
analysis["value_distribution"] = distribution_result.data
if analysis_type == "detailed":
# Detailed statistics (for numeric types)
try:
numeric_stats_sql = f"""
SELECT
MIN({column_name}) as min_value,
MAX({column_name}) as max_value,
AVG({column_name}) as avg_value
FROM {table_name}
WHERE {column_name} IS NOT NULL
"""
numeric_result = await connection.execute(numeric_stats_sql)
if numeric_result.data:
analysis.update(numeric_result.data[0])
except Exception:
# Non-numeric columns don't support numeric statistics
pass
return analysis
except Exception as e:
logger.error(f"Column analysis failed: {e}")
return {
"success": False,
"error": str(e),
"column_name": column_name,
"table_name": table_name
}
async def analyze_table_relationships(
self,
table_name: str,
depth: int = 2
) -> Dict[str, Any]:
"""Analyze table relationships"""
connection = await self.connection_manager.get_connection("system")
# Get table basic information
table_info_sql = f"""
SELECT
table_name,
table_comment,
table_rows
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = '{table_name}'
"""
table_result = await connection.execute(table_info_sql)
if not table_result.data:
raise ValueError(f"Table {table_name} does not exist")
# Get all tables list (for analyzing potential relationships)
all_tables_sql = """
SELECT
table_name,
table_comment
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_type = 'BASE TABLE'
AND table_name != %s
"""
all_tables_result = await connection.execute(all_tables_sql, (table_name,))
return {
"center_table": table_result.data[0],
"related_tables": all_tables_result.data,
"depth": depth,
"note": "Table relationship analysis based on column name similarity and business logic inference",
}
class PerformanceMonitor:
"""Performance monitor"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
async def get_performance_stats(
self,
metric_type: str = "queries",
time_range: str = "1h"
) -> Dict[str, Any]:
"""Get performance statistics"""
connection = await self.connection_manager.get_connection("system")
# Convert time range to seconds
time_mapping = {
"1h": 3600,
"6h": 21600,
"24h": 86400,
"7d": 604800
}
seconds = time_mapping.get(time_range, 3600)
if metric_type == "queries":
# Query performance metrics
stats = {
"metric_type": "queries",
"time_range": time_range,
"timestamp": datetime.now().isoformat(),
"total_queries": 0,
"avg_execution_time": 0.0,
"slow_queries": 0,
"error_queries": 0,
"note": "Query performance statistics (simulated data)"
}
elif metric_type == "connections":
# Connection statistics
connection_metrics = await self.connection_manager.get_metrics()
stats = {
"metric_type": "connections",
"time_range": time_range,
"timestamp": datetime.now().isoformat(),
"total_connections": connection_metrics.total_connections,
"active_connections": connection_metrics.active_connections,
"idle_connections": connection_metrics.idle_connections,
"failed_connections": connection_metrics.failed_connections,
"connection_errors": connection_metrics.connection_errors,
"avg_connection_time": connection_metrics.avg_connection_time,
"last_health_check": connection_metrics.last_health_check.isoformat() if connection_metrics.last_health_check else None
}
elif metric_type == "tables":
# Table-level statistics
tables_sql = """
SELECT
table_name,
table_rows,
data_length,
index_length,
create_time,
update_time
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_type = 'BASE TABLE'
ORDER BY table_rows DESC
LIMIT 20
"""
tables_result = await connection.execute(tables_sql)
stats = {
"metric_type": "tables",
"time_range": time_range,
"timestamp": datetime.now().isoformat(),
"table_count": len(tables_result.data),
"tables": tables_result.data
}
elif metric_type == "system":
# System-level metrics (simulated)
stats = {
"metric_type": "system",
"time_range": time_range,
"timestamp": datetime.now().isoformat(),
"cpu_usage": 45.2,
"memory_usage": 68.5,
"disk_usage": 72.1,
"network_io": {
"bytes_sent": 1024000,
"bytes_received": 2048000
},
"note": "System metrics (simulated data)"
}
else:
raise ValueError(f"Unsupported metric type: {metric_type}")
return stats
async def get_query_history(
self,
limit: int = 50,
order_by: str = "time"
) -> Dict[str, Any]:
"""Get query history"""
# Since Doris doesn't have a built-in query history table,
# we return simulated data
return {
"total_queries": 0,
"queries": [],
"limit": limit,
"order_by": order_by,
"note": "Query history feature requires audit log configuration"
}

View File

@@ -0,0 +1,608 @@
#!/usr/bin/env python3
"""
Doris Configuration Management Module
Implements configuration loading, validation and management functionality
"""
import json
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
try:
from dotenv import load_dotenv
except ImportError:
load_dotenv = None
@dataclass
class DatabaseConfig:
"""Database connection configuration"""
host: str = "localhost"
port: int = 9030
user: str = "root"
password: str = ""
database: str = "test"
charset: str = "utf8mb4"
# Connection pool configuration
min_connections: int = 5
max_connections: int = 20
connection_timeout: int = 30
health_check_interval: int = 60
max_connection_age: int = 3600
@dataclass
class SecurityConfig:
"""Security configuration"""
# Authentication configuration
auth_type: str = "token" # token, basic, oauth
token_secret: str = "default_secret"
token_expiry: int = 3600
# SQL security configuration
blocked_keywords: list[str] = field(
default_factory=lambda: [
"DROP",
"DELETE",
"TRUNCATE",
"ALTER",
"CREATE",
"INSERT",
"UPDATE",
"GRANT",
"REVOKE",
]
)
max_query_complexity: int = 100
max_result_rows: int = 10000
# Sensitive table configuration
sensitive_tables: dict[str, str] = field(default_factory=dict)
# Data masking configuration
enable_masking: bool = True
masking_rules: list[dict[str, Any]] = field(default_factory=list)
@dataclass
class PerformanceConfig:
"""Performance configuration"""
# Query cache configuration
enable_query_cache: bool = True
cache_ttl: int = 300
max_cache_size: int = 1000
# Concurrency control configuration
max_concurrent_queries: int = 50
query_timeout: int = 300
# Connection pool optimization configuration
connection_pool_size: int = 20
idle_timeout: int = 1800
@dataclass
class LoggingConfig:
"""Logging configuration"""
level: str = "INFO"
format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
file_path: str | None = None
max_file_size: int = 10 * 1024 * 1024 # 10MB
backup_count: int = 5
# Audit log configuration
enable_audit: bool = True
audit_file_path: str | None = None
@dataclass
class MonitoringConfig:
"""Monitoring configuration"""
# Metrics collection configuration
enable_metrics: bool = True
metrics_port: int = 8081
metrics_path: str = "/metrics"
# Health check configuration
health_check_port: int = 8082
health_check_path: str = "/health"
# Alert configuration
enable_alerts: bool = False
alert_webhook_url: str | None = None
@dataclass
class DorisConfig:
"""Doris MCP Server complete configuration"""
# Basic configuration
server_name: str = "doris-mcp-server"
server_version: str = "1.0.0"
server_port: int = 8080
# Sub-configuration modules
database: DatabaseConfig = field(default_factory=DatabaseConfig)
security: SecurityConfig = field(default_factory=SecurityConfig)
performance: PerformanceConfig = field(default_factory=PerformanceConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
monitoring: MonitoringConfig = field(default_factory=MonitoringConfig)
# Custom configuration
custom_config: dict[str, Any] = field(default_factory=dict)
@classmethod
def from_file(cls, config_path: str) -> "DorisConfig":
"""Load configuration from file"""
config_file = Path(config_path)
if not config_file.exists():
raise FileNotFoundError(f"Configuration file does not exist: {config_path}")
try:
with open(config_file, encoding="utf-8") as f:
if config_file.suffix.lower() == ".json":
config_data = json.load(f)
else:
# Support other formats (like YAML)
raise ValueError(f"Unsupported configuration file format: {config_file.suffix}")
return cls._from_dict(config_data)
except Exception as e:
raise ValueError(f"Failed to load configuration file: {e}")
@classmethod
def from_env(cls, env_file: str | None = None) -> "DorisConfig":
"""Load configuration from environment variables
Args:
env_file: .env file path, if None, search in the following order:
.env, .env.local, .env.production, .env.development
"""
# Load .env file
if load_dotenv is not None:
if env_file:
# Load specified .env file
if Path(env_file).exists():
load_dotenv(env_file)
logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_file}")
else:
logging.getLogger(__name__).warning(f"Environment configuration file does not exist: {env_file}")
else:
# Load .env files in priority order
env_files = [".env", ".env.local", ".env.production", ".env.development"]
for env_path in env_files:
if Path(env_path).exists():
load_dotenv(env_path)
logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_path}")
break
else:
logging.getLogger(__name__).info("No .env configuration file found, using system environment variables")
else:
logging.getLogger(__name__).warning("python-dotenv not installed, cannot load .env files")
config = cls()
# Database configuration
config.database.host = os.getenv("DORIS_HOST", config.database.host)
config.database.port = int(os.getenv("DORIS_PORT", str(config.database.port)))
config.database.user = os.getenv("DORIS_USER", config.database.user)
config.database.password = os.getenv("DORIS_PASSWORD", config.database.password)
config.database.database = os.getenv("DORIS_DATABASE", config.database.database)
# Connection pool configuration
config.database.min_connections = int(
os.getenv("DORIS_MIN_CONNECTIONS", str(config.database.min_connections))
)
config.database.max_connections = int(
os.getenv("DORIS_MAX_CONNECTIONS", str(config.database.max_connections))
)
config.database.connection_timeout = int(
os.getenv("DORIS_CONNECTION_TIMEOUT", str(config.database.connection_timeout))
)
config.database.health_check_interval = int(
os.getenv("DORIS_HEALTH_CHECK_INTERVAL", str(config.database.health_check_interval))
)
config.database.max_connection_age = int(
os.getenv("DORIS_MAX_CONNECTION_AGE", str(config.database.max_connection_age))
)
# Security configuration
config.security.auth_type = os.getenv("AUTH_TYPE", config.security.auth_type)
config.security.token_secret = os.getenv("TOKEN_SECRET", config.security.token_secret)
config.security.token_expiry = int(
os.getenv("TOKEN_EXPIRY", str(config.security.token_expiry))
)
config.security.max_result_rows = int(
os.getenv("MAX_RESULT_ROWS", str(config.security.max_result_rows))
)
config.security.max_query_complexity = int(
os.getenv("MAX_QUERY_COMPLEXITY", str(config.security.max_query_complexity))
)
config.security.enable_masking = (
os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true"
)
# Performance configuration
config.performance.enable_query_cache = (
os.getenv("ENABLE_QUERY_CACHE", "true").lower() == "true"
)
config.performance.cache_ttl = int(
os.getenv("CACHE_TTL", str(config.performance.cache_ttl))
)
config.performance.max_cache_size = int(
os.getenv("MAX_CACHE_SIZE", str(config.performance.max_cache_size))
)
config.performance.max_concurrent_queries = int(
os.getenv("MAX_CONCURRENT_QUERIES", str(config.performance.max_concurrent_queries))
)
config.performance.query_timeout = int(
os.getenv("QUERY_TIMEOUT", str(config.performance.query_timeout))
)
# Logging configuration
config.logging.level = os.getenv("LOG_LEVEL", config.logging.level)
config.logging.file_path = os.getenv("LOG_FILE_PATH", config.logging.file_path)
config.logging.enable_audit = (
os.getenv("ENABLE_AUDIT", str(config.logging.enable_audit).lower()).lower() == "true"
)
config.logging.audit_file_path = os.getenv("AUDIT_FILE_PATH", config.logging.audit_file_path)
# Monitoring configuration
config.monitoring.enable_metrics = (
os.getenv("ENABLE_METRICS", "true").lower() == "true"
)
config.monitoring.metrics_port = int(
os.getenv("METRICS_PORT", str(config.monitoring.metrics_port))
)
config.monitoring.health_check_port = int(
os.getenv("HEALTH_CHECK_PORT", str(config.monitoring.health_check_port))
)
config.monitoring.enable_alerts = (
os.getenv("ENABLE_ALERTS", str(config.monitoring.enable_alerts).lower()).lower() == "true"
)
config.monitoring.alert_webhook_url = os.getenv("ALERT_WEBHOOK_URL", config.monitoring.alert_webhook_url)
# Server configuration
config.server_name = os.getenv("SERVER_NAME", config.server_name)
config.server_version = os.getenv("SERVER_VERSION", config.server_version)
config.server_port = int(os.getenv("SERVER_PORT", str(config.server_port)))
return config
@classmethod
def _from_dict(cls, config_data: dict[str, Any]) -> "DorisConfig":
"""Create configuration object from dictionary"""
config = cls()
# Update basic configuration
for key in ["server_name", "server_version", "server_port"]:
if key in config_data:
setattr(config, key, config_data[key])
# Update database configuration
if "database" in config_data:
db_config = config_data["database"]
for key, value in db_config.items():
if hasattr(config.database, key):
setattr(config.database, key, value)
# Update security configuration
if "security" in config_data:
sec_config = config_data["security"]
for key, value in sec_config.items():
if hasattr(config.security, key):
setattr(config.security, key, value)
# Update performance configuration
if "performance" in config_data:
perf_config = config_data["performance"]
for key, value in perf_config.items():
if hasattr(config.performance, key):
setattr(config.performance, key, value)
# Update logging configuration
if "logging" in config_data:
log_config = config_data["logging"]
for key, value in log_config.items():
if hasattr(config.logging, key):
setattr(config.logging, key, value)
# Update monitoring configuration
if "monitoring" in config_data:
mon_config = config_data["monitoring"]
for key, value in mon_config.items():
if hasattr(config.monitoring, key):
setattr(config.monitoring, key, value)
# Custom configuration
config.custom_config = config_data.get("custom", {})
return config
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary format"""
return {
"server_name": self.server_name,
"server_version": self.server_version,
"server_port": self.server_port,
"database": {
"host": self.database.host,
"port": self.database.port,
"user": self.database.user,
"password": "***", # Hide password
"database": self.database.database,
"charset": self.database.charset,
"min_connections": self.database.min_connections,
"max_connections": self.database.max_connections,
"connection_timeout": self.database.connection_timeout,
"health_check_interval": self.database.health_check_interval,
"max_connection_age": self.database.max_connection_age,
},
"security": {
"auth_type": self.security.auth_type,
"token_secret": "***", # Hide secret key
"token_expiry": self.security.token_expiry,
"blocked_keywords": self.security.blocked_keywords,
"max_query_complexity": self.security.max_query_complexity,
"max_result_rows": self.security.max_result_rows,
"sensitive_tables": self.security.sensitive_tables,
"enable_masking": self.security.enable_masking,
"masking_rules": len(self.security.masking_rules),
},
"performance": {
"enable_query_cache": self.performance.enable_query_cache,
"cache_ttl": self.performance.cache_ttl,
"max_cache_size": self.performance.max_cache_size,
"max_concurrent_queries": self.performance.max_concurrent_queries,
"query_timeout": self.performance.query_timeout,
"connection_pool_size": self.performance.connection_pool_size,
"idle_timeout": self.performance.idle_timeout,
},
"logging": {
"level": self.logging.level,
"format": self.logging.format,
"file_path": self.logging.file_path,
"max_file_size": self.logging.max_file_size,
"backup_count": self.logging.backup_count,
"enable_audit": self.logging.enable_audit,
"audit_file_path": self.logging.audit_file_path,
},
"monitoring": {
"enable_metrics": self.monitoring.enable_metrics,
"metrics_port": self.monitoring.metrics_port,
"metrics_path": self.monitoring.metrics_path,
"health_check_port": self.monitoring.health_check_port,
"health_check_path": self.monitoring.health_check_path,
"enable_alerts": self.monitoring.enable_alerts,
"alert_webhook_url": self.monitoring.alert_webhook_url,
},
"custom": self.custom_config,
}
def save_to_file(self, config_path: str):
"""Save configuration to file"""
config_file = Path(config_path)
config_file.parent.mkdir(parents=True, exist_ok=True)
try:
with open(config_file, "w", encoding="utf-8") as f:
if config_file.suffix.lower() == ".json":
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
else:
raise ValueError(f"Unsupported configuration file format: {config_file.suffix}")
except Exception as e:
raise ValueError(f"Failed to save configuration file: {e}")
def validate(self) -> list[str]:
"""Validate configuration validity"""
errors = []
# Validate database configuration
if not self.database.host:
errors.append("Database host address cannot be empty")
if not (1 <= self.database.port <= 65535):
errors.append("Database port must be in the range 1-65535")
if not self.database.user:
errors.append("Database username cannot be empty")
if self.database.min_connections <= 0:
errors.append("Minimum connections must be greater than 0")
if self.database.max_connections <= self.database.min_connections:
errors.append("Maximum connections must be greater than minimum connections")
# Validate security configuration
if self.security.auth_type not in ["token", "basic", "oauth"]:
errors.append("Authentication type must be one of token, basic, or oauth")
if self.security.token_expiry <= 0:
errors.append("Token expiry time must be greater than 0")
if self.security.max_query_complexity <= 0:
errors.append("Maximum query complexity must be greater than 0")
if self.security.max_result_rows <= 0:
errors.append("Maximum result rows must be greater than 0")
# Validate performance configuration
if self.performance.cache_ttl <= 0:
errors.append("Cache TTL must be greater than 0")
if self.performance.max_concurrent_queries <= 0:
errors.append("Maximum concurrent queries must be greater than 0")
if self.performance.query_timeout <= 0:
errors.append("Query timeout must be greater than 0")
# Validate logging configuration
if self.logging.level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
errors.append("Log level must be one of DEBUG, INFO, WARNING, ERROR, or CRITICAL")
if self.logging.max_file_size <= 0:
errors.append("Maximum log file size must be greater than 0")
if self.logging.backup_count < 0:
errors.append("Log backup count cannot be negative")
# Validate monitoring configuration
if not (1 <= self.monitoring.metrics_port <= 65535):
errors.append("Monitoring port must be in the range 1-65535")
if not (1 <= self.monitoring.health_check_port <= 65535):
errors.append("Health check port must be in the range 1-65535")
return errors
def get_connection_string(self) -> str:
"""Get database connection string (hide password)"""
return f"mysql://{self.database.user}:***@{self.database.host}:{self.database.port}/{self.database.database}"
def get_config_summary(self) -> dict[str, Any]:
"""Get configuration summary information"""
return {
"server": f"{self.server_name} v{self.server_version}",
"database": f"{self.database.host}:{self.database.port}/{self.database.database}",
"connection_pool": f"{self.database.min_connections}-{self.database.max_connections}",
"security": {
"auth_type": self.security.auth_type,
"masking_enabled": self.security.enable_masking,
"blocked_keywords_count": len(self.security.blocked_keywords),
},
"performance": {
"cache_enabled": self.performance.enable_query_cache,
"max_concurrent": self.performance.max_concurrent_queries,
"query_timeout": self.performance.query_timeout,
},
"monitoring": {
"metrics_enabled": self.monitoring.enable_metrics,
"alerts_enabled": self.monitoring.enable_alerts,
},
}
class ConfigManager:
"""Configuration manager class"""
def __init__(self, config: DorisConfig):
self.config = config
self.logger = logging.getLogger(__name__)
def setup_logging(self):
"""Setup logging configuration"""
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(getattr(logging, self.config.logging.level.upper()))
# Clear existing handlers
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
# Create formatter
formatter = logging.Formatter(self.config.logging.format)
# Console handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
root_logger.addHandler(console_handler)
# File handler (if configured)
if self.config.logging.file_path:
try:
from logging.handlers import RotatingFileHandler
file_handler = RotatingFileHandler(
self.config.logging.file_path,
maxBytes=self.config.logging.max_file_size,
backupCount=self.config.logging.backup_count,
encoding="utf-8",
)
file_handler.setFormatter(formatter)
root_logger.addHandler(file_handler)
except Exception as e:
self.logger.warning(f"Failed to setup file logging: {e}")
# Audit log handler (if configured)
if self.config.logging.enable_audit and self.config.logging.audit_file_path:
try:
from logging.handlers import RotatingFileHandler
audit_logger = logging.getLogger("audit")
audit_handler = RotatingFileHandler(
self.config.logging.audit_file_path,
maxBytes=self.config.logging.max_file_size,
backupCount=self.config.logging.backup_count,
encoding="utf-8",
)
audit_handler.setFormatter(formatter)
audit_logger.addHandler(audit_handler)
audit_logger.setLevel(logging.INFO)
except Exception as e:
self.logger.warning(f"Failed to setup audit logging: {e}")
def validate_config(self) -> bool:
"""Validate configuration"""
errors = self.config.validate()
if errors:
self.logger.error("Configuration validation failed:")
for error in errors:
self.logger.error(f" - {error}")
return False
self.logger.info("Configuration validation passed")
return True
def log_config_summary(self):
"""Log configuration summary"""
summary = self.config.get_config_summary()
self.logger.info("Configuration Summary:")
self.logger.info(f" Server: {summary['server']}")
self.logger.info(f" Database: {summary['database']}")
self.logger.info(f" Connection Pool: {summary['connection_pool']}")
self.logger.info(f" Security: {summary['security']}")
self.logger.info(f" Performance: {summary['performance']}")
self.logger.info(f" Monitoring: {summary['monitoring']}")
def create_default_config_file(config_path: str):
"""Create default configuration file"""
config = DorisConfig()
config.save_to_file(config_path)
print(f"Default configuration file created: {config_path}")
# Example usage
if __name__ == "__main__":
# Create default configuration
config = DorisConfig()
# Load from environment variables
# config = DorisConfig.from_env()
# Load from file
# config = DorisConfig.from_file("config.json")
# Validate configuration
config_manager = ConfigManager(config)
if config_manager.validate_config():
config_manager.setup_logging()
config_manager.log_config_summary()
# Save configuration
config.save_to_file("example_config.json")
print("Configuration saved to example_config.json")
else:
print("Configuration validation failed")

View File

@@ -1,100 +1,479 @@
import os
import json
import pymysql
import pandas as pd
from typing import Dict, List, Optional, Any
from dotenv import load_dotenv
import re
#!/usr/bin/env python3
"""
Apache Doris Database Connection Management Module
# Load environment variables
load_dotenv(override=True)
Provides high-performance database connection pool management, automatic reconnection mechanism and connection health check functionality
Supports asynchronous operations and concurrent connection management, ensuring stability and performance for enterprise applications
"""
# Database configuration
DB_CONFIG = {
"host": os.getenv("DB_HOST", "localhost"),
"port": int(os.getenv("DB_PORT", "9030")),
"user": os.getenv("DB_USER", "root"),
"password": os.getenv("DB_PASSWORD", ""),
"database": os.getenv("DB_DATABASE", ""),
"charset": "utf8mb4",
"cursorclass": pymysql.cursors.DictCursor
}
import asyncio
import logging
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List
def get_db_connection(db_name: Optional[str] = None):
import aiomysql
from aiomysql import Connection, Pool
@dataclass
class ConnectionMetrics:
"""Connection pool performance metrics"""
total_connections: int = 0
active_connections: int = 0
idle_connections: int = 0
failed_connections: int = 0
connection_errors: int = 0
avg_connection_time: float = 0.0
last_health_check: datetime | None = None
@dataclass
class QueryResult:
"""Query result wrapper"""
data: list[dict[str, Any]]
metadata: dict[str, Any]
execution_time: float
row_count: int
class DorisConnection:
"""Doris database connection wrapper class"""
def __init__(self, connection: Connection, session_id: str, security_manager=None):
self.connection = connection
self.session_id = session_id
self.created_at = datetime.utcnow()
self.last_used = datetime.utcnow()
self.query_count = 0
self.is_healthy = True
self.security_manager = security_manager
async def execute(self, sql: str, params: tuple | None = None, auth_context=None) -> QueryResult:
"""Execute SQL query"""
start_time = time.time()
try:
# If security manager exists, perform SQL security check
security_result = None
if self.security_manager and auth_context:
validation_result = await self.security_manager.validate_sql_security(sql, auth_context)
if not validation_result.is_valid:
raise ValueError(f"SQL security validation failed: {validation_result.error_message}")
security_result = {
"is_valid": validation_result.is_valid,
"risk_level": validation_result.risk_level,
"blocked_operations": validation_result.blocked_operations
}
async with self.connection.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(sql, params)
# Check if it's a query statement (statement that returns result set)
sql_upper = sql.strip().upper()
if (sql_upper.startswith("SELECT") or
sql_upper.startswith("SHOW") or
sql_upper.startswith("DESCRIBE") or
sql_upper.startswith("DESC") or
sql_upper.startswith("EXPLAIN")):
data = await cursor.fetchall()
row_count = len(data)
else:
data = []
row_count = cursor.rowcount
execution_time = time.time() - start_time
self.last_used = datetime.utcnow()
self.query_count += 1
# Get column information
columns = []
if cursor.description:
columns = [desc[0] for desc in cursor.description]
# If security manager exists and has auth context, apply data masking
final_data = list(data) if data else []
if self.security_manager and auth_context and final_data:
final_data = await self.security_manager.apply_data_masking(final_data, auth_context)
metadata = {"columns": columns, "query": sql, "params": params}
if security_result:
metadata["security_check"] = security_result
return QueryResult(
data=final_data,
metadata=metadata,
execution_time=execution_time,
row_count=row_count,
)
except Exception as e:
self.is_healthy = False
logging.error(f"Query execution failed: {e}")
raise
async def ping(self) -> bool:
"""Check connection health status"""
try:
await self.connection.ping()
self.is_healthy = True
return True
except Exception:
self.is_healthy = False
return False
async def close(self):
"""Close connection"""
try:
if self.connection and not self.connection.closed:
await self.connection.ensure_closed()
except Exception as e:
logging.error(f"Error occurred while closing connection: {e}")
class DorisConnectionManager:
"""Doris database connection manager
Provides connection pool management, connection health monitoring, fault recovery and other functions
Supports session-level connection reuse and intelligent load balancing
Integrates security manager to provide unified security validation and data masking
"""
Get database connection
Args:
db_name: Specify the database name to connect to, use default config if None
Returns:
Database connection
"""
if db_name:
# Use default config but override database name
config = DB_CONFIG.copy()
config["database"] = db_name
return pymysql.connect(**config)
else:
# Use default config
return pymysql.connect(**DB_CONFIG)
def get_db_name() -> str:
"""Get the currently configured default database name"""
return DB_CONFIG["database"] or os.getenv("DB_DATABASE", "")
def __init__(self, config, security_manager=None):
self.config = config
self.pool: Pool | None = None
self.session_connections: dict[str, DorisConnection] = {}
self.metrics = ConnectionMetrics()
self.logger = logging.getLogger(__name__)
self.security_manager = security_manager
def execute_query(sql, db_name: Optional[str] = None):
"""
Execute SQL query and return results
Args:
sql: SQL query statement
db_name: Specify the database name to connect to, use default config if None
Returns:
Query results
"""
conn = get_db_connection(db_name)
try:
with conn.cursor() as cursor:
# Set connection character set to utf8 before executing query
cursor.execute("SET NAMES utf8")
# Health check configuration
self.health_check_interval = config.database.health_check_interval or 60
self.max_connection_age = config.database.max_connection_age or 3600
self.connection_timeout = config.database.connection_timeout or 30
# Start background tasks
self._health_check_task = None
self._cleanup_task = None
async def initialize(self):
"""Initialize connection manager"""
try:
# Create connection pool
self.pool = await aiomysql.create_pool(
host=self.config.database.host,
port=self.config.database.port,
user=self.config.database.user,
password=self.config.database.password,
db=self.config.database.database,
charset="utf8",
minsize=self.config.database.min_connections or 5,
maxsize=self.config.database.max_connections or 20,
autocommit=True,
connect_timeout=self.connection_timeout,
)
self.logger.info(
f"Connection pool initialized successfully, min connections: {self.config.database.min_connections}, "
f"max connections: {self.config.database.max_connections}"
)
# Start background monitoring tasks
self._health_check_task = asyncio.create_task(self._health_check_loop())
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
except Exception as e:
self.logger.error(f"Connection pool initialization failed: {e}")
raise
async def get_connection(self, session_id: str) -> DorisConnection:
"""Get database connection
Supports session-level connection reuse to improve performance and consistency
"""
# Check if there's an existing session connection
if session_id in self.session_connections:
conn = self.session_connections[session_id]
# Check connection health
if await conn.ping():
return conn
else:
# Connection is unhealthy, clean up and create new one
await self._cleanup_session_connection(session_id)
# Create new connection
return await self._create_new_connection(session_id)
async def _create_new_connection(self, session_id: str) -> DorisConnection:
"""Create new database connection"""
try:
if not self.pool:
raise RuntimeError("Connection pool not initialized")
# Get connection from pool
raw_connection = await self.pool.acquire()
# Execute the actual query
cursor.execute(sql)
result = cursor.fetchall()
return result
finally:
conn.close()
def execute_query_df(sql, db_name: Optional[str] = None):
"""
Execute SQL query and return pandas DataFrame
Args:
sql: SQL query statement
db_name: Specify the database name to connect to, use default config if None
Returns:
pandas DataFrame
"""
conn = get_db_connection(db_name)
try:
# Use a temporary cursor to execute the query and get results
with conn.cursor() as cursor:
# Set connection character set to utf8 before executing query
cursor.execute("SET NAMES utf8")
# Create wrapped connection
doris_conn = DorisConnection(raw_connection, session_id, self.security_manager)
# Execute the actual query
cursor.execute(sql)
result = cursor.fetchall()
# Store in session connections
self.session_connections[session_id] = doris_conn
self.metrics.total_connections += 1
self.logger.debug(f"Created new connection for session: {session_id}")
return doris_conn
except Exception as e:
self.metrics.connection_errors += 1
self.logger.error(f"Failed to create connection for session {session_id}: {e}")
raise
async def release_connection(self, session_id: str):
"""Release session connection"""
if session_id in self.session_connections:
await self._cleanup_session_connection(session_id)
async def _cleanup_session_connection(self, session_id: str):
"""Clean up session connection"""
if session_id in self.session_connections:
conn = self.session_connections[session_id]
try:
# Return connection to pool
if self.pool and conn.connection and not conn.connection.closed:
self.pool.release(conn.connection)
# Close connection wrapper
await conn.close()
except Exception as e:
self.logger.error(f"Error cleaning up connection for session {session_id}: {e}")
finally:
# Remove from session connections
del self.session_connections[session_id]
self.logger.debug(f"Cleaned up connection for session: {session_id}")
async def _health_check_loop(self):
"""Background health check loop"""
while True:
try:
await asyncio.sleep(self.health_check_interval)
await self._perform_health_check()
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error(f"Health check error: {e}")
async def _perform_health_check(self):
"""Perform health check"""
try:
unhealthy_sessions = []
for session_id, conn in self.session_connections.items():
if not await conn.ping():
unhealthy_sessions.append(session_id)
# Clean up unhealthy connections
for session_id in unhealthy_sessions:
await self._cleanup_session_connection(session_id)
self.metrics.failed_connections += 1
# Update metrics
await self._update_connection_metrics()
self.metrics.last_health_check = datetime.utcnow()
if unhealthy_sessions:
self.logger.warning(f"Cleaned up {len(unhealthy_sessions)} unhealthy connections")
except Exception as e:
self.logger.error(f"Health check failed: {e}")
async def _cleanup_loop(self):
"""Background cleanup loop"""
while True:
try:
await asyncio.sleep(300) # Run every 5 minutes
await self._cleanup_idle_connections()
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error(f"Cleanup loop error: {e}")
async def _cleanup_idle_connections(self):
"""Clean up idle connections"""
current_time = datetime.utcnow()
idle_sessions = []
# If no results, return empty DataFrame
if not result:
return pd.DataFrame()
# Manually convert dict results to DataFrame
df = pd.DataFrame(result)
return df
finally:
conn.close()
for session_id, conn in self.session_connections.items():
# Check if connection has exceeded maximum age
age = (current_time - conn.created_at).total_seconds()
if age > self.max_connection_age:
idle_sessions.append(session_id)
# Clean up idle connections
for session_id in idle_sessions:
await self._cleanup_session_connection(session_id)
if idle_sessions:
self.logger.info(f"Cleaned up {len(idle_sessions)} idle connections")
async def _update_connection_metrics(self):
"""Update connection metrics"""
self.metrics.active_connections = len(self.session_connections)
if self.pool:
self.metrics.idle_connections = self.pool.freesize
async def get_metrics(self) -> ConnectionMetrics:
"""Get connection metrics"""
await self._update_connection_metrics()
return self.metrics
async def execute_query(
self, session_id: str, sql: str, params: tuple | None = None, auth_context=None
) -> QueryResult:
"""Execute query"""
conn = await self.get_connection(session_id)
return await conn.execute(sql, params, auth_context)
@asynccontextmanager
async def get_connection_context(self, session_id: str):
"""Get connection context manager"""
conn = await self.get_connection(session_id)
try:
yield conn
finally:
# Connection will be reused, no need to close here
pass
async def close(self):
"""Close connection manager"""
try:
# Cancel background tasks
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
# Clean up all session connections
for session_id in list(self.session_connections.keys()):
await self._cleanup_session_connection(session_id)
# Close connection pool
if self.pool:
self.pool.close()
await self.pool.wait_closed()
self.logger.info("Connection manager closed successfully")
except Exception as e:
self.logger.error(f"Error closing connection manager: {e}")
async def test_connection(self) -> bool:
"""Test database connection"""
try:
if not self.pool:
return False
async with self.pool.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute("SELECT 1")
result = await cursor.fetchone()
return result is not None
except Exception as e:
self.logger.error(f"Connection test failed: {e}")
return False
class ConnectionPoolMonitor:
"""Connection pool monitor
Provides detailed monitoring and reporting capabilities for connection pool status
"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
self.logger = logging.getLogger(__name__)
async def get_pool_status(self) -> dict[str, Any]:
"""Get connection pool status"""
metrics = await self.connection_manager.get_metrics()
status = {
"pool_size": self.connection_manager.pool.size if self.connection_manager.pool else 0,
"free_connections": self.connection_manager.pool.freesize if self.connection_manager.pool else 0,
"active_sessions": len(self.connection_manager.session_connections),
"total_connections": metrics.total_connections,
"failed_connections": metrics.failed_connections,
"connection_errors": metrics.connection_errors,
"avg_connection_time": metrics.avg_connection_time,
"last_health_check": metrics.last_health_check.isoformat() if metrics.last_health_check else None,
}
return status
async def get_session_details(self) -> list[dict[str, Any]]:
"""Get session connection details"""
sessions = []
for session_id, conn in self.connection_manager.session_connections.items():
session_info = {
"session_id": session_id,
"created_at": conn.created_at.isoformat(),
"last_used": conn.last_used.isoformat(),
"query_count": conn.query_count,
"is_healthy": conn.is_healthy,
"connection_age": (datetime.utcnow() - conn.created_at).total_seconds(),
}
sessions.append(session_info)
return sessions
async def generate_health_report(self) -> dict[str, Any]:
"""Generate connection health report"""
pool_status = await self.get_pool_status()
session_details = await self.get_session_details()
# Calculate health statistics
healthy_sessions = sum(1 for s in session_details if s["is_healthy"])
total_sessions = len(session_details)
health_ratio = healthy_sessions / total_sessions if total_sessions > 0 else 1.0
report = {
"timestamp": datetime.utcnow().isoformat(),
"pool_status": pool_status,
"session_summary": {
"total_sessions": total_sessions,
"healthy_sessions": healthy_sessions,
"health_ratio": health_ratio,
},
"session_details": session_details,
"recommendations": [],
}
# Add recommendations based on health status
if health_ratio < 0.8:
report["recommendations"].append("Consider checking database connectivity and network stability")
if pool_status["connection_errors"] > 10:
report["recommendations"].append("High connection error rate detected, review connection configuration")
if pool_status["active_sessions"] > pool_status["pool_size"] * 0.9:
report["recommendations"].append("Connection pool utilization is high, consider increasing pool size")
return report

View File

@@ -1,226 +1,85 @@
"""
Unified Logging Configuration Module
Provides unified logging configuration, including:
- General logs: Record all program execution information
- Audit logs: Record JSON data for key operations and processing results
- Error logs: Specifically record program exceptions and errors
Logging configuration for Doris MCP Server.
"""
import os
import sys
import logging
import logging.handlers
import logging.config
import sys
from pathlib import Path
from typing import Dict
from datetime import datetime
from dotenv import load_dotenv
from typing import Any
# Load environment variables
load_dotenv(override=True)
# Get project root directory
PROJECT_ROOT = Path(__file__).parents[2].absolute()
def setup_logging(
level: str = "INFO",
log_file: str | None = None,
log_format: str | None = None,
) -> None:
"""
Setup logging configuration.
# Get log configuration from environment variables
LOG_DIR = os.getenv("LOG_DIR", str(PROJECT_ROOT / "logs"))
LOG_PREFIX = os.getenv("LOG_PREFIX", "doris_mcp")
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
LOG_MAX_DAYS = int(os.getenv("LOG_MAX_DAYS", "30"))
# Whether to output logs to the console (should be disabled when running as a service)
CONSOLE_LOGGING = os.getenv("CONSOLE_LOGGING", "false").lower() == "true"
# Whether stdio transport mode is being used
STDIO_MODE = os.getenv("MCP_TRANSPORT_TYPE", "").lower() == "stdio"
Args:
level: Logging level (DEBUG, INFO, WARNING, ERROR)
log_file: Optional log file path
log_format: Optional custom log format
"""
if log_format is None:
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
def purge_old_logs():
"""Clean up expired log files"""
# --- Only perform cleanup in non-Stdio mode ---
if STDIO_MODE:
return
try:
now = datetime.now()
log_dir = Path(LOG_DIR)
# Check if directory exists and is readable/writable
if not log_dir.is_dir() or not os.access(LOG_DIR, os.W_OK):
if not STDIO_MODE: # Avoid printing to stdout in stdio mode
print(f"Warning: Log directory {LOG_DIR} not accessible, skipping log purge.", file=sys.stderr)
return
# Base configuration
config: dict[str, Any] = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {"format": log_format, "datefmt": "%Y-%m-%d %H:%M:%S"}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": level,
"formatter": "default",
"stream": sys.stdout,
}
},
"root": {"level": level, "handlers": ["console"]},
"loggers": {
"doris_mcp_server": {
"level": level,
"handlers": ["console"],
"propagate": False,
}
},
}
for log_file in log_dir.glob(f"{LOG_PREFIX}*.20*"):
# Parse date
file_name = log_file.name
date_str = None
# Try to find the date part
parts = file_name.split('.')
for part in parts:
if part.startswith('20') and len(part) == 8: # 20YYMMDD format
date_str = part
break
if date_str:
try:
file_date = datetime.strptime(date_str, '%Y%m%d')
days_old = (now - file_date).days
if days_old > LOG_MAX_DAYS:
os.remove(log_file)
if not STDIO_MODE:
print(f"Deleted expired log file: {log_file}")
except (ValueError, OSError) as e:
if not STDIO_MODE:
print(f"Error processing log file {file_name}: {e}", file=sys.stderr)
except Exception as e:
if not STDIO_MODE:
print(f"Error cleaning up logs: {e}", file=sys.stderr)
# Add file handler if log_file is specified
if log_file:
# Ensure log directory exists
log_path = Path(log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
# Force disable console log output if in stdio mode
if STDIO_MODE:
CONSOLE_LOGGING = False
config["handlers"]["file"] = {
"class": "logging.handlers.RotatingFileHandler",
"level": level,
"formatter": "default",
"filename": log_file,
"maxBytes": 10485760, # 10MB
"backupCount": 5,
}
# --- Only create log directory and clean old logs in non-Stdio mode ---
if not STDIO_MODE:
try:
os.makedirs(LOG_DIR, exist_ok=True)
# Clean up expired logs on startup (also moved here, as it only handles file logs)
purge_old_logs()
except OSError as e:
# If directory creation fails (e.g., permission issue), print warning but continue to avoid startup failure
print(f"Warning: Failed to create log directory {LOG_DIR} or purge logs: {e}", file=sys.stderr)
# Add file handler to root and package loggers
config["root"]["handlers"].append("file")
config["loggers"]["doris_mcp_server"]["handlers"].append("file")
# Log file paths (definition still needed, but files might not be created/used)
LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.log")
AUDIT_LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.audit")
ERROR_LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.error")
# Log level mapping
LOG_LEVELS = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL
}
# Log format
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
AUDIT_FORMAT = '%(asctime)s - %(name)s - %(message)s'
ERROR_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(pathname)s:%(lineno)d - %(message)s'
# Dedicated audit log level
AUDIT = 25 # Level between INFO and WARNING
logging.addLevelName(AUDIT, "AUDIT")
# Logger object cache
_loggers: Dict[str, logging.Logger] = {}
# Handler type mapping, used to ensure no duplicates are added
_handler_types = {
'console': logging.StreamHandler,
'file': logging.handlers.TimedRotatingFileHandler,
'audit': logging.handlers.TimedRotatingFileHandler,
'error': logging.handlers.TimedRotatingFileHandler
}
logging.config.dictConfig(config)
def get_logger(name: str) -> logging.Logger:
"""
Get a logger with the specified name
Get a logger instance.
Args:
name: Logger name
Returns:
logging.Logger: Configured logger
Logger instance
"""
if name in _loggers:
return _loggers[name]
# Create logger
logger = logging.getLogger(name)
logger.setLevel(LOG_LEVELS.get(LOG_LEVEL, logging.INFO))
# Avoid duplicate logs caused by propagation
logger.propagate = False
# Check if handlers already exist to avoid duplicates
handler_types = set(type(h) for h in logger.handlers)
# Add audit log method
def audit(self, message, *args, **kwargs):
self.log(AUDIT, message, *args, **kwargs)
logger.audit = audit.__get__(logger)
# General log handler - output to console (only if enabled)
if CONSOLE_LOGGING and _handler_types['console'] not in handler_types:
# Use stderr instead of stdout to avoid conflicts with MCP communication
console_handler = logging.StreamHandler(sys.stderr)
console_handler.setFormatter(logging.Formatter(LOG_FORMAT))
logger.addHandler(console_handler)
# --- Only add file handlers in non-Stdio mode ---
if not STDIO_MODE:
# General log handler - daily rotating file
if _handler_types['file'] not in handler_types:
try: # Add try-except block
file_handler = logging.handlers.TimedRotatingFileHandler(
LOG_FILE,
when='midnight',
interval=1,
backupCount=LOG_MAX_DAYS,
encoding='utf-8'
)
file_handler.setFormatter(logging.Formatter(LOG_FORMAT))
file_handler.suffix = "%Y%m%d"
logger.addHandler(file_handler)
except OSError as e:
print(f"Warning: Failed to add file log handler for {LOG_FILE}: {e}", file=sys.stderr)
# Audit log handler - only logs AUDIT level
if _handler_types['audit'] not in handler_types:
try: # Add try-except block
audit_handler = logging.handlers.TimedRotatingFileHandler(
AUDIT_LOG_FILE,
when='midnight',
interval=1,
backupCount=LOG_MAX_DAYS,
encoding='utf-8'
)
audit_handler.setFormatter(logging.Formatter(AUDIT_FORMAT))
audit_handler.suffix = "%Y%m%d"
audit_handler.setLevel(AUDIT)
audit_handler.addFilter(lambda record: record.levelno == AUDIT)
logger.addHandler(audit_handler)
except OSError as e:
print(f"Warning: Failed to add audit log handler for {AUDIT_LOG_FILE}: {e}", file=sys.stderr)
# Error log handler - only logs ERROR level and above
if _handler_types['error'] not in handler_types:
try: # Add try-except block
error_handler = logging.handlers.TimedRotatingFileHandler(
ERROR_LOG_FILE,
when='midnight',
interval=1,
backupCount=LOG_MAX_DAYS,
encoding='utf-8'
)
error_handler.setFormatter(logging.Formatter(ERROR_FORMAT))
error_handler.suffix = "%Y%m%d"
error_handler.setLevel(logging.ERROR)
logger.addHandler(error_handler)
except OSError as e:
print(f"Warning: Failed to add error log handler for {ERROR_LOG_FILE}: {e}", file=sys.stderr)
# Cache logger
_loggers[name] = logger
return logger
# Default logger
logger = get_logger('doris_mcp')
# Audit logger - for recording processing results, business operations, etc.
audit_logger = get_logger('audit')
# Call to clean logs moved after directory creation, and added non-stdio check
return logging.getLogger(name)

View File

@@ -0,0 +1,800 @@
#!/usr/bin/env python3
"""
Doris Query Execution Module
Implements query optimization, cache management and performance monitoring functionality
"""
import asyncio
import hashlib
import json
import logging
import time
import os
import uuid
import traceback
from dataclasses import dataclass
from datetime import datetime, timedelta, date
from typing import Any, Dict
from decimal import Decimal
from .db import DorisConnectionManager, QueryResult
@dataclass
class QueryRequest:
"""Query request wrapper"""
sql: str
session_id: str
user_id: str
parameters: dict[str, Any] | None = None
timeout: int | None = None
cache_enabled: bool = True
@dataclass
class CachedQuery:
"""Cached query result"""
result: QueryResult
created_at: datetime
ttl: int
access_count: int = 0
last_accessed: datetime | None = None
def is_expired(self) -> bool:
"""Check if cache is expired"""
if self.ttl <= 0:
return False
return (datetime.utcnow() - self.created_at).total_seconds() > self.ttl
def access(self):
"""Record access"""
self.access_count += 1
self.last_accessed = datetime.utcnow()
@dataclass
class QueryMetrics:
"""Query performance metrics"""
total_queries: int = 0
successful_queries: int = 0
failed_queries: int = 0
cache_hits: int = 0
cache_misses: int = 0
avg_execution_time: float = 0.0
total_execution_time: float = 0.0
slow_queries: int = 0
concurrent_queries: int = 0
class QueryCache:
"""Query result cache manager"""
def __init__(self, max_size: int = 1000, default_ttl: int = 300):
self.max_size = max_size
self.default_ttl = default_ttl
self.cache: dict[str, CachedQuery] = {}
self.logger = logging.getLogger(__name__)
def _generate_cache_key(
self, sql: str, parameters: dict[str, Any] | None = None
) -> str:
"""Generate cache key"""
cache_data = {"sql": sql.strip().lower(), "parameters": parameters or {}}
cache_string = json.dumps(cache_data, sort_keys=True)
return hashlib.md5(cache_string.encode()).hexdigest()
async def get(
self, sql: str, parameters: dict[str, Any] | None = None
) -> CachedQuery | None:
"""Get cached query result"""
cache_key = self._generate_cache_key(sql, parameters)
if cache_key in self.cache:
cached_query = self.cache[cache_key]
if not cached_query.is_expired():
cached_query.access()
self.logger.debug(f"Cache hit: {cache_key}")
return cached_query
else:
# Clean up expired cache
del self.cache[cache_key]
self.logger.debug(f"Cache expired, cleaned up: {cache_key}")
return None
async def set(
self,
sql: str,
result: QueryResult,
parameters: dict[str, Any] | None = None,
ttl: int | None = None,
) -> str:
"""Set query result cache"""
cache_key = self._generate_cache_key(sql, parameters)
# Check cache size limit
if len(self.cache) >= self.max_size:
await self._evict_oldest()
cached_query = CachedQuery(
result=result, created_at=datetime.utcnow(), ttl=ttl or self.default_ttl
)
self.cache[cache_key] = cached_query
self.logger.debug(f"Cache set: {cache_key}")
return cache_key
async def _evict_oldest(self):
"""Clean up oldest cache item"""
if not self.cache:
return
# Find oldest cache item
oldest_key = min(self.cache.keys(), key=lambda k: self.cache[k].created_at)
del self.cache[oldest_key]
self.logger.debug(f"Cleaned up oldest cache: {oldest_key}")
async def clear_expired(self):
"""Clean up all expired cache"""
expired_keys = [
key for key, cached_query in self.cache.items() if cached_query.is_expired()
]
for key in expired_keys:
del self.cache[key]
if expired_keys:
self.logger.info(f"Cleaned up {len(expired_keys)} expired cache items")
async def clear_all(self):
"""Clean up all cache"""
cache_count = len(self.cache)
self.cache.clear()
self.logger.info(f"Cleaned up all cache, total {cache_count} items")
def get_stats(self) -> dict[str, Any]:
"""Get cache statistics"""
total_access = sum(cached.access_count for cached in self.cache.values())
return {
"cache_size": len(self.cache),
"max_size": self.max_size,
"total_access": total_access,
"hit_rate": 0.0
if total_access == 0
else sum(cached.access_count for cached in self.cache.values())
/ total_access,
}
class QueryOptimizer:
"""Query optimizer"""
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(__name__)
self.optimization_rules = self._load_optimization_rules()
def _load_optimization_rules(self) -> list[dict[str, Any]]:
"""Load query optimization rules"""
return [
{
"name": "add_limit_clause",
"description": "Add default limit for SELECT queries without LIMIT",
"pattern": r"^select\s+.*(?!.*limit\s+\d+)",
"action": "add_limit",
"params": {"default_limit": 1000},
},
{
"name": "optimize_count_query",
"description": "Optimize COUNT queries",
"pattern": r"select\s+count\(\*\)\s+from\s+(\w+)",
"action": "optimize_count",
"params": {},
},
]
async def optimize_query(self, sql: str, context: dict[str, Any]) -> str:
"""Apply query optimization"""
optimized_sql = sql
for rule in self.optimization_rules:
if self._should_apply_rule(rule, optimized_sql, context):
optimized_sql = await self._apply_optimization_rule(
optimized_sql, rule, context
)
self.logger.debug(f"Applied optimization rule: {rule['name']}")
return optimized_sql
def _should_apply_rule(
self, rule: dict[str, Any], sql: str, context: dict[str, Any]
) -> bool:
"""Check if optimization rule should be applied"""
import re
# Check pattern match
if "pattern" in rule:
if not re.search(rule["pattern"], sql, re.IGNORECASE):
return False
# Check conditions
if "conditions" in rule:
for condition in rule["conditions"]:
if not self._check_condition(condition, context):
return False
return True
def _check_condition(
self, condition: dict[str, Any], context: dict[str, Any]
) -> bool:
"""Check optimization condition"""
condition_type = condition.get("type")
if condition_type == "user_role":
required_roles = condition.get("roles", [])
user_roles = context.get("user_roles", [])
return any(role in user_roles for role in required_roles)
elif condition_type == "query_size":
max_size = condition.get("max_size", 1000)
return len(context.get("sql", "")) <= max_size
return True
async def _apply_optimization_rule(
self, sql: str, rule: dict[str, Any], context: dict[str, Any]
) -> str:
"""Apply optimization rule"""
action = rule.get("action")
params = rule.get("params", {})
if action == "add_limit":
return await self._add_limit_clause(sql, params)
elif action == "optimize_count":
return await self._optimize_count_query(sql, params)
elif action == "add_hints":
return await self._add_query_hints(sql, params)
return sql
async def _add_limit_clause(self, sql: str, params: dict[str, Any]) -> str:
"""Add LIMIT clause to query"""
import re
default_limit = params.get("default_limit", 1000)
# Check if LIMIT already exists
if re.search(r"\blimit\s+\d+", sql, re.IGNORECASE):
return sql
# Add LIMIT clause
if sql.strip().endswith(";"):
sql = sql.strip()[:-1]
return f"{sql} LIMIT {default_limit}"
async def _optimize_count_query(self, sql: str, params: dict[str, Any]) -> str:
"""Optimize COUNT query"""
# For COUNT queries, we can add optimization hints
return sql.replace("COUNT(*)", "COUNT(1)")
async def _add_query_hints(self, sql: str, params: dict[str, Any]) -> str:
"""Add query hints"""
hints = params.get("hints", [])
if not hints:
return sql
hint_string = "/*+ " + " ".join(hints) + " */"
return f"{hint_string} {sql}"
class DorisQueryExecutor:
"""Doris query executor with caching and optimization"""
def __init__(self, connection_manager: DorisConnectionManager, config=None):
self.connection_manager = connection_manager
self.config = config or self._create_default_config()
self.logger = logging.getLogger(__name__)
# Initialize components
cache_config = getattr(self.config, 'performance', None)
if cache_config:
cache_size = getattr(cache_config, 'max_cache_size', 1000)
cache_ttl = getattr(cache_config, 'cache_ttl', 300)
else:
cache_size = 1000
cache_ttl = 300
self.query_cache = QueryCache(max_size=cache_size, default_ttl=cache_ttl)
self.query_optimizer = QueryOptimizer(self.config)
self.metrics = QueryMetrics()
# Performance monitoring
self.slow_query_threshold = 5.0 # seconds
self.max_concurrent_queries = getattr(
getattr(self.config, 'performance', None), 'max_concurrent_queries', 50
) if hasattr(self.config, 'performance') else 50
# Background tasks
self._background_tasks = []
self._start_background_tasks()
def _create_default_config(self):
"""Create default configuration"""
class DefaultConfig:
def __init__(self):
self.performance = DefaultPerformanceConfig()
class DefaultPerformanceConfig:
def __init__(self):
self.max_cache_size = 1000
self.cache_ttl = 300
self.max_concurrent_queries = 50
return DefaultConfig()
def _start_background_tasks(self):
"""Start background tasks"""
try:
# Cache cleanup task
cleanup_task = asyncio.create_task(self._cache_cleanup_loop())
self._background_tasks.append(cleanup_task)
except RuntimeError:
# No event loop running (e.g., in tests), skip background tasks
self.logger.debug("No event loop running, skipping background tasks")
async def _cache_cleanup_loop(self):
"""Background cache cleanup loop"""
while True:
try:
await asyncio.sleep(300) # Run every 5 minutes
await self.query_cache.clear_expired()
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error(f"Cache cleanup error: {e}")
async def execute_query(
self, query_request: QueryRequest, auth_context=None
) -> QueryResult:
"""Execute query with caching and optimization"""
start_time = time.time()
self.metrics.total_queries += 1
self.metrics.concurrent_queries += 1
try:
# Check cache first
if query_request.cache_enabled:
cached_result = await self.query_cache.get(
query_request.sql, query_request.parameters
)
if cached_result:
self.metrics.cache_hits += 1
self.logger.debug(f"Cache hit for query: {query_request.sql[:50]}...")
return cached_result.result
self.metrics.cache_misses += 1
# Execute query
result = await self._execute_query_internal(query_request, auth_context)
# Cache result if enabled
if query_request.cache_enabled and result.row_count > 0:
await self.query_cache.set(
query_request.sql, result, query_request.parameters
)
self.metrics.successful_queries += 1
return result
except Exception as e:
self.metrics.failed_queries += 1
self.logger.error(f"Query execution failed: {e}")
raise
finally:
execution_time = time.time() - start_time
self.metrics.concurrent_queries -= 1
self._update_execution_metrics(execution_time)
async def _execute_query_internal(
self, query_request: QueryRequest, auth_context
) -> QueryResult:
"""Internal query execution"""
# Optimize query
optimized_sql = await self.query_optimizer.optimize_query(
query_request.sql, {"user_roles": getattr(auth_context, 'roles', [])}
)
# Execute query
connection = await self.connection_manager.get_connection(
query_request.session_id
)
# Set timeout if specified
if query_request.timeout:
try:
result = await asyncio.wait_for(
connection.execute(optimized_sql, query_request.parameters, auth_context),
timeout=query_request.timeout
)
except asyncio.TimeoutError:
raise Exception(f"Query timeout after {query_request.timeout} seconds")
else:
result = await connection.execute(optimized_sql, query_request.parameters, auth_context)
return result
def _update_execution_metrics(self, execution_time: float):
"""Update execution metrics"""
self.metrics.total_execution_time += execution_time
# Update average execution time
if self.metrics.successful_queries > 0:
self.metrics.avg_execution_time = (
self.metrics.total_execution_time / self.metrics.successful_queries
)
# Check for slow queries
if execution_time > self.slow_query_threshold:
self.metrics.slow_queries += 1
self.logger.warning(
f"Slow query detected: {execution_time:.2f}s (threshold: {self.slow_query_threshold}s)"
)
async def execute_batch_queries(
self, query_requests: list[QueryRequest], auth_context=None
) -> list[QueryResult]:
"""Execute multiple queries in batch"""
results = []
# Check concurrent query limit
if len(query_requests) > self.max_concurrent_queries:
raise Exception(
f"Batch size {len(query_requests)} exceeds maximum concurrent queries {self.max_concurrent_queries}"
)
# Execute queries concurrently
tasks = [
self.execute_query(request, auth_context) for request in query_requests
]
try:
results = await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:
self.logger.error(f"Batch query execution failed: {e}")
raise
return results
async def explain_query(self, sql: str, session_id: str) -> dict[str, Any]:
"""Get query execution plan"""
explain_sql = f"EXPLAIN {sql}"
connection = await self.connection_manager.get_connection(session_id)
result = await connection.execute(explain_sql)
return {
"query": sql,
"execution_plan": result.data,
"estimated_cost": "N/A", # Doris doesn't provide cost estimates
}
async def get_query_stats(self) -> dict[str, Any]:
"""Get query execution statistics"""
cache_stats = self.query_cache.get_stats()
return {
"query_metrics": {
"total_queries": self.metrics.total_queries,
"successful_queries": self.metrics.successful_queries,
"failed_queries": self.metrics.failed_queries,
"success_rate": (
self.metrics.successful_queries / self.metrics.total_queries
if self.metrics.total_queries > 0
else 0.0
),
"avg_execution_time": self.metrics.avg_execution_time,
"slow_queries": self.metrics.slow_queries,
"concurrent_queries": self.metrics.concurrent_queries,
},
"cache_metrics": {
"cache_hits": self.metrics.cache_hits,
"cache_misses": self.metrics.cache_misses,
"hit_rate": (
self.metrics.cache_hits
/ (self.metrics.cache_hits + self.metrics.cache_misses)
if (self.metrics.cache_hits + self.metrics.cache_misses) > 0
else 0.0
),
**cache_stats,
},
}
async def clear_cache(self):
"""Clear query cache"""
await self.query_cache.clear_all()
async def execute_sql_for_mcp(
self,
sql: str,
limit: int = 1000,
timeout: int = 30,
session_id: str = "mcp_session",
user_id: str = "mcp_user"
) -> Dict[str, Any]:
"""Execute SQL query for MCP interface - unified method"""
try:
if not sql:
return {
"success": False,
"error": "SQL query is required",
"data": None
}
# Add LIMIT if not present and it's a SELECT query
if sql.upper().startswith("SELECT") and "LIMIT" not in sql.upper():
if sql.endswith(";"):
sql = sql[:-1]
sql = f"{sql} LIMIT {limit}"
# Create auth context for MCP calls
class MockAuthContext:
def __init__(self):
self.user_id = user_id
self.roles = ["data_analyst"]
self.permissions = ["read_data", "execute_query"]
self.session_id = session_id
self.security_level = "internal"
auth_context = MockAuthContext()
# Create query request
query_request = QueryRequest(
sql=sql,
session_id=session_id,
user_id=user_id,
timeout=timeout,
cache_enabled=True
)
# Execute query
result = await self.execute_query(query_request, auth_context)
# Process results
processed_data = []
if result.data:
for row in result.data:
processed_row = self._serialize_row_data(row)
processed_data.append(processed_row)
return {
"success": True,
"data": processed_data,
"metadata": {
"row_count": result.row_count,
"execution_time": result.execution_time,
"columns": result.metadata.get("columns", []),
"query": sql
},
"error": None
}
except Exception as e:
error_msg = str(e)
self.logger.error(f"SQL execution error: {error_msg}")
# Analyze error for better user feedback
error_analysis = self._analyze_error(error_msg)
return {
"success": False,
"error": error_analysis.get("user_message", error_msg),
"error_type": error_analysis.get("error_type", "execution_error"),
"data": None,
"metadata": {
"query": sql,
"error_details": error_msg
}
}
def _serialize_row_data(self, row_data: Dict[str, Any]) -> Dict[str, Any]:
"""Serialize row data for JSON response"""
serialized = {}
for key, value in row_data.items():
if value is None:
serialized[key] = None
elif isinstance(value, (str, int, float, bool)):
serialized[key] = value
elif isinstance(value, Decimal):
serialized[key] = float(value)
elif isinstance(value, (datetime, date)):
serialized[key] = value.isoformat()
elif isinstance(value, bytes):
try:
serialized[key] = value.decode('utf-8')
except UnicodeDecodeError:
serialized[key] = str(value)
else:
serialized[key] = str(value)
return serialized
def _analyze_error(self, error_message: str) -> Dict[str, str]:
"""Analyze error message and provide user-friendly feedback"""
error_msg_lower = error_message.lower()
if "table" in error_msg_lower and "doesn't exist" in error_msg_lower:
return {
"error_type": "table_not_found",
"user_message": "The specified table does not exist. Please check the table name and database."
}
elif "column" in error_msg_lower and ("unknown" in error_msg_lower or "doesn't exist" in error_msg_lower):
return {
"error_type": "column_not_found",
"user_message": "One or more columns in the query do not exist. Please check column names."
}
elif "syntax error" in error_msg_lower or "sql syntax" in error_msg_lower:
return {
"error_type": "syntax_error",
"user_message": "SQL syntax error. Please check your query syntax."
}
elif "access denied" in error_msg_lower or "permission" in error_msg_lower:
return {
"error_type": "permission_denied",
"user_message": "Access denied. You don't have permission to execute this query."
}
elif "timeout" in error_msg_lower:
return {
"error_type": "timeout",
"user_message": "Query execution timed out. Try simplifying your query or adding more specific filters."
}
else:
return {
"error_type": "general_error",
"user_message": f"Query execution failed: {error_message}"
}
async def close(self):
"""Close query executor and cleanup resources"""
# Cancel background tasks
for task in self._background_tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Clear cache
await self.query_cache.clear_all()
self.logger.info("Query executor closed")
class QueryPerformanceMonitor:
"""Query performance monitor"""
def __init__(self, query_executor: DorisQueryExecutor):
self.query_executor = query_executor
self.logger = logging.getLogger(__name__)
self.performance_records = []
async def record_query_performance(
self, query_request: QueryRequest, result: QueryResult, execution_time: float
):
"""Record query performance"""
record = {
"timestamp": datetime.utcnow(),
"sql": query_request.sql,
"user_id": query_request.user_id,
"session_id": query_request.session_id,
"execution_time": execution_time,
"row_count": result.row_count,
"cache_hit": False, # This would need to be passed from executor
}
self.performance_records.append(record)
# Keep only recent records (last 1000)
if len(self.performance_records) > 1000:
self.performance_records = self.performance_records[-1000:]
async def get_performance_report(
self, time_range_minutes: int = 60
) -> dict[str, Any]:
"""Get performance report"""
cutoff_time = datetime.utcnow() - timedelta(minutes=time_range_minutes)
recent_records = [
record
for record in self.performance_records
if record["timestamp"] >= cutoff_time
]
if not recent_records:
return {"message": "No performance data available for the specified time range"}
# Calculate statistics
execution_times = [record["execution_time"] for record in recent_records]
row_counts = [record["row_count"] for record in recent_records]
return {
"time_range_minutes": time_range_minutes,
"total_queries": len(recent_records),
"avg_execution_time": sum(execution_times) / len(execution_times),
"max_execution_time": max(execution_times),
"min_execution_time": min(execution_times),
"avg_row_count": sum(row_counts) / len(row_counts),
"query_distribution": self._analyze_query_distribution(recent_records),
}
def _analyze_query_distribution(
self, records: list[dict[str, Any]]
) -> dict[str, Any]:
"""Analyze query distribution"""
query_types = {}
user_distribution = {}
for record in records:
# Analyze query type
sql_upper = record["sql"].strip().upper()
if sql_upper.startswith("SELECT"):
query_type = "SELECT"
elif sql_upper.startswith("INSERT"):
query_type = "INSERT"
elif sql_upper.startswith("UPDATE"):
query_type = "UPDATE"
elif sql_upper.startswith("DELETE"):
query_type = "DELETE"
else:
query_type = "OTHER"
query_types[query_type] = query_types.get(query_type, 0) + 1
# Analyze user distribution
user_id = record["user_id"]
user_distribution[user_id] = user_distribution.get(user_id, 0) + 1
return {"query_types": query_types, "user_distribution": user_distribution}
# Unified convenience function for MCP integration
async def execute_sql_query(sql: str, connection_manager: DorisConnectionManager, **kwargs) -> Dict[str, Any]:
"""Execute SQL query - unified convenience function for MCP tools"""
try:
# Create query executor
executor = DorisQueryExecutor(connection_manager)
try:
# Extract parameters from kwargs or use defaults
limit = kwargs.get("limit", 1000)
timeout = kwargs.get("timeout", 30)
session_id = kwargs.get("session_id", "mcp_session")
user_id = kwargs.get("user_id", "mcp_user")
result = await executor.execute_sql_for_mcp(
sql=sql,
limit=limit,
timeout=timeout,
session_id=session_id,
user_id=user_id
)
return result
finally:
await executor.close()
except Exception as e:
return {
"success": False,
"error": f"Query execution failed: {str(e)}",
"data": None
}

View File

@@ -8,6 +8,8 @@ import os
import json
import pandas as pd
import re
import uuid
import time
from typing import Dict, List, Any, Optional, Tuple
from dotenv import load_dotenv
from datetime import datetime, timedelta
@@ -26,23 +28,25 @@ ENABLE_MULTI_DATABASE=os.getenv("ENABLE_MULTI_DATABASE",True)
MULTI_DATABASE_NAMES=os.getenv("MULTI_DATABASE_NAMES","")
# Import local modules
from doris_mcp_server.utils.db import execute_query_df, execute_query
from .db import DorisConnectionManager
class MetadataExtractor:
"""Apache Doris Metadata Extractor"""
def __init__(self, db_name: str = None, catalog_name: str = None):
def __init__(self, db_name: str = None, catalog_name: str = None, connection_manager=None):
"""
Initialize the metadata extractor
Args:
db_name: Default database name, uses the currently connected database if not specified
catalog_name: Default catalog name for federation queries, uses the current catalog if not specified
connection_manager: DorisConnectionManager instance for database operations
"""
# Get configuration from environment variables
self.db_name = db_name or os.getenv("DB_DATABASE", "")
self.catalog_name = catalog_name # Store catalog name for federation support
self.metadata_db = METADATA_DB_NAME # Use constant
self.connection_manager = connection_manager
# Caching system
self.metadata_cache = {}
@@ -65,6 +69,9 @@ class MetadataExtractor:
# List of excluded system databases
self.excluded_databases = self._load_excluded_databases()
# Session ID for database queries
self._session_id = f"metadata_extractor_{uuid.uuid4().hex[:8]}"
def _load_excluded_databases(self) -> List[str]:
"""
Load the list of excluded databases configuration
@@ -482,7 +489,7 @@ class MetadataExtractor:
TABLE_SCHEMA = '{db_name}'
AND TABLE_NAME = '{table_name}'
"""
table_type_result = execute_query(table_type_query)
table_type_result = self._execute_query(table_type_query)
if table_type_result:
schema["table_type"] = table_type_result[0].get("TABLE_TYPE", "")
schema["engine"] = table_type_result[0].get("ENGINE", "")
@@ -633,31 +640,52 @@ class MetadataExtractor:
else:
query = f"SHOW INDEX FROM `{db_name}`.`{table_name}`"
df = execute_query_df(query)
# Process results
indexes = []
current_index = None
for _, row in df.iterrows():
index_name = row['Key_name']
column_name = row['Column_name']
try:
df = self._execute_query(query, return_dataframe=True)
if current_index is None or current_index['name'] != index_name:
# Process results
indexes = []
current_index = None
if not df.empty:
for _, row in df.iterrows():
try:
index_name = row['Key_name']
column_name = row['Column_name']
if current_index is None or current_index['name'] != index_name:
if current_index is not None:
indexes.append(current_index)
current_index = {
'name': index_name,
'columns': [column_name],
'unique': row['Non_unique'] == 0,
'type': row['Index_type']
}
else:
current_index['columns'].append(column_name)
except Exception as row_error:
logger.warning(f"Failed to process index row data: {row_error}")
continue
if current_index is not None:
indexes.append(current_index)
current_index = {
'name': index_name,
'columns': [column_name],
'unique': row['Non_unique'] == 0,
'type': row['Index_type']
}
else:
current_index['columns'].append(column_name)
if current_index is not None:
indexes.append(current_index)
except Exception as df_error:
logger.warning(f"DataFrame processing failed, trying regular query: {df_error}")
# Fall back to regular query
result = self._execute_query(query, return_dataframe=False)
indexes = []
if result:
# Simple processing, no complex index grouping
for row in result:
if isinstance(row, dict):
indexes.append({
'name': row.get('Key_name', ''),
'columns': [row.get('Column_name', '')],
'unique': row.get('Non_unique', 1) == 0,
'type': row.get('Index_type', '')
})
# Update cache
self.metadata_cache[cache_key] = indexes
@@ -748,7 +776,7 @@ class MetadataExtractor:
ORDER BY time DESC
LIMIT {limit}
"""
df = execute_query_df(query)
df = self._execute_query(query, return_dataframe=True)
return df
except Exception as e:
logger.error(f"Error getting audit logs: {str(e)}")
@@ -768,7 +796,7 @@ class MetadataExtractor:
try:
# Use SHOW CATALOGS command to get catalog list
query = "SHOW CATALOGS"
result = execute_query(query)
result = self._execute_query(query)
if not result:
catalogs = []
@@ -1057,7 +1085,7 @@ class MetadataExtractor:
AND TABLE_NAME = '{table_name}'
"""
partitions = execute_query(query)
partitions = self._execute_query(query)
if not partitions:
return {}
@@ -1099,10 +1127,511 @@ class MetadataExtractor:
# Replace 'information_schema' with 'catalog_name.information_schema'
modified_query = query.replace('information_schema', f'{catalog_name}.information_schema')
logger.info(f"Modified query for catalog {catalog_name}: {modified_query}")
return execute_query(modified_query, db_name)
return self._execute_query(modified_query, db_name)
else:
# Execute the original query
return execute_query(query, db_name)
return self._execute_query(query, db_name)
except Exception as e:
logger.error(f"Error executing query with catalog: {str(e)}")
raise
raise
async def _execute_query_async(self, query: str, db_name: str = None, return_dataframe: bool = False):
"""
Execute database query asynchronously
Args:
query: SQL query to execute
db_name: Database name to use (optional)
return_dataframe: Whether to return a pandas DataFrame instead of list
Returns:
Query result data (list of dictionaries or pandas DataFrame)
"""
try:
if self.connection_manager:
# Use the injected connection manager directly (async)
result = await self.connection_manager.execute_query(self._session_id, query, None)
# Extract data from QueryResult
if hasattr(result, 'data'):
data = result.data
else:
data = result
# Convert to DataFrame if requested
if return_dataframe and data:
import pandas as pd
return pd.DataFrame(data)
elif return_dataframe:
import pandas as pd
return pd.DataFrame()
else:
return data
else:
# Fallback: Return empty result
logger.warning("No connection manager provided, returning empty result")
if return_dataframe:
import pandas as pd
return pd.DataFrame()
else:
return []
except Exception as e:
logger.error(f"Error executing query: {str(e)}")
# Return empty result instead of raising exception to prevent cascade failures
if return_dataframe:
import pandas as pd
return pd.DataFrame()
else:
return []
def _execute_query(self, query: str, db_name: str = None, return_dataframe: bool = False):
"""
Execute database query with proper session management (sync wrapper)
Args:
query: SQL query to execute
db_name: Database name to use (optional)
return_dataframe: Whether to return a pandas DataFrame instead of list
Returns:
Query result data (list of dictionaries or pandas DataFrame)
"""
try:
if self.connection_manager:
import asyncio
# Try to run the async query
try:
# Check if there's a running event loop
loop = asyncio.get_running_loop()
# If we're in an async context, we need to run in a separate thread
import concurrent.futures
def run_in_new_loop():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(
self._execute_query_async(query, db_name, return_dataframe)
)
finally:
new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_new_loop)
return future.result(timeout=30)
except RuntimeError:
# No running loop, we can safely create one
return asyncio.run(
self._execute_query_async(query, db_name, return_dataframe)
)
else:
# Fallback: Return empty result
logger.warning("No connection manager provided, returning empty result")
if return_dataframe:
import pandas as pd
return pd.DataFrame()
else:
return []
except Exception as e:
logger.error(f"Error executing query: {str(e)}")
# Return empty result instead of raising exception to prevent cascade failures
if return_dataframe:
import pandas as pd
return pd.DataFrame()
else:
return []
async def get_table_schema_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> List[Dict[str, Any]]:
"""Asynchronously get table schema information"""
try:
# Use async query method
effective_catalog = catalog_name or self.catalog_name
# Build query statement
if effective_catalog and effective_catalog != "internal":
query = f"DESCRIBE `{effective_catalog}`.`{db_name or self.db_name}`.`{table_name}`"
else:
query = f"DESCRIBE `{db_name or self.db_name}`.`{table_name}`"
# Execute async query
result = await self._execute_query_async(query, db_name)
if not result:
return []
# Process results
schema = []
for row in result:
if isinstance(row, dict):
schema.append({
'column_name': row.get('Field', ''),
'data_type': row.get('Type', ''),
'is_nullable': row.get('Null', 'NO') == 'YES',
'default_value': row.get('Default', None),
'comment': row.get('Comment', ''),
'key': row.get('Key', ''),
'extra': row.get('Extra', '')
})
return schema
except Exception as e:
logger.error(f"Failed to get table schema: {e}")
return []
async def get_all_databases_async(self, catalog_name: str = None) -> List[str]:
"""Asynchronously get all database list"""
try:
effective_catalog = catalog_name or self.catalog_name
if effective_catalog and effective_catalog != "internal":
query = f"SHOW DATABASES FROM `{effective_catalog}`"
else:
query = "SHOW DATABASES"
result = await self._execute_query_async(query)
if not result:
return []
# Extract database names
databases = []
for row in result:
if isinstance(row, dict):
# Get the value of the first field (usually Database field)
db_name = list(row.values())[0] if row else None
if db_name:
databases.append(db_name)
return databases
except Exception as e:
logger.error(f"Failed to get database list: {e}")
return []
async def get_database_tables_async(self, db_name: str = None, catalog_name: str = None) -> List[str]:
"""Asynchronously get table list in database"""
try:
effective_catalog = catalog_name or self.catalog_name
effective_db = db_name or self.db_name
if effective_catalog and effective_catalog != "internal":
query = f"SHOW TABLES FROM `{effective_catalog}`.`{effective_db}`"
else:
query = f"SHOW TABLES FROM `{effective_db}`"
result = await self._execute_query_async(query, effective_db)
if not result:
return []
# Extract table names
tables = []
for row in result:
if isinstance(row, dict):
# Get the value of the first field (usually Tables_in_xxx field)
table_name = list(row.values())[0] if row else None
if table_name:
tables.append(table_name)
return tables
except Exception as e:
logger.error(f"Failed to get table list: {e}")
return []
async def get_catalog_list_async(self) -> List[str]:
"""Asynchronously get catalog list"""
try:
query = "SHOW CATALOGS"
result = await self._execute_query_async(query)
if not result:
return []
# Extract catalog names
catalogs = []
for row in result:
if isinstance(row, dict):
# SHOW CATALOGS returns fields including: CatalogId, CatalogName, Type, IsCurrent, CreateTime, LastUpdateTime, Comment
# We need to get the CatalogName field (second field)
if 'CatalogName' in row:
catalog_name = row['CatalogName']
else:
# If no CatalogName field, try to get the second field
values = list(row.values())
catalog_name = values[1] if len(values) > 1 else values[0] if values else None
if catalog_name:
catalogs.append(str(catalog_name))
return catalogs
except Exception as e:
logger.error(f"Failed to get catalog list: {e}")
return []
# ==================== Business layer methods (original metadata_tools.py functionality) ====================
def _format_response(self, success: bool, result: Any = None, error: str = None, message: str = "") -> Dict[str, Any]:
"""Format response result"""
response_data = {
"success": success,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
if success and result is not None:
response_data["result"] = result
response_data["message"] = message or "Operation successful"
elif not success:
response_data["error"] = error or "Unknown error"
response_data["message"] = message or "Operation failed"
return response_data
async def exec_query_for_mcp(
self,
sql: str,
db_name: str = None,
catalog_name: str = None,
max_rows: int = 100,
timeout: int = 30
) -> Dict[str, Any]:
"""
Execute SQL query and return results, supports catalog federation queries
Unified interface for MCP tools
"""
logger.info(f"Executing SQL query: {sql}, DB: {db_name}, Catalog: {catalog_name}, MaxRows: {max_rows}, Timeout: {timeout}")
try:
if not sql:
return self._format_response(success=False, error="No SQL statement provided", message="Please provide SQL statement to execute")
# Import query executor
from .query_executor import execute_sql_query
# Call execute_sql_query to execute query
exec_result = await execute_sql_query(
sql=sql,
connection_manager=self.connection_manager,
limit=max_rows,
timeout=timeout
)
return exec_result
except Exception as e:
logger.error(f"Failed to execute SQL query: {str(e)}", exc_info=True)
return self._format_response(success=False, error=str(e), message="Error occurred while executing SQL query")
async def get_table_schema_for_mcp(
self,
table_name: str,
db_name: str = None,
catalog_name: str = None
) -> Dict[str, Any]:
"""Get detailed schema information for specified table (columns, types, comments, etc.) - MCP interface"""
logger.info(f"Getting table schema: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
if not table_name:
return self._format_response(success=False, error="Missing table_name parameter")
try:
schema = await self.get_table_schema_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
if not schema:
return self._format_response(
success=False,
error="Table does not exist or has no columns",
message=f"Unable to get schema for table {catalog_name or 'default'}.{db_name or self.db_name}.{table_name}"
)
return self._format_response(success=True, result=schema)
except Exception as e:
logger.error(f"Failed to get table schema: {str(e)}", exc_info=True)
return self._format_response(success=False, error=str(e), message="Error occurred while getting table schema")
async def get_db_table_list_for_mcp(
self,
db_name: str = None,
catalog_name: str = None
) -> Dict[str, Any]:
"""Get list of all table names in specified database - MCP interface"""
logger.info(f"Getting database table list: DB: {db_name}, Catalog: {catalog_name}")
try:
tables = await self.get_database_tables_async(db_name=db_name, catalog_name=catalog_name)
return self._format_response(success=True, result=tables)
except Exception as e:
logger.error(f"Failed to get database table list: {str(e)}", exc_info=True)
return self._format_response(success=False, error=str(e), message="Error occurred while getting database table list")
async def get_db_list_for_mcp(self, catalog_name: str = None) -> Dict[str, Any]:
"""Get list of all database names on server - MCP interface"""
logger.info(f"Getting database list: Catalog: {catalog_name}")
try:
databases = await self.get_all_databases_async(catalog_name=catalog_name)
return self._format_response(success=True, result=databases)
except Exception as e:
logger.error(f"Failed to get database list: {str(e)}", exc_info=True)
return self._format_response(success=False, error=str(e), message="Error occurred while getting database list")
async def get_table_comment_for_mcp(
self,
table_name: str,
db_name: str = None,
catalog_name: str = None
) -> Dict[str, Any]:
"""Get comment information for specified table - MCP interface"""
logger.info(f"Getting table comment: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
if not table_name:
return self._format_response(success=False, error="Missing table_name parameter")
try:
comment = self.get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
return self._format_response(success=True, result=comment)
except Exception as e:
logger.error(f"Failed to get table comment: {str(e)}", exc_info=True)
return self._format_response(success=False, error=str(e), message="Error occurred while getting table comment")
async def get_table_column_comments_for_mcp(
self,
table_name: str,
db_name: str = None,
catalog_name: str = None
) -> Dict[str, Any]:
"""Get comment information for all columns in specified table - MCP interface"""
logger.info(f"Getting table column comments: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
if not table_name:
return self._format_response(success=False, error="Missing table_name parameter")
try:
comments = self.get_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
return self._format_response(success=True, result=comments)
except Exception as e:
logger.error(f"Failed to get table column comments: {str(e)}", exc_info=True)
return self._format_response(success=False, error=str(e), message="Error occurred while getting table column comments")
async def get_table_indexes_for_mcp(
self,
table_name: str,
db_name: str = None,
catalog_name: str = None
) -> Dict[str, Any]:
"""Get index information for specified table - MCP interface"""
logger.info(f"Getting table indexes: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
if not table_name:
return self._format_response(success=False, error="Missing table_name parameter")
try:
indexes = self.get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
return self._format_response(success=True, result=indexes)
except Exception as e:
logger.error(f"Failed to get table indexes: {str(e)}", exc_info=True)
return self._format_response(success=False, error=str(e), message="Error occurred while getting table indexes")
def _serialize_datetime_objects(self, data):
"""Serialize datetime objects to JSON compatible format"""
if isinstance(data, list):
return [self._serialize_datetime_objects(item) for item in data]
elif isinstance(data, dict):
return {key: self._serialize_datetime_objects(value) for key, value in data.items()}
elif hasattr(data, 'isoformat'): # datetime, date, time objects
return data.isoformat()
elif hasattr(data, 'strftime'): # pandas Timestamp objects
return data.strftime('%Y-%m-%d %H:%M:%S')
else:
return data
async def get_recent_audit_logs_for_mcp(self, days: int = 7, limit: int = 100) -> Dict[str, Any]:
"""Get recent audit log records - MCP interface"""
logger.info(f"Getting audit logs: Days: {days}, Limit: {limit}")
try:
logs_df = self.get_recent_audit_logs(days=days, limit=limit)
# Convert DataFrame to JSON format
if hasattr(logs_df, 'to_dict'):
try:
logs_data = logs_df.to_dict('records')
except Exception as e:
logger.warning(f"DataFrame.to_dict failed, trying manual conversion: {e}")
# Manually convert DataFrame to records format
logs_data = []
if not logs_df.empty:
for _, row in logs_df.iterrows():
logs_data.append(dict(row))
# Serialize datetime objects
logs_data = self._serialize_datetime_objects(logs_data)
else:
logs_data = self._serialize_datetime_objects(logs_df)
return self._format_response(success=True, result=logs_data)
except Exception as e:
logger.error(f"Failed to get audit logs: {str(e)}", exc_info=True)
return self._format_response(success=False, error=str(e), message="Error occurred while getting audit logs")
async def get_catalog_list_for_mcp(self) -> Dict[str, Any]:
"""Get Doris catalog list - MCP interface"""
logger.info("Getting catalog list")
try:
catalogs = await self.get_catalog_list_async()
return self._format_response(success=True, result=catalogs, message="Successfully retrieved catalog list")
except Exception as e:
logger.error(f"Failed to get catalog list: {str(e)}", exc_info=True)
return self._format_response(success=False, error=str(e), message="Error occurred while getting catalog list")
# ==================== Compatibility aliases ====================
# For backward compatibility, create MetadataManager alias
class MetadataManager:
"""
Metadata manager - backward compatibility class
Actually a wrapper for MetadataExtractor
"""
def __init__(self, connection_manager=None):
self.extractor = MetadataExtractor(connection_manager=connection_manager)
async def exec_query(self, sql: str, db_name: str = None, catalog_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]:
"""Execute SQL query and return results, supports catalog federation queries"""
return await self.extractor.exec_query_for_mcp(sql, db_name, catalog_name, max_rows, timeout)
async def get_table_schema(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
"""Get detailed schema information for specified table (columns, types, comments, etc.)"""
return await self.extractor.get_table_schema_for_mcp(table_name, db_name, catalog_name)
async def get_db_table_list(self, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
"""Get list of all table names in specified database"""
return await self.extractor.get_db_table_list_for_mcp(db_name, catalog_name)
async def get_db_list(self, catalog_name: str = None) -> Dict[str, Any]:
"""Get list of all database names on server"""
return await self.extractor.get_db_list_for_mcp(catalog_name)
async def get_table_comment(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
"""Get comment information for specified table"""
return await self.extractor.get_table_comment_for_mcp(table_name, db_name, catalog_name)
async def get_table_column_comments(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
"""Get comment information for all columns in specified table"""
return await self.extractor.get_table_column_comments_for_mcp(table_name, db_name, catalog_name)
async def get_table_indexes(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
"""Get index information for specified table"""
return await self.extractor.get_table_indexes_for_mcp(table_name, db_name, catalog_name)
async def get_recent_audit_logs(self, days: int = 7, limit: int = 100) -> Dict[str, Any]:
"""Get recent audit log records"""
return await self.extractor.get_recent_audit_logs_for_mcp(days, limit)
async def get_catalog_list(self) -> Dict[str, Any]:
"""Get Doris catalog list"""
return await self.extractor.get_catalog_list_for_mcp()

View File

@@ -0,0 +1,861 @@
#!/usr/bin/env python3
"""
Doris Security Management Module
Implements enterprise-level authentication, authorization, SQL security validation and data masking functionality
"""
import hashlib
import logging
import re
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any
import sqlparse
from sqlparse.sql import Statement
from sqlparse.tokens import Keyword, Name
class SecurityLevel(Enum):
"""Security level enumeration"""
PUBLIC = "public"
INTERNAL = "internal"
CONFIDENTIAL = "confidential"
SECRET = "secret"
@dataclass
class AuthContext:
"""Authentication context"""
user_id: str
roles: list[str]
permissions: list[str]
session_id: str
login_time: datetime | None = None
last_activity: datetime | None = None
security_level: SecurityLevel = SecurityLevel.INTERNAL
@dataclass
class ValidationResult:
"""Validation result"""
is_valid: bool
error_message: str | None = None
risk_level: str = "low"
blocked_operations: list[str] = None
def __post_init__(self):
if self.blocked_operations is None:
self.blocked_operations = []
@dataclass
class MaskingRule:
"""Data masking rule"""
column_pattern: str
algorithm: str
parameters: dict[str, Any]
security_level: SecurityLevel
class DorisSecurityManager:
"""Doris security manager
Provides complete security control functionality, including authentication, authorization, SQL security validation and data masking
"""
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(__name__)
# Initialize security components
self.auth_provider = AuthenticationProvider(config)
self.authz_provider = AuthorizationProvider(config)
self.sql_validator = SQLSecurityValidator(config)
self.masking_processor = DataMaskingProcessor(config)
# Security rule configuration
self.blocked_keywords = self._load_blocked_keywords()
self.sensitive_tables = self._load_sensitive_tables()
self.masking_rules = self._load_masking_rules()
def _load_blocked_keywords(self) -> set[str]:
"""Load blocked SQL keywords"""
default_blocked = {
"DROP",
"DELETE",
"TRUNCATE",
"ALTER",
"CREATE",
"INSERT",
"UPDATE",
"GRANT",
"REVOKE",
"EXEC",
"EXECUTE",
"SHUTDOWN",
"KILL",
}
# Load custom rules from configuration file
if hasattr(self.config, 'get'):
custom_blocked = set(self.config.get("blocked_keywords", []))
else:
custom_blocked = set()
return default_blocked.union(custom_blocked)
def _load_sensitive_tables(self) -> dict[str, SecurityLevel]:
"""Load sensitive table configuration"""
default_tables = {
"user_info": SecurityLevel.CONFIDENTIAL,
"payment_records": SecurityLevel.SECRET,
"employee_data": SecurityLevel.CONFIDENTIAL,
"public_reports": SecurityLevel.PUBLIC,
}
if hasattr(self.config, 'get'):
config_tables = self.config.get("sensitive_tables", {})
# Convert string values to SecurityLevel enum
for table_name, level in config_tables.items():
if isinstance(level, str):
try:
default_tables[table_name] = SecurityLevel(level.lower())
except ValueError:
default_tables[table_name] = SecurityLevel.INTERNAL
else:
default_tables[table_name] = level
return default_tables
else:
return default_tables
def _load_masking_rules(self) -> list[MaskingRule]:
"""Load data masking rules"""
default_rules = [
MaskingRule(
column_pattern=r".*phone.*|.*mobile.*",
algorithm="phone_mask",
parameters={"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4},
security_level=SecurityLevel.INTERNAL,
),
MaskingRule(
column_pattern=r".*email.*",
algorithm="email_mask",
parameters={"mask_char": "*"},
security_level=SecurityLevel.INTERNAL,
),
MaskingRule(
column_pattern=r".*id_card.*|.*identity.*",
algorithm="id_mask",
parameters={"mask_char": "*", "keep_prefix": 6, "keep_suffix": 4},
security_level=SecurityLevel.CONFIDENTIAL,
),
]
# Load custom rules from configuration
custom_rules = []
if hasattr(self.config, 'get'):
custom_rules = self.config.get("masking_rules", [])
elif hasattr(self.config, 'security') and hasattr(self.config.security, 'masking_rules'):
custom_rules = self.config.security.masking_rules
for rule_config in custom_rules:
if isinstance(rule_config, dict):
default_rules.append(MaskingRule(**rule_config))
elif isinstance(rule_config, MaskingRule):
default_rules.append(rule_config)
return default_rules
async def authenticate_request(self, auth_info: dict[str, Any]) -> AuthContext:
"""Validate request authentication information"""
return await self.auth_provider.authenticate(auth_info)
async def authorize_resource_access(
self, auth_context: AuthContext, resource_uri: str
) -> bool:
"""Validate resource access permissions"""
return await self.authz_provider.check_permission(
auth_context, resource_uri, "read"
)
async def validate_sql_security(
self, sql: str, auth_context: AuthContext
) -> ValidationResult:
"""Validate SQL query security"""
return await self.sql_validator.validate(sql, auth_context)
async def apply_data_masking(
self, data: list[dict[str, Any]], auth_context: AuthContext
) -> list[dict[str, Any]]:
"""Apply data masking processing"""
return await self.masking_processor.process(data, auth_context)
class AuthenticationProvider:
"""Authentication provider"""
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(__name__)
self.session_cache = {}
async def authenticate(self, auth_info: dict[str, Any]) -> AuthContext:
"""Perform identity authentication"""
auth_type = auth_info.get("type", "token")
if auth_type == "token":
return await self._authenticate_token(auth_info)
elif auth_type == "basic":
return await self._authenticate_basic(auth_info)
else:
raise ValueError(f"Unsupported authentication type: {auth_type}")
async def _authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext:
"""Token authentication"""
token = auth_info.get("token")
if not token:
raise ValueError("Missing authentication token")
# Validate token (simplified implementation, should validate JWT or query authentication service in practice)
user_info = await self._validate_token(token)
return AuthContext(
user_id=user_info["user_id"],
roles=user_info["roles"],
permissions=user_info["permissions"],
session_id=auth_info.get("session_id", "default"),
login_time=datetime.utcnow(),
security_level=SecurityLevel(user_info.get("security_level", "internal")),
)
async def _authenticate_basic(self, auth_info: dict[str, Any]) -> AuthContext:
"""Basic authentication (username password)"""
username = auth_info.get("username")
password = auth_info.get("password")
if not username or not password:
raise ValueError("Missing username or password")
# Validate username password (simplified implementation)
user_info = await self._validate_credentials(username, password)
return AuthContext(
user_id=user_info["user_id"],
roles=user_info["roles"],
permissions=user_info["permissions"],
session_id=auth_info.get("session_id", "default"),
login_time=datetime.utcnow(),
security_level=SecurityLevel(user_info.get("security_level", "internal")),
)
async def _validate_token(self, token: str) -> dict[str, Any]:
"""Validate token validity"""
# Simplified implementation for testing, should parse JWT or query authentication service in practice
valid_tokens = {
"valid_token_123": {
"user_id": "test_user",
"roles": ["data_analyst"],
"permissions": ["read_data"],
"security_level": SecurityLevel.INTERNAL,
},
"admin_token_456": {
"user_id": "admin_user",
"roles": ["data_admin"],
"permissions": ["admin"],
"security_level": SecurityLevel.SECRET,
}
}
if token in valid_tokens:
return valid_tokens[token]
else:
raise ValueError("Invalid token")
async def _validate_credentials(
self, username: str, password: str
) -> dict[str, Any]:
"""Validate user credentials"""
# Simplified implementation for testing, should query user database in practice
valid_users = {
"admin": {
"password": "admin123",
"user_id": "admin_user",
"roles": ["data_admin"],
"permissions": ["admin", "read_data", "write_data"],
"security_level": SecurityLevel.SECRET,
},
"analyst": {
"password": "analyst123",
"user_id": "analyst_user",
"roles": ["data_analyst"],
"permissions": ["read_data"],
"security_level": SecurityLevel.INTERNAL,
}
}
if username in valid_users and valid_users[username]["password"] == password:
user_info = valid_users[username].copy()
del user_info["password"] # Remove password from returned info
return user_info
else:
raise ValueError("Incorrect username or password")
class AuthorizationProvider:
"""Authorization provider"""
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(__name__)
self.permission_cache = {}
# Load sensitive tables configuration
self.sensitive_tables = self._load_sensitive_tables()
def _load_sensitive_tables(self) -> dict[str, SecurityLevel]:
"""Load sensitive table configuration"""
default_tables = {
"user_info": SecurityLevel.CONFIDENTIAL,
"payment_records": SecurityLevel.SECRET,
"employee_data": SecurityLevel.CONFIDENTIAL,
"public_reports": SecurityLevel.PUBLIC,
}
if hasattr(self.config, 'get'):
config_tables = self.config.get("sensitive_tables", {})
# Convert string values to SecurityLevel enum
for table_name, level in config_tables.items():
if isinstance(level, str):
try:
default_tables[table_name] = SecurityLevel(level.lower())
except ValueError:
default_tables[table_name] = SecurityLevel.INTERNAL
else:
default_tables[table_name] = level
return default_tables
else:
return default_tables
async def check_permission(
self, auth_context: AuthContext, resource_uri: str, action: str
) -> bool:
"""Check permissions"""
# Parse resource information
resource_info = self._parse_resource_uri(resource_uri)
# First check security level - this is mandatory
if not await self._check_security_level_permission(auth_context, resource_info):
return False
# Then check role-based permissions
if await self._check_role_permission(auth_context, resource_info, action):
return True
# Finally check user-based permissions
if await self._check_user_permission(auth_context, resource_info, action):
return True
return False
def _parse_resource_uri(self, uri: str) -> dict[str, str]:
"""Parse resource URI"""
parts = uri.split("/")
if len(parts) >= 3:
return {
"type": parts[2], # table, view, etc.
"name": parts[3] if len(parts) > 3 else "",
"schema": parts[4] if len(parts) > 4 else "default",
}
return {"type": "unknown", "name": "", "schema": "default"}
async def _check_role_permission(
self, auth_context: AuthContext, resource_info: dict[str, str], action: str
) -> bool:
"""Check role-based permissions"""
# Role permission mapping
role_permissions = {
"data_analyst": {"table": ["read"], "view": ["read"]},
"data_admin": {
"table": ["read", "write", "admin"],
"view": ["read", "write", "admin"],
},
}
for role in auth_context.roles:
role_perms = role_permissions.get(role, {})
resource_perms = role_perms.get(resource_info["type"], [])
if action in resource_perms:
return True
return False
async def _check_user_permission(
self, auth_context: AuthContext, resource_info: dict[str, str], action: str
) -> bool:
"""Check user-based permissions"""
# User-specific permission check
if "admin" in auth_context.permissions:
return True
if action == "read" and "read_data" in auth_context.permissions:
return True
return False
async def _check_security_level_permission(
self, auth_context: AuthContext, resource_info: dict[str, str]
) -> bool:
"""Check security level permissions"""
# Get resource security level
resource_security_level = self._get_resource_security_level(resource_info)
# Check if user security level is sufficient
security_hierarchy = {
SecurityLevel.PUBLIC: 0,
SecurityLevel.INTERNAL: 1,
SecurityLevel.CONFIDENTIAL: 2,
SecurityLevel.SECRET: 3,
}
user_level = security_hierarchy.get(auth_context.security_level, 0)
resource_level = security_hierarchy.get(resource_security_level, 0)
# User must have higher or equal security level to access resource
return user_level >= resource_level
def _get_resource_security_level(
self, resource_info: dict[str, str]
) -> SecurityLevel:
"""Get resource security level"""
# Get table security level from configuration
table_name = resource_info.get("name", "")
# Use the loaded sensitive tables
sensitive_tables = self.sensitive_tables
# Convert string values to SecurityLevel enum if needed
security_level = sensitive_tables.get(table_name, SecurityLevel.INTERNAL)
if isinstance(security_level, str):
try:
security_level = SecurityLevel(security_level.lower())
except ValueError:
security_level = SecurityLevel.INTERNAL
return security_level
class SQLSecurityValidator:
"""SQL security validator"""
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(__name__)
# Handle DorisConfig object or dictionary configuration
if hasattr(config, 'get'):
# Dictionary configuration
self.blocked_keywords = set(config.get("blocked_keywords", []))
self.max_query_complexity = config.get("max_query_complexity", 100)
else:
# DorisConfig object, use default values
self.blocked_keywords = set(["DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE"])
self.max_query_complexity = 100
async def validate(self, sql: str, auth_context: AuthContext) -> ValidationResult:
"""Validate SQL query security"""
try:
# Parse SQL statement
parsed = sqlparse.parse(sql)[0]
# Check blocked operations first (more specific)
keyword_result = await self._check_blocked_keywords(parsed)
if not keyword_result.is_valid:
return keyword_result
# Check SQL injection risks
injection_result = await self._check_sql_injection(sql, parsed)
if not injection_result.is_valid:
return injection_result
# Check query complexity
complexity_result = await self._check_query_complexity(parsed)
if not complexity_result.is_valid:
return complexity_result
# Check table access permissions
table_result = await self._check_table_access(parsed, auth_context)
if not table_result.is_valid:
return table_result
return ValidationResult(is_valid=True)
except Exception as e:
self.logger.error(f"SQL security validation failed: {e}")
return ValidationResult(
is_valid=False,
error_message=f"SQL parsing error: {str(e)}",
risk_level="high",
)
async def _check_sql_injection(
self, sql: str, parsed: Statement
) -> ValidationResult:
"""Check SQL injection risks"""
# Check common SQL injection patterns
injection_patterns = [
r"(\s|^)(union|select|insert|update|delete|drop|create|alter)\s+.*\s+(union|select|insert|update|delete|drop|create|alter)",
r"(\s|^)(or|and)\s+\d+\s*=\s*\d+",
r"(\s|^)(or|and)\s+['\"].*['\"]",
r";\s*(drop|delete|truncate|alter|create)",
r"(exec|execute|sp_|xp_)",
r"(script|javascript|vbscript)",
r"(char|ascii|substring|concat)\s*\(",
]
sql_lower = sql.lower()
for pattern in injection_patterns:
if re.search(pattern, sql_lower, re.IGNORECASE):
return ValidationResult(
is_valid=False,
error_message="Potential SQL injection risk detected",
risk_level="high",
)
# Check suspicious quotes and comments
if self._has_suspicious_quotes_or_comments(sql):
return ValidationResult(
is_valid=False,
error_message="Suspicious quote or comment pattern detected",
risk_level="medium",
)
return ValidationResult(is_valid=True)
def _has_suspicious_quotes_or_comments(self, sql: str) -> bool:
"""Check suspicious quote and comment patterns"""
# Check unmatched quotes
single_quotes = sql.count("'")
double_quotes = sql.count('"')
if single_quotes % 2 != 0 or double_quotes % 2 != 0:
return True
# Check SQL comments
if "--" in sql or "/*" in sql:
return True
return False
async def _check_blocked_keywords(self, parsed: Statement) -> ValidationResult:
"""Check blocked keywords"""
blocked_operations = []
# Check all tokens in the parsed statement
for token in parsed.flatten():
# Check if token is a keyword (including DML/DDL) or name that matches blocked operations
if (token.ttype is Keyword or
token.ttype is Name or
(token.ttype and str(token.ttype).startswith('Token.Keyword'))):
token_value = token.value.upper().strip()
if token_value in self.blocked_keywords:
blocked_operations.append(token_value)
# Also check for DDL/DML keywords in token values
elif hasattr(token, 'value') and token.value:
token_value = token.value.upper().strip()
for blocked_keyword in self.blocked_keywords:
if blocked_keyword in token_value:
blocked_operations.append(blocked_keyword)
if blocked_operations:
return ValidationResult(
is_valid=False,
error_message=f"Contains blocked operations: {', '.join(set(blocked_operations))}",
risk_level="high",
blocked_operations=list(set(blocked_operations)),
)
return ValidationResult(is_valid=True)
async def _check_query_complexity(self, parsed: Statement) -> ValidationResult:
"""Check query complexity"""
complexity_score = 0
# Calculate complexity score
for token in parsed.flatten():
if token.ttype is Keyword:
keyword = token.value.upper()
if keyword in ["JOIN", "INNER", "LEFT", "RIGHT", "FULL"]:
complexity_score += 10
elif keyword in ["UNION", "INTERSECT", "EXCEPT"]:
complexity_score += 15
elif keyword in ["GROUP BY", "ORDER BY", "HAVING"]:
complexity_score += 5
elif keyword in ["SUBQUERY", "EXISTS", "IN"]:
complexity_score += 8
if complexity_score > self.max_query_complexity:
return ValidationResult(
is_valid=False,
error_message=f"Query complexity too high (score: {complexity_score}, limit: {self.max_query_complexity})",
risk_level="medium",
)
return ValidationResult(is_valid=True)
async def _check_table_access(
self, parsed: Statement, auth_context: AuthContext
) -> ValidationResult:
"""Check table access permissions"""
# Extract table names from query
tables = self._extract_table_names(parsed)
# Check access permissions for each table
unauthorized_tables = []
for table in tables:
# Should call authorization provider to check permissions
# Simplified implementation, assume some tables require special permissions
if (
table.lower() in ["sensitive_data", "admin_logs"]
and "admin" not in auth_context.roles
):
unauthorized_tables.append(table)
if unauthorized_tables:
return ValidationResult(
is_valid=False,
error_message=f"No access to tables: {', '.join(unauthorized_tables)}",
risk_level="high",
)
return ValidationResult(is_valid=True)
def _extract_table_names(self, parsed: Statement) -> list[str]:
"""Extract table names from SQL statement"""
tables = []
# Simplified table name extraction logic
tokens = list(parsed.flatten())
for i, token in enumerate(tokens):
if token.ttype is Keyword and token.value.upper() == "FROM":
# Find table name after FROM
for j in range(i + 1, len(tokens)):
next_token = tokens[j]
if next_token.ttype is Name:
tables.append(next_token.value)
break
elif next_token.ttype is Keyword:
break
return tables
class DataMaskingProcessor:
"""Data masking processor"""
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(__name__)
self.masking_algorithms = self._init_masking_algorithms()
self.masking_rules = self._load_masking_rules()
def _load_masking_rules(self) -> list[MaskingRule]:
"""Load data masking rules"""
default_rules = [
MaskingRule(
column_pattern=r".*phone.*|.*mobile.*",
algorithm="phone_mask",
parameters={"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4},
security_level=SecurityLevel.INTERNAL,
),
MaskingRule(
column_pattern=r".*email.*",
algorithm="email_mask",
parameters={"mask_char": "*"},
security_level=SecurityLevel.INTERNAL,
),
MaskingRule(
column_pattern=r".*id_card.*|.*identity.*",
algorithm="id_mask",
parameters={"mask_char": "*", "keep_prefix": 6, "keep_suffix": 4},
security_level=SecurityLevel.CONFIDENTIAL,
),
]
# Load custom rules from configuration
if hasattr(self.config, 'get'):
custom_rules = self.config.get("masking_rules", [])
for rule_config in custom_rules:
if isinstance(rule_config, dict):
# Convert string security level to enum
if 'security_level' in rule_config and isinstance(rule_config['security_level'], str):
try:
rule_config['security_level'] = SecurityLevel(rule_config['security_level'].lower())
except ValueError:
rule_config['security_level'] = SecurityLevel.INTERNAL
default_rules.append(MaskingRule(**rule_config))
elif isinstance(rule_config, MaskingRule):
default_rules.append(rule_config)
return default_rules
def _init_masking_algorithms(self) -> dict[str, callable]:
"""Initialize masking algorithms"""
return {
"phone_mask": self._mask_phone,
"email_mask": self._mask_email,
"id_mask": self._mask_id_card,
"name_mask": self._mask_name,
"partial_mask": self._mask_partial,
}
async def process(
self, data: list[dict[str, Any]], auth_context: AuthContext
) -> list[dict[str, Any]]:
"""Process data masking"""
if not data:
return data
# Get applicable masking rules
applicable_rules = self._get_applicable_rules(auth_context)
masked_data = []
for row in data:
masked_row = {}
for column, value in row.items():
masked_value = await self._apply_masking_rules(
column, value, applicable_rules
)
masked_row[column] = masked_value
masked_data.append(masked_row)
return masked_data
def _get_applicable_rules(self, auth_context: AuthContext) -> list[MaskingRule]:
"""Get applicable masking rules"""
applicable_rules = []
for rule in self.masking_rules:
# Decide whether to apply masking rules based on user security level
if self._should_apply_rule(rule, auth_context):
applicable_rules.append(rule)
return applicable_rules
def _should_apply_rule(self, rule: MaskingRule, auth_context: AuthContext) -> bool:
"""Determine whether masking rule should be applied"""
# Admin users can see original data
if "admin" in auth_context.roles:
return False
# Decide based on security level
security_hierarchy = {
SecurityLevel.PUBLIC: 0,
SecurityLevel.INTERNAL: 1,
SecurityLevel.CONFIDENTIAL: 2,
SecurityLevel.SECRET: 3,
}
user_level = security_hierarchy.get(auth_context.security_level, 0)
rule_level = security_hierarchy.get(rule.security_level, 0)
# Apply masking if user level is less than or equal to rule level
return user_level <= rule_level
async def _apply_masking_rules(
self, column: str, value: Any, rules: list[MaskingRule]
) -> Any:
"""Apply masking rules"""
if value is None:
return value
for rule in rules:
if re.match(rule.column_pattern, column, re.IGNORECASE):
algorithm = self.masking_algorithms.get(rule.algorithm)
if algorithm:
return algorithm(str(value), rule.parameters)
return value
def _mask_phone(self, value: str, params: dict[str, Any]) -> str:
"""Phone number masking"""
if len(value) < 7:
return value
mask_char = params.get("mask_char", "*")
keep_prefix = params.get("keep_prefix", 3)
keep_suffix = params.get("keep_suffix", 4)
if len(value) <= keep_prefix + keep_suffix:
return mask_char * len(value)
prefix = value[:keep_prefix]
suffix = value[-keep_suffix:]
middle_length = len(value) - keep_prefix - keep_suffix
return prefix + mask_char * middle_length + suffix
def _mask_email(self, value: str, params: dict[str, Any]) -> str:
"""Email masking"""
if "@" not in value:
return value
mask_char = params.get("mask_char", "*")
local, domain = value.split("@", 1)
if len(local) <= 2:
masked_local = mask_char * len(local)
else:
masked_local = local[0] + mask_char * (len(local) - 2) + local[-1]
return f"{masked_local}@{domain}"
def _mask_id_card(self, value: str, params: dict[str, Any]) -> str:
"""ID card number masking"""
if len(value) < 10:
return value
mask_char = params.get("mask_char", "*")
keep_prefix = params.get("keep_prefix", 6)
keep_suffix = params.get("keep_suffix", 4)
if len(value) <= keep_prefix + keep_suffix:
return mask_char * len(value)
prefix = value[:keep_prefix]
suffix = value[-keep_suffix:]
middle_length = len(value) - keep_prefix - keep_suffix
return prefix + mask_char * middle_length + suffix
def _mask_name(self, value: str, params: dict[str, Any]) -> str:
"""Name masking"""
if len(value) <= 1:
return value
mask_char = params.get("mask_char", "*")
if len(value) == 2:
return value[0] + mask_char
else:
return value[0] + mask_char * (len(value) - 2) + value[-1]
def _mask_partial(self, value: str, params: dict[str, Any]) -> str:
"""Partial masking"""
mask_char = params.get("mask_char", "*")
mask_ratio = params.get("mask_ratio", 0.5)
mask_length = int(len(value) * mask_ratio)
start_pos = (len(value) - mask_length) // 2
result = list(value)
for i in range(start_pos, start_pos + mask_length):
if i < len(result):
result[i] = mask_char
return "".join(result)

View File

@@ -1,352 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
SQL Execution Tool
Responsible for executing SQL queries and handling results
"""
import os
import json
import logging
import traceback
import time
from typing import Dict, Any
import re
import datetime
from decimal import Decimal
# Get logger
logger = logging.getLogger("doris-mcp.sql-executor")
# Add environment variable control for whether to perform SQL security checks
ENABLE_SQL_SECURITY_CHECK = os.environ.get('ENABLE_SQL_SECURITY_CHECK', 'true').lower() == 'true'
async def execute_sql_query(ctx) -> Dict[str, Any]:
"""
Execute SQL query and return results
Args:
ctx: Context object or dictionary containing request parameters
Returns:
Dict[str, Any]: Execution result
"""
try:
# Support the case where the passed argument is a dictionary
if isinstance(ctx, dict) and 'params' in ctx:
params = ctx['params']
else:
params = ctx.params
sql = params.get("sql")
db_name = params.get("db_name", os.getenv("DB_DATABASE", ""))
catalog_name = params.get("catalog_name", None) # Add catalog parameter support
max_rows = params.get("max_rows", 1000) # Maximum number of rows to return
timeout = params.get("timeout", 30) # Timeout in seconds
if not sql:
return {
"content": [
{
"type": "text",
"text": json.dumps({
"success": False,
"error": "Missing SQL parameter",
"message": "Please provide the SQL query to execute"
}, ensure_ascii=False)
}
]
}
# First check SQL security
security_result = await _check_sql_security(sql)
if not security_result.get("is_safe", False):
return {
"content": [
{
"type": "text",
"text": json.dumps({
"success": False,
"error": "SQL security check failed",
"message": "Query contains unsafe operations and cannot be executed",
"security_issues": security_result.get("security_issues", [])
}, ensure_ascii=False)
}
]
}
# Import database connection tool
from doris_mcp_server.utils.db import execute_query
if not sql:
return {
"content": [
{
"type": "text",
"text": json.dumps({
"success": False,
"error": "Missing SQL parameter",
"message": "Please provide the SQL query to execute"
}, ensure_ascii=False)
}
]
}
# Ensure SELECT statements include a LIMIT clause
sql_lower = sql.lower().strip()
if sql_lower.startswith("select") and "limit" not in sql_lower:
sql = sql.rstrip(";") + f" LIMIT {max_rows};"
# Start timer
start_time = time.time()
# Execute query
try:
# For federation queries, SQL must use three-part naming: catalog_name.db_name.table_name
# This is enforced at the tool description level
result = execute_query(sql, db_name)
# Calculate execution time
execution_time = time.time() - start_time
# Build return result
if isinstance(result, list):
# Handle list of query results
row_count = len(result)
# Extract column names
if hasattr(result[0], "_fields"):
# If it's a named tuple
columns = list(result[0]._fields)
else:
# Otherwise, assume it's a dictionary
columns = list(result[0].keys()) if isinstance(result[0], dict) else []
# Convert results to serializable format
data = []
for row in result:
row_dict = {}
if hasattr(row, "_asdict"):
# If it's a named tuple
row_dict = row._asdict()
elif isinstance(row, dict):
# If it's a dictionary
row_dict = row
else:
# If it's a list or tuple
row_dict = dict(zip(columns, row)) if columns else row
# Handle special types to make them JSON serializable
serialized_row = _serialize_row_data(row_dict)
data.append(serialized_row)
return {
"content": [
{
"type": "text",
"text": json.dumps({
"success": True,
"sql": sql,
"row_count": row_count,
"columns": columns,
"data": data[:max_rows], # Limit returned rows
"execution_time": execution_time,
"truncated": row_count > max_rows
}, ensure_ascii=False)
}
]
}
else:
# Handle other types of results
other_response = {
"success": True,
"sql": sql,
"result": str(result),
"execution_time": execution_time
}
other_response = _serialize_row_data(other_response)
return {
"content": [
{
"type": "text",
"text": json.dumps(other_response, ensure_ascii=False)
}
]
}
except Exception as db_error:
error_message = str(db_error)
# Try to get more detailed error information
error_details = {}
if "timeout" in error_message.lower():
error_details["type"] = "timeout"
error_details["suggestion"] = "Query timed out, please optimize SQL or increase timeout"
elif "syntax" in error_message.lower():
error_details["type"] = "syntax"
error_details["suggestion"] = "SQL syntax error, please check syntax"
elif "not found" in error_message.lower() or "doesn't exist" in error_message.lower():
error_details["type"] = "not_found"
error_details["suggestion"] = "Table or column not found, please check table and column names"
else:
error_details["type"] = "unknown"
error_details["suggestion"] = "Please check the SQL statement and try simplifying the query"
# Create error response
error_response = {
"success": False,
"error": error_message,
"error_details": error_details,
"sql": sql,
"db_name": db_name
}
# Ensure error response is also serializable
error_response = _serialize_row_data(error_response)
return {
"content": [
{
"type": "text",
"text": json.dumps(error_response, ensure_ascii=False)
}
]
}
except Exception as e:
logger.error(f"Failed to execute SQL query: {str(e)}")
logger.error(traceback.format_exc())
error_response = {
"success": False,
"error": str(e),
"message": "Error occurred while executing SQL query"
}
# Ensure error response is also serializable
error_response = _serialize_row_data(error_response)
return {
"content": [
{
"type": "text",
"text": json.dumps(error_response, ensure_ascii=False)
}
]
}
# Helper function
async def _check_sql_security(sql: str) -> Dict[str, Any]:
"""Check SQL security"""
# If environment variable is set to disable security check, return safe immediately
if not ENABLE_SQL_SECURITY_CHECK:
return {
"is_safe": True,
"security_issues": []
}
# Check if SQL contains dangerous operations
sql_lower = sql.lower()
# Check if it's a read-only query type
is_read_only = sql_lower.strip().startswith(("select ", "show ", "desc ", "describe ", "explain "))
# Define list of dangerous operations (checked for both read-only and non-read-only queries)
dangerous_operations = [
(r'\bdelete\b', "DELETE operation"),
(r'\bdrop\b', "DROP TABLE/DATABASE operation"),
(r'\btruncate\b', "TRUNCATE TABLE operation"),
(r'\bupdate\b', "UPDATE operation"),
(r'\binsert\b', "INSERT operation"),
(r'\balter\b', "ALTER TABLE structure operation"),
(r'\bcreate\b', "CREATE TABLE/DATABASE operation"),
(r'\bgrant\b', "GRANT operation"),
(r'\brevoke\b', "REVOKE permission operation"),
(r'\bexec\b', "EXECUTE stored procedure"),
(r'\bxp_', "Extended stored procedure, potential security risk"),
(r'\bshutdown\b', "SHUTDOWN database operation"),
(r'\binto\s+outfile\b', "Write to file operation"),
(r'\bload_file\b', "Load file operation")
]
# Dangerous operations checked only for non-read-only queries
non_readonly_operations = []
if not is_read_only:
non_readonly_operations = [
(r'--', "SQL comment, potential SQL injection"),
(r'/\*', "SQL block comment, potential SQL injection")
]
# Check if dangerous operations are included
security_issues = []
# Check dangerous operations applicable to all queries
for operation, description in dangerous_operations:
if re.search(operation, sql_lower):
# For specific keywords in read-only queries, differentiate if used as independent operations
if is_read_only and operation in [r'\bcreate\b', r'\bdrop\b', r'\bdelete\b', r'\binsert\b', r'\bupdate\b', r'\balter\b']:
# Check if used as DDL/DML keyword, e.g., CREATE TABLE, DROP DATABASE
pattern = operation + r'\s+(?:table|database|view|index|procedure|function|trigger|event)'
if re.search(pattern, sql_lower):
security_issues.append({
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
"description": description,
"severity": "High"
})
else:
security_issues.append({
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
"description": description,
"severity": "High"
})
# Check dangerous operations specific to non-read-only queries
for operation, description in non_readonly_operations:
if re.search(operation, sql_lower):
security_issues.append({
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
"description": description,
"severity": "Medium"
})
return {
"is_safe": len(security_issues) == 0,
"security_issues": security_issues
}
def _serialize_row_data(row_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert special types in row data (like date, time, Decimal) to JSON serializable format
Args:
row_data: Row data dictionary
Returns:
Dict[str, Any]: Processed serializable dictionary
"""
serialized_data = {}
for key, value in row_data.items():
if value is None:
serialized_data[key] = None
elif isinstance(value, (datetime.date, datetime.datetime)):
# Convert date and time types to ISO format string
serialized_data[key] = value.isoformat()
elif isinstance(value, Decimal):
# Convert Decimal type to float
serialized_data[key] = float(value)
elif isinstance(value, (list, tuple)):
# Recursively process elements in list or tuple
serialized_data[key] = [
_serialize_row_data(item) if isinstance(item, dict) else item
for item in value
]
elif isinstance(value, dict):
# Recursively process nested dictionaries
serialized_data[key] = _serialize_row_data(value)
else:
serialized_data[key] = value
return serialized_data