From f99399c6c78cee9089e6bf2c405bc974b6bc67ca Mon Sep 17 00:00:00 2001 From: Yijia Su <54164178+FreeOnePlus@users.noreply.github.com> Date: Tue, 2 Sep 2025 18:40:48 +0800 Subject: [PATCH] [Performance]Add a controllable MCP Server DB Pool permission authentication system (#53) * 0.5.1 Version * fix 0.5.1 schema async bug * fix security bug * fix security bug * Add complete Token, JWT, OAuth authentication system * Add complete Token, JWT, OAuth authentication system * Add complete Token, JWT, OAuth authentication system * Add complete Token, JWT, OAuth authentication system * Add a controllable MCP Server DB Pool permission authentication system, connect it with the Doris permission system, and provide it to enterprise-level applications concurrently with the multi-Worker mode. --- doris_mcp_server/auth/auth_middleware.py | 4 +- doris_mcp_server/auth/oauth_provider.py | 4 +- doris_mcp_server/auth/token_manager.py | 179 +++++++++++- doris_mcp_server/main.py | 18 +- doris_mcp_server/multiworker_app.py | 9 +- doris_mcp_server/utils/config.py | 35 ++- doris_mcp_server/utils/db.py | 336 ++++++++++++++++++++++- doris_mcp_server/utils/query_executor.py | 4 + doris_mcp_server/utils/security.py | 61 +++- tokens.json | 33 ++- 10 files changed, 636 insertions(+), 47 deletions(-) diff --git a/doris_mcp_server/auth/auth_middleware.py b/doris_mcp_server/auth/auth_middleware.py index 3fd1792..c44d2f5 100644 --- a/doris_mcp_server/auth/auth_middleware.py +++ b/doris_mcp_server/auth/auth_middleware.py @@ -117,13 +117,15 @@ class AuthMiddleware: # Build authentication context auth_context = AuthContext( + token_id=payload.get('jti', ''), user_id=payload.get('sub'), roles=payload.get('roles', []), permissions=payload.get('permissions', []), + security_level=SecurityLevel(payload.get('security_level', 'internal')), session_id=payload.get('jti'), # Use JWT ID as session ID login_time=datetime.fromtimestamp(payload.get('iat', 0)), last_activity=datetime.utcnow(), - security_level=SecurityLevel(payload.get('security_level', 'internal')) + token=token # Store raw token for token-bound database configuration ) logger.info(f"JWT authentication successful for user: {auth_context.user_id}") diff --git a/doris_mcp_server/auth/oauth_provider.py b/doris_mcp_server/auth/oauth_provider.py index 71ce095..9d4586e 100644 --- a/doris_mcp_server/auth/oauth_provider.py +++ b/doris_mcp_server/auth/oauth_provider.py @@ -191,13 +191,15 @@ class OAuthAuthenticationProvider: session_id = f"oauth_{user_info.sub}_{datetime.utcnow().timestamp()}" return AuthContext( + token_id=f"oauth_{user_info.sub}", user_id=user_info.sub, roles=user_info.roles, permissions=permissions, + security_level=security_level, session_id=session_id, login_time=datetime.utcnow(), last_activity=datetime.utcnow(), - security_level=security_level + token="" # OAuth doesn't have raw token, use empty string ) async def _determine_security_level(self, user_info: OAuthUserInfo) -> SecurityLevel: diff --git a/doris_mcp_server/auth/token_manager.py b/doris_mcp_server/auth/token_manager.py index ed742d9..458dd05 100644 --- a/doris_mcp_server/auth/token_manager.py +++ b/doris_mcp_server/auth/token_manager.py @@ -11,17 +11,32 @@ import json import os import secrets import time +import asyncio from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import Dict, List, Optional, Any +from pathlib import Path from ..utils.logger import get_logger from ..utils.security import SecurityLevel +@dataclass +class DatabaseConfig: + """Database connection configuration for token binding""" + + host: str + port: int = 9030 + user: str = "" + password: str = "" + database: str = "information_schema" + charset: str = "UTF8" + fe_http_port: int = 8030 + + @dataclass class TokenInfo: - """Token information structure""" + """Token information structure with optional database binding""" token_id: str # Unique token identifier for audit and management created_at: datetime = field(default_factory=datetime.utcnow) @@ -29,6 +44,7 @@ class TokenInfo: last_used: Optional[datetime] = None description: str = "" # Optional description for token purpose is_active: bool = True + database_config: Optional[DatabaseConfig] = None # Optional database binding @dataclass @@ -65,13 +81,23 @@ class TokenManager: self.default_token_expiry_hours = getattr(config.security, 'default_token_expiry_hours', 24 * 30) # 30 days self.token_hash_algorithm = getattr(config.security, 'token_hash_algorithm', 'sha256') + # Hot reload configuration + self.enable_hot_reload = True + self.hot_reload_interval = 10 # Check every 10 seconds + self._file_last_modified = 0 + self._hot_reload_task = None + # Initialize with default tokens if none exist self._initialize_default_tokens() # Load tokens from configuration self._load_tokens() - self.logger.info(f"TokenManager initialized with {len(self._tokens)} tokens") + # Start hot reload monitoring + if self.enable_hot_reload: + self._start_hot_reload() + + self.logger.info(f"TokenManager initialized with {len(self._tokens)} tokens, hot reload: {self.enable_hot_reload}") def _initialize_default_tokens(self): """Initialize default tokens for basic authentication (configurable via environment)""" @@ -132,7 +158,7 @@ class TokenManager: self.logger.info("Skipped default tokens initialization (custom tokens detected)") def _add_token_from_config(self, token_config: Dict[str, Any]): - """Add token from configuration""" + """Add token from configuration with optional database binding""" try: # Calculate expiration time expires_at = None @@ -141,12 +167,27 @@ class TokenManager: if expires_hours is not None: expires_at = datetime.utcnow() + timedelta(hours=expires_hours) + # Parse database configuration if provided + database_config = None + if 'database_config' in token_config: + db_config = token_config['database_config'] + database_config = DatabaseConfig( + host=db_config.get('host', 'localhost'), + port=db_config.get('port', 9030), + user=db_config.get('user', 'root'), + password=db_config.get('password', ''), + database=db_config.get('database', 'information_schema'), + charset=db_config.get('charset', 'UTF8'), + fe_http_port=db_config.get('fe_http_port', 8030) + ) + # Create token info token_info = TokenInfo( token_id=token_config['token_id'], expires_at=expires_at, description=token_config.get('description', ''), - is_active=token_config.get('is_active', True) + is_active=token_config.get('is_active', True), + database_config=database_config ) # Hash the token @@ -157,10 +198,12 @@ class TokenManager: self._tokens[token_hash] = token_info self._token_ids[token_info.token_id] = token_hash - self.logger.debug(f"Added token '{token_info.token_id}'") + db_info = f" with DB binding ({database_config.host})" if database_config else "" + self.logger.debug(f"Added token '{token_info.token_id}'{db_info}") except Exception as e: self.logger.error(f"Failed to add token from config: {e}") + raise def _load_tokens(self): """Load tokens from configuration sources""" @@ -378,7 +421,7 @@ class TokenManager: tokens = [] for token_hash, token_info in self._tokens.items(): - tokens.append({ + token_data = { 'token_id': token_info.token_id, 'created_at': token_info.created_at.isoformat(), 'expires_at': token_info.expires_at.isoformat() if token_info.expires_at else None, @@ -386,7 +429,21 @@ class TokenManager: 'is_active': token_info.is_active, 'description': token_info.description, 'is_expired': token_info.expires_at and datetime.utcnow() > token_info.expires_at if token_info.expires_at else False - }) + } + + # Add database binding info (without sensitive data) + if token_info.database_config: + token_data['database_binding'] = { + 'host': token_info.database_config.host, + 'port': token_info.database_config.port, + 'user': token_info.database_config.user, + 'database': token_info.database_config.database, + 'has_password': bool(token_info.database_config.password) + } + else: + token_data['database_binding'] = None + + tokens.append(token_data) # Sort by creation time tokens.sort(key=lambda x: x['created_at'], reverse=True) @@ -439,6 +496,32 @@ class TokenManager: self.logger.error(f"Failed to save tokens to file: {e}") return False + def get_database_config_by_token(self, token: str) -> Optional[DatabaseConfig]: + """Get database configuration bound to a token + + Args: + token: The raw token string + + Returns: + DatabaseConfig if token exists and has database binding, None otherwise + """ + try: + token_hash = self._hash_token(token) + token_info = self._tokens.get(token_hash) + + if not token_info or not token_info.is_active: + return None + + # Check expiration + if token_info.expires_at and datetime.utcnow() > token_info.expires_at: + return None + + return token_info.database_config + + except Exception as e: + self.logger.error(f"Failed to get database config for token: {e}") + return None + def get_token_stats(self) -> Dict[str, Any]: """Get token statistics""" now = datetime.utcnow() @@ -446,11 +529,89 @@ class TokenManager: active_tokens = sum(1 for info in self._tokens.values() if info.is_active) expired_tokens = sum(1 for info in self._tokens.values() if info.expires_at and now > info.expires_at) + tokens_with_db = sum(1 for info in self._tokens.values() + if info.database_config is not None) return { 'total_tokens': total_tokens, 'active_tokens': active_tokens, 'expired_tokens': expired_tokens, + 'tokens_with_database_binding': tokens_with_db, 'expiry_enabled': self.enable_token_expiry, - 'default_expiry_hours': self.default_token_expiry_hours - } \ No newline at end of file + 'default_expiry_hours': self.default_token_expiry_hours, + 'hot_reload_enabled': self.enable_hot_reload, + 'last_file_check': datetime.fromtimestamp(self._file_last_modified).isoformat() if self._file_last_modified else None + } + + def _start_hot_reload(self): + """Start hot reload monitoring task""" + if self._hot_reload_task: + return # Already running + + # Update initial file modification time + self._update_file_modified_time() + + # Start monitoring task + self._hot_reload_task = asyncio.create_task(self._hot_reload_monitor()) + self.logger.info(f"Started hot reload monitoring for {self.token_file_path}") + + def stop_hot_reload(self): + """Stop hot reload monitoring""" + if self._hot_reload_task: + self._hot_reload_task.cancel() + self._hot_reload_task = None + self.logger.info("Stopped hot reload monitoring") + + def _update_file_modified_time(self): + """Update the last modified time of tokens file""" + try: + if os.path.exists(self.token_file_path): + self._file_last_modified = os.path.getmtime(self.token_file_path) + except Exception as e: + self.logger.debug(f"Failed to get file modification time: {e}") + + async def _hot_reload_monitor(self): + """Background task to monitor tokens.json file changes""" + while True: + try: + await asyncio.sleep(self.hot_reload_interval) + + if not os.path.exists(self.token_file_path): + continue + + # Check if file was modified + current_mtime = os.path.getmtime(self.token_file_path) + if current_mtime > self._file_last_modified: + self.logger.info(f"Detected changes in {self.token_file_path}, reloading tokens...") + + try: + # Backup current tokens + old_tokens = self._tokens.copy() + old_token_ids = self._token_ids.copy() + + # Clear and reload + self._tokens.clear() + self._token_ids.clear() + + # Reinitialize default tokens + self._initialize_default_tokens() + + # Load from file + self._load_tokens_from_file() + + # Update modification time + self._file_last_modified = current_mtime + + self.logger.info(f"Hot reload completed, {len(self._tokens)} tokens loaded") + + except Exception as reload_error: + # Restore backup on failure + self.logger.error(f"Hot reload failed, restoring previous tokens: {reload_error}") + self._tokens = old_tokens + self._token_ids = old_token_ids + + except asyncio.CancelledError: + self.logger.info("Hot reload monitor stopped") + break + except Exception as e: + self.logger.error(f"Error in hot reload monitor: {e}") \ No newline at end of file diff --git a/doris_mcp_server/main.py b/doris_mcp_server/main.py index 5b60fde..72bd29c 100644 --- a/doris_mcp_server/main.py +++ b/doris_mcp_server/main.py @@ -230,11 +230,15 @@ class DorisServer: self.config = config self.server = Server("doris-mcp-server") - # Initialize security manager + # Initialize security manager (without connection_manager initially) self.security_manager = DorisSecurityManager(config) - # Initialize connection manager, pass in security manager - self.connection_manager = DorisConnectionManager(config, self.security_manager) + # Initialize connection manager, pass in security manager and token manager for token-bound DB config + token_manager = self.security_manager.auth_provider.token_manager if hasattr(self.security_manager, 'auth_provider') and hasattr(self.security_manager.auth_provider, 'token_manager') else None + self.connection_manager = DorisConnectionManager(config, self.security_manager, token_manager) + + # Set connection manager reference in security manager for database validation + self.security_manager.connection_manager = self.connection_manager # Initialize independent managers self.resources_manager = DorisResourcesManager(self.connection_manager) @@ -785,14 +789,14 @@ Examples: parser.add_argument( "--port", type=int, - default=os.getenv("SERVER_PORT", _default_config.server_port), - help=f"Port number for HTTP mode (default: {_default_config.server_port})" + default=3000, + help="Port number for HTTP mode (default: 3000)" ) parser.add_argument( "--workers", type=int, - default=int(os.getenv("WORKERS", "1")), + default=1, help="Number of worker processes for HTTP mode (default: 1, use 0 for auto-detect CPU cores)" ) @@ -804,7 +808,7 @@ Examples: ) parser.add_argument( - "--doris-port", "--db-port", type=int, default=os.getenv("DORIS_PORT", _default_config.database.port), help=f"Doris database port number (default: {_default_config.database.port})" + "--doris-port", "--db-port", type=int, default=9030, help="Doris database port number (default: 9030)" ) parser.add_argument( diff --git a/doris_mcp_server/multiworker_app.py b/doris_mcp_server/multiworker_app.py index ba6b06c..af43123 100644 --- a/doris_mcp_server/multiworker_app.py +++ b/doris_mcp_server/multiworker_app.py @@ -270,8 +270,13 @@ async def initialize_worker(): await _worker_security_manager.initialize() logger.info(f"Worker {os.getpid()} security manager initialization completed") - # Create connection manager - _worker_connection_manager = DorisConnectionManager(config, _worker_security_manager) + # Create connection manager with token manager for token-bound DB config + token_manager = _worker_security_manager.auth_provider.token_manager if hasattr(_worker_security_manager, 'auth_provider') and hasattr(_worker_security_manager.auth_provider, 'token_manager') else None + _worker_connection_manager = DorisConnectionManager(config, _worker_security_manager, token_manager) + + # Set connection manager reference in security manager for database validation + _worker_security_manager.connection_manager = _worker_connection_manager + await _worker_connection_manager.initialize() # Create MCP server diff --git a/doris_mcp_server/utils/config.py b/doris_mcp_server/utils/config.py index fe52f66..563cca4 100644 --- a/doris_mcp_server/utils/config.py +++ b/doris_mcp_server/utils/config.py @@ -372,19 +372,34 @@ class DorisConfig: config = cls() - # Database configuration - config.database.host = os.getenv("DORIS_HOST", config.database.host) - config.database.port = int(os.getenv("DORIS_PORT", str(config.database.port))) - config.database.user = os.getenv("DORIS_USER", config.database.user) - config.database.password = os.getenv("DORIS_PASSWORD", config.database.password) - config.database.database = os.getenv("DORIS_DATABASE", config.database.database) - config.database.fe_http_port = int(os.getenv("DORIS_FE_HTTP_PORT", str(config.database.fe_http_port))) + # Database configuration - handle empty strings properly + doris_host = os.getenv("DORIS_HOST", "").strip() + config.database.host = doris_host if doris_host else config.database.host + + doris_port = os.getenv("DORIS_PORT", "").strip() + if doris_port and doris_port.isdigit(): + config.database.port = int(doris_port) + + doris_user = os.getenv("DORIS_USER", "").strip() + config.database.user = doris_user if doris_user else config.database.user + + doris_password = os.getenv("DORIS_PASSWORD", "") + config.database.password = doris_password if doris_password else config.database.password + + doris_database = os.getenv("DORIS_DATABASE", "").strip() + config.database.database = doris_database if doris_database else config.database.database + + doris_fe_http_port = os.getenv("DORIS_FE_HTTP_PORT", "").strip() + if doris_fe_http_port and doris_fe_http_port.isdigit(): + config.database.fe_http_port = int(doris_fe_http_port) # BE nodes configuration be_hosts_env = os.getenv("DORIS_BE_HOSTS", "") if be_hosts_env: config.database.be_hosts = [host.strip() for host in be_hosts_env.split(",") if host.strip()] - config.database.be_webserver_port = int(os.getenv("DORIS_BE_WEBSERVER_PORT", str(config.database.be_webserver_port))) + be_webserver_port = os.getenv("DORIS_BE_WEBSERVER_PORT", "").strip() + if be_webserver_port and be_webserver_port.isdigit(): + config.database.be_webserver_port = int(be_webserver_port) # Arrow Flight SQL Configuration fe_arrow_port_env = os.getenv("FE_ARROW_FLIGHT_SQL_PORT") @@ -557,7 +572,9 @@ class DorisConfig: # Server configuration config.server_name = os.getenv("SERVER_NAME", config.server_name) config.server_version = os.getenv("SERVER_VERSION", config.server_version) - config.server_port = int(os.getenv("SERVER_PORT", str(config.server_port))) + server_port = os.getenv("SERVER_PORT", "").strip() + if server_port and server_port.isdigit(): + config.server_port = int(server_port) config.temp_files_dir = os.getenv("TEMP_FILES_DIR", config.temp_files_dir) return config diff --git a/doris_mcp_server/utils/db.py b/doris_mcp_server/utils/db.py index 8cc050b..8d5c58a 100644 --- a/doris_mcp_server/utils/db.py +++ b/doris_mcp_server/utils/db.py @@ -240,15 +240,30 @@ class DorisConnectionManager: Uses direct connection pool management with proper synchronization Implements connection pool health monitoring and proactive cleanup + Supports token-bound database configurations for multi-tenant access """ - def __init__(self, config, security_manager=None): + def __init__(self, config, security_manager=None, token_manager=None): self.config = config self.pool: Pool | None = None self.logger = get_logger(__name__) self.security_manager = security_manager + self.token_manager = token_manager # Token manager for token-bound DB config self.session_cache = DorisSessionCache(self) + + # Store original database config for fallback + self.original_db_config = { + 'host': config.database.host, + 'port': config.database.port, + 'user': config.database.user, + 'password': config.database.password, + 'database': config.database.database, + 'charset': config.database.charset + } + + # Current active database config (may be overridden by token-bound config) + self.active_db_config = self.original_db_config.copy() # Connection pool state management self.pool_recovering = False @@ -267,14 +282,7 @@ class DorisConnectionManager: # Database connection parameters from config.database self.pool_recovery_lock = self._recovery_lock # Compatibility alias - self.host = config.database.host - self.port = config.database.port - self.user = config.database.user - self.password = config.database.password - self.database = config.database.database - # Convert charset to aiomysql compatible format - charset_map = {"UTF8": "utf8", "UTF8MB4": "utf8mb4"} - self.charset = charset_map.get(config.database.charset.upper(), config.database.charset.lower()) + self._update_db_params_from_config(self.active_db_config) self.connect_timeout = config.database.connection_timeout # Connection pool parameters - more conservative settings @@ -285,12 +293,307 @@ class DorisConnectionManager: # 🔧 FIX: Add missing monitoring parameters that were removed during refactoring self.health_check_interval = 30 # seconds self.pool_warmup_size = 3 # connections to maintain + + def _update_db_params_from_config(self, db_config: dict): + """Update database connection parameters from config dictionary""" + self.host = db_config['host'] + self.port = db_config['port'] + self.user = db_config['user'] + self.password = db_config['password'] + self.database = db_config['database'] + # Convert charset to aiomysql compatible format + charset_map = {"UTF8": "utf8", "UTF8MB4": "utf8mb4"} + self.charset = charset_map.get(db_config['charset'].upper(), db_config['charset'].lower()) + + def _is_config_empty(self, config_value) -> bool: + """Check if a config value is empty (None, empty string, or 'null')""" + return config_value is None or config_value == '' or str(config_value).lower() == 'null' + + def _has_valid_global_config(self) -> bool: + """Check if global database configuration is valid and non-empty""" + return (not self._is_config_empty(self.original_db_config['host']) and + not self._is_config_empty(self.original_db_config['user'])) + + def _find_available_token_with_db_config(self) -> str: + """Find the first available token with database configuration + + Returns: + Raw token string if found, empty string if not found + """ + if not self.token_manager: + return "" + + try: + for token_hash, token_info in self.token_manager._tokens.items(): + if (token_info.database_config and + token_info.is_active and + not self._is_config_empty(token_info.database_config.host) and + not self._is_config_empty(token_info.database_config.user)): + + # We need to find the raw token from the hash + # This is a bit tricky since we only store hashes + # We'll need to use the admin token from tokens.json if it has db config + if token_info.token_id == 'admin-token': + # Try the known admin token + return 'doris_admin_token_123456' + elif 'tenant' in token_info.token_id: + # For tenant tokens, we'll need a different approach + # For now, skip these as we don't know the raw token + continue + + return "" + except Exception as e: + self.logger.error(f"Error finding available token: {e}") + return "" + + async def configure_for_token(self, token: str) -> tuple[bool, str]: + """Configure connection manager for token with new priority logic + + Priority: Token-bound DB config > .env config > error + + Args: + token: Authentication token to get database config for + + Returns: + (success: bool, config_source: str): Result and which config was used + + Raises: + RuntimeError: If no valid database configuration is available + """ + try: + # Priority 1: Try token-bound database config first + if self.token_manager: + db_config = self.token_manager.get_database_config_by_token(token) + if db_config: + # Convert DatabaseConfig to dictionary + token_db_config = { + 'host': db_config.host, + 'port': db_config.port, + 'user': db_config.user, + 'password': db_config.password, + 'database': db_config.database, + 'charset': db_config.charset + } + + # Check if token-bound config is valid + if (not self._is_config_empty(token_db_config['host']) and + not self._is_config_empty(token_db_config['user'])): + self.logger.info(f"Using token-bound database configuration for host: {token_db_config['host']}") + self.active_db_config = token_db_config + self._update_db_params_from_config(self.active_db_config) + + # Create/recreate connection pool with token-bound config + await self._ensure_pool_with_current_config() + + return True, "token-bound" + + # Priority 2: Use global .env config if available + if self._has_valid_global_config(): + self.logger.info("Using global .env database configuration") + self.active_db_config = self.original_db_config.copy() + self._update_db_params_from_config(self.active_db_config) + + # Create/recreate connection pool with global config + await self._ensure_pool_with_current_config() + + return True, "global-env" + + # Priority 3: No valid configuration available + error_msg = ( + "No valid database configuration available for this token. " + "Please contact administrator to:\n" + "1. Add database configuration to tokens.json for this token, OR\n" + "2. Configure valid global database settings in .env file\n" + "Required fields: DB_HOST, DB_USER" + ) + self.logger.error(error_msg) + raise RuntimeError(error_msg) + + except Exception as e: + self.logger.error(f"Failed to configure database for token: {e}") + raise + + async def _ensure_pool_with_current_config(self): + """Ensure connection pool exists with current configuration""" + try: + # If pool exists with different config, need to recreate it + # If no pool exists, create one with current config + if self.pool and not self.pool.closed: + # Since we can't reliably check pool config attributes, + # we'll recreate the pool if we detect a potential config change + # by checking if current config differs from what we stored + pool_needs_recreation = False + + # Compare current config with what we might have used before + if hasattr(self, '_last_pool_config'): + current_config = { + 'host': self.host, + 'port': self.port, + 'user': self.user, + 'database': self.database + } + if current_config != self._last_pool_config: + pool_needs_recreation = True + + if pool_needs_recreation: + self.logger.info("Database configuration changed, recreating connection pool") + await self._recreate_pool() + elif not self.pool: + self.logger.info("Creating connection pool with current configuration") + await self._create_pool_with_current_config() + + # Test the connection immediately + if not await self._test_pool_health(): + raise RuntimeError(f"Database connection test failed for {self.host}:{self.port}") + + except Exception as e: + self.logger.error(f"Failed to ensure connection pool: {e}") + raise + + async def _create_pool_with_current_config(self): + """Create connection pool with current database configuration""" + try: + self.pool = await aiomysql.create_pool( + host=self.host, + port=self.port, + user=self.user, + password=self.password, + db=self.database, + charset=self.charset, + minsize=self.minsize, + maxsize=self.maxsize, + pool_recycle=self.pool_recycle, + connect_timeout=self.connect_timeout, + autocommit=True + ) + + # Store the current config for comparison later + self._last_pool_config = { + 'host': self.host, + 'port': self.port, + 'user': self.user, + 'database': self.database + } + + # Test initial connection + if not await self._test_pool_health(): + raise RuntimeError("Connection pool health check failed") + + # Start background monitoring tasks if not already running + if not self.pool_health_check_task or self.pool_health_check_task.done(): + self.pool_health_check_task = asyncio.create_task(self._pool_health_monitor()) + if not self.pool_cleanup_task or self.pool_cleanup_task.done(): + self.pool_cleanup_task = asyncio.create_task(self._pool_cleanup_monitor()) + + # Perform initial pool warmup + await self._warmup_pool() + + self.logger.info(f"Connection pool created successfully with {self.host}:{self.port}") + + except Exception as e: + self.logger.error(f"Failed to create connection pool: {e}") + raise + + async def _recreate_pool(self): + """Recreate connection pool with current database configuration""" + try: + # Close existing pool + if self.pool and not self.pool.closed: + self.pool.close() + await self.pool.wait_closed() + self.pool = None + + # Create new pool with current config + await self._create_pool_with_current_config() + + except Exception as e: + self.logger.error(f"Failed to recreate connection pool: {e}") + raise + + def validate_database_configuration(self) -> tuple[bool, str]: + """Validate database configuration completeness + + Returns: + (is_valid, error_message): Configuration validation result + """ + # Check if Token authentication is enabled + token_auth_enabled = getattr(self.config.security, 'enable_token_auth', False) + + # Check if tokens.json exists and has valid tokens with database configs + tokens_file_available = False + token_bound_configs_available = False + + if self.token_manager: + try: + # Check if tokens.json file exists + import os + tokens_file_path = getattr(self.token_manager, 'token_file_path', 'tokens.json') + tokens_file_available = os.path.exists(tokens_file_path) + + # Check if any tokens have database configurations + if tokens_file_available or self.token_manager._tokens: + for token_hash, token_info in self.token_manager._tokens.items(): + if token_info.database_config: + token_bound_configs_available = True + break + except Exception: + pass + + # Validate .env database configuration + env_config_valid = self._has_valid_global_config() + + # Decision logic + if token_auth_enabled: + if tokens_file_available: + # tokens.json exists - either .env OR token-bound config must be valid + if env_config_valid or token_bound_configs_available: + return True, "Configuration valid" + else: + return False, ( + "Token authentication is enabled and tokens.json exists, but no valid database " + "configuration found. Please provide either:\n" + "1. Valid database configuration in .env file (DB_HOST, DB_USER, etc.)\n" + "2. Database configuration in tokens.json for at least one token" + ) + else: + # tokens.json does not exist - must have valid .env config + if env_config_valid: + return True, "Configuration valid" + else: + return False, ( + "Token authentication is enabled but tokens.json file not found. " + "Either:\n" + "1. Create tokens.json file with token configurations\n" + "2. Provide valid database configuration in .env file (DB_HOST, DB_USER, etc.)" + ) + else: + # Token auth is disabled, must have valid .env config + if env_config_valid: + return True, "Configuration valid" + else: + return False, ( + "Token authentication is disabled. Valid database configuration is required " + "in .env file (DB_HOST, DB_USER, etc.)" + ) async def initialize(self): """Initialize connection pool with health monitoring""" try: + # First validate configuration + is_valid, error_message = self.validate_database_configuration() + if not is_valid: + self.logger.error(f"Database configuration validation failed: {error_message}") + raise RuntimeError(f"Database configuration validation failed:\n{error_message}") + + self.logger.info(f"Database configuration validated successfully") self.logger.info(f"Initializing connection pool to {self.host}:{self.port}") + # Only create connection pool if we have valid global config + # Token-bound configs will be handled dynamically during requests + if not self._has_valid_global_config(): + self.logger.info("No valid global database config, pool will be created dynamically for token-bound configs") + return + # Create connection pool self.pool = await aiomysql.create_pool( host=self.host, @@ -592,7 +895,20 @@ class DorisConnectionManager: # Check if pool is available if not self.pool: self.logger.warning("Connection pool is not available, attempting recovery...") - await self._recover_pool_with_lock() + + # Try to use token-bound configuration if available + if self.token_manager and not self._has_valid_global_config(): + available_token = self._find_available_token_with_db_config() + if available_token: + self.logger.info(f"Using token-bound configuration for pool creation: {available_token}") + try: + await self.configure_for_token(available_token) + except Exception as e: + self.logger.error(f"Failed to configure with token-bound config: {e}") + + # Fallback to recovery + if not self.pool: + await self._recover_pool_with_lock() if not self.pool: raise RuntimeError("Connection pool is not available and recovery failed") diff --git a/doris_mcp_server/utils/query_executor.py b/doris_mcp_server/utils/query_executor.py index 63f5d54..3d76677 100644 --- a/doris_mcp_server/utils/query_executor.py +++ b/doris_mcp_server/utils/query_executor.py @@ -426,6 +426,10 @@ class DorisQueryExecutor: self, query_request: QueryRequest, auth_context ) -> QueryResult: """Internal query execution""" + + # Database configuration should already be handled during authentication + # No need to configure again during query execution + # Optimize query optimized_sql = await self.query_optimizer.optimize_query( query_request.sql, {"user_roles": getattr(auth_context, 'roles', [])} diff --git a/doris_mcp_server/utils/security.py b/doris_mcp_server/utils/security.py index 9b319b0..ce06aa3 100644 --- a/doris_mcp_server/utils/security.py +++ b/doris_mcp_server/utils/security.py @@ -47,11 +47,16 @@ class SecurityLevel(Enum): class AuthContext: """Authentication context for audit and session tracking""" - token_id: str # Token identifier for audit logging + token_id: str = "" # Token identifier for audit logging + user_id: str = "" # User identifier + roles: list[str] = field(default_factory=list) # User roles + permissions: list[str] = field(default_factory=list) # User permissions + security_level: 'SecurityLevel' = field(default_factory=lambda: SecurityLevel.INTERNAL) # Security level client_ip: str = "unknown" # Client IP address session_id: str = "" # Session identifier login_time: datetime = field(default_factory=datetime.utcnow) last_activity: datetime | None = None + token: str = "" # Raw token for token-bound database configuration @dataclass @@ -84,12 +89,13 @@ class DorisSecurityManager: Provides complete security control functionality, including authentication, authorization, SQL security validation and data masking """ - def __init__(self, config): + def __init__(self, config, connection_manager=None): self.config = config self.logger = get_logger(__name__) + self.connection_manager = connection_manager # Initialize security components - self.auth_provider = AuthenticationProvider(config) + self.auth_provider = AuthenticationProvider(config, self) self.authz_provider = AuthorizationProvider(config) self.sql_validator = SQLSecurityValidator(config) self.masking_processor = DataMaskingProcessor(config) @@ -226,6 +232,10 @@ class DorisSecurityManager: # Return anonymous context when no authentication is enabled return AuthContext( token_id="anonymous", + user_id="anonymous", + roles=["anonymous"], + permissions=["read"], + security_level=SecurityLevel.PUBLIC, client_ip=auth_info.get("client_ip", "unknown"), session_id="anonymous_session" ) @@ -392,18 +402,50 @@ class DorisSecurityManager: return {"error": "Token manager not initialized"} return self.auth_provider.token_manager.get_token_stats() + + async def _validate_token_database_config(self, token: str, token_info) -> None: + """Validate database configuration for token immediately during authentication + + This ensures database connectivity issues are caught at authentication time, + not during query execution, providing better user experience. + + Args: + token: Raw authentication token + token_info: TokenInfo object from token validation + + Raises: + ValueError: If database configuration is invalid or connection fails + """ + try: + if not self.connection_manager: + self.logger.warning("Connection manager not available for immediate database validation") + return + + # Configure and test database connection for this token + success, config_source = await self.connection_manager.configure_for_token(token) + + if success: + self.logger.info(f"Database configuration validated successfully for token {token_info.token_id} (source: {config_source})") + else: + raise ValueError("Database configuration validation failed") + + except Exception as e: + error_msg = f"Database configuration validation failed for token {token_info.token_id}: {str(e)}" + self.logger.error(error_msg) + raise ValueError(error_msg) class AuthenticationProvider: """Authentication provider""" - def __init__(self, config): + def __init__(self, config, security_manager=None): self.config = config self.logger = get_logger(__name__) self.session_cache = {} self.jwt_manager = None self.oauth_provider = None self.token_manager = None + self.security_manager = security_manager # Initialize authentication providers based on individual switches auth_methods_enabled = [] @@ -583,12 +625,21 @@ class AuthenticationProvider: token_info = validation_result.token_info + # Immediately validate database configuration for this token + if self.security_manager: + await self.security_manager._validate_token_database_config(token, token_info) + return AuthContext( token_id=token_info.token_id, + user_id=token_info.token_id, # Use token_id as user_id for token auth + roles=["token_user"], # Default role for token users + permissions=["read", "write"], # Default permissions for token users + security_level=SecurityLevel.INTERNAL, client_ip=auth_info.get("client_ip", "unknown"), session_id=auth_info.get("session_id", f"session_{token_info.token_id}"), login_time=datetime.utcnow(), - last_activity=token_info.last_used + last_activity=token_info.last_used, + token=token # Store raw token for token-bound database configuration ) except Exception as e: diff --git a/tokens.json b/tokens.json index 1524a49..0eac2f6 100644 --- a/tokens.json +++ b/tokens.json @@ -8,21 +8,48 @@ "token": "doris_admin_token_123456", "description": "Doris admin API access token", "expires_hours": null, - "is_active": true + "is_active": true, + "database_config": { + "host": "127.0.0.1", + "port": 9030, + "user": "root", + "password": "", + "database": "information_schema", + "charset": "UTF8", + "fe_http_port": 8030 + } }, { "token_id": "analyst-token", "token": "doris_analyst_token_123456", "description": "Doris analyst API access token", "expires_hours": 8760, - "is_active": true + "is_active": true, + "database_config": { + "host": "127.0.0.1", + "port": 9030, + "user": "root", + "password": "", + "database": "information_schema", + "charset": "UTF8", + "fe_http_port": 8030 + } }, { "token_id": "readonly-token", "token": "doris_readonly_token_123456", "description": "Doris readonly API access token", "expires_hours": 4320, - "is_active": true + "is_active": true, + "database_config": { + "host": "127.0.0.1", + "port": 9030, + "user": "root", + "password": "", + "database": "information_schema", + "charset": "UTF8", + "fe_http_port": 8030 + } } ], "notes": [