#!/usr/bin/env python3 # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """ Doris Configuration Management Module Implements configuration loading, validation and management functionality """ import json import logging import os from dataclasses import dataclass, field from pathlib import Path from typing import Any try: from dotenv import load_dotenv except ImportError: load_dotenv = None @dataclass class DatabaseConfig: """Database connection configuration""" host: str = "localhost" port: int = 9030 user: str = "root" password: str = "" database: str = "information_schema" charset: str = "UTF8" # Connection pool configuration min_connections: int = 5 max_connections: int = 20 connection_timeout: int = 30 health_check_interval: int = 60 max_connection_age: int = 3600 @dataclass class SecurityConfig: """Security configuration""" # Authentication configuration auth_type: str = "token" # token, basic, oauth token_secret: str = "default_secret" token_expiry: int = 3600 # SQL security configuration blocked_keywords: list[str] = field( default_factory=lambda: [ "DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE", "GRANT", "REVOKE", ] ) max_query_complexity: int = 100 max_result_rows: int = 10000 # Sensitive table configuration sensitive_tables: dict[str, str] = field(default_factory=dict) # Data masking configuration enable_masking: bool = True masking_rules: list[dict[str, Any]] = field(default_factory=list) @dataclass class PerformanceConfig: """Performance configuration""" # Query cache configuration enable_query_cache: bool = True cache_ttl: int = 300 max_cache_size: int = 1000 # Concurrency control configuration max_concurrent_queries: int = 50 query_timeout: int = 300 # Connection pool optimization configuration connection_pool_size: int = 20 idle_timeout: int = 1800 @dataclass class LoggingConfig: """Logging configuration""" level: str = "INFO" format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" file_path: str | None = None max_file_size: int = 10 * 1024 * 1024 # 10MB backup_count: int = 5 # Audit log configuration enable_audit: bool = True audit_file_path: str | None = None @dataclass class MonitoringConfig: """Monitoring configuration""" # Metrics collection configuration enable_metrics: bool = True metrics_port: int = 3001 metrics_path: str = "/metrics" # Health check configuration health_check_port: int = 3002 health_check_path: str = "/health" # Alert configuration enable_alerts: bool = False alert_webhook_url: str | None = None @dataclass class DorisConfig: """Doris MCP Server complete configuration""" # Basic configuration server_name: str = "doris-mcp-server" server_version: str = "0.3.0" server_port: int = 3000 transport: str = "stdio" # Sub-configuration modules database: DatabaseConfig = field(default_factory=DatabaseConfig) security: SecurityConfig = field(default_factory=SecurityConfig) performance: PerformanceConfig = field(default_factory=PerformanceConfig) logging: LoggingConfig = field(default_factory=LoggingConfig) monitoring: MonitoringConfig = field(default_factory=MonitoringConfig) # Custom configuration custom_config: dict[str, Any] = field(default_factory=dict) @classmethod def from_file(cls, config_path: str) -> "DorisConfig": """Load configuration from file""" config_file = Path(config_path) if not config_file.exists(): raise FileNotFoundError(f"Configuration file does not exist: {config_path}") try: with open(config_file, encoding="utf-8") as f: if config_file.suffix.lower() == ".json": config_data = json.load(f) else: # Support other formats (like YAML) raise ValueError(f"Unsupported configuration file format: {config_file.suffix}") return cls._from_dict(config_data) except Exception as e: raise ValueError(f"Failed to load configuration file: {e}") @classmethod def from_env(cls, env_file: str | None = None) -> "DorisConfig": """Load configuration from environment variables Args: env_file: .env file path, if None, search in the following order: .env, .env.local, .env.production, .env.development """ # Load .env file if load_dotenv is not None: if env_file: # Load specified .env file if Path(env_file).exists(): load_dotenv(env_file) logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_file}") else: logging.getLogger(__name__).warning(f"Environment configuration file does not exist: {env_file}") else: # Load .env files in priority order env_files = [".env", ".env.local", ".env.production", ".env.development"] for env_path in env_files: if Path(env_path).exists(): load_dotenv(env_path) logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_path}") break else: logging.getLogger(__name__).info("No .env configuration file found, using system environment variables") else: logging.getLogger(__name__).warning("python-dotenv not installed, cannot load .env files") config = cls() # Database configuration config.database.host = os.getenv("DORIS_HOST", config.database.host) config.database.port = int(os.getenv("DORIS_PORT", str(config.database.port))) config.database.user = os.getenv("DORIS_USER", config.database.user) config.database.password = os.getenv("DORIS_PASSWORD", config.database.password) config.database.database = os.getenv("DORIS_DATABASE", config.database.database) # Connection pool configuration config.database.min_connections = int( os.getenv("DORIS_MIN_CONNECTIONS", str(config.database.min_connections)) ) config.database.max_connections = int( os.getenv("DORIS_MAX_CONNECTIONS", str(config.database.max_connections)) ) config.database.connection_timeout = int( os.getenv("DORIS_CONNECTION_TIMEOUT", str(config.database.connection_timeout)) ) config.database.health_check_interval = int( os.getenv("DORIS_HEALTH_CHECK_INTERVAL", str(config.database.health_check_interval)) ) config.database.max_connection_age = int( os.getenv("DORIS_MAX_CONNECTION_AGE", str(config.database.max_connection_age)) ) # Security configuration config.security.auth_type = os.getenv("AUTH_TYPE", config.security.auth_type) config.security.token_secret = os.getenv("TOKEN_SECRET", config.security.token_secret) config.security.token_expiry = int( os.getenv("TOKEN_EXPIRY", str(config.security.token_expiry)) ) config.security.max_result_rows = int( os.getenv("MAX_RESULT_ROWS", str(config.security.max_result_rows)) ) config.security.max_query_complexity = int( os.getenv("MAX_QUERY_COMPLEXITY", str(config.security.max_query_complexity)) ) config.security.enable_masking = ( os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true" ) # Performance configuration config.performance.enable_query_cache = ( os.getenv("ENABLE_QUERY_CACHE", "true").lower() == "true" ) config.performance.cache_ttl = int( os.getenv("CACHE_TTL", str(config.performance.cache_ttl)) ) config.performance.max_cache_size = int( os.getenv("MAX_CACHE_SIZE", str(config.performance.max_cache_size)) ) config.performance.max_concurrent_queries = int( os.getenv("MAX_CONCURRENT_QUERIES", str(config.performance.max_concurrent_queries)) ) config.performance.query_timeout = int( os.getenv("QUERY_TIMEOUT", str(config.performance.query_timeout)) ) # Logging configuration config.logging.level = os.getenv("LOG_LEVEL", config.logging.level) config.logging.file_path = os.getenv("LOG_FILE_PATH", config.logging.file_path) config.logging.enable_audit = ( os.getenv("ENABLE_AUDIT", str(config.logging.enable_audit).lower()).lower() == "true" ) config.logging.audit_file_path = os.getenv("AUDIT_FILE_PATH", config.logging.audit_file_path) # Monitoring configuration config.monitoring.enable_metrics = ( os.getenv("ENABLE_METRICS", "true").lower() == "true" ) config.monitoring.metrics_port = int( os.getenv("METRICS_PORT", str(config.monitoring.metrics_port)) ) config.monitoring.health_check_port = int( os.getenv("HEALTH_CHECK_PORT", str(config.monitoring.health_check_port)) ) config.monitoring.enable_alerts = ( os.getenv("ENABLE_ALERTS", str(config.monitoring.enable_alerts).lower()).lower() == "true" ) config.monitoring.alert_webhook_url = os.getenv("ALERT_WEBHOOK_URL", config.monitoring.alert_webhook_url) # Server configuration config.server_name = os.getenv("SERVER_NAME", config.server_name) config.server_version = os.getenv("SERVER_VERSION", config.server_version) config.server_port = int(os.getenv("SERVER_PORT", str(config.server_port))) return config @classmethod def _from_dict(cls, config_data: dict[str, Any]) -> "DorisConfig": """Create configuration object from dictionary""" config = cls() # Update basic configuration for key in ["server_name", "server_version", "server_port"]: if key in config_data: setattr(config, key, config_data[key]) # Update database configuration if "database" in config_data: db_config = config_data["database"] for key, value in db_config.items(): if hasattr(config.database, key): setattr(config.database, key, value) # Update security configuration if "security" in config_data: sec_config = config_data["security"] for key, value in sec_config.items(): if hasattr(config.security, key): setattr(config.security, key, value) # Update performance configuration if "performance" in config_data: perf_config = config_data["performance"] for key, value in perf_config.items(): if hasattr(config.performance, key): setattr(config.performance, key, value) # Update logging configuration if "logging" in config_data: log_config = config_data["logging"] for key, value in log_config.items(): if hasattr(config.logging, key): setattr(config.logging, key, value) # Update monitoring configuration if "monitoring" in config_data: mon_config = config_data["monitoring"] for key, value in mon_config.items(): if hasattr(config.monitoring, key): setattr(config.monitoring, key, value) # Custom configuration config.custom_config = config_data.get("custom", {}) return config def to_dict(self) -> dict[str, Any]: """Convert to dictionary format""" return { "server_name": self.server_name, "server_version": self.server_version, "server_port": self.server_port, "database": { "host": self.database.host, "port": self.database.port, "user": self.database.user, "password": "***", # Hide password "database": self.database.database, "charset": self.database.charset, "min_connections": self.database.min_connections, "max_connections": self.database.max_connections, "connection_timeout": self.database.connection_timeout, "health_check_interval": self.database.health_check_interval, "max_connection_age": self.database.max_connection_age, }, "security": { "auth_type": self.security.auth_type, "token_secret": "***", # Hide secret key "token_expiry": self.security.token_expiry, "blocked_keywords": self.security.blocked_keywords, "max_query_complexity": self.security.max_query_complexity, "max_result_rows": self.security.max_result_rows, "sensitive_tables": self.security.sensitive_tables, "enable_masking": self.security.enable_masking, "masking_rules": len(self.security.masking_rules), }, "performance": { "enable_query_cache": self.performance.enable_query_cache, "cache_ttl": self.performance.cache_ttl, "max_cache_size": self.performance.max_cache_size, "max_concurrent_queries": self.performance.max_concurrent_queries, "query_timeout": self.performance.query_timeout, "connection_pool_size": self.performance.connection_pool_size, "idle_timeout": self.performance.idle_timeout, }, "logging": { "level": self.logging.level, "format": self.logging.format, "file_path": self.logging.file_path, "max_file_size": self.logging.max_file_size, "backup_count": self.logging.backup_count, "enable_audit": self.logging.enable_audit, "audit_file_path": self.logging.audit_file_path, }, "monitoring": { "enable_metrics": self.monitoring.enable_metrics, "metrics_port": self.monitoring.metrics_port, "metrics_path": self.monitoring.metrics_path, "health_check_port": self.monitoring.health_check_port, "health_check_path": self.monitoring.health_check_path, "enable_alerts": self.monitoring.enable_alerts, "alert_webhook_url": self.monitoring.alert_webhook_url, }, "custom": self.custom_config, } def save_to_file(self, config_path: str): """Save configuration to file""" config_file = Path(config_path) config_file.parent.mkdir(parents=True, exist_ok=True) try: with open(config_file, "w", encoding="utf-8") as f: if config_file.suffix.lower() == ".json": json.dump(self.to_dict(), f, indent=2, ensure_ascii=False) else: raise ValueError(f"Unsupported configuration file format: {config_file.suffix}") except Exception as e: raise ValueError(f"Failed to save configuration file: {e}") def validate(self) -> list[str]: """Validate configuration validity""" errors = [] # Validate database configuration if not self.database.host: errors.append("Database host address cannot be empty") if not (1 <= self.database.port <= 65535): errors.append("Database port must be in the range 1-65535") if not self.database.user: errors.append("Database username cannot be empty") if self.database.min_connections <= 0: errors.append("Minimum connections must be greater than 0") if self.database.max_connections <= self.database.min_connections: errors.append("Maximum connections must be greater than minimum connections") # Validate security configuration if self.security.auth_type not in ["token", "basic", "oauth"]: errors.append("Authentication type must be one of token, basic, or oauth") if self.security.token_expiry <= 0: errors.append("Token expiry time must be greater than 0") if self.security.max_query_complexity <= 0: errors.append("Maximum query complexity must be greater than 0") if self.security.max_result_rows <= 0: errors.append("Maximum result rows must be greater than 0") # Validate performance configuration if self.performance.cache_ttl <= 0: errors.append("Cache TTL must be greater than 0") if self.performance.max_concurrent_queries <= 0: errors.append("Maximum concurrent queries must be greater than 0") if self.performance.query_timeout <= 0: errors.append("Query timeout must be greater than 0") # Validate logging configuration if self.logging.level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]: errors.append("Log level must be one of DEBUG, INFO, WARNING, ERROR, or CRITICAL") if self.logging.max_file_size <= 0: errors.append("Maximum log file size must be greater than 0") if self.logging.backup_count < 0: errors.append("Log backup count cannot be negative") # Validate monitoring configuration if not (1 <= self.monitoring.metrics_port <= 65535): errors.append("Monitoring port must be in the range 1-65535") if not (1 <= self.monitoring.health_check_port <= 65535): errors.append("Health check port must be in the range 1-65535") return errors def get_connection_string(self) -> str: """Get database connection string (hide password)""" return f"mysql://{self.database.user}:***@{self.database.host}:{self.database.port}/{self.database.database}" def get_config_summary(self) -> dict[str, Any]: """Get configuration summary information""" return { "server": f"{self.server_name} v{self.server_version}", "database": f"{self.database.host}:{self.database.port}/{self.database.database}", "connection_pool": f"{self.database.min_connections}-{self.database.max_connections}", "security": { "auth_type": self.security.auth_type, "masking_enabled": self.security.enable_masking, "blocked_keywords_count": len(self.security.blocked_keywords), }, "performance": { "cache_enabled": self.performance.enable_query_cache, "max_concurrent": self.performance.max_concurrent_queries, "query_timeout": self.performance.query_timeout, }, "monitoring": { "metrics_enabled": self.monitoring.enable_metrics, "alerts_enabled": self.monitoring.enable_alerts, }, } class ConfigManager: """Configuration manager class""" def __init__(self, config: DorisConfig): self.config = config self.logger = logging.getLogger(__name__) def setup_logging(self): """Setup logging configuration""" # Configure root logger root_logger = logging.getLogger() root_logger.setLevel(getattr(logging, self.config.logging.level.upper())) # Clear existing handlers for handler in root_logger.handlers[:]: root_logger.removeHandler(handler) # Create formatter formatter = logging.Formatter(self.config.logging.format) # Console handler console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) root_logger.addHandler(console_handler) # File handler (if configured) if self.config.logging.file_path: try: from logging.handlers import RotatingFileHandler file_handler = RotatingFileHandler( self.config.logging.file_path, maxBytes=self.config.logging.max_file_size, backupCount=self.config.logging.backup_count, encoding="utf-8", ) file_handler.setFormatter(formatter) root_logger.addHandler(file_handler) except Exception as e: self.logger.warning(f"Failed to setup file logging: {e}") # Audit log handler (if configured) if self.config.logging.enable_audit and self.config.logging.audit_file_path: try: from logging.handlers import RotatingFileHandler audit_logger = logging.getLogger("audit") audit_handler = RotatingFileHandler( self.config.logging.audit_file_path, maxBytes=self.config.logging.max_file_size, backupCount=self.config.logging.backup_count, encoding="utf-8", ) audit_handler.setFormatter(formatter) audit_logger.addHandler(audit_handler) audit_logger.setLevel(logging.INFO) except Exception as e: self.logger.warning(f"Failed to setup audit logging: {e}") def validate_config(self) -> bool: """Validate configuration""" errors = self.config.validate() if errors: self.logger.error("Configuration validation failed:") for error in errors: self.logger.error(f" - {error}") return False self.logger.info("Configuration validation passed") return True def log_config_summary(self): """Log configuration summary""" summary = self.config.get_config_summary() self.logger.info("Configuration Summary:") self.logger.info(f" Server: {summary['server']}") self.logger.info(f" Database: {summary['database']}") self.logger.info(f" Connection Pool: {summary['connection_pool']}") self.logger.info(f" Security: {summary['security']}") self.logger.info(f" Performance: {summary['performance']}") self.logger.info(f" Monitoring: {summary['monitoring']}") def create_default_config_file(config_path: str): """Create default configuration file""" config = DorisConfig() config.save_to_file(config_path) print(f"Default configuration file created: {config_path}") # Example usage if __name__ == "__main__": # Create default configuration config = DorisConfig() # Load from environment variables # config = DorisConfig.from_env() # Load from file # config = DorisConfig.from_file("config.json") # Validate configuration config_manager = ConfigManager(config) if config_manager.validate_config(): config_manager.setup_logging() config_manager.log_config_summary() # Save configuration config.save_to_file("example_config.json") print("Configuration saved to example_config.json") else: print("Configuration validation failed")