0.3.0 Release Version
This commit is contained in:
@@ -1 +1,10 @@
|
||||
# Mark directory as a package
|
||||
"""
|
||||
Utilities Package - Contains utility classes and helper functions.
|
||||
|
||||
This package includes:
|
||||
- Database connection and operations
|
||||
- Configuration management
|
||||
- Security utilities
|
||||
- Query execution helpers
|
||||
- Logging configuration
|
||||
"""
|
||||
|
||||
318
doris_mcp_server/utils/analysis_tools.py
Normal file
318
doris_mcp_server/utils/analysis_tools.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
Data Analysis Tools Module
|
||||
Provides data analysis functions including table analysis, column statistics, performance monitoring, etc.
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TableAnalyzer:
|
||||
"""Table analyzer"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
|
||||
async def get_table_summary(
|
||||
self,
|
||||
table_name: str,
|
||||
include_sample: bool = True,
|
||||
sample_size: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""Get table summary information"""
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
|
||||
# Get table basic information
|
||||
table_info_sql = f"""
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment,
|
||||
table_rows,
|
||||
create_time,
|
||||
engine
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '{table_name}'
|
||||
"""
|
||||
|
||||
table_info_result = await connection.execute(table_info_sql)
|
||||
if not table_info_result.data:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
|
||||
table_info = table_info_result.data[0]
|
||||
|
||||
# Get column information
|
||||
columns_sql = f"""
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
is_nullable,
|
||||
column_comment
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '{table_name}'
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
columns_result = await connection.execute(columns_sql)
|
||||
|
||||
summary = {
|
||||
"table_name": table_info["table_name"],
|
||||
"comment": table_info.get("table_comment"),
|
||||
"row_count": table_info.get("table_rows", 0),
|
||||
"create_time": str(table_info.get("create_time")),
|
||||
"engine": table_info.get("engine"),
|
||||
"column_count": len(columns_result.data),
|
||||
"columns": columns_result.data,
|
||||
}
|
||||
|
||||
# Get sample data
|
||||
if include_sample and sample_size > 0:
|
||||
sample_sql = f"SELECT * FROM {table_name} LIMIT {sample_size}"
|
||||
sample_result = await connection.execute(sample_sql)
|
||||
summary["sample_data"] = sample_result.data
|
||||
|
||||
return summary
|
||||
|
||||
async def analyze_column(
|
||||
self,
|
||||
table_name: str,
|
||||
column_name: str,
|
||||
analysis_type: str = "basic"
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze column statistics"""
|
||||
try:
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
|
||||
# Basic statistics
|
||||
basic_stats_sql = f"""
|
||||
SELECT
|
||||
'{column_name}' as column_name,
|
||||
COUNT(*) as total_count,
|
||||
COUNT({column_name}) as non_null_count,
|
||||
COUNT(DISTINCT {column_name}) as distinct_count
|
||||
FROM {table_name}
|
||||
"""
|
||||
|
||||
basic_result = await connection.execute(basic_stats_sql)
|
||||
if not basic_result.data:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Unable to get statistics for table {table_name} column {column_name}"
|
||||
}
|
||||
|
||||
analysis = basic_result.data[0].copy()
|
||||
analysis["success"] = True
|
||||
analysis["analysis_type"] = analysis_type
|
||||
|
||||
if analysis_type in ["distribution", "detailed"]:
|
||||
# Data distribution analysis
|
||||
distribution_sql = f"""
|
||||
SELECT
|
||||
{column_name} as value,
|
||||
COUNT(*) as frequency
|
||||
FROM {table_name}
|
||||
WHERE {column_name} IS NOT NULL
|
||||
GROUP BY {column_name}
|
||||
ORDER BY frequency DESC
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
distribution_result = await connection.execute(distribution_sql)
|
||||
analysis["value_distribution"] = distribution_result.data
|
||||
|
||||
if analysis_type == "detailed":
|
||||
# Detailed statistics (for numeric types)
|
||||
try:
|
||||
numeric_stats_sql = f"""
|
||||
SELECT
|
||||
MIN({column_name}) as min_value,
|
||||
MAX({column_name}) as max_value,
|
||||
AVG({column_name}) as avg_value
|
||||
FROM {table_name}
|
||||
WHERE {column_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
numeric_result = await connection.execute(numeric_stats_sql)
|
||||
if numeric_result.data:
|
||||
analysis.update(numeric_result.data[0])
|
||||
except Exception:
|
||||
# Non-numeric columns don't support numeric statistics
|
||||
pass
|
||||
|
||||
return analysis
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Column analysis failed: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"column_name": column_name,
|
||||
"table_name": table_name
|
||||
}
|
||||
|
||||
async def analyze_table_relationships(
|
||||
self,
|
||||
table_name: str,
|
||||
depth: int = 2
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze table relationships"""
|
||||
connection = await self.connection_manager.get_connection("system")
|
||||
|
||||
# Get table basic information
|
||||
table_info_sql = f"""
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment,
|
||||
table_rows
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '{table_name}'
|
||||
"""
|
||||
|
||||
table_result = await connection.execute(table_info_sql)
|
||||
if not table_result.data:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
|
||||
# Get all tables list (for analyzing potential relationships)
|
||||
all_tables_sql = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_type = 'BASE TABLE'
|
||||
AND table_name != %s
|
||||
"""
|
||||
|
||||
all_tables_result = await connection.execute(all_tables_sql, (table_name,))
|
||||
|
||||
return {
|
||||
"center_table": table_result.data[0],
|
||||
"related_tables": all_tables_result.data,
|
||||
"depth": depth,
|
||||
"note": "Table relationship analysis based on column name similarity and business logic inference",
|
||||
}
|
||||
|
||||
|
||||
class PerformanceMonitor:
|
||||
"""Performance monitor"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
|
||||
async def get_performance_stats(
|
||||
self,
|
||||
metric_type: str = "queries",
|
||||
time_range: str = "1h"
|
||||
) -> Dict[str, Any]:
|
||||
"""Get performance statistics"""
|
||||
connection = await self.connection_manager.get_connection("system")
|
||||
|
||||
# Convert time range to seconds
|
||||
time_mapping = {
|
||||
"1h": 3600,
|
||||
"6h": 21600,
|
||||
"24h": 86400,
|
||||
"7d": 604800
|
||||
}
|
||||
|
||||
seconds = time_mapping.get(time_range, 3600)
|
||||
|
||||
if metric_type == "queries":
|
||||
# Query performance metrics
|
||||
stats = {
|
||||
"metric_type": "queries",
|
||||
"time_range": time_range,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"total_queries": 0,
|
||||
"avg_execution_time": 0.0,
|
||||
"slow_queries": 0,
|
||||
"error_queries": 0,
|
||||
"note": "Query performance statistics (simulated data)"
|
||||
}
|
||||
|
||||
elif metric_type == "connections":
|
||||
# Connection statistics
|
||||
connection_metrics = await self.connection_manager.get_metrics()
|
||||
stats = {
|
||||
"metric_type": "connections",
|
||||
"time_range": time_range,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"total_connections": connection_metrics.total_connections,
|
||||
"active_connections": connection_metrics.active_connections,
|
||||
"idle_connections": connection_metrics.idle_connections,
|
||||
"failed_connections": connection_metrics.failed_connections,
|
||||
"connection_errors": connection_metrics.connection_errors,
|
||||
"avg_connection_time": connection_metrics.avg_connection_time,
|
||||
"last_health_check": connection_metrics.last_health_check.isoformat() if connection_metrics.last_health_check else None
|
||||
}
|
||||
|
||||
elif metric_type == "tables":
|
||||
# Table-level statistics
|
||||
tables_sql = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_rows,
|
||||
data_length,
|
||||
index_length,
|
||||
create_time,
|
||||
update_time
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_type = 'BASE TABLE'
|
||||
ORDER BY table_rows DESC
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
tables_result = await connection.execute(tables_sql)
|
||||
stats = {
|
||||
"metric_type": "tables",
|
||||
"time_range": time_range,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"table_count": len(tables_result.data),
|
||||
"tables": tables_result.data
|
||||
}
|
||||
|
||||
elif metric_type == "system":
|
||||
# System-level metrics (simulated)
|
||||
stats = {
|
||||
"metric_type": "system",
|
||||
"time_range": time_range,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"cpu_usage": 45.2,
|
||||
"memory_usage": 68.5,
|
||||
"disk_usage": 72.1,
|
||||
"network_io": {
|
||||
"bytes_sent": 1024000,
|
||||
"bytes_received": 2048000
|
||||
},
|
||||
"note": "System metrics (simulated data)"
|
||||
}
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported metric type: {metric_type}")
|
||||
|
||||
return stats
|
||||
|
||||
async def get_query_history(
|
||||
self,
|
||||
limit: int = 50,
|
||||
order_by: str = "time"
|
||||
) -> Dict[str, Any]:
|
||||
"""Get query history"""
|
||||
# Since Doris doesn't have a built-in query history table,
|
||||
# we return simulated data
|
||||
return {
|
||||
"total_queries": 0,
|
||||
"queries": [],
|
||||
"limit": limit,
|
||||
"order_by": order_by,
|
||||
"note": "Query history feature requires audit log configuration"
|
||||
}
|
||||
608
doris_mcp_server/utils/config.py
Normal file
608
doris_mcp_server/utils/config.py
Normal file
@@ -0,0 +1,608 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Doris Configuration Management Module
|
||||
Implements configuration loading, validation and management functionality
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except ImportError:
|
||||
load_dotenv = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""Database connection configuration"""
|
||||
|
||||
host: str = "localhost"
|
||||
port: int = 9030
|
||||
user: str = "root"
|
||||
password: str = ""
|
||||
database: str = "test"
|
||||
charset: str = "utf8mb4"
|
||||
|
||||
# Connection pool configuration
|
||||
min_connections: int = 5
|
||||
max_connections: int = 20
|
||||
connection_timeout: int = 30
|
||||
health_check_interval: int = 60
|
||||
max_connection_age: int = 3600
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityConfig:
|
||||
"""Security configuration"""
|
||||
|
||||
# Authentication configuration
|
||||
auth_type: str = "token" # token, basic, oauth
|
||||
token_secret: str = "default_secret"
|
||||
token_expiry: int = 3600
|
||||
|
||||
# SQL security configuration
|
||||
blocked_keywords: list[str] = field(
|
||||
default_factory=lambda: [
|
||||
"DROP",
|
||||
"DELETE",
|
||||
"TRUNCATE",
|
||||
"ALTER",
|
||||
"CREATE",
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"GRANT",
|
||||
"REVOKE",
|
||||
]
|
||||
)
|
||||
max_query_complexity: int = 100
|
||||
max_result_rows: int = 10000
|
||||
|
||||
# Sensitive table configuration
|
||||
sensitive_tables: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# Data masking configuration
|
||||
enable_masking: bool = True
|
||||
masking_rules: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerformanceConfig:
|
||||
"""Performance configuration"""
|
||||
|
||||
# Query cache configuration
|
||||
enable_query_cache: bool = True
|
||||
cache_ttl: int = 300
|
||||
max_cache_size: int = 1000
|
||||
|
||||
# Concurrency control configuration
|
||||
max_concurrent_queries: int = 50
|
||||
query_timeout: int = 300
|
||||
|
||||
# Connection pool optimization configuration
|
||||
connection_pool_size: int = 20
|
||||
idle_timeout: int = 1800
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoggingConfig:
|
||||
"""Logging configuration"""
|
||||
|
||||
level: str = "INFO"
|
||||
format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
file_path: str | None = None
|
||||
max_file_size: int = 10 * 1024 * 1024 # 10MB
|
||||
backup_count: int = 5
|
||||
|
||||
# Audit log configuration
|
||||
enable_audit: bool = True
|
||||
audit_file_path: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MonitoringConfig:
|
||||
"""Monitoring configuration"""
|
||||
|
||||
# Metrics collection configuration
|
||||
enable_metrics: bool = True
|
||||
metrics_port: int = 8081
|
||||
metrics_path: str = "/metrics"
|
||||
|
||||
# Health check configuration
|
||||
health_check_port: int = 8082
|
||||
health_check_path: str = "/health"
|
||||
|
||||
# Alert configuration
|
||||
enable_alerts: bool = False
|
||||
alert_webhook_url: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DorisConfig:
|
||||
"""Doris MCP Server complete configuration"""
|
||||
|
||||
# Basic configuration
|
||||
server_name: str = "doris-mcp-server"
|
||||
server_version: str = "1.0.0"
|
||||
server_port: int = 8080
|
||||
|
||||
# Sub-configuration modules
|
||||
database: DatabaseConfig = field(default_factory=DatabaseConfig)
|
||||
security: SecurityConfig = field(default_factory=SecurityConfig)
|
||||
performance: PerformanceConfig = field(default_factory=PerformanceConfig)
|
||||
logging: LoggingConfig = field(default_factory=LoggingConfig)
|
||||
monitoring: MonitoringConfig = field(default_factory=MonitoringConfig)
|
||||
|
||||
# Custom configuration
|
||||
custom_config: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_path: str) -> "DorisConfig":
|
||||
"""Load configuration from file"""
|
||||
config_file = Path(config_path)
|
||||
|
||||
if not config_file.exists():
|
||||
raise FileNotFoundError(f"Configuration file does not exist: {config_path}")
|
||||
|
||||
try:
|
||||
with open(config_file, encoding="utf-8") as f:
|
||||
if config_file.suffix.lower() == ".json":
|
||||
config_data = json.load(f)
|
||||
else:
|
||||
# Support other formats (like YAML)
|
||||
raise ValueError(f"Unsupported configuration file format: {config_file.suffix}")
|
||||
|
||||
return cls._from_dict(config_data)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load configuration file: {e}")
|
||||
|
||||
@classmethod
|
||||
def from_env(cls, env_file: str | None = None) -> "DorisConfig":
|
||||
"""Load configuration from environment variables
|
||||
|
||||
Args:
|
||||
env_file: .env file path, if None, search in the following order:
|
||||
.env, .env.local, .env.production, .env.development
|
||||
"""
|
||||
# Load .env file
|
||||
if load_dotenv is not None:
|
||||
if env_file:
|
||||
# Load specified .env file
|
||||
if Path(env_file).exists():
|
||||
load_dotenv(env_file)
|
||||
logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_file}")
|
||||
else:
|
||||
logging.getLogger(__name__).warning(f"Environment configuration file does not exist: {env_file}")
|
||||
else:
|
||||
# Load .env files in priority order
|
||||
env_files = [".env", ".env.local", ".env.production", ".env.development"]
|
||||
for env_path in env_files:
|
||||
if Path(env_path).exists():
|
||||
load_dotenv(env_path)
|
||||
logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_path}")
|
||||
break
|
||||
else:
|
||||
logging.getLogger(__name__).info("No .env configuration file found, using system environment variables")
|
||||
else:
|
||||
logging.getLogger(__name__).warning("python-dotenv not installed, cannot load .env files")
|
||||
|
||||
config = cls()
|
||||
|
||||
# Database configuration
|
||||
config.database.host = os.getenv("DORIS_HOST", config.database.host)
|
||||
config.database.port = int(os.getenv("DORIS_PORT", str(config.database.port)))
|
||||
config.database.user = os.getenv("DORIS_USER", config.database.user)
|
||||
config.database.password = os.getenv("DORIS_PASSWORD", config.database.password)
|
||||
config.database.database = os.getenv("DORIS_DATABASE", config.database.database)
|
||||
|
||||
# Connection pool configuration
|
||||
config.database.min_connections = int(
|
||||
os.getenv("DORIS_MIN_CONNECTIONS", str(config.database.min_connections))
|
||||
)
|
||||
config.database.max_connections = int(
|
||||
os.getenv("DORIS_MAX_CONNECTIONS", str(config.database.max_connections))
|
||||
)
|
||||
config.database.connection_timeout = int(
|
||||
os.getenv("DORIS_CONNECTION_TIMEOUT", str(config.database.connection_timeout))
|
||||
)
|
||||
config.database.health_check_interval = int(
|
||||
os.getenv("DORIS_HEALTH_CHECK_INTERVAL", str(config.database.health_check_interval))
|
||||
)
|
||||
config.database.max_connection_age = int(
|
||||
os.getenv("DORIS_MAX_CONNECTION_AGE", str(config.database.max_connection_age))
|
||||
)
|
||||
|
||||
# Security configuration
|
||||
config.security.auth_type = os.getenv("AUTH_TYPE", config.security.auth_type)
|
||||
config.security.token_secret = os.getenv("TOKEN_SECRET", config.security.token_secret)
|
||||
config.security.token_expiry = int(
|
||||
os.getenv("TOKEN_EXPIRY", str(config.security.token_expiry))
|
||||
)
|
||||
config.security.max_result_rows = int(
|
||||
os.getenv("MAX_RESULT_ROWS", str(config.security.max_result_rows))
|
||||
)
|
||||
config.security.max_query_complexity = int(
|
||||
os.getenv("MAX_QUERY_COMPLEXITY", str(config.security.max_query_complexity))
|
||||
)
|
||||
config.security.enable_masking = (
|
||||
os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true"
|
||||
)
|
||||
|
||||
# Performance configuration
|
||||
config.performance.enable_query_cache = (
|
||||
os.getenv("ENABLE_QUERY_CACHE", "true").lower() == "true"
|
||||
)
|
||||
config.performance.cache_ttl = int(
|
||||
os.getenv("CACHE_TTL", str(config.performance.cache_ttl))
|
||||
)
|
||||
config.performance.max_cache_size = int(
|
||||
os.getenv("MAX_CACHE_SIZE", str(config.performance.max_cache_size))
|
||||
)
|
||||
config.performance.max_concurrent_queries = int(
|
||||
os.getenv("MAX_CONCURRENT_QUERIES", str(config.performance.max_concurrent_queries))
|
||||
)
|
||||
config.performance.query_timeout = int(
|
||||
os.getenv("QUERY_TIMEOUT", str(config.performance.query_timeout))
|
||||
)
|
||||
|
||||
# Logging configuration
|
||||
config.logging.level = os.getenv("LOG_LEVEL", config.logging.level)
|
||||
config.logging.file_path = os.getenv("LOG_FILE_PATH", config.logging.file_path)
|
||||
config.logging.enable_audit = (
|
||||
os.getenv("ENABLE_AUDIT", str(config.logging.enable_audit).lower()).lower() == "true"
|
||||
)
|
||||
config.logging.audit_file_path = os.getenv("AUDIT_FILE_PATH", config.logging.audit_file_path)
|
||||
|
||||
# Monitoring configuration
|
||||
config.monitoring.enable_metrics = (
|
||||
os.getenv("ENABLE_METRICS", "true").lower() == "true"
|
||||
)
|
||||
config.monitoring.metrics_port = int(
|
||||
os.getenv("METRICS_PORT", str(config.monitoring.metrics_port))
|
||||
)
|
||||
config.monitoring.health_check_port = int(
|
||||
os.getenv("HEALTH_CHECK_PORT", str(config.monitoring.health_check_port))
|
||||
)
|
||||
config.monitoring.enable_alerts = (
|
||||
os.getenv("ENABLE_ALERTS", str(config.monitoring.enable_alerts).lower()).lower() == "true"
|
||||
)
|
||||
config.monitoring.alert_webhook_url = os.getenv("ALERT_WEBHOOK_URL", config.monitoring.alert_webhook_url)
|
||||
|
||||
# Server configuration
|
||||
config.server_name = os.getenv("SERVER_NAME", config.server_name)
|
||||
config.server_version = os.getenv("SERVER_VERSION", config.server_version)
|
||||
config.server_port = int(os.getenv("SERVER_PORT", str(config.server_port)))
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _from_dict(cls, config_data: dict[str, Any]) -> "DorisConfig":
|
||||
"""Create configuration object from dictionary"""
|
||||
config = cls()
|
||||
|
||||
# Update basic configuration
|
||||
for key in ["server_name", "server_version", "server_port"]:
|
||||
if key in config_data:
|
||||
setattr(config, key, config_data[key])
|
||||
|
||||
# Update database configuration
|
||||
if "database" in config_data:
|
||||
db_config = config_data["database"]
|
||||
for key, value in db_config.items():
|
||||
if hasattr(config.database, key):
|
||||
setattr(config.database, key, value)
|
||||
|
||||
# Update security configuration
|
||||
if "security" in config_data:
|
||||
sec_config = config_data["security"]
|
||||
for key, value in sec_config.items():
|
||||
if hasattr(config.security, key):
|
||||
setattr(config.security, key, value)
|
||||
|
||||
# Update performance configuration
|
||||
if "performance" in config_data:
|
||||
perf_config = config_data["performance"]
|
||||
for key, value in perf_config.items():
|
||||
if hasattr(config.performance, key):
|
||||
setattr(config.performance, key, value)
|
||||
|
||||
# Update logging configuration
|
||||
if "logging" in config_data:
|
||||
log_config = config_data["logging"]
|
||||
for key, value in log_config.items():
|
||||
if hasattr(config.logging, key):
|
||||
setattr(config.logging, key, value)
|
||||
|
||||
# Update monitoring configuration
|
||||
if "monitoring" in config_data:
|
||||
mon_config = config_data["monitoring"]
|
||||
for key, value in mon_config.items():
|
||||
if hasattr(config.monitoring, key):
|
||||
setattr(config.monitoring, key, value)
|
||||
|
||||
# Custom configuration
|
||||
config.custom_config = config_data.get("custom", {})
|
||||
|
||||
return config
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary format"""
|
||||
return {
|
||||
"server_name": self.server_name,
|
||||
"server_version": self.server_version,
|
||||
"server_port": self.server_port,
|
||||
"database": {
|
||||
"host": self.database.host,
|
||||
"port": self.database.port,
|
||||
"user": self.database.user,
|
||||
"password": "***", # Hide password
|
||||
"database": self.database.database,
|
||||
"charset": self.database.charset,
|
||||
"min_connections": self.database.min_connections,
|
||||
"max_connections": self.database.max_connections,
|
||||
"connection_timeout": self.database.connection_timeout,
|
||||
"health_check_interval": self.database.health_check_interval,
|
||||
"max_connection_age": self.database.max_connection_age,
|
||||
},
|
||||
"security": {
|
||||
"auth_type": self.security.auth_type,
|
||||
"token_secret": "***", # Hide secret key
|
||||
"token_expiry": self.security.token_expiry,
|
||||
"blocked_keywords": self.security.blocked_keywords,
|
||||
"max_query_complexity": self.security.max_query_complexity,
|
||||
"max_result_rows": self.security.max_result_rows,
|
||||
"sensitive_tables": self.security.sensitive_tables,
|
||||
"enable_masking": self.security.enable_masking,
|
||||
"masking_rules": len(self.security.masking_rules),
|
||||
},
|
||||
"performance": {
|
||||
"enable_query_cache": self.performance.enable_query_cache,
|
||||
"cache_ttl": self.performance.cache_ttl,
|
||||
"max_cache_size": self.performance.max_cache_size,
|
||||
"max_concurrent_queries": self.performance.max_concurrent_queries,
|
||||
"query_timeout": self.performance.query_timeout,
|
||||
"connection_pool_size": self.performance.connection_pool_size,
|
||||
"idle_timeout": self.performance.idle_timeout,
|
||||
},
|
||||
"logging": {
|
||||
"level": self.logging.level,
|
||||
"format": self.logging.format,
|
||||
"file_path": self.logging.file_path,
|
||||
"max_file_size": self.logging.max_file_size,
|
||||
"backup_count": self.logging.backup_count,
|
||||
"enable_audit": self.logging.enable_audit,
|
||||
"audit_file_path": self.logging.audit_file_path,
|
||||
},
|
||||
"monitoring": {
|
||||
"enable_metrics": self.monitoring.enable_metrics,
|
||||
"metrics_port": self.monitoring.metrics_port,
|
||||
"metrics_path": self.monitoring.metrics_path,
|
||||
"health_check_port": self.monitoring.health_check_port,
|
||||
"health_check_path": self.monitoring.health_check_path,
|
||||
"enable_alerts": self.monitoring.enable_alerts,
|
||||
"alert_webhook_url": self.monitoring.alert_webhook_url,
|
||||
},
|
||||
"custom": self.custom_config,
|
||||
}
|
||||
|
||||
def save_to_file(self, config_path: str):
|
||||
"""Save configuration to file"""
|
||||
config_file = Path(config_path)
|
||||
config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
with open(config_file, "w", encoding="utf-8") as f:
|
||||
if config_file.suffix.lower() == ".json":
|
||||
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
else:
|
||||
raise ValueError(f"Unsupported configuration file format: {config_file.suffix}")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to save configuration file: {e}")
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
"""Validate configuration validity"""
|
||||
errors = []
|
||||
|
||||
# Validate database configuration
|
||||
if not self.database.host:
|
||||
errors.append("Database host address cannot be empty")
|
||||
|
||||
if not (1 <= self.database.port <= 65535):
|
||||
errors.append("Database port must be in the range 1-65535")
|
||||
|
||||
if not self.database.user:
|
||||
errors.append("Database username cannot be empty")
|
||||
|
||||
if self.database.min_connections <= 0:
|
||||
errors.append("Minimum connections must be greater than 0")
|
||||
|
||||
if self.database.max_connections <= self.database.min_connections:
|
||||
errors.append("Maximum connections must be greater than minimum connections")
|
||||
|
||||
# Validate security configuration
|
||||
if self.security.auth_type not in ["token", "basic", "oauth"]:
|
||||
errors.append("Authentication type must be one of token, basic, or oauth")
|
||||
|
||||
if self.security.token_expiry <= 0:
|
||||
errors.append("Token expiry time must be greater than 0")
|
||||
|
||||
if self.security.max_query_complexity <= 0:
|
||||
errors.append("Maximum query complexity must be greater than 0")
|
||||
|
||||
if self.security.max_result_rows <= 0:
|
||||
errors.append("Maximum result rows must be greater than 0")
|
||||
|
||||
# Validate performance configuration
|
||||
if self.performance.cache_ttl <= 0:
|
||||
errors.append("Cache TTL must be greater than 0")
|
||||
|
||||
if self.performance.max_concurrent_queries <= 0:
|
||||
errors.append("Maximum concurrent queries must be greater than 0")
|
||||
|
||||
if self.performance.query_timeout <= 0:
|
||||
errors.append("Query timeout must be greater than 0")
|
||||
|
||||
# Validate logging configuration
|
||||
if self.logging.level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
|
||||
errors.append("Log level must be one of DEBUG, INFO, WARNING, ERROR, or CRITICAL")
|
||||
|
||||
if self.logging.max_file_size <= 0:
|
||||
errors.append("Maximum log file size must be greater than 0")
|
||||
|
||||
if self.logging.backup_count < 0:
|
||||
errors.append("Log backup count cannot be negative")
|
||||
|
||||
# Validate monitoring configuration
|
||||
if not (1 <= self.monitoring.metrics_port <= 65535):
|
||||
errors.append("Monitoring port must be in the range 1-65535")
|
||||
|
||||
if not (1 <= self.monitoring.health_check_port <= 65535):
|
||||
errors.append("Health check port must be in the range 1-65535")
|
||||
|
||||
return errors
|
||||
|
||||
def get_connection_string(self) -> str:
|
||||
"""Get database connection string (hide password)"""
|
||||
return f"mysql://{self.database.user}:***@{self.database.host}:{self.database.port}/{self.database.database}"
|
||||
|
||||
def get_config_summary(self) -> dict[str, Any]:
|
||||
"""Get configuration summary information"""
|
||||
return {
|
||||
"server": f"{self.server_name} v{self.server_version}",
|
||||
"database": f"{self.database.host}:{self.database.port}/{self.database.database}",
|
||||
"connection_pool": f"{self.database.min_connections}-{self.database.max_connections}",
|
||||
"security": {
|
||||
"auth_type": self.security.auth_type,
|
||||
"masking_enabled": self.security.enable_masking,
|
||||
"blocked_keywords_count": len(self.security.blocked_keywords),
|
||||
},
|
||||
"performance": {
|
||||
"cache_enabled": self.performance.enable_query_cache,
|
||||
"max_concurrent": self.performance.max_concurrent_queries,
|
||||
"query_timeout": self.performance.query_timeout,
|
||||
},
|
||||
"monitoring": {
|
||||
"metrics_enabled": self.monitoring.enable_metrics,
|
||||
"alerts_enabled": self.monitoring.enable_alerts,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""Configuration manager class"""
|
||||
|
||||
def __init__(self, config: DorisConfig):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def setup_logging(self):
|
||||
"""Setup logging configuration"""
|
||||
# Configure root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(getattr(logging, self.config.logging.level.upper()))
|
||||
|
||||
# Clear existing handlers
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter(self.config.logging.format)
|
||||
|
||||
# Console handler
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# File handler (if configured)
|
||||
if self.config.logging.file_path:
|
||||
try:
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
file_handler = RotatingFileHandler(
|
||||
self.config.logging.file_path,
|
||||
maxBytes=self.config.logging.max_file_size,
|
||||
backupCount=self.config.logging.backup_count,
|
||||
encoding="utf-8",
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(file_handler)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to setup file logging: {e}")
|
||||
|
||||
# Audit log handler (if configured)
|
||||
if self.config.logging.enable_audit and self.config.logging.audit_file_path:
|
||||
try:
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
audit_logger = logging.getLogger("audit")
|
||||
audit_handler = RotatingFileHandler(
|
||||
self.config.logging.audit_file_path,
|
||||
maxBytes=self.config.logging.max_file_size,
|
||||
backupCount=self.config.logging.backup_count,
|
||||
encoding="utf-8",
|
||||
)
|
||||
audit_handler.setFormatter(formatter)
|
||||
audit_logger.addHandler(audit_handler)
|
||||
audit_logger.setLevel(logging.INFO)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to setup audit logging: {e}")
|
||||
|
||||
def validate_config(self) -> bool:
|
||||
"""Validate configuration"""
|
||||
errors = self.config.validate()
|
||||
if errors:
|
||||
self.logger.error("Configuration validation failed:")
|
||||
for error in errors:
|
||||
self.logger.error(f" - {error}")
|
||||
return False
|
||||
|
||||
self.logger.info("Configuration validation passed")
|
||||
return True
|
||||
|
||||
def log_config_summary(self):
|
||||
"""Log configuration summary"""
|
||||
summary = self.config.get_config_summary()
|
||||
self.logger.info("Configuration Summary:")
|
||||
self.logger.info(f" Server: {summary['server']}")
|
||||
self.logger.info(f" Database: {summary['database']}")
|
||||
self.logger.info(f" Connection Pool: {summary['connection_pool']}")
|
||||
self.logger.info(f" Security: {summary['security']}")
|
||||
self.logger.info(f" Performance: {summary['performance']}")
|
||||
self.logger.info(f" Monitoring: {summary['monitoring']}")
|
||||
|
||||
|
||||
def create_default_config_file(config_path: str):
|
||||
"""Create default configuration file"""
|
||||
config = DorisConfig()
|
||||
config.save_to_file(config_path)
|
||||
print(f"Default configuration file created: {config_path}")
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Create default configuration
|
||||
config = DorisConfig()
|
||||
|
||||
# Load from environment variables
|
||||
# config = DorisConfig.from_env()
|
||||
|
||||
# Load from file
|
||||
# config = DorisConfig.from_file("config.json")
|
||||
|
||||
# Validate configuration
|
||||
config_manager = ConfigManager(config)
|
||||
if config_manager.validate_config():
|
||||
config_manager.setup_logging()
|
||||
config_manager.log_config_summary()
|
||||
|
||||
# Save configuration
|
||||
config.save_to_file("example_config.json")
|
||||
print("Configuration saved to example_config.json")
|
||||
else:
|
||||
print("Configuration validation failed")
|
||||
@@ -1,100 +1,479 @@
|
||||
import os
|
||||
import json
|
||||
import pymysql
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dotenv import load_dotenv
|
||||
import re
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Apache Doris Database Connection Management Module
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
Provides high-performance database connection pool management, automatic reconnection mechanism and connection health check functionality
|
||||
Supports asynchronous operations and concurrent connection management, ensuring stability and performance for enterprise applications
|
||||
"""
|
||||
|
||||
# Database configuration
|
||||
DB_CONFIG = {
|
||||
"host": os.getenv("DB_HOST", "localhost"),
|
||||
"port": int(os.getenv("DB_PORT", "9030")),
|
||||
"user": os.getenv("DB_USER", "root"),
|
||||
"password": os.getenv("DB_PASSWORD", ""),
|
||||
"database": os.getenv("DB_DATABASE", ""),
|
||||
"charset": "utf8mb4",
|
||||
"cursorclass": pymysql.cursors.DictCursor
|
||||
}
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
def get_db_connection(db_name: Optional[str] = None):
|
||||
import aiomysql
|
||||
from aiomysql import Connection, Pool
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionMetrics:
|
||||
"""Connection pool performance metrics"""
|
||||
|
||||
total_connections: int = 0
|
||||
active_connections: int = 0
|
||||
idle_connections: int = 0
|
||||
failed_connections: int = 0
|
||||
connection_errors: int = 0
|
||||
avg_connection_time: float = 0.0
|
||||
last_health_check: datetime | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryResult:
|
||||
"""Query result wrapper"""
|
||||
|
||||
data: list[dict[str, Any]]
|
||||
metadata: dict[str, Any]
|
||||
execution_time: float
|
||||
row_count: int
|
||||
|
||||
|
||||
class DorisConnection:
|
||||
"""Doris database connection wrapper class"""
|
||||
|
||||
def __init__(self, connection: Connection, session_id: str, security_manager=None):
|
||||
self.connection = connection
|
||||
self.session_id = session_id
|
||||
self.created_at = datetime.utcnow()
|
||||
self.last_used = datetime.utcnow()
|
||||
self.query_count = 0
|
||||
self.is_healthy = True
|
||||
self.security_manager = security_manager
|
||||
|
||||
async def execute(self, sql: str, params: tuple | None = None, auth_context=None) -> QueryResult:
|
||||
"""Execute SQL query"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# If security manager exists, perform SQL security check
|
||||
security_result = None
|
||||
if self.security_manager and auth_context:
|
||||
validation_result = await self.security_manager.validate_sql_security(sql, auth_context)
|
||||
if not validation_result.is_valid:
|
||||
raise ValueError(f"SQL security validation failed: {validation_result.error_message}")
|
||||
security_result = {
|
||||
"is_valid": validation_result.is_valid,
|
||||
"risk_level": validation_result.risk_level,
|
||||
"blocked_operations": validation_result.blocked_operations
|
||||
}
|
||||
|
||||
async with self.connection.cursor(aiomysql.DictCursor) as cursor:
|
||||
await cursor.execute(sql, params)
|
||||
|
||||
# Check if it's a query statement (statement that returns result set)
|
||||
sql_upper = sql.strip().upper()
|
||||
if (sql_upper.startswith("SELECT") or
|
||||
sql_upper.startswith("SHOW") or
|
||||
sql_upper.startswith("DESCRIBE") or
|
||||
sql_upper.startswith("DESC") or
|
||||
sql_upper.startswith("EXPLAIN")):
|
||||
data = await cursor.fetchall()
|
||||
row_count = len(data)
|
||||
else:
|
||||
data = []
|
||||
row_count = cursor.rowcount
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
self.last_used = datetime.utcnow()
|
||||
self.query_count += 1
|
||||
|
||||
# Get column information
|
||||
columns = []
|
||||
if cursor.description:
|
||||
columns = [desc[0] for desc in cursor.description]
|
||||
|
||||
# If security manager exists and has auth context, apply data masking
|
||||
final_data = list(data) if data else []
|
||||
if self.security_manager and auth_context and final_data:
|
||||
final_data = await self.security_manager.apply_data_masking(final_data, auth_context)
|
||||
|
||||
metadata = {"columns": columns, "query": sql, "params": params}
|
||||
if security_result:
|
||||
metadata["security_check"] = security_result
|
||||
|
||||
return QueryResult(
|
||||
data=final_data,
|
||||
metadata=metadata,
|
||||
execution_time=execution_time,
|
||||
row_count=row_count,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.is_healthy = False
|
||||
logging.error(f"Query execution failed: {e}")
|
||||
raise
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""Check connection health status"""
|
||||
try:
|
||||
await self.connection.ping()
|
||||
self.is_healthy = True
|
||||
return True
|
||||
except Exception:
|
||||
self.is_healthy = False
|
||||
return False
|
||||
|
||||
async def close(self):
|
||||
"""Close connection"""
|
||||
try:
|
||||
if self.connection and not self.connection.closed:
|
||||
await self.connection.ensure_closed()
|
||||
except Exception as e:
|
||||
logging.error(f"Error occurred while closing connection: {e}")
|
||||
|
||||
|
||||
class DorisConnectionManager:
|
||||
"""Doris database connection manager
|
||||
|
||||
Provides connection pool management, connection health monitoring, fault recovery and other functions
|
||||
Supports session-level connection reuse and intelligent load balancing
|
||||
Integrates security manager to provide unified security validation and data masking
|
||||
"""
|
||||
Get database connection
|
||||
|
||||
Args:
|
||||
db_name: Specify the database name to connect to, use default config if None
|
||||
|
||||
Returns:
|
||||
Database connection
|
||||
"""
|
||||
if db_name:
|
||||
# Use default config but override database name
|
||||
config = DB_CONFIG.copy()
|
||||
config["database"] = db_name
|
||||
return pymysql.connect(**config)
|
||||
else:
|
||||
# Use default config
|
||||
return pymysql.connect(**DB_CONFIG)
|
||||
|
||||
def get_db_name() -> str:
|
||||
"""Get the currently configured default database name"""
|
||||
return DB_CONFIG["database"] or os.getenv("DB_DATABASE", "")
|
||||
def __init__(self, config, security_manager=None):
|
||||
self.config = config
|
||||
self.pool: Pool | None = None
|
||||
self.session_connections: dict[str, DorisConnection] = {}
|
||||
self.metrics = ConnectionMetrics()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.security_manager = security_manager
|
||||
|
||||
def execute_query(sql, db_name: Optional[str] = None):
|
||||
"""
|
||||
Execute SQL query and return results
|
||||
|
||||
Args:
|
||||
sql: SQL query statement
|
||||
db_name: Specify the database name to connect to, use default config if None
|
||||
|
||||
Returns:
|
||||
Query results
|
||||
"""
|
||||
conn = get_db_connection(db_name)
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
# Set connection character set to utf8 before executing query
|
||||
cursor.execute("SET NAMES utf8")
|
||||
# Health check configuration
|
||||
self.health_check_interval = config.database.health_check_interval or 60
|
||||
self.max_connection_age = config.database.max_connection_age or 3600
|
||||
self.connection_timeout = config.database.connection_timeout or 30
|
||||
|
||||
# Start background tasks
|
||||
self._health_check_task = None
|
||||
self._cleanup_task = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize connection manager"""
|
||||
try:
|
||||
# Create connection pool
|
||||
self.pool = await aiomysql.create_pool(
|
||||
host=self.config.database.host,
|
||||
port=self.config.database.port,
|
||||
user=self.config.database.user,
|
||||
password=self.config.database.password,
|
||||
db=self.config.database.database,
|
||||
charset="utf8",
|
||||
minsize=self.config.database.min_connections or 5,
|
||||
maxsize=self.config.database.max_connections or 20,
|
||||
autocommit=True,
|
||||
connect_timeout=self.connection_timeout,
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
f"Connection pool initialized successfully, min connections: {self.config.database.min_connections}, "
|
||||
f"max connections: {self.config.database.max_connections}"
|
||||
)
|
||||
|
||||
# Start background monitoring tasks
|
||||
self._health_check_task = asyncio.create_task(self._health_check_loop())
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Connection pool initialization failed: {e}")
|
||||
raise
|
||||
|
||||
async def get_connection(self, session_id: str) -> DorisConnection:
|
||||
"""Get database connection
|
||||
|
||||
Supports session-level connection reuse to improve performance and consistency
|
||||
"""
|
||||
# Check if there's an existing session connection
|
||||
if session_id in self.session_connections:
|
||||
conn = self.session_connections[session_id]
|
||||
# Check connection health
|
||||
if await conn.ping():
|
||||
return conn
|
||||
else:
|
||||
# Connection is unhealthy, clean up and create new one
|
||||
await self._cleanup_session_connection(session_id)
|
||||
|
||||
# Create new connection
|
||||
return await self._create_new_connection(session_id)
|
||||
|
||||
async def _create_new_connection(self, session_id: str) -> DorisConnection:
|
||||
"""Create new database connection"""
|
||||
try:
|
||||
if not self.pool:
|
||||
raise RuntimeError("Connection pool not initialized")
|
||||
|
||||
# Get connection from pool
|
||||
raw_connection = await self.pool.acquire()
|
||||
|
||||
# Execute the actual query
|
||||
cursor.execute(sql)
|
||||
result = cursor.fetchall()
|
||||
return result
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def execute_query_df(sql, db_name: Optional[str] = None):
|
||||
"""
|
||||
Execute SQL query and return pandas DataFrame
|
||||
|
||||
Args:
|
||||
sql: SQL query statement
|
||||
db_name: Specify the database name to connect to, use default config if None
|
||||
|
||||
Returns:
|
||||
pandas DataFrame
|
||||
"""
|
||||
conn = get_db_connection(db_name)
|
||||
try:
|
||||
# Use a temporary cursor to execute the query and get results
|
||||
with conn.cursor() as cursor:
|
||||
# Set connection character set to utf8 before executing query
|
||||
cursor.execute("SET NAMES utf8")
|
||||
# Create wrapped connection
|
||||
doris_conn = DorisConnection(raw_connection, session_id, self.security_manager)
|
||||
|
||||
# Execute the actual query
|
||||
cursor.execute(sql)
|
||||
result = cursor.fetchall()
|
||||
# Store in session connections
|
||||
self.session_connections[session_id] = doris_conn
|
||||
|
||||
self.metrics.total_connections += 1
|
||||
self.logger.debug(f"Created new connection for session: {session_id}")
|
||||
|
||||
return doris_conn
|
||||
|
||||
except Exception as e:
|
||||
self.metrics.connection_errors += 1
|
||||
self.logger.error(f"Failed to create connection for session {session_id}: {e}")
|
||||
raise
|
||||
|
||||
async def release_connection(self, session_id: str):
|
||||
"""Release session connection"""
|
||||
if session_id in self.session_connections:
|
||||
await self._cleanup_session_connection(session_id)
|
||||
|
||||
async def _cleanup_session_connection(self, session_id: str):
|
||||
"""Clean up session connection"""
|
||||
if session_id in self.session_connections:
|
||||
conn = self.session_connections[session_id]
|
||||
try:
|
||||
# Return connection to pool
|
||||
if self.pool and conn.connection and not conn.connection.closed:
|
||||
self.pool.release(conn.connection)
|
||||
|
||||
# Close connection wrapper
|
||||
await conn.close()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cleaning up connection for session {session_id}: {e}")
|
||||
finally:
|
||||
# Remove from session connections
|
||||
del self.session_connections[session_id]
|
||||
self.logger.debug(f"Cleaned up connection for session: {session_id}")
|
||||
|
||||
async def _health_check_loop(self):
|
||||
"""Background health check loop"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.health_check_interval)
|
||||
await self._perform_health_check()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Health check error: {e}")
|
||||
|
||||
async def _perform_health_check(self):
|
||||
"""Perform health check"""
|
||||
try:
|
||||
unhealthy_sessions = []
|
||||
|
||||
for session_id, conn in self.session_connections.items():
|
||||
if not await conn.ping():
|
||||
unhealthy_sessions.append(session_id)
|
||||
|
||||
# Clean up unhealthy connections
|
||||
for session_id in unhealthy_sessions:
|
||||
await self._cleanup_session_connection(session_id)
|
||||
self.metrics.failed_connections += 1
|
||||
|
||||
# Update metrics
|
||||
await self._update_connection_metrics()
|
||||
self.metrics.last_health_check = datetime.utcnow()
|
||||
|
||||
if unhealthy_sessions:
|
||||
self.logger.warning(f"Cleaned up {len(unhealthy_sessions)} unhealthy connections")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Health check failed: {e}")
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""Background cleanup loop"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(300) # Run every 5 minutes
|
||||
await self._cleanup_idle_connections()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Cleanup loop error: {e}")
|
||||
|
||||
async def _cleanup_idle_connections(self):
|
||||
"""Clean up idle connections"""
|
||||
current_time = datetime.utcnow()
|
||||
idle_sessions = []
|
||||
|
||||
# If no results, return empty DataFrame
|
||||
if not result:
|
||||
return pd.DataFrame()
|
||||
|
||||
# Manually convert dict results to DataFrame
|
||||
df = pd.DataFrame(result)
|
||||
return df
|
||||
finally:
|
||||
conn.close()
|
||||
for session_id, conn in self.session_connections.items():
|
||||
# Check if connection has exceeded maximum age
|
||||
age = (current_time - conn.created_at).total_seconds()
|
||||
if age > self.max_connection_age:
|
||||
idle_sessions.append(session_id)
|
||||
|
||||
# Clean up idle connections
|
||||
for session_id in idle_sessions:
|
||||
await self._cleanup_session_connection(session_id)
|
||||
|
||||
if idle_sessions:
|
||||
self.logger.info(f"Cleaned up {len(idle_sessions)} idle connections")
|
||||
|
||||
async def _update_connection_metrics(self):
|
||||
"""Update connection metrics"""
|
||||
self.metrics.active_connections = len(self.session_connections)
|
||||
if self.pool:
|
||||
self.metrics.idle_connections = self.pool.freesize
|
||||
|
||||
async def get_metrics(self) -> ConnectionMetrics:
|
||||
"""Get connection metrics"""
|
||||
await self._update_connection_metrics()
|
||||
return self.metrics
|
||||
|
||||
async def execute_query(
|
||||
self, session_id: str, sql: str, params: tuple | None = None, auth_context=None
|
||||
) -> QueryResult:
|
||||
"""Execute query"""
|
||||
conn = await self.get_connection(session_id)
|
||||
return await conn.execute(sql, params, auth_context)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_connection_context(self, session_id: str):
|
||||
"""Get connection context manager"""
|
||||
conn = await self.get_connection(session_id)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
# Connection will be reused, no need to close here
|
||||
pass
|
||||
|
||||
async def close(self):
|
||||
"""Close connection manager"""
|
||||
try:
|
||||
# Cancel background tasks
|
||||
if self._health_check_task:
|
||||
self._health_check_task.cancel()
|
||||
try:
|
||||
await self._health_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Clean up all session connections
|
||||
for session_id in list(self.session_connections.keys()):
|
||||
await self._cleanup_session_connection(session_id)
|
||||
|
||||
# Close connection pool
|
||||
if self.pool:
|
||||
self.pool.close()
|
||||
await self.pool.wait_closed()
|
||||
|
||||
self.logger.info("Connection manager closed successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error closing connection manager: {e}")
|
||||
|
||||
async def test_connection(self) -> bool:
|
||||
"""Test database connection"""
|
||||
try:
|
||||
if not self.pool:
|
||||
return False
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("SELECT 1")
|
||||
result = await cursor.fetchone()
|
||||
return result is not None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Connection test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class ConnectionPoolMonitor:
|
||||
"""Connection pool monitor
|
||||
|
||||
Provides detailed monitoring and reporting capabilities for connection pool status
|
||||
"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
async def get_pool_status(self) -> dict[str, Any]:
|
||||
"""Get connection pool status"""
|
||||
metrics = await self.connection_manager.get_metrics()
|
||||
|
||||
status = {
|
||||
"pool_size": self.connection_manager.pool.size if self.connection_manager.pool else 0,
|
||||
"free_connections": self.connection_manager.pool.freesize if self.connection_manager.pool else 0,
|
||||
"active_sessions": len(self.connection_manager.session_connections),
|
||||
"total_connections": metrics.total_connections,
|
||||
"failed_connections": metrics.failed_connections,
|
||||
"connection_errors": metrics.connection_errors,
|
||||
"avg_connection_time": metrics.avg_connection_time,
|
||||
"last_health_check": metrics.last_health_check.isoformat() if metrics.last_health_check else None,
|
||||
}
|
||||
|
||||
return status
|
||||
|
||||
async def get_session_details(self) -> list[dict[str, Any]]:
|
||||
"""Get session connection details"""
|
||||
sessions = []
|
||||
|
||||
for session_id, conn in self.connection_manager.session_connections.items():
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
"created_at": conn.created_at.isoformat(),
|
||||
"last_used": conn.last_used.isoformat(),
|
||||
"query_count": conn.query_count,
|
||||
"is_healthy": conn.is_healthy,
|
||||
"connection_age": (datetime.utcnow() - conn.created_at).total_seconds(),
|
||||
}
|
||||
sessions.append(session_info)
|
||||
|
||||
return sessions
|
||||
|
||||
async def generate_health_report(self) -> dict[str, Any]:
|
||||
"""Generate connection health report"""
|
||||
pool_status = await self.get_pool_status()
|
||||
session_details = await self.get_session_details()
|
||||
|
||||
# Calculate health statistics
|
||||
healthy_sessions = sum(1 for s in session_details if s["is_healthy"])
|
||||
total_sessions = len(session_details)
|
||||
health_ratio = healthy_sessions / total_sessions if total_sessions > 0 else 1.0
|
||||
|
||||
report = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"pool_status": pool_status,
|
||||
"session_summary": {
|
||||
"total_sessions": total_sessions,
|
||||
"healthy_sessions": healthy_sessions,
|
||||
"health_ratio": health_ratio,
|
||||
},
|
||||
"session_details": session_details,
|
||||
"recommendations": [],
|
||||
}
|
||||
|
||||
# Add recommendations based on health status
|
||||
if health_ratio < 0.8:
|
||||
report["recommendations"].append("Consider checking database connectivity and network stability")
|
||||
|
||||
if pool_status["connection_errors"] > 10:
|
||||
report["recommendations"].append("High connection error rate detected, review connection configuration")
|
||||
|
||||
if pool_status["active_sessions"] > pool_status["pool_size"] * 0.9:
|
||||
report["recommendations"].append("Connection pool utilization is high, consider increasing pool size")
|
||||
|
||||
return report
|
||||
|
||||
@@ -1,226 +1,85 @@
|
||||
"""
|
||||
Unified Logging Configuration Module
|
||||
|
||||
Provides unified logging configuration, including:
|
||||
- General logs: Record all program execution information
|
||||
- Audit logs: Record JSON data for key operations and processing results
|
||||
- Error logs: Specifically record program exceptions and errors
|
||||
Logging configuration for Doris MCP Server.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import logging.handlers
|
||||
import logging.config
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from datetime import datetime
|
||||
from dotenv import load_dotenv
|
||||
from typing import Any
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Get project root directory
|
||||
PROJECT_ROOT = Path(__file__).parents[2].absolute()
|
||||
def setup_logging(
|
||||
level: str = "INFO",
|
||||
log_file: str | None = None,
|
||||
log_format: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Setup logging configuration.
|
||||
|
||||
# Get log configuration from environment variables
|
||||
LOG_DIR = os.getenv("LOG_DIR", str(PROJECT_ROOT / "logs"))
|
||||
LOG_PREFIX = os.getenv("LOG_PREFIX", "doris_mcp")
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
LOG_MAX_DAYS = int(os.getenv("LOG_MAX_DAYS", "30"))
|
||||
# Whether to output logs to the console (should be disabled when running as a service)
|
||||
CONSOLE_LOGGING = os.getenv("CONSOLE_LOGGING", "false").lower() == "true"
|
||||
# Whether stdio transport mode is being used
|
||||
STDIO_MODE = os.getenv("MCP_TRANSPORT_TYPE", "").lower() == "stdio"
|
||||
Args:
|
||||
level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
log_file: Optional log file path
|
||||
log_format: Optional custom log format
|
||||
"""
|
||||
if log_format is None:
|
||||
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
|
||||
def purge_old_logs():
|
||||
"""Clean up expired log files"""
|
||||
# --- Only perform cleanup in non-Stdio mode ---
|
||||
if STDIO_MODE:
|
||||
return
|
||||
try:
|
||||
now = datetime.now()
|
||||
log_dir = Path(LOG_DIR)
|
||||
# Check if directory exists and is readable/writable
|
||||
if not log_dir.is_dir() or not os.access(LOG_DIR, os.W_OK):
|
||||
if not STDIO_MODE: # Avoid printing to stdout in stdio mode
|
||||
print(f"Warning: Log directory {LOG_DIR} not accessible, skipping log purge.", file=sys.stderr)
|
||||
return
|
||||
# Base configuration
|
||||
config: dict[str, Any] = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"default": {"format": log_format, "datefmt": "%Y-%m-%d %H:%M:%S"}
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": level,
|
||||
"formatter": "default",
|
||||
"stream": sys.stdout,
|
||||
}
|
||||
},
|
||||
"root": {"level": level, "handlers": ["console"]},
|
||||
"loggers": {
|
||||
"doris_mcp_server": {
|
||||
"level": level,
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
for log_file in log_dir.glob(f"{LOG_PREFIX}*.20*"):
|
||||
# Parse date
|
||||
file_name = log_file.name
|
||||
date_str = None
|
||||
|
||||
# Try to find the date part
|
||||
parts = file_name.split('.')
|
||||
for part in parts:
|
||||
if part.startswith('20') and len(part) == 8: # 20YYMMDD format
|
||||
date_str = part
|
||||
break
|
||||
|
||||
if date_str:
|
||||
try:
|
||||
file_date = datetime.strptime(date_str, '%Y%m%d')
|
||||
days_old = (now - file_date).days
|
||||
|
||||
if days_old > LOG_MAX_DAYS:
|
||||
os.remove(log_file)
|
||||
if not STDIO_MODE:
|
||||
print(f"Deleted expired log file: {log_file}")
|
||||
except (ValueError, OSError) as e:
|
||||
if not STDIO_MODE:
|
||||
print(f"Error processing log file {file_name}: {e}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
if not STDIO_MODE:
|
||||
print(f"Error cleaning up logs: {e}", file=sys.stderr)
|
||||
# Add file handler if log_file is specified
|
||||
if log_file:
|
||||
# Ensure log directory exists
|
||||
log_path = Path(log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Force disable console log output if in stdio mode
|
||||
if STDIO_MODE:
|
||||
CONSOLE_LOGGING = False
|
||||
config["handlers"]["file"] = {
|
||||
"class": "logging.handlers.RotatingFileHandler",
|
||||
"level": level,
|
||||
"formatter": "default",
|
||||
"filename": log_file,
|
||||
"maxBytes": 10485760, # 10MB
|
||||
"backupCount": 5,
|
||||
}
|
||||
|
||||
# --- Only create log directory and clean old logs in non-Stdio mode ---
|
||||
if not STDIO_MODE:
|
||||
try:
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
# Clean up expired logs on startup (also moved here, as it only handles file logs)
|
||||
purge_old_logs()
|
||||
except OSError as e:
|
||||
# If directory creation fails (e.g., permission issue), print warning but continue to avoid startup failure
|
||||
print(f"Warning: Failed to create log directory {LOG_DIR} or purge logs: {e}", file=sys.stderr)
|
||||
# Add file handler to root and package loggers
|
||||
config["root"]["handlers"].append("file")
|
||||
config["loggers"]["doris_mcp_server"]["handlers"].append("file")
|
||||
|
||||
# Log file paths (definition still needed, but files might not be created/used)
|
||||
LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.log")
|
||||
AUDIT_LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.audit")
|
||||
ERROR_LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.error")
|
||||
|
||||
# Log level mapping
|
||||
LOG_LEVELS = {
|
||||
"DEBUG": logging.DEBUG,
|
||||
"INFO": logging.INFO,
|
||||
"WARNING": logging.WARNING,
|
||||
"ERROR": logging.ERROR,
|
||||
"CRITICAL": logging.CRITICAL
|
||||
}
|
||||
|
||||
# Log format
|
||||
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
AUDIT_FORMAT = '%(asctime)s - %(name)s - %(message)s'
|
||||
ERROR_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(pathname)s:%(lineno)d - %(message)s'
|
||||
|
||||
# Dedicated audit log level
|
||||
AUDIT = 25 # Level between INFO and WARNING
|
||||
logging.addLevelName(AUDIT, "AUDIT")
|
||||
|
||||
# Logger object cache
|
||||
_loggers: Dict[str, logging.Logger] = {}
|
||||
|
||||
# Handler type mapping, used to ensure no duplicates are added
|
||||
_handler_types = {
|
||||
'console': logging.StreamHandler,
|
||||
'file': logging.handlers.TimedRotatingFileHandler,
|
||||
'audit': logging.handlers.TimedRotatingFileHandler,
|
||||
'error': logging.handlers.TimedRotatingFileHandler
|
||||
}
|
||||
logging.config.dictConfig(config)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""
|
||||
Get a logger with the specified name
|
||||
|
||||
Get a logger instance.
|
||||
|
||||
Args:
|
||||
name: Logger name
|
||||
|
||||
|
||||
Returns:
|
||||
logging.Logger: Configured logger
|
||||
Logger instance
|
||||
"""
|
||||
if name in _loggers:
|
||||
return _loggers[name]
|
||||
|
||||
# Create logger
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(LOG_LEVELS.get(LOG_LEVEL, logging.INFO))
|
||||
|
||||
# Avoid duplicate logs caused by propagation
|
||||
logger.propagate = False
|
||||
|
||||
# Check if handlers already exist to avoid duplicates
|
||||
handler_types = set(type(h) for h in logger.handlers)
|
||||
|
||||
# Add audit log method
|
||||
def audit(self, message, *args, **kwargs):
|
||||
self.log(AUDIT, message, *args, **kwargs)
|
||||
|
||||
logger.audit = audit.__get__(logger)
|
||||
|
||||
# General log handler - output to console (only if enabled)
|
||||
if CONSOLE_LOGGING and _handler_types['console'] not in handler_types:
|
||||
# Use stderr instead of stdout to avoid conflicts with MCP communication
|
||||
console_handler = logging.StreamHandler(sys.stderr)
|
||||
console_handler.setFormatter(logging.Formatter(LOG_FORMAT))
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# --- Only add file handlers in non-Stdio mode ---
|
||||
if not STDIO_MODE:
|
||||
# General log handler - daily rotating file
|
||||
if _handler_types['file'] not in handler_types:
|
||||
try: # Add try-except block
|
||||
file_handler = logging.handlers.TimedRotatingFileHandler(
|
||||
LOG_FILE,
|
||||
when='midnight',
|
||||
interval=1,
|
||||
backupCount=LOG_MAX_DAYS,
|
||||
encoding='utf-8'
|
||||
)
|
||||
file_handler.setFormatter(logging.Formatter(LOG_FORMAT))
|
||||
file_handler.suffix = "%Y%m%d"
|
||||
logger.addHandler(file_handler)
|
||||
except OSError as e:
|
||||
print(f"Warning: Failed to add file log handler for {LOG_FILE}: {e}", file=sys.stderr)
|
||||
|
||||
# Audit log handler - only logs AUDIT level
|
||||
if _handler_types['audit'] not in handler_types:
|
||||
try: # Add try-except block
|
||||
audit_handler = logging.handlers.TimedRotatingFileHandler(
|
||||
AUDIT_LOG_FILE,
|
||||
when='midnight',
|
||||
interval=1,
|
||||
backupCount=LOG_MAX_DAYS,
|
||||
encoding='utf-8'
|
||||
)
|
||||
audit_handler.setFormatter(logging.Formatter(AUDIT_FORMAT))
|
||||
audit_handler.suffix = "%Y%m%d"
|
||||
audit_handler.setLevel(AUDIT)
|
||||
audit_handler.addFilter(lambda record: record.levelno == AUDIT)
|
||||
logger.addHandler(audit_handler)
|
||||
except OSError as e:
|
||||
print(f"Warning: Failed to add audit log handler for {AUDIT_LOG_FILE}: {e}", file=sys.stderr)
|
||||
|
||||
# Error log handler - only logs ERROR level and above
|
||||
if _handler_types['error'] not in handler_types:
|
||||
try: # Add try-except block
|
||||
error_handler = logging.handlers.TimedRotatingFileHandler(
|
||||
ERROR_LOG_FILE,
|
||||
when='midnight',
|
||||
interval=1,
|
||||
backupCount=LOG_MAX_DAYS,
|
||||
encoding='utf-8'
|
||||
)
|
||||
error_handler.setFormatter(logging.Formatter(ERROR_FORMAT))
|
||||
error_handler.suffix = "%Y%m%d"
|
||||
error_handler.setLevel(logging.ERROR)
|
||||
logger.addHandler(error_handler)
|
||||
except OSError as e:
|
||||
print(f"Warning: Failed to add error log handler for {ERROR_LOG_FILE}: {e}", file=sys.stderr)
|
||||
|
||||
# Cache logger
|
||||
_loggers[name] = logger
|
||||
|
||||
return logger
|
||||
|
||||
# Default logger
|
||||
logger = get_logger('doris_mcp')
|
||||
|
||||
# Audit logger - for recording processing results, business operations, etc.
|
||||
audit_logger = get_logger('audit')
|
||||
|
||||
# Call to clean logs moved after directory creation, and added non-stdio check
|
||||
return logging.getLogger(name)
|
||||
|
||||
800
doris_mcp_server/utils/query_executor.py
Normal file
800
doris_mcp_server/utils/query_executor.py
Normal file
@@ -0,0 +1,800 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Doris Query Execution Module
|
||||
Implements query optimization, cache management and performance monitoring functionality
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import os
|
||||
import uuid
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, date
|
||||
from typing import Any, Dict
|
||||
from decimal import Decimal
|
||||
|
||||
from .db import DorisConnectionManager, QueryResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryRequest:
|
||||
"""Query request wrapper"""
|
||||
|
||||
sql: str
|
||||
session_id: str
|
||||
user_id: str
|
||||
parameters: dict[str, Any] | None = None
|
||||
timeout: int | None = None
|
||||
cache_enabled: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedQuery:
|
||||
"""Cached query result"""
|
||||
|
||||
result: QueryResult
|
||||
created_at: datetime
|
||||
ttl: int
|
||||
access_count: int = 0
|
||||
last_accessed: datetime | None = None
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if cache is expired"""
|
||||
if self.ttl <= 0:
|
||||
return False
|
||||
return (datetime.utcnow() - self.created_at).total_seconds() > self.ttl
|
||||
|
||||
def access(self):
|
||||
"""Record access"""
|
||||
self.access_count += 1
|
||||
self.last_accessed = datetime.utcnow()
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryMetrics:
|
||||
"""Query performance metrics"""
|
||||
|
||||
total_queries: int = 0
|
||||
successful_queries: int = 0
|
||||
failed_queries: int = 0
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
avg_execution_time: float = 0.0
|
||||
total_execution_time: float = 0.0
|
||||
slow_queries: int = 0
|
||||
concurrent_queries: int = 0
|
||||
|
||||
|
||||
class QueryCache:
|
||||
"""Query result cache manager"""
|
||||
|
||||
def __init__(self, max_size: int = 1000, default_ttl: int = 300):
|
||||
self.max_size = max_size
|
||||
self.default_ttl = default_ttl
|
||||
self.cache: dict[str, CachedQuery] = {}
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def _generate_cache_key(
|
||||
self, sql: str, parameters: dict[str, Any] | None = None
|
||||
) -> str:
|
||||
"""Generate cache key"""
|
||||
cache_data = {"sql": sql.strip().lower(), "parameters": parameters or {}}
|
||||
cache_string = json.dumps(cache_data, sort_keys=True)
|
||||
return hashlib.md5(cache_string.encode()).hexdigest()
|
||||
|
||||
async def get(
|
||||
self, sql: str, parameters: dict[str, Any] | None = None
|
||||
) -> CachedQuery | None:
|
||||
"""Get cached query result"""
|
||||
cache_key = self._generate_cache_key(sql, parameters)
|
||||
|
||||
if cache_key in self.cache:
|
||||
cached_query = self.cache[cache_key]
|
||||
|
||||
if not cached_query.is_expired():
|
||||
cached_query.access()
|
||||
self.logger.debug(f"Cache hit: {cache_key}")
|
||||
return cached_query
|
||||
else:
|
||||
# Clean up expired cache
|
||||
del self.cache[cache_key]
|
||||
self.logger.debug(f"Cache expired, cleaned up: {cache_key}")
|
||||
|
||||
return None
|
||||
|
||||
async def set(
|
||||
self,
|
||||
sql: str,
|
||||
result: QueryResult,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
ttl: int | None = None,
|
||||
) -> str:
|
||||
"""Set query result cache"""
|
||||
cache_key = self._generate_cache_key(sql, parameters)
|
||||
|
||||
# Check cache size limit
|
||||
if len(self.cache) >= self.max_size:
|
||||
await self._evict_oldest()
|
||||
|
||||
cached_query = CachedQuery(
|
||||
result=result, created_at=datetime.utcnow(), ttl=ttl or self.default_ttl
|
||||
)
|
||||
|
||||
self.cache[cache_key] = cached_query
|
||||
self.logger.debug(f"Cache set: {cache_key}")
|
||||
|
||||
return cache_key
|
||||
|
||||
async def _evict_oldest(self):
|
||||
"""Clean up oldest cache item"""
|
||||
if not self.cache:
|
||||
return
|
||||
|
||||
# Find oldest cache item
|
||||
oldest_key = min(self.cache.keys(), key=lambda k: self.cache[k].created_at)
|
||||
|
||||
del self.cache[oldest_key]
|
||||
self.logger.debug(f"Cleaned up oldest cache: {oldest_key}")
|
||||
|
||||
async def clear_expired(self):
|
||||
"""Clean up all expired cache"""
|
||||
expired_keys = [
|
||||
key for key, cached_query in self.cache.items() if cached_query.is_expired()
|
||||
]
|
||||
|
||||
for key in expired_keys:
|
||||
del self.cache[key]
|
||||
|
||||
if expired_keys:
|
||||
self.logger.info(f"Cleaned up {len(expired_keys)} expired cache items")
|
||||
|
||||
async def clear_all(self):
|
||||
"""Clean up all cache"""
|
||||
cache_count = len(self.cache)
|
||||
self.cache.clear()
|
||||
self.logger.info(f"Cleaned up all cache, total {cache_count} items")
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics"""
|
||||
total_access = sum(cached.access_count for cached in self.cache.values())
|
||||
|
||||
return {
|
||||
"cache_size": len(self.cache),
|
||||
"max_size": self.max_size,
|
||||
"total_access": total_access,
|
||||
"hit_rate": 0.0
|
||||
if total_access == 0
|
||||
else sum(cached.access_count for cached in self.cache.values())
|
||||
/ total_access,
|
||||
}
|
||||
|
||||
|
||||
class QueryOptimizer:
|
||||
"""Query optimizer"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.optimization_rules = self._load_optimization_rules()
|
||||
|
||||
def _load_optimization_rules(self) -> list[dict[str, Any]]:
|
||||
"""Load query optimization rules"""
|
||||
return [
|
||||
{
|
||||
"name": "add_limit_clause",
|
||||
"description": "Add default limit for SELECT queries without LIMIT",
|
||||
"pattern": r"^select\s+.*(?!.*limit\s+\d+)",
|
||||
"action": "add_limit",
|
||||
"params": {"default_limit": 1000},
|
||||
},
|
||||
{
|
||||
"name": "optimize_count_query",
|
||||
"description": "Optimize COUNT queries",
|
||||
"pattern": r"select\s+count\(\*\)\s+from\s+(\w+)",
|
||||
"action": "optimize_count",
|
||||
"params": {},
|
||||
},
|
||||
]
|
||||
|
||||
async def optimize_query(self, sql: str, context: dict[str, Any]) -> str:
|
||||
"""Apply query optimization"""
|
||||
optimized_sql = sql
|
||||
|
||||
for rule in self.optimization_rules:
|
||||
if self._should_apply_rule(rule, optimized_sql, context):
|
||||
optimized_sql = await self._apply_optimization_rule(
|
||||
optimized_sql, rule, context
|
||||
)
|
||||
self.logger.debug(f"Applied optimization rule: {rule['name']}")
|
||||
|
||||
return optimized_sql
|
||||
|
||||
def _should_apply_rule(
|
||||
self, rule: dict[str, Any], sql: str, context: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Check if optimization rule should be applied"""
|
||||
import re
|
||||
|
||||
# Check pattern match
|
||||
if "pattern" in rule:
|
||||
if not re.search(rule["pattern"], sql, re.IGNORECASE):
|
||||
return False
|
||||
|
||||
# Check conditions
|
||||
if "conditions" in rule:
|
||||
for condition in rule["conditions"]:
|
||||
if not self._check_condition(condition, context):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _check_condition(
|
||||
self, condition: dict[str, Any], context: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Check optimization condition"""
|
||||
condition_type = condition.get("type")
|
||||
|
||||
if condition_type == "user_role":
|
||||
required_roles = condition.get("roles", [])
|
||||
user_roles = context.get("user_roles", [])
|
||||
return any(role in user_roles for role in required_roles)
|
||||
|
||||
elif condition_type == "query_size":
|
||||
max_size = condition.get("max_size", 1000)
|
||||
return len(context.get("sql", "")) <= max_size
|
||||
|
||||
return True
|
||||
|
||||
async def _apply_optimization_rule(
|
||||
self, sql: str, rule: dict[str, Any], context: dict[str, Any]
|
||||
) -> str:
|
||||
"""Apply optimization rule"""
|
||||
action = rule.get("action")
|
||||
params = rule.get("params", {})
|
||||
|
||||
if action == "add_limit":
|
||||
return await self._add_limit_clause(sql, params)
|
||||
elif action == "optimize_count":
|
||||
return await self._optimize_count_query(sql, params)
|
||||
elif action == "add_hints":
|
||||
return await self._add_query_hints(sql, params)
|
||||
|
||||
return sql
|
||||
|
||||
async def _add_limit_clause(self, sql: str, params: dict[str, Any]) -> str:
|
||||
"""Add LIMIT clause to query"""
|
||||
import re
|
||||
|
||||
default_limit = params.get("default_limit", 1000)
|
||||
|
||||
# Check if LIMIT already exists
|
||||
if re.search(r"\blimit\s+\d+", sql, re.IGNORECASE):
|
||||
return sql
|
||||
|
||||
# Add LIMIT clause
|
||||
if sql.strip().endswith(";"):
|
||||
sql = sql.strip()[:-1]
|
||||
|
||||
return f"{sql} LIMIT {default_limit}"
|
||||
|
||||
async def _optimize_count_query(self, sql: str, params: dict[str, Any]) -> str:
|
||||
"""Optimize COUNT query"""
|
||||
# For COUNT queries, we can add optimization hints
|
||||
return sql.replace("COUNT(*)", "COUNT(1)")
|
||||
|
||||
async def _add_query_hints(self, sql: str, params: dict[str, Any]) -> str:
|
||||
"""Add query hints"""
|
||||
hints = params.get("hints", [])
|
||||
if not hints:
|
||||
return sql
|
||||
|
||||
hint_string = "/*+ " + " ".join(hints) + " */"
|
||||
return f"{hint_string} {sql}"
|
||||
|
||||
|
||||
class DorisQueryExecutor:
|
||||
"""Doris query executor with caching and optimization"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager, config=None):
|
||||
self.connection_manager = connection_manager
|
||||
self.config = config or self._create_default_config()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize components
|
||||
cache_config = getattr(self.config, 'performance', None)
|
||||
if cache_config:
|
||||
cache_size = getattr(cache_config, 'max_cache_size', 1000)
|
||||
cache_ttl = getattr(cache_config, 'cache_ttl', 300)
|
||||
else:
|
||||
cache_size = 1000
|
||||
cache_ttl = 300
|
||||
|
||||
self.query_cache = QueryCache(max_size=cache_size, default_ttl=cache_ttl)
|
||||
self.query_optimizer = QueryOptimizer(self.config)
|
||||
self.metrics = QueryMetrics()
|
||||
|
||||
# Performance monitoring
|
||||
self.slow_query_threshold = 5.0 # seconds
|
||||
self.max_concurrent_queries = getattr(
|
||||
getattr(self.config, 'performance', None), 'max_concurrent_queries', 50
|
||||
) if hasattr(self.config, 'performance') else 50
|
||||
|
||||
# Background tasks
|
||||
self._background_tasks = []
|
||||
self._start_background_tasks()
|
||||
|
||||
def _create_default_config(self):
|
||||
"""Create default configuration"""
|
||||
class DefaultConfig:
|
||||
def __init__(self):
|
||||
self.performance = DefaultPerformanceConfig()
|
||||
|
||||
class DefaultPerformanceConfig:
|
||||
def __init__(self):
|
||||
self.max_cache_size = 1000
|
||||
self.cache_ttl = 300
|
||||
self.max_concurrent_queries = 50
|
||||
|
||||
return DefaultConfig()
|
||||
|
||||
def _start_background_tasks(self):
|
||||
"""Start background tasks"""
|
||||
try:
|
||||
# Cache cleanup task
|
||||
cleanup_task = asyncio.create_task(self._cache_cleanup_loop())
|
||||
self._background_tasks.append(cleanup_task)
|
||||
except RuntimeError:
|
||||
# No event loop running (e.g., in tests), skip background tasks
|
||||
self.logger.debug("No event loop running, skipping background tasks")
|
||||
|
||||
async def _cache_cleanup_loop(self):
|
||||
"""Background cache cleanup loop"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(300) # Run every 5 minutes
|
||||
await self.query_cache.clear_expired()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Cache cleanup error: {e}")
|
||||
|
||||
async def execute_query(
|
||||
self, query_request: QueryRequest, auth_context=None
|
||||
) -> QueryResult:
|
||||
"""Execute query with caching and optimization"""
|
||||
start_time = time.time()
|
||||
self.metrics.total_queries += 1
|
||||
self.metrics.concurrent_queries += 1
|
||||
|
||||
try:
|
||||
# Check cache first
|
||||
if query_request.cache_enabled:
|
||||
cached_result = await self.query_cache.get(
|
||||
query_request.sql, query_request.parameters
|
||||
)
|
||||
if cached_result:
|
||||
self.metrics.cache_hits += 1
|
||||
self.logger.debug(f"Cache hit for query: {query_request.sql[:50]}...")
|
||||
return cached_result.result
|
||||
|
||||
self.metrics.cache_misses += 1
|
||||
|
||||
# Execute query
|
||||
result = await self._execute_query_internal(query_request, auth_context)
|
||||
|
||||
# Cache result if enabled
|
||||
if query_request.cache_enabled and result.row_count > 0:
|
||||
await self.query_cache.set(
|
||||
query_request.sql, result, query_request.parameters
|
||||
)
|
||||
|
||||
self.metrics.successful_queries += 1
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.metrics.failed_queries += 1
|
||||
self.logger.error(f"Query execution failed: {e}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
execution_time = time.time() - start_time
|
||||
self.metrics.concurrent_queries -= 1
|
||||
self._update_execution_metrics(execution_time)
|
||||
|
||||
async def _execute_query_internal(
|
||||
self, query_request: QueryRequest, auth_context
|
||||
) -> QueryResult:
|
||||
"""Internal query execution"""
|
||||
# Optimize query
|
||||
optimized_sql = await self.query_optimizer.optimize_query(
|
||||
query_request.sql, {"user_roles": getattr(auth_context, 'roles', [])}
|
||||
)
|
||||
|
||||
# Execute query
|
||||
connection = await self.connection_manager.get_connection(
|
||||
query_request.session_id
|
||||
)
|
||||
|
||||
# Set timeout if specified
|
||||
if query_request.timeout:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
connection.execute(optimized_sql, query_request.parameters, auth_context),
|
||||
timeout=query_request.timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(f"Query timeout after {query_request.timeout} seconds")
|
||||
else:
|
||||
result = await connection.execute(optimized_sql, query_request.parameters, auth_context)
|
||||
|
||||
return result
|
||||
|
||||
def _update_execution_metrics(self, execution_time: float):
|
||||
"""Update execution metrics"""
|
||||
self.metrics.total_execution_time += execution_time
|
||||
|
||||
# Update average execution time
|
||||
if self.metrics.successful_queries > 0:
|
||||
self.metrics.avg_execution_time = (
|
||||
self.metrics.total_execution_time / self.metrics.successful_queries
|
||||
)
|
||||
|
||||
# Check for slow queries
|
||||
if execution_time > self.slow_query_threshold:
|
||||
self.metrics.slow_queries += 1
|
||||
self.logger.warning(
|
||||
f"Slow query detected: {execution_time:.2f}s (threshold: {self.slow_query_threshold}s)"
|
||||
)
|
||||
|
||||
async def execute_batch_queries(
|
||||
self, query_requests: list[QueryRequest], auth_context=None
|
||||
) -> list[QueryResult]:
|
||||
"""Execute multiple queries in batch"""
|
||||
results = []
|
||||
|
||||
# Check concurrent query limit
|
||||
if len(query_requests) > self.max_concurrent_queries:
|
||||
raise Exception(
|
||||
f"Batch size {len(query_requests)} exceeds maximum concurrent queries {self.max_concurrent_queries}"
|
||||
)
|
||||
|
||||
# Execute queries concurrently
|
||||
tasks = [
|
||||
self.execute_query(request, auth_context) for request in query_requests
|
||||
]
|
||||
|
||||
try:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Batch query execution failed: {e}")
|
||||
raise
|
||||
|
||||
return results
|
||||
|
||||
async def explain_query(self, sql: str, session_id: str) -> dict[str, Any]:
|
||||
"""Get query execution plan"""
|
||||
explain_sql = f"EXPLAIN {sql}"
|
||||
|
||||
connection = await self.connection_manager.get_connection(session_id)
|
||||
result = await connection.execute(explain_sql)
|
||||
|
||||
return {
|
||||
"query": sql,
|
||||
"execution_plan": result.data,
|
||||
"estimated_cost": "N/A", # Doris doesn't provide cost estimates
|
||||
}
|
||||
|
||||
async def get_query_stats(self) -> dict[str, Any]:
|
||||
"""Get query execution statistics"""
|
||||
cache_stats = self.query_cache.get_stats()
|
||||
|
||||
return {
|
||||
"query_metrics": {
|
||||
"total_queries": self.metrics.total_queries,
|
||||
"successful_queries": self.metrics.successful_queries,
|
||||
"failed_queries": self.metrics.failed_queries,
|
||||
"success_rate": (
|
||||
self.metrics.successful_queries / self.metrics.total_queries
|
||||
if self.metrics.total_queries > 0
|
||||
else 0.0
|
||||
),
|
||||
"avg_execution_time": self.metrics.avg_execution_time,
|
||||
"slow_queries": self.metrics.slow_queries,
|
||||
"concurrent_queries": self.metrics.concurrent_queries,
|
||||
},
|
||||
"cache_metrics": {
|
||||
"cache_hits": self.metrics.cache_hits,
|
||||
"cache_misses": self.metrics.cache_misses,
|
||||
"hit_rate": (
|
||||
self.metrics.cache_hits
|
||||
/ (self.metrics.cache_hits + self.metrics.cache_misses)
|
||||
if (self.metrics.cache_hits + self.metrics.cache_misses) > 0
|
||||
else 0.0
|
||||
),
|
||||
**cache_stats,
|
||||
},
|
||||
}
|
||||
|
||||
async def clear_cache(self):
|
||||
"""Clear query cache"""
|
||||
await self.query_cache.clear_all()
|
||||
|
||||
async def execute_sql_for_mcp(
|
||||
self,
|
||||
sql: str,
|
||||
limit: int = 1000,
|
||||
timeout: int = 30,
|
||||
session_id: str = "mcp_session",
|
||||
user_id: str = "mcp_user"
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute SQL query for MCP interface - unified method"""
|
||||
try:
|
||||
if not sql:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "SQL query is required",
|
||||
"data": None
|
||||
}
|
||||
|
||||
# Add LIMIT if not present and it's a SELECT query
|
||||
if sql.upper().startswith("SELECT") and "LIMIT" not in sql.upper():
|
||||
if sql.endswith(";"):
|
||||
sql = sql[:-1]
|
||||
sql = f"{sql} LIMIT {limit}"
|
||||
|
||||
# Create auth context for MCP calls
|
||||
class MockAuthContext:
|
||||
def __init__(self):
|
||||
self.user_id = user_id
|
||||
self.roles = ["data_analyst"]
|
||||
self.permissions = ["read_data", "execute_query"]
|
||||
self.session_id = session_id
|
||||
self.security_level = "internal"
|
||||
|
||||
auth_context = MockAuthContext()
|
||||
|
||||
# Create query request
|
||||
query_request = QueryRequest(
|
||||
sql=sql,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
timeout=timeout,
|
||||
cache_enabled=True
|
||||
)
|
||||
|
||||
# Execute query
|
||||
result = await self.execute_query(query_request, auth_context)
|
||||
|
||||
# Process results
|
||||
processed_data = []
|
||||
if result.data:
|
||||
for row in result.data:
|
||||
processed_row = self._serialize_row_data(row)
|
||||
processed_data.append(processed_row)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": processed_data,
|
||||
"metadata": {
|
||||
"row_count": result.row_count,
|
||||
"execution_time": result.execution_time,
|
||||
"columns": result.metadata.get("columns", []),
|
||||
"query": sql
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
self.logger.error(f"SQL execution error: {error_msg}")
|
||||
|
||||
# Analyze error for better user feedback
|
||||
error_analysis = self._analyze_error(error_msg)
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_analysis.get("user_message", error_msg),
|
||||
"error_type": error_analysis.get("error_type", "execution_error"),
|
||||
"data": None,
|
||||
"metadata": {
|
||||
"query": sql,
|
||||
"error_details": error_msg
|
||||
}
|
||||
}
|
||||
|
||||
def _serialize_row_data(self, row_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Serialize row data for JSON response"""
|
||||
serialized = {}
|
||||
|
||||
for key, value in row_data.items():
|
||||
if value is None:
|
||||
serialized[key] = None
|
||||
elif isinstance(value, (str, int, float, bool)):
|
||||
serialized[key] = value
|
||||
elif isinstance(value, Decimal):
|
||||
serialized[key] = float(value)
|
||||
elif isinstance(value, (datetime, date)):
|
||||
serialized[key] = value.isoformat()
|
||||
elif isinstance(value, bytes):
|
||||
try:
|
||||
serialized[key] = value.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
serialized[key] = str(value)
|
||||
else:
|
||||
serialized[key] = str(value)
|
||||
|
||||
return serialized
|
||||
|
||||
def _analyze_error(self, error_message: str) -> Dict[str, str]:
|
||||
"""Analyze error message and provide user-friendly feedback"""
|
||||
error_msg_lower = error_message.lower()
|
||||
|
||||
if "table" in error_msg_lower and "doesn't exist" in error_msg_lower:
|
||||
return {
|
||||
"error_type": "table_not_found",
|
||||
"user_message": "The specified table does not exist. Please check the table name and database."
|
||||
}
|
||||
elif "column" in error_msg_lower and ("unknown" in error_msg_lower or "doesn't exist" in error_msg_lower):
|
||||
return {
|
||||
"error_type": "column_not_found",
|
||||
"user_message": "One or more columns in the query do not exist. Please check column names."
|
||||
}
|
||||
elif "syntax error" in error_msg_lower or "sql syntax" in error_msg_lower:
|
||||
return {
|
||||
"error_type": "syntax_error",
|
||||
"user_message": "SQL syntax error. Please check your query syntax."
|
||||
}
|
||||
elif "access denied" in error_msg_lower or "permission" in error_msg_lower:
|
||||
return {
|
||||
"error_type": "permission_denied",
|
||||
"user_message": "Access denied. You don't have permission to execute this query."
|
||||
}
|
||||
elif "timeout" in error_msg_lower:
|
||||
return {
|
||||
"error_type": "timeout",
|
||||
"user_message": "Query execution timed out. Try simplifying your query or adding more specific filters."
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"error_type": "general_error",
|
||||
"user_message": f"Query execution failed: {error_message}"
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
"""Close query executor and cleanup resources"""
|
||||
# Cancel background tasks
|
||||
for task in self._background_tasks:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Clear cache
|
||||
await self.query_cache.clear_all()
|
||||
|
||||
self.logger.info("Query executor closed")
|
||||
|
||||
|
||||
class QueryPerformanceMonitor:
|
||||
"""Query performance monitor"""
|
||||
|
||||
def __init__(self, query_executor: DorisQueryExecutor):
|
||||
self.query_executor = query_executor
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.performance_records = []
|
||||
|
||||
async def record_query_performance(
|
||||
self, query_request: QueryRequest, result: QueryResult, execution_time: float
|
||||
):
|
||||
"""Record query performance"""
|
||||
record = {
|
||||
"timestamp": datetime.utcnow(),
|
||||
"sql": query_request.sql,
|
||||
"user_id": query_request.user_id,
|
||||
"session_id": query_request.session_id,
|
||||
"execution_time": execution_time,
|
||||
"row_count": result.row_count,
|
||||
"cache_hit": False, # This would need to be passed from executor
|
||||
}
|
||||
|
||||
self.performance_records.append(record)
|
||||
|
||||
# Keep only recent records (last 1000)
|
||||
if len(self.performance_records) > 1000:
|
||||
self.performance_records = self.performance_records[-1000:]
|
||||
|
||||
async def get_performance_report(
|
||||
self, time_range_minutes: int = 60
|
||||
) -> dict[str, Any]:
|
||||
"""Get performance report"""
|
||||
cutoff_time = datetime.utcnow() - timedelta(minutes=time_range_minutes)
|
||||
|
||||
recent_records = [
|
||||
record
|
||||
for record in self.performance_records
|
||||
if record["timestamp"] >= cutoff_time
|
||||
]
|
||||
|
||||
if not recent_records:
|
||||
return {"message": "No performance data available for the specified time range"}
|
||||
|
||||
# Calculate statistics
|
||||
execution_times = [record["execution_time"] for record in recent_records]
|
||||
row_counts = [record["row_count"] for record in recent_records]
|
||||
|
||||
return {
|
||||
"time_range_minutes": time_range_minutes,
|
||||
"total_queries": len(recent_records),
|
||||
"avg_execution_time": sum(execution_times) / len(execution_times),
|
||||
"max_execution_time": max(execution_times),
|
||||
"min_execution_time": min(execution_times),
|
||||
"avg_row_count": sum(row_counts) / len(row_counts),
|
||||
"query_distribution": self._analyze_query_distribution(recent_records),
|
||||
}
|
||||
|
||||
def _analyze_query_distribution(
|
||||
self, records: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Analyze query distribution"""
|
||||
query_types = {}
|
||||
user_distribution = {}
|
||||
|
||||
for record in records:
|
||||
# Analyze query type
|
||||
sql_upper = record["sql"].strip().upper()
|
||||
if sql_upper.startswith("SELECT"):
|
||||
query_type = "SELECT"
|
||||
elif sql_upper.startswith("INSERT"):
|
||||
query_type = "INSERT"
|
||||
elif sql_upper.startswith("UPDATE"):
|
||||
query_type = "UPDATE"
|
||||
elif sql_upper.startswith("DELETE"):
|
||||
query_type = "DELETE"
|
||||
else:
|
||||
query_type = "OTHER"
|
||||
|
||||
query_types[query_type] = query_types.get(query_type, 0) + 1
|
||||
|
||||
# Analyze user distribution
|
||||
user_id = record["user_id"]
|
||||
user_distribution[user_id] = user_distribution.get(user_id, 0) + 1
|
||||
|
||||
return {"query_types": query_types, "user_distribution": user_distribution}
|
||||
|
||||
|
||||
# Unified convenience function for MCP integration
|
||||
async def execute_sql_query(sql: str, connection_manager: DorisConnectionManager, **kwargs) -> Dict[str, Any]:
|
||||
"""Execute SQL query - unified convenience function for MCP tools"""
|
||||
try:
|
||||
# Create query executor
|
||||
executor = DorisQueryExecutor(connection_manager)
|
||||
|
||||
try:
|
||||
# Extract parameters from kwargs or use defaults
|
||||
limit = kwargs.get("limit", 1000)
|
||||
timeout = kwargs.get("timeout", 30)
|
||||
session_id = kwargs.get("session_id", "mcp_session")
|
||||
user_id = kwargs.get("user_id", "mcp_user")
|
||||
|
||||
result = await executor.execute_sql_for_mcp(
|
||||
sql=sql,
|
||||
limit=limit,
|
||||
timeout=timeout,
|
||||
session_id=session_id,
|
||||
user_id=user_id
|
||||
)
|
||||
return result
|
||||
finally:
|
||||
await executor.close()
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Query execution failed: {str(e)}",
|
||||
"data": None
|
||||
}
|
||||
@@ -8,6 +8,8 @@ import os
|
||||
import json
|
||||
import pandas as pd
|
||||
import re
|
||||
import uuid
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from dotenv import load_dotenv
|
||||
from datetime import datetime, timedelta
|
||||
@@ -26,23 +28,25 @@ ENABLE_MULTI_DATABASE=os.getenv("ENABLE_MULTI_DATABASE",True)
|
||||
MULTI_DATABASE_NAMES=os.getenv("MULTI_DATABASE_NAMES","")
|
||||
|
||||
# Import local modules
|
||||
from doris_mcp_server.utils.db import execute_query_df, execute_query
|
||||
from .db import DorisConnectionManager
|
||||
|
||||
class MetadataExtractor:
|
||||
"""Apache Doris Metadata Extractor"""
|
||||
|
||||
def __init__(self, db_name: str = None, catalog_name: str = None):
|
||||
def __init__(self, db_name: str = None, catalog_name: str = None, connection_manager=None):
|
||||
"""
|
||||
Initialize the metadata extractor
|
||||
|
||||
Args:
|
||||
db_name: Default database name, uses the currently connected database if not specified
|
||||
catalog_name: Default catalog name for federation queries, uses the current catalog if not specified
|
||||
connection_manager: DorisConnectionManager instance for database operations
|
||||
"""
|
||||
# Get configuration from environment variables
|
||||
self.db_name = db_name or os.getenv("DB_DATABASE", "")
|
||||
self.catalog_name = catalog_name # Store catalog name for federation support
|
||||
self.metadata_db = METADATA_DB_NAME # Use constant
|
||||
self.connection_manager = connection_manager
|
||||
|
||||
# Caching system
|
||||
self.metadata_cache = {}
|
||||
@@ -65,6 +69,9 @@ class MetadataExtractor:
|
||||
# List of excluded system databases
|
||||
self.excluded_databases = self._load_excluded_databases()
|
||||
|
||||
# Session ID for database queries
|
||||
self._session_id = f"metadata_extractor_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
def _load_excluded_databases(self) -> List[str]:
|
||||
"""
|
||||
Load the list of excluded databases configuration
|
||||
@@ -482,7 +489,7 @@ class MetadataExtractor:
|
||||
TABLE_SCHEMA = '{db_name}'
|
||||
AND TABLE_NAME = '{table_name}'
|
||||
"""
|
||||
table_type_result = execute_query(table_type_query)
|
||||
table_type_result = self._execute_query(table_type_query)
|
||||
if table_type_result:
|
||||
schema["table_type"] = table_type_result[0].get("TABLE_TYPE", "")
|
||||
schema["engine"] = table_type_result[0].get("ENGINE", "")
|
||||
@@ -633,31 +640,52 @@ class MetadataExtractor:
|
||||
else:
|
||||
query = f"SHOW INDEX FROM `{db_name}`.`{table_name}`"
|
||||
|
||||
df = execute_query_df(query)
|
||||
|
||||
# Process results
|
||||
indexes = []
|
||||
current_index = None
|
||||
|
||||
for _, row in df.iterrows():
|
||||
index_name = row['Key_name']
|
||||
column_name = row['Column_name']
|
||||
try:
|
||||
df = self._execute_query(query, return_dataframe=True)
|
||||
|
||||
if current_index is None or current_index['name'] != index_name:
|
||||
# Process results
|
||||
indexes = []
|
||||
current_index = None
|
||||
|
||||
if not df.empty:
|
||||
for _, row in df.iterrows():
|
||||
try:
|
||||
index_name = row['Key_name']
|
||||
column_name = row['Column_name']
|
||||
|
||||
if current_index is None or current_index['name'] != index_name:
|
||||
if current_index is not None:
|
||||
indexes.append(current_index)
|
||||
|
||||
current_index = {
|
||||
'name': index_name,
|
||||
'columns': [column_name],
|
||||
'unique': row['Non_unique'] == 0,
|
||||
'type': row['Index_type']
|
||||
}
|
||||
else:
|
||||
current_index['columns'].append(column_name)
|
||||
except Exception as row_error:
|
||||
logger.warning(f"Failed to process index row data: {row_error}")
|
||||
continue
|
||||
|
||||
if current_index is not None:
|
||||
indexes.append(current_index)
|
||||
|
||||
current_index = {
|
||||
'name': index_name,
|
||||
'columns': [column_name],
|
||||
'unique': row['Non_unique'] == 0,
|
||||
'type': row['Index_type']
|
||||
}
|
||||
else:
|
||||
current_index['columns'].append(column_name)
|
||||
|
||||
if current_index is not None:
|
||||
indexes.append(current_index)
|
||||
except Exception as df_error:
|
||||
logger.warning(f"DataFrame processing failed, trying regular query: {df_error}")
|
||||
# Fall back to regular query
|
||||
result = self._execute_query(query, return_dataframe=False)
|
||||
indexes = []
|
||||
if result:
|
||||
# Simple processing, no complex index grouping
|
||||
for row in result:
|
||||
if isinstance(row, dict):
|
||||
indexes.append({
|
||||
'name': row.get('Key_name', ''),
|
||||
'columns': [row.get('Column_name', '')],
|
||||
'unique': row.get('Non_unique', 1) == 0,
|
||||
'type': row.get('Index_type', '')
|
||||
})
|
||||
|
||||
# Update cache
|
||||
self.metadata_cache[cache_key] = indexes
|
||||
@@ -748,7 +776,7 @@ class MetadataExtractor:
|
||||
ORDER BY time DESC
|
||||
LIMIT {limit}
|
||||
"""
|
||||
df = execute_query_df(query)
|
||||
df = self._execute_query(query, return_dataframe=True)
|
||||
return df
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting audit logs: {str(e)}")
|
||||
@@ -768,7 +796,7 @@ class MetadataExtractor:
|
||||
try:
|
||||
# Use SHOW CATALOGS command to get catalog list
|
||||
query = "SHOW CATALOGS"
|
||||
result = execute_query(query)
|
||||
result = self._execute_query(query)
|
||||
|
||||
if not result:
|
||||
catalogs = []
|
||||
@@ -1057,7 +1085,7 @@ class MetadataExtractor:
|
||||
AND TABLE_NAME = '{table_name}'
|
||||
"""
|
||||
|
||||
partitions = execute_query(query)
|
||||
partitions = self._execute_query(query)
|
||||
|
||||
if not partitions:
|
||||
return {}
|
||||
@@ -1099,10 +1127,511 @@ class MetadataExtractor:
|
||||
# Replace 'information_schema' with 'catalog_name.information_schema'
|
||||
modified_query = query.replace('information_schema', f'{catalog_name}.information_schema')
|
||||
logger.info(f"Modified query for catalog {catalog_name}: {modified_query}")
|
||||
return execute_query(modified_query, db_name)
|
||||
return self._execute_query(modified_query, db_name)
|
||||
else:
|
||||
# Execute the original query
|
||||
return execute_query(query, db_name)
|
||||
return self._execute_query(query, db_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing query with catalog: {str(e)}")
|
||||
raise
|
||||
raise
|
||||
|
||||
async def _execute_query_async(self, query: str, db_name: str = None, return_dataframe: bool = False):
|
||||
"""
|
||||
Execute database query asynchronously
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
db_name: Database name to use (optional)
|
||||
return_dataframe: Whether to return a pandas DataFrame instead of list
|
||||
|
||||
Returns:
|
||||
Query result data (list of dictionaries or pandas DataFrame)
|
||||
"""
|
||||
try:
|
||||
if self.connection_manager:
|
||||
# Use the injected connection manager directly (async)
|
||||
result = await self.connection_manager.execute_query(self._session_id, query, None)
|
||||
|
||||
# Extract data from QueryResult
|
||||
if hasattr(result, 'data'):
|
||||
data = result.data
|
||||
else:
|
||||
data = result
|
||||
|
||||
# Convert to DataFrame if requested
|
||||
if return_dataframe and data:
|
||||
import pandas as pd
|
||||
return pd.DataFrame(data)
|
||||
elif return_dataframe:
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
return data
|
||||
else:
|
||||
# Fallback: Return empty result
|
||||
logger.warning("No connection manager provided, returning empty result")
|
||||
if return_dataframe:
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing query: {str(e)}")
|
||||
# Return empty result instead of raising exception to prevent cascade failures
|
||||
if return_dataframe:
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
return []
|
||||
|
||||
def _execute_query(self, query: str, db_name: str = None, return_dataframe: bool = False):
|
||||
"""
|
||||
Execute database query with proper session management (sync wrapper)
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
db_name: Database name to use (optional)
|
||||
return_dataframe: Whether to return a pandas DataFrame instead of list
|
||||
|
||||
Returns:
|
||||
Query result data (list of dictionaries or pandas DataFrame)
|
||||
"""
|
||||
try:
|
||||
if self.connection_manager:
|
||||
import asyncio
|
||||
|
||||
# Try to run the async query
|
||||
try:
|
||||
# Check if there's a running event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we're in an async context, we need to run in a separate thread
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_new_loop():
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(
|
||||
self._execute_query_async(query, db_name, return_dataframe)
|
||||
)
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
return future.result(timeout=30)
|
||||
|
||||
except RuntimeError:
|
||||
# No running loop, we can safely create one
|
||||
return asyncio.run(
|
||||
self._execute_query_async(query, db_name, return_dataframe)
|
||||
)
|
||||
else:
|
||||
# Fallback: Return empty result
|
||||
logger.warning("No connection manager provided, returning empty result")
|
||||
if return_dataframe:
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing query: {str(e)}")
|
||||
# Return empty result instead of raising exception to prevent cascade failures
|
||||
if return_dataframe:
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
return []
|
||||
|
||||
async def get_table_schema_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> List[Dict[str, Any]]:
|
||||
"""Asynchronously get table schema information"""
|
||||
try:
|
||||
# Use async query method
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
# Build query statement
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
query = f"DESCRIBE `{effective_catalog}`.`{db_name or self.db_name}`.`{table_name}`"
|
||||
else:
|
||||
query = f"DESCRIBE `{db_name or self.db_name}`.`{table_name}`"
|
||||
|
||||
# Execute async query
|
||||
result = await self._execute_query_async(query, db_name)
|
||||
|
||||
if not result:
|
||||
return []
|
||||
|
||||
# Process results
|
||||
schema = []
|
||||
for row in result:
|
||||
if isinstance(row, dict):
|
||||
schema.append({
|
||||
'column_name': row.get('Field', ''),
|
||||
'data_type': row.get('Type', ''),
|
||||
'is_nullable': row.get('Null', 'NO') == 'YES',
|
||||
'default_value': row.get('Default', None),
|
||||
'comment': row.get('Comment', ''),
|
||||
'key': row.get('Key', ''),
|
||||
'extra': row.get('Extra', '')
|
||||
})
|
||||
|
||||
return schema
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table schema: {e}")
|
||||
return []
|
||||
|
||||
async def get_all_databases_async(self, catalog_name: str = None) -> List[str]:
|
||||
"""Asynchronously get all database list"""
|
||||
try:
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
query = f"SHOW DATABASES FROM `{effective_catalog}`"
|
||||
else:
|
||||
query = "SHOW DATABASES"
|
||||
|
||||
result = await self._execute_query_async(query)
|
||||
|
||||
if not result:
|
||||
return []
|
||||
|
||||
# Extract database names
|
||||
databases = []
|
||||
for row in result:
|
||||
if isinstance(row, dict):
|
||||
# Get the value of the first field (usually Database field)
|
||||
db_name = list(row.values())[0] if row else None
|
||||
if db_name:
|
||||
databases.append(db_name)
|
||||
|
||||
return databases
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get database list: {e}")
|
||||
return []
|
||||
|
||||
async def get_database_tables_async(self, db_name: str = None, catalog_name: str = None) -> List[str]:
|
||||
"""Asynchronously get table list in database"""
|
||||
try:
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
effective_db = db_name or self.db_name
|
||||
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
query = f"SHOW TABLES FROM `{effective_catalog}`.`{effective_db}`"
|
||||
else:
|
||||
query = f"SHOW TABLES FROM `{effective_db}`"
|
||||
|
||||
result = await self._execute_query_async(query, effective_db)
|
||||
|
||||
if not result:
|
||||
return []
|
||||
|
||||
# Extract table names
|
||||
tables = []
|
||||
for row in result:
|
||||
if isinstance(row, dict):
|
||||
# Get the value of the first field (usually Tables_in_xxx field)
|
||||
table_name = list(row.values())[0] if row else None
|
||||
if table_name:
|
||||
tables.append(table_name)
|
||||
|
||||
return tables
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table list: {e}")
|
||||
return []
|
||||
|
||||
async def get_catalog_list_async(self) -> List[str]:
|
||||
"""Asynchronously get catalog list"""
|
||||
try:
|
||||
query = "SHOW CATALOGS"
|
||||
result = await self._execute_query_async(query)
|
||||
|
||||
if not result:
|
||||
return []
|
||||
|
||||
# Extract catalog names
|
||||
catalogs = []
|
||||
for row in result:
|
||||
if isinstance(row, dict):
|
||||
# SHOW CATALOGS returns fields including: CatalogId, CatalogName, Type, IsCurrent, CreateTime, LastUpdateTime, Comment
|
||||
# We need to get the CatalogName field (second field)
|
||||
if 'CatalogName' in row:
|
||||
catalog_name = row['CatalogName']
|
||||
else:
|
||||
# If no CatalogName field, try to get the second field
|
||||
values = list(row.values())
|
||||
catalog_name = values[1] if len(values) > 1 else values[0] if values else None
|
||||
|
||||
if catalog_name:
|
||||
catalogs.append(str(catalog_name))
|
||||
|
||||
return catalogs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get catalog list: {e}")
|
||||
return []
|
||||
|
||||
# ==================== Business layer methods (original metadata_tools.py functionality) ====================
|
||||
|
||||
def _format_response(self, success: bool, result: Any = None, error: str = None, message: str = "") -> Dict[str, Any]:
|
||||
"""Format response result"""
|
||||
response_data = {
|
||||
"success": success,
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
if success and result is not None:
|
||||
response_data["result"] = result
|
||||
response_data["message"] = message or "Operation successful"
|
||||
elif not success:
|
||||
response_data["error"] = error or "Unknown error"
|
||||
response_data["message"] = message or "Operation failed"
|
||||
|
||||
return response_data
|
||||
|
||||
async def exec_query_for_mcp(
|
||||
self,
|
||||
sql: str,
|
||||
db_name: str = None,
|
||||
catalog_name: str = None,
|
||||
max_rows: int = 100,
|
||||
timeout: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute SQL query and return results, supports catalog federation queries
|
||||
Unified interface for MCP tools
|
||||
"""
|
||||
logger.info(f"Executing SQL query: {sql}, DB: {db_name}, Catalog: {catalog_name}, MaxRows: {max_rows}, Timeout: {timeout}")
|
||||
|
||||
try:
|
||||
if not sql:
|
||||
return self._format_response(success=False, error="No SQL statement provided", message="Please provide SQL statement to execute")
|
||||
|
||||
# Import query executor
|
||||
from .query_executor import execute_sql_query
|
||||
|
||||
# Call execute_sql_query to execute query
|
||||
exec_result = await execute_sql_query(
|
||||
sql=sql,
|
||||
connection_manager=self.connection_manager,
|
||||
limit=max_rows,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
return exec_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute SQL query: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while executing SQL query")
|
||||
|
||||
async def get_table_schema_for_mcp(
|
||||
self,
|
||||
table_name: str,
|
||||
db_name: str = None,
|
||||
catalog_name: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get detailed schema information for specified table (columns, types, comments, etc.) - MCP interface"""
|
||||
logger.info(f"Getting table schema: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
|
||||
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
try:
|
||||
schema = await self.get_table_schema_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
if not schema:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error="Table does not exist or has no columns",
|
||||
message=f"Unable to get schema for table {catalog_name or 'default'}.{db_name or self.db_name}.{table_name}"
|
||||
)
|
||||
|
||||
return self._format_response(success=True, result=schema)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table schema: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting table schema")
|
||||
|
||||
async def get_db_table_list_for_mcp(
|
||||
self,
|
||||
db_name: str = None,
|
||||
catalog_name: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get list of all table names in specified database - MCP interface"""
|
||||
logger.info(f"Getting database table list: DB: {db_name}, Catalog: {catalog_name}")
|
||||
|
||||
try:
|
||||
tables = await self.get_database_tables_async(db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=tables)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get database table list: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting database table list")
|
||||
|
||||
async def get_db_list_for_mcp(self, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get list of all database names on server - MCP interface"""
|
||||
logger.info(f"Getting database list: Catalog: {catalog_name}")
|
||||
|
||||
try:
|
||||
databases = await self.get_all_databases_async(catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=databases)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get database list: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting database list")
|
||||
|
||||
async def get_table_comment_for_mcp(
|
||||
self,
|
||||
table_name: str,
|
||||
db_name: str = None,
|
||||
catalog_name: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get comment information for specified table - MCP interface"""
|
||||
logger.info(f"Getting table comment: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
|
||||
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
try:
|
||||
comment = self.get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=comment)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table comment: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting table comment")
|
||||
|
||||
async def get_table_column_comments_for_mcp(
|
||||
self,
|
||||
table_name: str,
|
||||
db_name: str = None,
|
||||
catalog_name: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get comment information for all columns in specified table - MCP interface"""
|
||||
logger.info(f"Getting table column comments: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
|
||||
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
try:
|
||||
comments = self.get_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=comments)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table column comments: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting table column comments")
|
||||
|
||||
async def get_table_indexes_for_mcp(
|
||||
self,
|
||||
table_name: str,
|
||||
db_name: str = None,
|
||||
catalog_name: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get index information for specified table - MCP interface"""
|
||||
logger.info(f"Getting table indexes: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
|
||||
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
try:
|
||||
indexes = self.get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=indexes)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table indexes: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting table indexes")
|
||||
|
||||
def _serialize_datetime_objects(self, data):
|
||||
"""Serialize datetime objects to JSON compatible format"""
|
||||
if isinstance(data, list):
|
||||
return [self._serialize_datetime_objects(item) for item in data]
|
||||
elif isinstance(data, dict):
|
||||
return {key: self._serialize_datetime_objects(value) for key, value in data.items()}
|
||||
elif hasattr(data, 'isoformat'): # datetime, date, time objects
|
||||
return data.isoformat()
|
||||
elif hasattr(data, 'strftime'): # pandas Timestamp objects
|
||||
return data.strftime('%Y-%m-%d %H:%M:%S')
|
||||
else:
|
||||
return data
|
||||
|
||||
async def get_recent_audit_logs_for_mcp(self, days: int = 7, limit: int = 100) -> Dict[str, Any]:
|
||||
"""Get recent audit log records - MCP interface"""
|
||||
logger.info(f"Getting audit logs: Days: {days}, Limit: {limit}")
|
||||
|
||||
try:
|
||||
logs_df = self.get_recent_audit_logs(days=days, limit=limit)
|
||||
|
||||
# Convert DataFrame to JSON format
|
||||
if hasattr(logs_df, 'to_dict'):
|
||||
try:
|
||||
logs_data = logs_df.to_dict('records')
|
||||
except Exception as e:
|
||||
logger.warning(f"DataFrame.to_dict failed, trying manual conversion: {e}")
|
||||
# Manually convert DataFrame to records format
|
||||
logs_data = []
|
||||
if not logs_df.empty:
|
||||
for _, row in logs_df.iterrows():
|
||||
logs_data.append(dict(row))
|
||||
# Serialize datetime objects
|
||||
logs_data = self._serialize_datetime_objects(logs_data)
|
||||
else:
|
||||
logs_data = self._serialize_datetime_objects(logs_df)
|
||||
|
||||
return self._format_response(success=True, result=logs_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get audit logs: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting audit logs")
|
||||
|
||||
async def get_catalog_list_for_mcp(self) -> Dict[str, Any]:
|
||||
"""Get Doris catalog list - MCP interface"""
|
||||
logger.info("Getting catalog list")
|
||||
|
||||
try:
|
||||
catalogs = await self.get_catalog_list_async()
|
||||
return self._format_response(success=True, result=catalogs, message="Successfully retrieved catalog list")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get catalog list: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting catalog list")
|
||||
|
||||
|
||||
# ==================== Compatibility aliases ====================
|
||||
|
||||
# For backward compatibility, create MetadataManager alias
|
||||
class MetadataManager:
|
||||
"""
|
||||
Metadata manager - backward compatibility class
|
||||
Actually a wrapper for MetadataExtractor
|
||||
"""
|
||||
|
||||
def __init__(self, connection_manager=None):
|
||||
self.extractor = MetadataExtractor(connection_manager=connection_manager)
|
||||
|
||||
async def exec_query(self, sql: str, db_name: str = None, catalog_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""Execute SQL query and return results, supports catalog federation queries"""
|
||||
return await self.extractor.exec_query_for_mcp(sql, db_name, catalog_name, max_rows, timeout)
|
||||
|
||||
async def get_table_schema(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get detailed schema information for specified table (columns, types, comments, etc.)"""
|
||||
return await self.extractor.get_table_schema_for_mcp(table_name, db_name, catalog_name)
|
||||
|
||||
async def get_db_table_list(self, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get list of all table names in specified database"""
|
||||
return await self.extractor.get_db_table_list_for_mcp(db_name, catalog_name)
|
||||
|
||||
async def get_db_list(self, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get list of all database names on server"""
|
||||
return await self.extractor.get_db_list_for_mcp(catalog_name)
|
||||
|
||||
async def get_table_comment(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get comment information for specified table"""
|
||||
return await self.extractor.get_table_comment_for_mcp(table_name, db_name, catalog_name)
|
||||
|
||||
async def get_table_column_comments(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get comment information for all columns in specified table"""
|
||||
return await self.extractor.get_table_column_comments_for_mcp(table_name, db_name, catalog_name)
|
||||
|
||||
async def get_table_indexes(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get index information for specified table"""
|
||||
return await self.extractor.get_table_indexes_for_mcp(table_name, db_name, catalog_name)
|
||||
|
||||
async def get_recent_audit_logs(self, days: int = 7, limit: int = 100) -> Dict[str, Any]:
|
||||
"""Get recent audit log records"""
|
||||
return await self.extractor.get_recent_audit_logs_for_mcp(days, limit)
|
||||
|
||||
async def get_catalog_list(self) -> Dict[str, Any]:
|
||||
"""Get Doris catalog list"""
|
||||
return await self.extractor.get_catalog_list_for_mcp()
|
||||
861
doris_mcp_server/utils/security.py
Normal file
861
doris_mcp_server/utils/security.py
Normal file
@@ -0,0 +1,861 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Doris Security Management Module
|
||||
Implements enterprise-level authentication, authorization, SQL security validation and data masking functionality
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import sqlparse
|
||||
from sqlparse.sql import Statement
|
||||
from sqlparse.tokens import Keyword, Name
|
||||
|
||||
|
||||
class SecurityLevel(Enum):
|
||||
"""Security level enumeration"""
|
||||
|
||||
PUBLIC = "public"
|
||||
INTERNAL = "internal"
|
||||
CONFIDENTIAL = "confidential"
|
||||
SECRET = "secret"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthContext:
|
||||
"""Authentication context"""
|
||||
|
||||
user_id: str
|
||||
roles: list[str]
|
||||
permissions: list[str]
|
||||
session_id: str
|
||||
login_time: datetime | None = None
|
||||
last_activity: datetime | None = None
|
||||
security_level: SecurityLevel = SecurityLevel.INTERNAL
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Validation result"""
|
||||
|
||||
is_valid: bool
|
||||
error_message: str | None = None
|
||||
risk_level: str = "low"
|
||||
blocked_operations: list[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.blocked_operations is None:
|
||||
self.blocked_operations = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskingRule:
|
||||
"""Data masking rule"""
|
||||
|
||||
column_pattern: str
|
||||
algorithm: str
|
||||
parameters: dict[str, Any]
|
||||
security_level: SecurityLevel
|
||||
|
||||
|
||||
class DorisSecurityManager:
|
||||
"""Doris security manager
|
||||
|
||||
Provides complete security control functionality, including authentication, authorization, SQL security validation and data masking
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize security components
|
||||
self.auth_provider = AuthenticationProvider(config)
|
||||
self.authz_provider = AuthorizationProvider(config)
|
||||
self.sql_validator = SQLSecurityValidator(config)
|
||||
self.masking_processor = DataMaskingProcessor(config)
|
||||
|
||||
# Security rule configuration
|
||||
self.blocked_keywords = self._load_blocked_keywords()
|
||||
self.sensitive_tables = self._load_sensitive_tables()
|
||||
self.masking_rules = self._load_masking_rules()
|
||||
|
||||
def _load_blocked_keywords(self) -> set[str]:
|
||||
"""Load blocked SQL keywords"""
|
||||
default_blocked = {
|
||||
"DROP",
|
||||
"DELETE",
|
||||
"TRUNCATE",
|
||||
"ALTER",
|
||||
"CREATE",
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"GRANT",
|
||||
"REVOKE",
|
||||
"EXEC",
|
||||
"EXECUTE",
|
||||
"SHUTDOWN",
|
||||
"KILL",
|
||||
}
|
||||
|
||||
# Load custom rules from configuration file
|
||||
if hasattr(self.config, 'get'):
|
||||
custom_blocked = set(self.config.get("blocked_keywords", []))
|
||||
else:
|
||||
custom_blocked = set()
|
||||
|
||||
return default_blocked.union(custom_blocked)
|
||||
|
||||
def _load_sensitive_tables(self) -> dict[str, SecurityLevel]:
|
||||
"""Load sensitive table configuration"""
|
||||
default_tables = {
|
||||
"user_info": SecurityLevel.CONFIDENTIAL,
|
||||
"payment_records": SecurityLevel.SECRET,
|
||||
"employee_data": SecurityLevel.CONFIDENTIAL,
|
||||
"public_reports": SecurityLevel.PUBLIC,
|
||||
}
|
||||
|
||||
if hasattr(self.config, 'get'):
|
||||
config_tables = self.config.get("sensitive_tables", {})
|
||||
# Convert string values to SecurityLevel enum
|
||||
for table_name, level in config_tables.items():
|
||||
if isinstance(level, str):
|
||||
try:
|
||||
default_tables[table_name] = SecurityLevel(level.lower())
|
||||
except ValueError:
|
||||
default_tables[table_name] = SecurityLevel.INTERNAL
|
||||
else:
|
||||
default_tables[table_name] = level
|
||||
return default_tables
|
||||
else:
|
||||
return default_tables
|
||||
|
||||
def _load_masking_rules(self) -> list[MaskingRule]:
|
||||
"""Load data masking rules"""
|
||||
default_rules = [
|
||||
MaskingRule(
|
||||
column_pattern=r".*phone.*|.*mobile.*",
|
||||
algorithm="phone_mask",
|
||||
parameters={"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4},
|
||||
security_level=SecurityLevel.INTERNAL,
|
||||
),
|
||||
MaskingRule(
|
||||
column_pattern=r".*email.*",
|
||||
algorithm="email_mask",
|
||||
parameters={"mask_char": "*"},
|
||||
security_level=SecurityLevel.INTERNAL,
|
||||
),
|
||||
MaskingRule(
|
||||
column_pattern=r".*id_card.*|.*identity.*",
|
||||
algorithm="id_mask",
|
||||
parameters={"mask_char": "*", "keep_prefix": 6, "keep_suffix": 4},
|
||||
security_level=SecurityLevel.CONFIDENTIAL,
|
||||
),
|
||||
]
|
||||
|
||||
# Load custom rules from configuration
|
||||
custom_rules = []
|
||||
if hasattr(self.config, 'get'):
|
||||
custom_rules = self.config.get("masking_rules", [])
|
||||
elif hasattr(self.config, 'security') and hasattr(self.config.security, 'masking_rules'):
|
||||
custom_rules = self.config.security.masking_rules
|
||||
|
||||
for rule_config in custom_rules:
|
||||
if isinstance(rule_config, dict):
|
||||
default_rules.append(MaskingRule(**rule_config))
|
||||
elif isinstance(rule_config, MaskingRule):
|
||||
default_rules.append(rule_config)
|
||||
|
||||
return default_rules
|
||||
|
||||
async def authenticate_request(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Validate request authentication information"""
|
||||
return await self.auth_provider.authenticate(auth_info)
|
||||
|
||||
async def authorize_resource_access(
|
||||
self, auth_context: AuthContext, resource_uri: str
|
||||
) -> bool:
|
||||
"""Validate resource access permissions"""
|
||||
return await self.authz_provider.check_permission(
|
||||
auth_context, resource_uri, "read"
|
||||
)
|
||||
|
||||
async def validate_sql_security(
|
||||
self, sql: str, auth_context: AuthContext
|
||||
) -> ValidationResult:
|
||||
"""Validate SQL query security"""
|
||||
return await self.sql_validator.validate(sql, auth_context)
|
||||
|
||||
async def apply_data_masking(
|
||||
self, data: list[dict[str, Any]], auth_context: AuthContext
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Apply data masking processing"""
|
||||
return await self.masking_processor.process(data, auth_context)
|
||||
|
||||
|
||||
class AuthenticationProvider:
|
||||
"""Authentication provider"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.session_cache = {}
|
||||
|
||||
async def authenticate(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Perform identity authentication"""
|
||||
auth_type = auth_info.get("type", "token")
|
||||
|
||||
if auth_type == "token":
|
||||
return await self._authenticate_token(auth_info)
|
||||
elif auth_type == "basic":
|
||||
return await self._authenticate_basic(auth_info)
|
||||
else:
|
||||
raise ValueError(f"Unsupported authentication type: {auth_type}")
|
||||
|
||||
async def _authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Token authentication"""
|
||||
token = auth_info.get("token")
|
||||
if not token:
|
||||
raise ValueError("Missing authentication token")
|
||||
|
||||
# Validate token (simplified implementation, should validate JWT or query authentication service in practice)
|
||||
user_info = await self._validate_token(token)
|
||||
|
||||
return AuthContext(
|
||||
user_id=user_info["user_id"],
|
||||
roles=user_info["roles"],
|
||||
permissions=user_info["permissions"],
|
||||
session_id=auth_info.get("session_id", "default"),
|
||||
login_time=datetime.utcnow(),
|
||||
security_level=SecurityLevel(user_info.get("security_level", "internal")),
|
||||
)
|
||||
|
||||
async def _authenticate_basic(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Basic authentication (username password)"""
|
||||
username = auth_info.get("username")
|
||||
password = auth_info.get("password")
|
||||
|
||||
if not username or not password:
|
||||
raise ValueError("Missing username or password")
|
||||
|
||||
# Validate username password (simplified implementation)
|
||||
user_info = await self._validate_credentials(username, password)
|
||||
|
||||
return AuthContext(
|
||||
user_id=user_info["user_id"],
|
||||
roles=user_info["roles"],
|
||||
permissions=user_info["permissions"],
|
||||
session_id=auth_info.get("session_id", "default"),
|
||||
login_time=datetime.utcnow(),
|
||||
security_level=SecurityLevel(user_info.get("security_level", "internal")),
|
||||
)
|
||||
|
||||
async def _validate_token(self, token: str) -> dict[str, Any]:
|
||||
"""Validate token validity"""
|
||||
# Simplified implementation for testing, should parse JWT or query authentication service in practice
|
||||
valid_tokens = {
|
||||
"valid_token_123": {
|
||||
"user_id": "test_user",
|
||||
"roles": ["data_analyst"],
|
||||
"permissions": ["read_data"],
|
||||
"security_level": SecurityLevel.INTERNAL,
|
||||
},
|
||||
"admin_token_456": {
|
||||
"user_id": "admin_user",
|
||||
"roles": ["data_admin"],
|
||||
"permissions": ["admin"],
|
||||
"security_level": SecurityLevel.SECRET,
|
||||
}
|
||||
}
|
||||
|
||||
if token in valid_tokens:
|
||||
return valid_tokens[token]
|
||||
else:
|
||||
raise ValueError("Invalid token")
|
||||
|
||||
async def _validate_credentials(
|
||||
self, username: str, password: str
|
||||
) -> dict[str, Any]:
|
||||
"""Validate user credentials"""
|
||||
# Simplified implementation for testing, should query user database in practice
|
||||
valid_users = {
|
||||
"admin": {
|
||||
"password": "admin123",
|
||||
"user_id": "admin_user",
|
||||
"roles": ["data_admin"],
|
||||
"permissions": ["admin", "read_data", "write_data"],
|
||||
"security_level": SecurityLevel.SECRET,
|
||||
},
|
||||
"analyst": {
|
||||
"password": "analyst123",
|
||||
"user_id": "analyst_user",
|
||||
"roles": ["data_analyst"],
|
||||
"permissions": ["read_data"],
|
||||
"security_level": SecurityLevel.INTERNAL,
|
||||
}
|
||||
}
|
||||
|
||||
if username in valid_users and valid_users[username]["password"] == password:
|
||||
user_info = valid_users[username].copy()
|
||||
del user_info["password"] # Remove password from returned info
|
||||
return user_info
|
||||
else:
|
||||
raise ValueError("Incorrect username or password")
|
||||
|
||||
|
||||
class AuthorizationProvider:
|
||||
"""Authorization provider"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.permission_cache = {}
|
||||
|
||||
# Load sensitive tables configuration
|
||||
self.sensitive_tables = self._load_sensitive_tables()
|
||||
|
||||
def _load_sensitive_tables(self) -> dict[str, SecurityLevel]:
|
||||
"""Load sensitive table configuration"""
|
||||
default_tables = {
|
||||
"user_info": SecurityLevel.CONFIDENTIAL,
|
||||
"payment_records": SecurityLevel.SECRET,
|
||||
"employee_data": SecurityLevel.CONFIDENTIAL,
|
||||
"public_reports": SecurityLevel.PUBLIC,
|
||||
}
|
||||
|
||||
if hasattr(self.config, 'get'):
|
||||
config_tables = self.config.get("sensitive_tables", {})
|
||||
# Convert string values to SecurityLevel enum
|
||||
for table_name, level in config_tables.items():
|
||||
if isinstance(level, str):
|
||||
try:
|
||||
default_tables[table_name] = SecurityLevel(level.lower())
|
||||
except ValueError:
|
||||
default_tables[table_name] = SecurityLevel.INTERNAL
|
||||
else:
|
||||
default_tables[table_name] = level
|
||||
return default_tables
|
||||
else:
|
||||
return default_tables
|
||||
|
||||
async def check_permission(
|
||||
self, auth_context: AuthContext, resource_uri: str, action: str
|
||||
) -> bool:
|
||||
"""Check permissions"""
|
||||
# Parse resource information
|
||||
resource_info = self._parse_resource_uri(resource_uri)
|
||||
|
||||
# First check security level - this is mandatory
|
||||
if not await self._check_security_level_permission(auth_context, resource_info):
|
||||
return False
|
||||
|
||||
# Then check role-based permissions
|
||||
if await self._check_role_permission(auth_context, resource_info, action):
|
||||
return True
|
||||
|
||||
# Finally check user-based permissions
|
||||
if await self._check_user_permission(auth_context, resource_info, action):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _parse_resource_uri(self, uri: str) -> dict[str, str]:
|
||||
"""Parse resource URI"""
|
||||
parts = uri.split("/")
|
||||
if len(parts) >= 3:
|
||||
return {
|
||||
"type": parts[2], # table, view, etc.
|
||||
"name": parts[3] if len(parts) > 3 else "",
|
||||
"schema": parts[4] if len(parts) > 4 else "default",
|
||||
}
|
||||
return {"type": "unknown", "name": "", "schema": "default"}
|
||||
|
||||
async def _check_role_permission(
|
||||
self, auth_context: AuthContext, resource_info: dict[str, str], action: str
|
||||
) -> bool:
|
||||
"""Check role-based permissions"""
|
||||
# Role permission mapping
|
||||
role_permissions = {
|
||||
"data_analyst": {"table": ["read"], "view": ["read"]},
|
||||
"data_admin": {
|
||||
"table": ["read", "write", "admin"],
|
||||
"view": ["read", "write", "admin"],
|
||||
},
|
||||
}
|
||||
|
||||
for role in auth_context.roles:
|
||||
role_perms = role_permissions.get(role, {})
|
||||
resource_perms = role_perms.get(resource_info["type"], [])
|
||||
if action in resource_perms:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _check_user_permission(
|
||||
self, auth_context: AuthContext, resource_info: dict[str, str], action: str
|
||||
) -> bool:
|
||||
"""Check user-based permissions"""
|
||||
# User-specific permission check
|
||||
if "admin" in auth_context.permissions:
|
||||
return True
|
||||
|
||||
if action == "read" and "read_data" in auth_context.permissions:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _check_security_level_permission(
|
||||
self, auth_context: AuthContext, resource_info: dict[str, str]
|
||||
) -> bool:
|
||||
"""Check security level permissions"""
|
||||
# Get resource security level
|
||||
resource_security_level = self._get_resource_security_level(resource_info)
|
||||
|
||||
# Check if user security level is sufficient
|
||||
security_hierarchy = {
|
||||
SecurityLevel.PUBLIC: 0,
|
||||
SecurityLevel.INTERNAL: 1,
|
||||
SecurityLevel.CONFIDENTIAL: 2,
|
||||
SecurityLevel.SECRET: 3,
|
||||
}
|
||||
|
||||
user_level = security_hierarchy.get(auth_context.security_level, 0)
|
||||
resource_level = security_hierarchy.get(resource_security_level, 0)
|
||||
|
||||
# User must have higher or equal security level to access resource
|
||||
return user_level >= resource_level
|
||||
|
||||
def _get_resource_security_level(
|
||||
self, resource_info: dict[str, str]
|
||||
) -> SecurityLevel:
|
||||
"""Get resource security level"""
|
||||
# Get table security level from configuration
|
||||
table_name = resource_info.get("name", "")
|
||||
|
||||
# Use the loaded sensitive tables
|
||||
sensitive_tables = self.sensitive_tables
|
||||
|
||||
# Convert string values to SecurityLevel enum if needed
|
||||
security_level = sensitive_tables.get(table_name, SecurityLevel.INTERNAL)
|
||||
if isinstance(security_level, str):
|
||||
try:
|
||||
security_level = SecurityLevel(security_level.lower())
|
||||
except ValueError:
|
||||
security_level = SecurityLevel.INTERNAL
|
||||
|
||||
return security_level
|
||||
|
||||
|
||||
class SQLSecurityValidator:
|
||||
"""SQL security validator"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Handle DorisConfig object or dictionary configuration
|
||||
if hasattr(config, 'get'):
|
||||
# Dictionary configuration
|
||||
self.blocked_keywords = set(config.get("blocked_keywords", []))
|
||||
self.max_query_complexity = config.get("max_query_complexity", 100)
|
||||
else:
|
||||
# DorisConfig object, use default values
|
||||
self.blocked_keywords = set(["DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE"])
|
||||
self.max_query_complexity = 100
|
||||
|
||||
async def validate(self, sql: str, auth_context: AuthContext) -> ValidationResult:
|
||||
"""Validate SQL query security"""
|
||||
try:
|
||||
# Parse SQL statement
|
||||
parsed = sqlparse.parse(sql)[0]
|
||||
|
||||
# Check blocked operations first (more specific)
|
||||
keyword_result = await self._check_blocked_keywords(parsed)
|
||||
if not keyword_result.is_valid:
|
||||
return keyword_result
|
||||
|
||||
# Check SQL injection risks
|
||||
injection_result = await self._check_sql_injection(sql, parsed)
|
||||
if not injection_result.is_valid:
|
||||
return injection_result
|
||||
|
||||
# Check query complexity
|
||||
complexity_result = await self._check_query_complexity(parsed)
|
||||
if not complexity_result.is_valid:
|
||||
return complexity_result
|
||||
|
||||
# Check table access permissions
|
||||
table_result = await self._check_table_access(parsed, auth_context)
|
||||
if not table_result.is_valid:
|
||||
return table_result
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"SQL security validation failed: {e}")
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"SQL parsing error: {str(e)}",
|
||||
risk_level="high",
|
||||
)
|
||||
|
||||
async def _check_sql_injection(
|
||||
self, sql: str, parsed: Statement
|
||||
) -> ValidationResult:
|
||||
"""Check SQL injection risks"""
|
||||
# Check common SQL injection patterns
|
||||
injection_patterns = [
|
||||
r"(\s|^)(union|select|insert|update|delete|drop|create|alter)\s+.*\s+(union|select|insert|update|delete|drop|create|alter)",
|
||||
r"(\s|^)(or|and)\s+\d+\s*=\s*\d+",
|
||||
r"(\s|^)(or|and)\s+['\"].*['\"]",
|
||||
r";\s*(drop|delete|truncate|alter|create)",
|
||||
r"(exec|execute|sp_|xp_)",
|
||||
r"(script|javascript|vbscript)",
|
||||
r"(char|ascii|substring|concat)\s*\(",
|
||||
]
|
||||
|
||||
sql_lower = sql.lower()
|
||||
for pattern in injection_patterns:
|
||||
if re.search(pattern, sql_lower, re.IGNORECASE):
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Potential SQL injection risk detected",
|
||||
risk_level="high",
|
||||
)
|
||||
|
||||
# Check suspicious quotes and comments
|
||||
if self._has_suspicious_quotes_or_comments(sql):
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Suspicious quote or comment pattern detected",
|
||||
risk_level="medium",
|
||||
)
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
def _has_suspicious_quotes_or_comments(self, sql: str) -> bool:
|
||||
"""Check suspicious quote and comment patterns"""
|
||||
# Check unmatched quotes
|
||||
single_quotes = sql.count("'")
|
||||
double_quotes = sql.count('"')
|
||||
|
||||
if single_quotes % 2 != 0 or double_quotes % 2 != 0:
|
||||
return True
|
||||
|
||||
# Check SQL comments
|
||||
if "--" in sql or "/*" in sql:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _check_blocked_keywords(self, parsed: Statement) -> ValidationResult:
|
||||
"""Check blocked keywords"""
|
||||
blocked_operations = []
|
||||
|
||||
# Check all tokens in the parsed statement
|
||||
for token in parsed.flatten():
|
||||
# Check if token is a keyword (including DML/DDL) or name that matches blocked operations
|
||||
if (token.ttype is Keyword or
|
||||
token.ttype is Name or
|
||||
(token.ttype and str(token.ttype).startswith('Token.Keyword'))):
|
||||
token_value = token.value.upper().strip()
|
||||
if token_value in self.blocked_keywords:
|
||||
blocked_operations.append(token_value)
|
||||
# Also check for DDL/DML keywords in token values
|
||||
elif hasattr(token, 'value') and token.value:
|
||||
token_value = token.value.upper().strip()
|
||||
for blocked_keyword in self.blocked_keywords:
|
||||
if blocked_keyword in token_value:
|
||||
blocked_operations.append(blocked_keyword)
|
||||
|
||||
if blocked_operations:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"Contains blocked operations: {', '.join(set(blocked_operations))}",
|
||||
risk_level="high",
|
||||
blocked_operations=list(set(blocked_operations)),
|
||||
)
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
async def _check_query_complexity(self, parsed: Statement) -> ValidationResult:
|
||||
"""Check query complexity"""
|
||||
complexity_score = 0
|
||||
|
||||
# Calculate complexity score
|
||||
for token in parsed.flatten():
|
||||
if token.ttype is Keyword:
|
||||
keyword = token.value.upper()
|
||||
if keyword in ["JOIN", "INNER", "LEFT", "RIGHT", "FULL"]:
|
||||
complexity_score += 10
|
||||
elif keyword in ["UNION", "INTERSECT", "EXCEPT"]:
|
||||
complexity_score += 15
|
||||
elif keyword in ["GROUP BY", "ORDER BY", "HAVING"]:
|
||||
complexity_score += 5
|
||||
elif keyword in ["SUBQUERY", "EXISTS", "IN"]:
|
||||
complexity_score += 8
|
||||
|
||||
if complexity_score > self.max_query_complexity:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"Query complexity too high (score: {complexity_score}, limit: {self.max_query_complexity})",
|
||||
risk_level="medium",
|
||||
)
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
async def _check_table_access(
|
||||
self, parsed: Statement, auth_context: AuthContext
|
||||
) -> ValidationResult:
|
||||
"""Check table access permissions"""
|
||||
# Extract table names from query
|
||||
tables = self._extract_table_names(parsed)
|
||||
|
||||
# Check access permissions for each table
|
||||
unauthorized_tables = []
|
||||
for table in tables:
|
||||
# Should call authorization provider to check permissions
|
||||
# Simplified implementation, assume some tables require special permissions
|
||||
if (
|
||||
table.lower() in ["sensitive_data", "admin_logs"]
|
||||
and "admin" not in auth_context.roles
|
||||
):
|
||||
unauthorized_tables.append(table)
|
||||
|
||||
if unauthorized_tables:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"No access to tables: {', '.join(unauthorized_tables)}",
|
||||
risk_level="high",
|
||||
)
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
def _extract_table_names(self, parsed: Statement) -> list[str]:
|
||||
"""Extract table names from SQL statement"""
|
||||
tables = []
|
||||
|
||||
# Simplified table name extraction logic
|
||||
tokens = list(parsed.flatten())
|
||||
for i, token in enumerate(tokens):
|
||||
if token.ttype is Keyword and token.value.upper() == "FROM":
|
||||
# Find table name after FROM
|
||||
for j in range(i + 1, len(tokens)):
|
||||
next_token = tokens[j]
|
||||
if next_token.ttype is Name:
|
||||
tables.append(next_token.value)
|
||||
break
|
||||
elif next_token.ttype is Keyword:
|
||||
break
|
||||
|
||||
return tables
|
||||
|
||||
|
||||
class DataMaskingProcessor:
|
||||
"""Data masking processor"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.masking_algorithms = self._init_masking_algorithms()
|
||||
self.masking_rules = self._load_masking_rules()
|
||||
|
||||
def _load_masking_rules(self) -> list[MaskingRule]:
|
||||
"""Load data masking rules"""
|
||||
default_rules = [
|
||||
MaskingRule(
|
||||
column_pattern=r".*phone.*|.*mobile.*",
|
||||
algorithm="phone_mask",
|
||||
parameters={"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4},
|
||||
security_level=SecurityLevel.INTERNAL,
|
||||
),
|
||||
MaskingRule(
|
||||
column_pattern=r".*email.*",
|
||||
algorithm="email_mask",
|
||||
parameters={"mask_char": "*"},
|
||||
security_level=SecurityLevel.INTERNAL,
|
||||
),
|
||||
MaskingRule(
|
||||
column_pattern=r".*id_card.*|.*identity.*",
|
||||
algorithm="id_mask",
|
||||
parameters={"mask_char": "*", "keep_prefix": 6, "keep_suffix": 4},
|
||||
security_level=SecurityLevel.CONFIDENTIAL,
|
||||
),
|
||||
]
|
||||
|
||||
# Load custom rules from configuration
|
||||
if hasattr(self.config, 'get'):
|
||||
custom_rules = self.config.get("masking_rules", [])
|
||||
for rule_config in custom_rules:
|
||||
if isinstance(rule_config, dict):
|
||||
# Convert string security level to enum
|
||||
if 'security_level' in rule_config and isinstance(rule_config['security_level'], str):
|
||||
try:
|
||||
rule_config['security_level'] = SecurityLevel(rule_config['security_level'].lower())
|
||||
except ValueError:
|
||||
rule_config['security_level'] = SecurityLevel.INTERNAL
|
||||
default_rules.append(MaskingRule(**rule_config))
|
||||
elif isinstance(rule_config, MaskingRule):
|
||||
default_rules.append(rule_config)
|
||||
|
||||
return default_rules
|
||||
|
||||
def _init_masking_algorithms(self) -> dict[str, callable]:
|
||||
"""Initialize masking algorithms"""
|
||||
return {
|
||||
"phone_mask": self._mask_phone,
|
||||
"email_mask": self._mask_email,
|
||||
"id_mask": self._mask_id_card,
|
||||
"name_mask": self._mask_name,
|
||||
"partial_mask": self._mask_partial,
|
||||
}
|
||||
|
||||
async def process(
|
||||
self, data: list[dict[str, Any]], auth_context: AuthContext
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Process data masking"""
|
||||
if not data:
|
||||
return data
|
||||
|
||||
# Get applicable masking rules
|
||||
applicable_rules = self._get_applicable_rules(auth_context)
|
||||
|
||||
masked_data = []
|
||||
for row in data:
|
||||
masked_row = {}
|
||||
for column, value in row.items():
|
||||
masked_value = await self._apply_masking_rules(
|
||||
column, value, applicable_rules
|
||||
)
|
||||
masked_row[column] = masked_value
|
||||
masked_data.append(masked_row)
|
||||
|
||||
return masked_data
|
||||
|
||||
def _get_applicable_rules(self, auth_context: AuthContext) -> list[MaskingRule]:
|
||||
"""Get applicable masking rules"""
|
||||
applicable_rules = []
|
||||
|
||||
for rule in self.masking_rules:
|
||||
# Decide whether to apply masking rules based on user security level
|
||||
if self._should_apply_rule(rule, auth_context):
|
||||
applicable_rules.append(rule)
|
||||
|
||||
return applicable_rules
|
||||
|
||||
def _should_apply_rule(self, rule: MaskingRule, auth_context: AuthContext) -> bool:
|
||||
"""Determine whether masking rule should be applied"""
|
||||
# Admin users can see original data
|
||||
if "admin" in auth_context.roles:
|
||||
return False
|
||||
|
||||
# Decide based on security level
|
||||
security_hierarchy = {
|
||||
SecurityLevel.PUBLIC: 0,
|
||||
SecurityLevel.INTERNAL: 1,
|
||||
SecurityLevel.CONFIDENTIAL: 2,
|
||||
SecurityLevel.SECRET: 3,
|
||||
}
|
||||
|
||||
user_level = security_hierarchy.get(auth_context.security_level, 0)
|
||||
rule_level = security_hierarchy.get(rule.security_level, 0)
|
||||
|
||||
# Apply masking if user level is less than or equal to rule level
|
||||
return user_level <= rule_level
|
||||
|
||||
async def _apply_masking_rules(
|
||||
self, column: str, value: Any, rules: list[MaskingRule]
|
||||
) -> Any:
|
||||
"""Apply masking rules"""
|
||||
if value is None:
|
||||
return value
|
||||
|
||||
for rule in rules:
|
||||
if re.match(rule.column_pattern, column, re.IGNORECASE):
|
||||
algorithm = self.masking_algorithms.get(rule.algorithm)
|
||||
if algorithm:
|
||||
return algorithm(str(value), rule.parameters)
|
||||
|
||||
return value
|
||||
|
||||
def _mask_phone(self, value: str, params: dict[str, Any]) -> str:
|
||||
"""Phone number masking"""
|
||||
if len(value) < 7:
|
||||
return value
|
||||
|
||||
mask_char = params.get("mask_char", "*")
|
||||
keep_prefix = params.get("keep_prefix", 3)
|
||||
keep_suffix = params.get("keep_suffix", 4)
|
||||
|
||||
if len(value) <= keep_prefix + keep_suffix:
|
||||
return mask_char * len(value)
|
||||
|
||||
prefix = value[:keep_prefix]
|
||||
suffix = value[-keep_suffix:]
|
||||
middle_length = len(value) - keep_prefix - keep_suffix
|
||||
|
||||
return prefix + mask_char * middle_length + suffix
|
||||
|
||||
def _mask_email(self, value: str, params: dict[str, Any]) -> str:
|
||||
"""Email masking"""
|
||||
if "@" not in value:
|
||||
return value
|
||||
|
||||
mask_char = params.get("mask_char", "*")
|
||||
local, domain = value.split("@", 1)
|
||||
|
||||
if len(local) <= 2:
|
||||
masked_local = mask_char * len(local)
|
||||
else:
|
||||
masked_local = local[0] + mask_char * (len(local) - 2) + local[-1]
|
||||
|
||||
return f"{masked_local}@{domain}"
|
||||
|
||||
def _mask_id_card(self, value: str, params: dict[str, Any]) -> str:
|
||||
"""ID card number masking"""
|
||||
if len(value) < 10:
|
||||
return value
|
||||
|
||||
mask_char = params.get("mask_char", "*")
|
||||
keep_prefix = params.get("keep_prefix", 6)
|
||||
keep_suffix = params.get("keep_suffix", 4)
|
||||
|
||||
if len(value) <= keep_prefix + keep_suffix:
|
||||
return mask_char * len(value)
|
||||
|
||||
prefix = value[:keep_prefix]
|
||||
suffix = value[-keep_suffix:]
|
||||
middle_length = len(value) - keep_prefix - keep_suffix
|
||||
|
||||
return prefix + mask_char * middle_length + suffix
|
||||
|
||||
def _mask_name(self, value: str, params: dict[str, Any]) -> str:
|
||||
"""Name masking"""
|
||||
if len(value) <= 1:
|
||||
return value
|
||||
|
||||
mask_char = params.get("mask_char", "*")
|
||||
|
||||
if len(value) == 2:
|
||||
return value[0] + mask_char
|
||||
else:
|
||||
return value[0] + mask_char * (len(value) - 2) + value[-1]
|
||||
|
||||
def _mask_partial(self, value: str, params: dict[str, Any]) -> str:
|
||||
"""Partial masking"""
|
||||
mask_char = params.get("mask_char", "*")
|
||||
mask_ratio = params.get("mask_ratio", 0.5)
|
||||
|
||||
mask_length = int(len(value) * mask_ratio)
|
||||
start_pos = (len(value) - mask_length) // 2
|
||||
|
||||
result = list(value)
|
||||
for i in range(start_pos, start_pos + mask_length):
|
||||
if i < len(result):
|
||||
result[i] = mask_char
|
||||
|
||||
return "".join(result)
|
||||
@@ -1,352 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
SQL Execution Tool
|
||||
|
||||
Responsible for executing SQL queries and handling results
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
import time
|
||||
from typing import Dict, Any
|
||||
import re
|
||||
import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger("doris-mcp.sql-executor")
|
||||
|
||||
# Add environment variable control for whether to perform SQL security checks
|
||||
ENABLE_SQL_SECURITY_CHECK = os.environ.get('ENABLE_SQL_SECURITY_CHECK', 'true').lower() == 'true'
|
||||
|
||||
async def execute_sql_query(ctx) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute SQL query and return results
|
||||
|
||||
Args:
|
||||
ctx: Context object or dictionary containing request parameters
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Execution result
|
||||
"""
|
||||
try:
|
||||
# Support the case where the passed argument is a dictionary
|
||||
if isinstance(ctx, dict) and 'params' in ctx:
|
||||
params = ctx['params']
|
||||
else:
|
||||
params = ctx.params
|
||||
|
||||
sql = params.get("sql")
|
||||
db_name = params.get("db_name", os.getenv("DB_DATABASE", ""))
|
||||
catalog_name = params.get("catalog_name", None) # Add catalog parameter support
|
||||
max_rows = params.get("max_rows", 1000) # Maximum number of rows to return
|
||||
timeout = params.get("timeout", 30) # Timeout in seconds
|
||||
|
||||
if not sql:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"success": False,
|
||||
"error": "Missing SQL parameter",
|
||||
"message": "Please provide the SQL query to execute"
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# First check SQL security
|
||||
security_result = await _check_sql_security(sql)
|
||||
if not security_result.get("is_safe", False):
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"success": False,
|
||||
"error": "SQL security check failed",
|
||||
"message": "Query contains unsafe operations and cannot be executed",
|
||||
"security_issues": security_result.get("security_issues", [])
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Import database connection tool
|
||||
from doris_mcp_server.utils.db import execute_query
|
||||
|
||||
if not sql:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"success": False,
|
||||
"error": "Missing SQL parameter",
|
||||
"message": "Please provide the SQL query to execute"
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Ensure SELECT statements include a LIMIT clause
|
||||
sql_lower = sql.lower().strip()
|
||||
if sql_lower.startswith("select") and "limit" not in sql_lower:
|
||||
sql = sql.rstrip(";") + f" LIMIT {max_rows};"
|
||||
|
||||
# Start timer
|
||||
start_time = time.time()
|
||||
|
||||
# Execute query
|
||||
try:
|
||||
# For federation queries, SQL must use three-part naming: catalog_name.db_name.table_name
|
||||
# This is enforced at the tool description level
|
||||
|
||||
result = execute_query(sql, db_name)
|
||||
|
||||
# Calculate execution time
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# Build return result
|
||||
if isinstance(result, list):
|
||||
# Handle list of query results
|
||||
row_count = len(result)
|
||||
|
||||
# Extract column names
|
||||
if hasattr(result[0], "_fields"):
|
||||
# If it's a named tuple
|
||||
columns = list(result[0]._fields)
|
||||
else:
|
||||
# Otherwise, assume it's a dictionary
|
||||
columns = list(result[0].keys()) if isinstance(result[0], dict) else []
|
||||
|
||||
# Convert results to serializable format
|
||||
data = []
|
||||
for row in result:
|
||||
row_dict = {}
|
||||
if hasattr(row, "_asdict"):
|
||||
# If it's a named tuple
|
||||
row_dict = row._asdict()
|
||||
elif isinstance(row, dict):
|
||||
# If it's a dictionary
|
||||
row_dict = row
|
||||
else:
|
||||
# If it's a list or tuple
|
||||
row_dict = dict(zip(columns, row)) if columns else row
|
||||
|
||||
# Handle special types to make them JSON serializable
|
||||
serialized_row = _serialize_row_data(row_dict)
|
||||
data.append(serialized_row)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"success": True,
|
||||
"sql": sql,
|
||||
"row_count": row_count,
|
||||
"columns": columns,
|
||||
"data": data[:max_rows], # Limit returned rows
|
||||
"execution_time": execution_time,
|
||||
"truncated": row_count > max_rows
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
# Handle other types of results
|
||||
other_response = {
|
||||
"success": True,
|
||||
"sql": sql,
|
||||
"result": str(result),
|
||||
"execution_time": execution_time
|
||||
}
|
||||
other_response = _serialize_row_data(other_response)
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(other_response, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as db_error:
|
||||
error_message = str(db_error)
|
||||
|
||||
# Try to get more detailed error information
|
||||
error_details = {}
|
||||
if "timeout" in error_message.lower():
|
||||
error_details["type"] = "timeout"
|
||||
error_details["suggestion"] = "Query timed out, please optimize SQL or increase timeout"
|
||||
elif "syntax" in error_message.lower():
|
||||
error_details["type"] = "syntax"
|
||||
error_details["suggestion"] = "SQL syntax error, please check syntax"
|
||||
elif "not found" in error_message.lower() or "doesn't exist" in error_message.lower():
|
||||
error_details["type"] = "not_found"
|
||||
error_details["suggestion"] = "Table or column not found, please check table and column names"
|
||||
else:
|
||||
error_details["type"] = "unknown"
|
||||
error_details["suggestion"] = "Please check the SQL statement and try simplifying the query"
|
||||
|
||||
# Create error response
|
||||
error_response = {
|
||||
"success": False,
|
||||
"error": error_message,
|
||||
"error_details": error_details,
|
||||
"sql": sql,
|
||||
"db_name": db_name
|
||||
}
|
||||
|
||||
# Ensure error response is also serializable
|
||||
error_response = _serialize_row_data(error_response)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(error_response, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute SQL query: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
error_response = {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": "Error occurred while executing SQL query"
|
||||
}
|
||||
|
||||
# Ensure error response is also serializable
|
||||
error_response = _serialize_row_data(error_response)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(error_response, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Helper function
|
||||
async def _check_sql_security(sql: str) -> Dict[str, Any]:
|
||||
"""Check SQL security"""
|
||||
# If environment variable is set to disable security check, return safe immediately
|
||||
if not ENABLE_SQL_SECURITY_CHECK:
|
||||
return {
|
||||
"is_safe": True,
|
||||
"security_issues": []
|
||||
}
|
||||
|
||||
# Check if SQL contains dangerous operations
|
||||
sql_lower = sql.lower()
|
||||
|
||||
# Check if it's a read-only query type
|
||||
is_read_only = sql_lower.strip().startswith(("select ", "show ", "desc ", "describe ", "explain "))
|
||||
|
||||
# Define list of dangerous operations (checked for both read-only and non-read-only queries)
|
||||
dangerous_operations = [
|
||||
(r'\bdelete\b', "DELETE operation"),
|
||||
(r'\bdrop\b', "DROP TABLE/DATABASE operation"),
|
||||
(r'\btruncate\b', "TRUNCATE TABLE operation"),
|
||||
(r'\bupdate\b', "UPDATE operation"),
|
||||
(r'\binsert\b', "INSERT operation"),
|
||||
(r'\balter\b', "ALTER TABLE structure operation"),
|
||||
(r'\bcreate\b', "CREATE TABLE/DATABASE operation"),
|
||||
(r'\bgrant\b', "GRANT operation"),
|
||||
(r'\brevoke\b', "REVOKE permission operation"),
|
||||
(r'\bexec\b', "EXECUTE stored procedure"),
|
||||
(r'\bxp_', "Extended stored procedure, potential security risk"),
|
||||
(r'\bshutdown\b', "SHUTDOWN database operation"),
|
||||
(r'\binto\s+outfile\b', "Write to file operation"),
|
||||
(r'\bload_file\b', "Load file operation")
|
||||
]
|
||||
|
||||
# Dangerous operations checked only for non-read-only queries
|
||||
non_readonly_operations = []
|
||||
if not is_read_only:
|
||||
non_readonly_operations = [
|
||||
(r'--', "SQL comment, potential SQL injection"),
|
||||
(r'/\*', "SQL block comment, potential SQL injection")
|
||||
]
|
||||
|
||||
# Check if dangerous operations are included
|
||||
security_issues = []
|
||||
|
||||
# Check dangerous operations applicable to all queries
|
||||
for operation, description in dangerous_operations:
|
||||
if re.search(operation, sql_lower):
|
||||
# For specific keywords in read-only queries, differentiate if used as independent operations
|
||||
if is_read_only and operation in [r'\bcreate\b', r'\bdrop\b', r'\bdelete\b', r'\binsert\b', r'\bupdate\b', r'\balter\b']:
|
||||
# Check if used as DDL/DML keyword, e.g., CREATE TABLE, DROP DATABASE
|
||||
pattern = operation + r'\s+(?:table|database|view|index|procedure|function|trigger|event)'
|
||||
if re.search(pattern, sql_lower):
|
||||
security_issues.append({
|
||||
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
|
||||
"description": description,
|
||||
"severity": "High"
|
||||
})
|
||||
else:
|
||||
security_issues.append({
|
||||
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
|
||||
"description": description,
|
||||
"severity": "High"
|
||||
})
|
||||
|
||||
# Check dangerous operations specific to non-read-only queries
|
||||
for operation, description in non_readonly_operations:
|
||||
if re.search(operation, sql_lower):
|
||||
security_issues.append({
|
||||
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
|
||||
"description": description,
|
||||
"severity": "Medium"
|
||||
})
|
||||
|
||||
return {
|
||||
"is_safe": len(security_issues) == 0,
|
||||
"security_issues": security_issues
|
||||
}
|
||||
|
||||
|
||||
def _serialize_row_data(row_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert special types in row data (like date, time, Decimal) to JSON serializable format
|
||||
|
||||
Args:
|
||||
row_data: Row data dictionary
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Processed serializable dictionary
|
||||
"""
|
||||
serialized_data = {}
|
||||
for key, value in row_data.items():
|
||||
if value is None:
|
||||
serialized_data[key] = None
|
||||
elif isinstance(value, (datetime.date, datetime.datetime)):
|
||||
# Convert date and time types to ISO format string
|
||||
serialized_data[key] = value.isoformat()
|
||||
elif isinstance(value, Decimal):
|
||||
# Convert Decimal type to float
|
||||
serialized_data[key] = float(value)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
# Recursively process elements in list or tuple
|
||||
serialized_data[key] = [
|
||||
_serialize_row_data(item) if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
elif isinstance(value, dict):
|
||||
# Recursively process nested dictionaries
|
||||
serialized_data[key] = _serialize_row_data(value)
|
||||
else:
|
||||
serialized_data[key] = value
|
||||
return serialized_data
|
||||
Reference in New Issue
Block a user