Files
doris-mcp-server/doris_mcp_server/utils/config.py

656 lines
26 KiB
Python
Raw Normal View History

2025-06-08 18:44:40 +08:00
#!/usr/bin/env python3
2025-06-08 19:22:13 +08:00
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
2025-06-08 18:44:40 +08:00
"""
Doris Configuration Management Module
Implements configuration loading, validation and management functionality
"""
import json
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
try:
from dotenv import load_dotenv
except ImportError:
load_dotenv = None
@dataclass
class DatabaseConfig:
"""Database connection configuration"""
host: str = "localhost"
port: int = 9030
user: str = "root"
password: str = ""
database: str = "information_schema"
charset: str = "UTF8"
2025-06-08 18:44:40 +08:00
2025-06-12 19:36:16 +08:00
# FE HTTP API port for profile and other HTTP APIs
fe_http_port: int = 8030
# BE nodes configuration for external access
# If be_hosts is empty, will use "show backends" to get BE nodes
be_hosts: list[str] = field(default_factory=list)
be_webserver_port: int = 8040
2025-06-08 18:44:40 +08:00
# Connection pool configuration
min_connections: int = 5
max_connections: int = 20
connection_timeout: int = 30
health_check_interval: int = 60
max_connection_age: int = 3600
@dataclass
class SecurityConfig:
"""Security configuration"""
# Authentication configuration
auth_type: str = "token" # token, basic, oauth
token_secret: str = "default_secret"
token_expiry: int = 3600
# SQL security configuration
blocked_keywords: list[str] = field(
default_factory=lambda: [
"DROP",
"DELETE",
"TRUNCATE",
"ALTER",
"CREATE",
"INSERT",
"UPDATE",
"GRANT",
"REVOKE",
]
)
max_query_complexity: int = 100
max_result_rows: int = 10000
# Sensitive table configuration
sensitive_tables: dict[str, str] = field(default_factory=dict)
# Data masking configuration
enable_masking: bool = True
masking_rules: list[dict[str, Any]] = field(default_factory=list)
@dataclass
class PerformanceConfig:
"""Performance configuration"""
# Query cache configuration
enable_query_cache: bool = True
cache_ttl: int = 300
max_cache_size: int = 1000
# Concurrency control configuration
max_concurrent_queries: int = 50
query_timeout: int = 300
# Connection pool optimization configuration
connection_pool_size: int = 20
idle_timeout: int = 1800
2025-06-12 19:36:16 +08:00
# Response content size limit (characters)
max_response_content_size: int = 4096
2025-06-08 18:44:40 +08:00
@dataclass
class LoggingConfig:
"""Logging configuration"""
level: str = "INFO"
format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
file_path: str | None = None
max_file_size: int = 10 * 1024 * 1024 # 10MB
backup_count: int = 5
# Audit log configuration
enable_audit: bool = True
audit_file_path: str | None = None
@dataclass
class MonitoringConfig:
"""Monitoring configuration"""
# Metrics collection configuration
enable_metrics: bool = True
metrics_port: int = 3001
2025-06-08 18:44:40 +08:00
metrics_path: str = "/metrics"
# Health check configuration
health_check_port: int = 3002
2025-06-08 18:44:40 +08:00
health_check_path: str = "/health"
# Alert configuration
enable_alerts: bool = False
alert_webhook_url: str | None = None
@dataclass
class DorisConfig:
"""Doris MCP Server complete configuration"""
# Basic configuration
server_name: str = "doris-mcp-server"
2025-06-12 19:36:16 +08:00
server_version: str = "0.4.0"
server_port: int = 3000
transport: str = "stdio"
2025-06-12 19:36:16 +08:00
# Temporary files configuration
temp_files_dir: str = "tmp" # Temporary files directory for Explain and Profile outputs
2025-06-08 18:44:40 +08:00
# Sub-configuration modules
database: DatabaseConfig = field(default_factory=DatabaseConfig)
security: SecurityConfig = field(default_factory=SecurityConfig)
performance: PerformanceConfig = field(default_factory=PerformanceConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
monitoring: MonitoringConfig = field(default_factory=MonitoringConfig)
# Custom configuration
custom_config: dict[str, Any] = field(default_factory=dict)
@classmethod
def from_file(cls, config_path: str) -> "DorisConfig":
"""Load configuration from file"""
config_file = Path(config_path)
if not config_file.exists():
raise FileNotFoundError(f"Configuration file does not exist: {config_path}")
try:
with open(config_file, encoding="utf-8") as f:
if config_file.suffix.lower() == ".json":
config_data = json.load(f)
else:
# Support other formats (like YAML)
raise ValueError(f"Unsupported configuration file format: {config_file.suffix}")
return cls._from_dict(config_data)
except Exception as e:
raise ValueError(f"Failed to load configuration file: {e}")
@classmethod
def from_env(cls, env_file: str | None = None) -> "DorisConfig":
"""Load configuration from environment variables
Args:
env_file: .env file path, if None, search in the following order:
.env, .env.local, .env.production, .env.development
"""
# Load .env file
if load_dotenv is not None:
if env_file:
# Load specified .env file
if Path(env_file).exists():
load_dotenv(env_file)
logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_file}")
else:
logging.getLogger(__name__).warning(f"Environment configuration file does not exist: {env_file}")
else:
# Load .env files in priority order
env_files = [".env", ".env.local", ".env.production", ".env.development"]
for env_path in env_files:
if Path(env_path).exists():
load_dotenv(env_path)
logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_path}")
break
else:
logging.getLogger(__name__).info("No .env configuration file found, using system environment variables")
else:
logging.getLogger(__name__).warning("python-dotenv not installed, cannot load .env files")
config = cls()
# Database configuration
config.database.host = os.getenv("DORIS_HOST", config.database.host)
config.database.port = int(os.getenv("DORIS_PORT", str(config.database.port)))
config.database.user = os.getenv("DORIS_USER", config.database.user)
config.database.password = os.getenv("DORIS_PASSWORD", config.database.password)
config.database.database = os.getenv("DORIS_DATABASE", config.database.database)
2025-06-12 19:36:16 +08:00
config.database.fe_http_port = int(os.getenv("DORIS_FE_HTTP_PORT", str(config.database.fe_http_port)))
# BE nodes configuration
be_hosts_env = os.getenv("DORIS_BE_HOSTS", "")
if be_hosts_env:
config.database.be_hosts = [host.strip() for host in be_hosts_env.split(",") if host.strip()]
config.database.be_webserver_port = int(os.getenv("DORIS_BE_WEBSERVER_PORT", str(config.database.be_webserver_port)))
2025-06-08 18:44:40 +08:00
# Connection pool configuration
config.database.min_connections = int(
os.getenv("DORIS_MIN_CONNECTIONS", str(config.database.min_connections))
)
config.database.max_connections = int(
os.getenv("DORIS_MAX_CONNECTIONS", str(config.database.max_connections))
)
config.database.connection_timeout = int(
os.getenv("DORIS_CONNECTION_TIMEOUT", str(config.database.connection_timeout))
)
config.database.health_check_interval = int(
os.getenv("DORIS_HEALTH_CHECK_INTERVAL", str(config.database.health_check_interval))
)
config.database.max_connection_age = int(
os.getenv("DORIS_MAX_CONNECTION_AGE", str(config.database.max_connection_age))
)
# Security configuration
config.security.auth_type = os.getenv("AUTH_TYPE", config.security.auth_type)
config.security.token_secret = os.getenv("TOKEN_SECRET", config.security.token_secret)
config.security.token_expiry = int(
os.getenv("TOKEN_EXPIRY", str(config.security.token_expiry))
)
config.security.max_result_rows = int(
os.getenv("MAX_RESULT_ROWS", str(config.security.max_result_rows))
)
config.security.max_query_complexity = int(
os.getenv("MAX_QUERY_COMPLEXITY", str(config.security.max_query_complexity))
)
config.security.enable_masking = (
os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true"
)
# Performance configuration
config.performance.enable_query_cache = (
os.getenv("ENABLE_QUERY_CACHE", "true").lower() == "true"
)
config.performance.cache_ttl = int(
os.getenv("CACHE_TTL", str(config.performance.cache_ttl))
)
config.performance.max_cache_size = int(
os.getenv("MAX_CACHE_SIZE", str(config.performance.max_cache_size))
)
config.performance.max_concurrent_queries = int(
os.getenv("MAX_CONCURRENT_QUERIES", str(config.performance.max_concurrent_queries))
)
config.performance.query_timeout = int(
os.getenv("QUERY_TIMEOUT", str(config.performance.query_timeout))
)
2025-06-12 19:36:16 +08:00
config.performance.max_response_content_size = int(
os.getenv("MAX_RESPONSE_CONTENT_SIZE", str(config.performance.max_response_content_size))
)
2025-06-08 18:44:40 +08:00
# Logging configuration
config.logging.level = os.getenv("LOG_LEVEL", config.logging.level)
config.logging.file_path = os.getenv("LOG_FILE_PATH", config.logging.file_path)
config.logging.enable_audit = (
os.getenv("ENABLE_AUDIT", str(config.logging.enable_audit).lower()).lower() == "true"
)
config.logging.audit_file_path = os.getenv("AUDIT_FILE_PATH", config.logging.audit_file_path)
# Monitoring configuration
config.monitoring.enable_metrics = (
os.getenv("ENABLE_METRICS", "true").lower() == "true"
)
config.monitoring.metrics_port = int(
os.getenv("METRICS_PORT", str(config.monitoring.metrics_port))
)
config.monitoring.health_check_port = int(
os.getenv("HEALTH_CHECK_PORT", str(config.monitoring.health_check_port))
)
config.monitoring.enable_alerts = (
os.getenv("ENABLE_ALERTS", str(config.monitoring.enable_alerts).lower()).lower() == "true"
)
config.monitoring.alert_webhook_url = os.getenv("ALERT_WEBHOOK_URL", config.monitoring.alert_webhook_url)
# Server configuration
config.server_name = os.getenv("SERVER_NAME", config.server_name)
config.server_version = os.getenv("SERVER_VERSION", config.server_version)
config.server_port = int(os.getenv("SERVER_PORT", str(config.server_port)))
2025-06-12 19:36:16 +08:00
config.temp_files_dir = os.getenv("TEMP_FILES_DIR", config.temp_files_dir)
2025-06-08 18:44:40 +08:00
return config
@classmethod
def _from_dict(cls, config_data: dict[str, Any]) -> "DorisConfig":
"""Create configuration object from dictionary"""
config = cls()
# Update basic configuration
2025-06-12 19:36:16 +08:00
for key in ["server_name", "server_version", "server_port", "temp_files_dir"]:
2025-06-08 18:44:40 +08:00
if key in config_data:
setattr(config, key, config_data[key])
# Update database configuration
if "database" in config_data:
db_config = config_data["database"]
for key, value in db_config.items():
if hasattr(config.database, key):
setattr(config.database, key, value)
# Update security configuration
if "security" in config_data:
sec_config = config_data["security"]
for key, value in sec_config.items():
if hasattr(config.security, key):
setattr(config.security, key, value)
# Update performance configuration
if "performance" in config_data:
perf_config = config_data["performance"]
for key, value in perf_config.items():
if hasattr(config.performance, key):
setattr(config.performance, key, value)
# Update logging configuration
if "logging" in config_data:
log_config = config_data["logging"]
for key, value in log_config.items():
if hasattr(config.logging, key):
setattr(config.logging, key, value)
# Update monitoring configuration
if "monitoring" in config_data:
mon_config = config_data["monitoring"]
for key, value in mon_config.items():
if hasattr(config.monitoring, key):
setattr(config.monitoring, key, value)
# Custom configuration
config.custom_config = config_data.get("custom", {})
return config
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary format"""
return {
"server_name": self.server_name,
"server_version": self.server_version,
"server_port": self.server_port,
2025-06-12 19:36:16 +08:00
"temp_files_dir": self.temp_files_dir,
2025-06-08 18:44:40 +08:00
"database": {
"host": self.database.host,
"port": self.database.port,
"user": self.database.user,
"password": "***", # Hide password
"database": self.database.database,
"charset": self.database.charset,
2025-06-12 19:36:16 +08:00
"fe_http_port": self.database.fe_http_port,
"be_hosts": self.database.be_hosts,
"be_webserver_port": self.database.be_webserver_port,
2025-06-08 18:44:40 +08:00
"min_connections": self.database.min_connections,
"max_connections": self.database.max_connections,
"connection_timeout": self.database.connection_timeout,
"health_check_interval": self.database.health_check_interval,
"max_connection_age": self.database.max_connection_age,
},
"security": {
"auth_type": self.security.auth_type,
"token_secret": "***", # Hide secret key
"token_expiry": self.security.token_expiry,
"blocked_keywords": self.security.blocked_keywords,
"max_query_complexity": self.security.max_query_complexity,
"max_result_rows": self.security.max_result_rows,
"sensitive_tables": self.security.sensitive_tables,
"enable_masking": self.security.enable_masking,
"masking_rules": len(self.security.masking_rules),
},
"performance": {
"enable_query_cache": self.performance.enable_query_cache,
"cache_ttl": self.performance.cache_ttl,
"max_cache_size": self.performance.max_cache_size,
"max_concurrent_queries": self.performance.max_concurrent_queries,
"query_timeout": self.performance.query_timeout,
"connection_pool_size": self.performance.connection_pool_size,
"idle_timeout": self.performance.idle_timeout,
2025-06-12 19:36:16 +08:00
"max_response_content_size": self.performance.max_response_content_size,
2025-06-08 18:44:40 +08:00
},
"logging": {
"level": self.logging.level,
"format": self.logging.format,
"file_path": self.logging.file_path,
"max_file_size": self.logging.max_file_size,
"backup_count": self.logging.backup_count,
"enable_audit": self.logging.enable_audit,
"audit_file_path": self.logging.audit_file_path,
},
"monitoring": {
"enable_metrics": self.monitoring.enable_metrics,
"metrics_port": self.monitoring.metrics_port,
"metrics_path": self.monitoring.metrics_path,
"health_check_port": self.monitoring.health_check_port,
"health_check_path": self.monitoring.health_check_path,
"enable_alerts": self.monitoring.enable_alerts,
"alert_webhook_url": self.monitoring.alert_webhook_url,
},
"custom": self.custom_config,
}
def save_to_file(self, config_path: str):
"""Save configuration to file"""
config_file = Path(config_path)
config_file.parent.mkdir(parents=True, exist_ok=True)
try:
with open(config_file, "w", encoding="utf-8") as f:
if config_file.suffix.lower() == ".json":
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
else:
raise ValueError(f"Unsupported configuration file format: {config_file.suffix}")
except Exception as e:
raise ValueError(f"Failed to save configuration file: {e}")
def validate(self) -> list[str]:
"""Validate configuration validity"""
errors = []
# Validate database configuration
if not self.database.host:
errors.append("Database host address cannot be empty")
if not (1 <= self.database.port <= 65535):
errors.append("Database port must be in the range 1-65535")
if not self.database.user:
errors.append("Database username cannot be empty")
if self.database.min_connections <= 0:
errors.append("Minimum connections must be greater than 0")
if self.database.max_connections <= self.database.min_connections:
errors.append("Maximum connections must be greater than minimum connections")
# Validate security configuration
if self.security.auth_type not in ["token", "basic", "oauth"]:
errors.append("Authentication type must be one of token, basic, or oauth")
if self.security.token_expiry <= 0:
errors.append("Token expiry time must be greater than 0")
if self.security.max_query_complexity <= 0:
errors.append("Maximum query complexity must be greater than 0")
if self.security.max_result_rows <= 0:
errors.append("Maximum result rows must be greater than 0")
# Validate performance configuration
if self.performance.cache_ttl <= 0:
errors.append("Cache TTL must be greater than 0")
if self.performance.max_concurrent_queries <= 0:
errors.append("Maximum concurrent queries must be greater than 0")
if self.performance.query_timeout <= 0:
errors.append("Query timeout must be greater than 0")
# Validate logging configuration
if self.logging.level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
errors.append("Log level must be one of DEBUG, INFO, WARNING, ERROR, or CRITICAL")
if self.logging.max_file_size <= 0:
errors.append("Maximum log file size must be greater than 0")
if self.logging.backup_count < 0:
errors.append("Log backup count cannot be negative")
# Validate monitoring configuration
if not (1 <= self.monitoring.metrics_port <= 65535):
errors.append("Monitoring port must be in the range 1-65535")
if not (1 <= self.monitoring.health_check_port <= 65535):
errors.append("Health check port must be in the range 1-65535")
return errors
def get_connection_string(self) -> str:
"""Get database connection string (hide password)"""
return f"mysql://{self.database.user}:***@{self.database.host}:{self.database.port}/{self.database.database}"
def get_config_summary(self) -> dict[str, Any]:
"""Get configuration summary information"""
return {
"server": f"{self.server_name} v{self.server_version}",
"database": f"{self.database.host}:{self.database.port}/{self.database.database}",
"connection_pool": f"{self.database.min_connections}-{self.database.max_connections}",
"security": {
"auth_type": self.security.auth_type,
"masking_enabled": self.security.enable_masking,
"blocked_keywords_count": len(self.security.blocked_keywords),
},
"performance": {
"cache_enabled": self.performance.enable_query_cache,
"max_concurrent": self.performance.max_concurrent_queries,
"query_timeout": self.performance.query_timeout,
},
"monitoring": {
"metrics_enabled": self.monitoring.enable_metrics,
"alerts_enabled": self.monitoring.enable_alerts,
},
}
class ConfigManager:
"""Configuration manager class"""
def __init__(self, config: DorisConfig):
self.config = config
self.logger = logging.getLogger(__name__)
def setup_logging(self):
"""Setup logging configuration"""
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(getattr(logging, self.config.logging.level.upper()))
# Clear existing handlers
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
# Create formatter
formatter = logging.Formatter(self.config.logging.format)
# Console handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
root_logger.addHandler(console_handler)
# File handler (if configured)
if self.config.logging.file_path:
try:
from logging.handlers import RotatingFileHandler
file_handler = RotatingFileHandler(
self.config.logging.file_path,
maxBytes=self.config.logging.max_file_size,
backupCount=self.config.logging.backup_count,
encoding="utf-8",
)
file_handler.setFormatter(formatter)
root_logger.addHandler(file_handler)
except Exception as e:
self.logger.warning(f"Failed to setup file logging: {e}")
# Audit log handler (if configured)
if self.config.logging.enable_audit and self.config.logging.audit_file_path:
try:
from logging.handlers import RotatingFileHandler
audit_logger = logging.getLogger("audit")
audit_handler = RotatingFileHandler(
self.config.logging.audit_file_path,
maxBytes=self.config.logging.max_file_size,
backupCount=self.config.logging.backup_count,
encoding="utf-8",
)
audit_handler.setFormatter(formatter)
audit_logger.addHandler(audit_handler)
audit_logger.setLevel(logging.INFO)
except Exception as e:
self.logger.warning(f"Failed to setup audit logging: {e}")
def validate_config(self) -> bool:
"""Validate configuration"""
errors = self.config.validate()
if errors:
self.logger.error("Configuration validation failed:")
for error in errors:
self.logger.error(f" - {error}")
return False
self.logger.info("Configuration validation passed")
return True
def log_config_summary(self):
"""Log configuration summary"""
summary = self.config.get_config_summary()
self.logger.info("Configuration Summary:")
self.logger.info(f" Server: {summary['server']}")
self.logger.info(f" Database: {summary['database']}")
self.logger.info(f" Connection Pool: {summary['connection_pool']}")
self.logger.info(f" Security: {summary['security']}")
self.logger.info(f" Performance: {summary['performance']}")
self.logger.info(f" Monitoring: {summary['monitoring']}")
def create_default_config_file(config_path: str):
"""Create default configuration file"""
config = DorisConfig()
config.save_to_file(config_path)
print(f"Default configuration file created: {config_path}")
# Example usage
if __name__ == "__main__":
# Create default configuration
config = DorisConfig()
# Load from environment variables
# config = DorisConfig.from_env()
# Load from file
# config = DorisConfig.from_file("config.json")
# Validate configuration
config_manager = ConfigManager(config)
if config_manager.validate_config():
config_manager.setup_logging()
config_manager.log_config_summary()
# Save configuration
config.save_to_file("example_config.json")
print("Configuration saved to example_config.json")
else:
print("Configuration validation failed")