From 81305ffbf97482e29abf70dc0ba9be928ed39df5 Mon Sep 17 00:00:00 2001 From: bingquanzhao Date: Wed, 24 Dec 2025 20:39:16 +0800 Subject: [PATCH] [fix]fix token auth (#69) * fix tocken auth * Further fixes to the token overwriting issue and restoration of hot reloading of tokens.json. --- doris_mcp_server/main.py | 6 +- doris_mcp_server/utils/analysis_tools.py | 12 +- doris_mcp_server/utils/db.py | 373 +++++++++++++++++++-- doris_mcp_server/utils/query_executor.py | 8 +- doris_mcp_server/utils/schema_extractor.py | 20 +- doris_mcp_server/utils/security.py | 5 + 6 files changed, 384 insertions(+), 40 deletions(-) diff --git a/doris_mcp_server/main.py b/doris_mcp_server/main.py index 2630aeb..7a27e6e 100644 --- a/doris_mcp_server/main.py +++ b/doris_mcp_server/main.py @@ -644,10 +644,10 @@ class DorisServer: # FIX for Issue #62 Bug 1: Set auth_context in context variable # This allows tools to access token information for token-bound database configuration + # CRITICAL: Use the global ContextVar from security.py to ensure same instance is used everywhere try: - from contextvars import ContextVar - auth_context_var: ContextVar = ContextVar('mcp_auth_context', default=None) - auth_context_var.set(auth_context) + from .utils.security import mcp_auth_context_var + mcp_auth_context_var.set(auth_context) self.logger.debug(f"Set auth_context in context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}") except Exception as ctx_error: self.logger.warning(f"Failed to set auth_context in context variable: {ctx_error}") diff --git a/doris_mcp_server/utils/analysis_tools.py b/doris_mcp_server/utils/analysis_tools.py index b3fb7e4..5983df7 100644 --- a/doris_mcp_server/utils/analysis_tools.py +++ b/doris_mcp_server/utils/analysis_tools.py @@ -397,6 +397,14 @@ class SQLAnalyzer: logger.info(f"Generating SQL explain for query ID: {query_id}") + # 🔧 FIX: Get auth_context for token-bound database configuration + auth_context = None + try: + from .security import mcp_auth_context_var + auth_context = mcp_auth_context_var.get() + except Exception: + pass + # Switch database if specified # SECURITY FIX: Validate and quote db_name if db_name: @@ -405,7 +413,7 @@ class SQLAnalyzer: except SQLSecurityError as e: return {"success": False, "error": f"Invalid database name: {e}"} safe_db = quote_identifier(db_name, "database name") - await self.connection_manager.execute_query("explain_session", f"USE {safe_db}") + await self.connection_manager.execute_query("explain_session", f"USE {safe_db}", None, auth_context) # Construct EXPLAIN query explain_type = "EXPLAIN VERBOSE" if verbose else "EXPLAIN" @@ -414,7 +422,7 @@ class SQLAnalyzer: logger.info(f"Executing explain query: {explain_sql}") # Execute explain query - result = await self.connection_manager.execute_query("explain_session", explain_sql) + result = await self.connection_manager.execute_query("explain_session", explain_sql, None, auth_context) # Format explain output explain_content = [] diff --git a/doris_mcp_server/utils/db.py b/doris_mcp_server/utils/db.py index 8ddf3f3..cd7b0ce 100644 --- a/doris_mcp_server/utils/db.py +++ b/doris_mcp_server/utils/db.py @@ -255,6 +255,13 @@ class DorisConnectionManager: self.security_manager = security_manager self.token_manager = token_manager # Token manager for token-bound DB config + # 🔧 FIX for multi-tenant concurrency: Per-token connection pool isolation + # Each token gets its own connection pool to prevent configuration conflicts + self.token_pools: Dict[str, Pool] = {} # token_hash -> pool + self.token_configs: Dict[str, dict] = {} # token_hash -> db_config + self._token_pool_locks: Dict[str, asyncio.Lock] = {} # token_hash -> lock + self._token_pools_lock = asyncio.Lock() # Lock for managing token_pools dict + # FIX for Issue #58 Problem 1: Disable session caching to prevent connection sharing # Session caching causes multiple threads to share the same MySQL connection, # leading to race conditions and deadlocks in multi-threaded environments @@ -276,6 +283,7 @@ class DorisConnectionManager: } # Current active database config (may be overridden by token-bound config) + # NOTE: This is kept for backward compatibility with non-token requests self.active_db_config = self.original_db_config.copy() # Connection pool state management @@ -359,6 +367,281 @@ class DorisConnectionManager: self.logger.error(f"Error finding available token: {e}") return "" + def _get_token_hash(self, token: str) -> str: + """Get hash of token for use as dictionary key""" + import hashlib + return hashlib.sha256(token.encode()).hexdigest()[:16] + + def _get_current_token_db_config(self, token: str) -> dict | None: + """Get current database config for token from TokenManager + + This is used to check if config has changed for hot reload support. + """ + if not self.token_manager: + return None + + token_db_config = self.token_manager.get_database_config_by_token(token) + if token_db_config: + return { + 'host': token_db_config.host, + 'port': token_db_config.port, + 'user': token_db_config.user, + 'password': token_db_config.password, + 'database': token_db_config.database, + 'charset': token_db_config.charset + } + return None + + def _config_changed(self, old_config: dict, new_config: dict) -> bool: + """Check if database configuration has changed""" + if old_config is None or new_config is None: + return old_config != new_config + + # Compare key fields + for key in ['host', 'port', 'user', 'password', 'database']: + if old_config.get(key) != new_config.get(key): + return True + return False + + async def get_pool_for_token(self, token: str) -> tuple[Pool, dict]: + """Get or create a dedicated connection pool for a specific token + + This method implements per-token connection pool isolation to prevent + concurrent requests from different tokens interfering with each other. + + 🔧 FIX: Supports hot reload - if tokens.json config changes, + the old pool is closed and a new one is created automatically. + + Args: + token: Authentication token + + Returns: + (pool, db_config): The dedicated pool and its configuration + + Raises: + RuntimeError: If no valid database configuration is available + """ + token_hash = self._get_token_hash(token) + + # Fast path: pool already exists + if token_hash in self.token_pools: + pool = self.token_pools[token_hash] + cached_config = self.token_configs.get(token_hash) + + # 🔧 FIX: Check if config has changed (hot reload support) + current_config = self._get_current_token_db_config(token) + if current_config and cached_config and self._config_changed(cached_config, current_config): + self.logger.info(f"Token config changed (hash: {token_hash[:8]}...), recreating pool...") + # Config changed, need to recreate pool + async with self._token_pools_lock: + # Close old pool + old_pool = self.token_pools.pop(token_hash, None) + if old_pool and not old_pool.closed: + try: + old_pool.close() + await asyncio.wait_for(old_pool.wait_closed(), timeout=2.0) + except Exception as e: + self.logger.warning(f"Error closing old pool during hot reload: {e}") + self.token_configs.pop(token_hash, None) + # Continue to slow path to create new pool + elif pool and not pool.closed: + return pool, cached_config + + # Slow path: need to create pool (with lock to prevent race conditions) + async with self._token_pools_lock: + # Double-check after acquiring lock + if token_hash in self.token_pools: + pool = self.token_pools[token_hash] + if pool and not pool.closed: + return pool, self.token_configs[token_hash] + + # Get database config for this token + db_config = None + config_source = "unknown" + + if self.token_manager: + token_db_config = self.token_manager.get_database_config_by_token(token) + if token_db_config: + db_config = { + 'host': token_db_config.host, + 'port': token_db_config.port, + 'user': token_db_config.user, + 'password': token_db_config.password, + 'database': token_db_config.database, + 'charset': token_db_config.charset + } + config_source = "token-bound" + + # Fallback to global config if token has no specific config + if not db_config or self._is_config_empty(db_config.get('host')) or self._is_config_empty(db_config.get('user')): + if self._has_valid_global_config(): + db_config = self.original_db_config.copy() + config_source = "global-env" + else: + raise RuntimeError( + f"No valid database configuration available for token. " + f"Please configure database in tokens.json or .env file." + ) + + # Create dedicated pool for this token + self.logger.info(f"Creating dedicated connection pool for token (hash: {token_hash[:8]}...) " + f"using {config_source} config: {db_config['user']}@{db_config['host']}:{db_config['port']}") + + pool = await self._create_pool_with_config(db_config) + + # Store pool and config + self.token_pools[token_hash] = pool + self.token_configs[token_hash] = db_config + + # Create lock for this token if not exists + if token_hash not in self._token_pool_locks: + self._token_pool_locks[token_hash] = asyncio.Lock() + + return pool, db_config + + async def _create_pool_with_config(self, db_config: dict) -> Pool: + """Create a connection pool with specified configuration + + Args: + db_config: Database configuration dictionary + + Returns: + Created connection pool + """ + # Convert charset to aiomysql compatible format + charset_map = {"UTF8": "utf8", "UTF8MB4": "utf8mb4"} + charset = charset_map.get(db_config['charset'].upper(), db_config['charset'].lower()) + + self.logger.debug(f"Creating pool for {db_config['user']}@{db_config['host']}:{db_config['port']}/{db_config['database']}") + + try: + pool = await asyncio.wait_for( + aiomysql.create_pool( + host=db_config['host'], + port=db_config['port'], + user=db_config['user'], + password=db_config['password'], + db=db_config['database'], + charset=charset, + minsize=0, # Don't pre-create connections + maxsize=self.maxsize, + connect_timeout=self.connect_timeout, + autocommit=True, + pool_recycle=self.pool_recycle + ), + timeout=self.connect_timeout + 5 # Give extra time for pool creation + ) + self.logger.info(f"Successfully created pool for {db_config['user']}@{db_config['host']}:{db_config['port']}") + return pool + except asyncio.TimeoutError: + self.logger.error(f"Timeout creating pool for {db_config['user']}@{db_config['host']}:{db_config['port']}") + raise RuntimeError(f"Timeout creating connection pool for {db_config['user']}@{db_config['host']}:{db_config['port']}") + except Exception as e: + self.logger.error(f"Failed to create pool for {db_config['user']}@{db_config['host']}:{db_config['port']}: {type(e).__name__}: {e}") + raise + + async def get_connection_for_token(self, token: str, session_id: str) -> 'DorisConnection': + """Get a connection from the token's dedicated pool + + Args: + token: Authentication token + session_id: Session identifier for logging + + Returns: + DorisConnection wrapper + """ + pool, db_config = await self.get_pool_for_token(token) + + try: + connection = await asyncio.wait_for( + pool.acquire(), + timeout=self.connect_timeout + ) + + self.logger.debug(f"Session {session_id}: Acquired connection from token pool " + f"(user: {db_config['user']}@{db_config['host']})") + + return DorisConnection(connection, session_id, self.security_manager) + + except Exception as e: + self.logger.error(f"Session {session_id}: Failed to acquire connection from token pool: {e}") + raise + + async def release_connection_for_token(self, token: str, connection: 'DorisConnection'): + """Release a connection back to the token's dedicated pool + + Args: + token: Authentication token + connection: DorisConnection wrapper to release + """ + token_hash = self._get_token_hash(token) + + if token_hash in self.token_pools: + pool = self.token_pools[token_hash] + if pool and not pool.closed: + try: + pool.release(connection.connection) + except Exception as e: + self.logger.warning(f"Failed to release connection to token pool: {e}") + + async def cleanup_token_pools(self, max_idle_time: int = 3600): + """Clean up idle token connection pools + + Args: + max_idle_time: Maximum idle time in seconds before closing a pool + """ + async with self._token_pools_lock: + pools_to_remove = [] + + for token_hash, pool in self.token_pools.items(): + if pool and not pool.closed: + # Check if pool is idle (no active connections) + if pool.size == 0 and pool.freesize == 0: + pools_to_remove.append(token_hash) + elif pool and pool.closed: + pools_to_remove.append(token_hash) + + for token_hash in pools_to_remove: + try: + pool = self.token_pools.pop(token_hash, None) + if pool and not pool.closed: + pool.close() + await pool.wait_closed() + self.token_configs.pop(token_hash, None) + self._token_pool_locks.pop(token_hash, None) + self.logger.info(f"Cleaned up idle token pool (hash: {token_hash[:8]}...)") + except Exception as e: + self.logger.warning(f"Error cleaning up token pool: {e}") + + async def close_all_token_pools(self): + """Close all token connection pools (for shutdown)""" + # Use timeout to prevent blocking on lock acquisition during shutdown + try: + async with asyncio.timeout(5): # 5 second timeout for lock + async with self._token_pools_lock: + for token_hash, pool in list(self.token_pools.items()): + try: + if pool and not pool.closed: + pool.close() + # Use timeout for wait_closed to prevent hanging + try: + await asyncio.wait_for(pool.wait_closed(), timeout=2.0) + except asyncio.TimeoutError: + self.logger.warning(f"Timeout waiting for token pool to close (hash: {token_hash[:8]}...)") + self.logger.info(f"Closed token pool (hash: {token_hash[:8]}...)") + except Exception as e: + self.logger.warning(f"Error closing token pool: {e}") + + self.token_pools.clear() + self.token_configs.clear() + self._token_pool_locks.clear() + except asyncio.TimeoutError: + self.logger.warning("Timeout acquiring lock for token pool cleanup, forcing clear") + # Force clear without lock + self.token_pools.clear() + self.token_configs.clear() + self._token_pool_locks.clear() + async def configure_for_token(self, token: str) -> tuple[bool, str]: """Configure connection manager for token with new priority logic @@ -1092,7 +1375,26 @@ class DorisConnectionManager: Uses only semaphore to prevent too many concurrent acquisitions. If the connection is successfully obtained, it will be added to the connection pool cache. + + 🔧 FIX for token isolation: Now automatically checks for auth_context from ContextVar + and uses token-specific connection pool if available. """ + # 🔧 FIX: Check for auth_context from global ContextVar + # This ensures all tools using get_connection respect token-bound database configuration + auth_context = None + try: + from .security import mcp_auth_context_var + auth_context = mcp_auth_context_var.get() + except Exception as e: + self.logger.debug(f"get_connection: Could not get auth_context: {e}") + + if auth_context and hasattr(auth_context, 'token') and auth_context.token: + # Use token-specific connection pool + # SECURITY: Do NOT catch exceptions here - if token pool fails, don't fallback to global pool + # This prevents privilege escalation + self.logger.debug(f"get_connection: Using token-specific pool for session {session_id}") + return await self.get_connection_for_token(auth_context.token, session_id) + cached_conn = self.session_cache.get(session_id) if cached_conn: return cached_conn @@ -1239,10 +1541,16 @@ class DorisConnectionManager: except asyncio.CancelledError: pass - # Close connection pool + # 🔧 FIX: Close all per-token connection pools + await self.close_all_token_pools() + + # Close global connection pool with timeout if self.pool: self.pool.close() - await self.pool.wait_closed() + try: + await asyncio.wait_for(self.pool.wait_closed(), timeout=5.0) + except asyncio.TimeoutError: + self.logger.warning("Timeout waiting for global pool to close") self.logger.info("Connection manager closed successfully") @@ -1271,37 +1579,43 @@ class DorisConnectionManager: async def execute_query( self, session_id: str, sql: str, params: tuple | None = None, auth_context=None ) -> QueryResult: - """Execute query - Simplified Strategy with automatic connection management + """Execute query - Enhanced Strategy with per-token connection pool isolation - FIX for Issue #62 Bug 1: Configure token-bound database before query execution + FIX for multi-tenant concurrency: Each token now uses its own dedicated connection pool + to prevent configuration conflicts between concurrent requests from different tokens. """ connection = None + token = None + try: - # FIX: Configure database for token BEFORE getting connection - # This ensures token-bound database configuration is used instead of global config + # Check if we have a token for per-token pool isolation if auth_context and hasattr(auth_context, 'token') and auth_context.token: + token = auth_context.token + try: - success, config_source = await self.configure_for_token(auth_context.token) - if success: - self.logger.info(f"Session {session_id}: Using {config_source} database configuration") - else: - self.logger.warning(f"Session {session_id}: Token configuration failed, may use global config") - except Exception as token_config_error: - # SECURITY: If token should have config but configuration fails, don't fallback + # 🔧 FIX: Use dedicated connection pool for this token + # This prevents concurrent requests from different tokens interfering + connection = await self.get_connection_for_token(token, session_id) + + # Get the config for logging + token_hash = self._get_token_hash(token) + if token_hash in self.token_configs: + db_config = self.token_configs[token_hash] + self.logger.info(f"Session {session_id}: Using dedicated pool for {db_config['user']}@{db_config['host']}") + + except Exception as token_pool_error: + # SECURITY: If token should have pool but creation fails, don't fallback # This prevents privilege escalation (using high-privilege default user) - if self.token_manager: - self.logger.error(f"Session {session_id}: Token database configuration failed: {token_config_error}") - raise RuntimeError( - f"Failed to configure database for authenticated token. " - f"This is a security measure to prevent using default high-privilege credentials. " - f"Error: {token_config_error}" - ) - else: - # No token manager, can use global config - self.logger.warning(f"Session {session_id}: No token manager, using global config") - - # Always get fresh connection from pool (with configured database) - connection = await self.get_connection(session_id) + self.logger.error(f"Session {session_id}: Token pool error: {token_pool_error}") + raise RuntimeError( + f"Failed to get connection for authenticated token. " + f"This is a security measure to prevent using default high-privilege credentials. " + f"Error: {token_pool_error}" + ) + else: + # No token - use global pool (backward compatibility) + self.logger.debug(f"Session {session_id}: No token, using global connection pool") + connection = await self.get_connection(session_id) # Execute query result = await connection.execute(sql, params, auth_context) @@ -1312,9 +1626,12 @@ class DorisConnectionManager: self.logger.error(f"Query execution failed for session {session_id}: {e}") raise finally: - # Always release connection back to pool + # Always release connection back to the appropriate pool if connection: - await self.release_connection(session_id, connection) + if token: + await self.release_connection_for_token(token, connection) + else: + await self.release_connection(session_id, connection) @asynccontextmanager async def get_connection_context(self, session_id: str): diff --git a/doris_mcp_server/utils/query_executor.py b/doris_mcp_server/utils/query_executor.py index 1e15c2c..75abf34 100644 --- a/doris_mcp_server/utils/query_executor.py +++ b/doris_mcp_server/utils/query_executor.py @@ -639,7 +639,13 @@ class DorisQueryExecutor: if hasattr(self.connection_manager, 'config') and hasattr(self.connection_manager.config, 'security'): if self.connection_manager.config.security.enable_security_check: try: - security_manager = DorisSecurityManager(self.connection_manager.config) + # 🔧 FIX: Use existing security_manager to avoid creating multiple TokenManager instances + # Creating new DorisSecurityManager each time causes multiple hot reload monitors + security_manager = getattr(self.connection_manager, 'security_manager', None) + if not security_manager: + # Fallback: create new one only if not available (should rarely happen) + self.logger.warning("No existing security_manager, creating new instance") + security_manager = DorisSecurityManager(self.connection_manager.config) validation_result = await security_manager.validate_sql_security(sql, auth_context) if not validation_result.is_valid: diff --git a/doris_mcp_server/utils/schema_extractor.py b/doris_mcp_server/utils/schema_extractor.py index f46e659..c4fabf9 100644 --- a/doris_mcp_server/utils/schema_extractor.py +++ b/doris_mcp_server/utils/schema_extractor.py @@ -1194,8 +1194,17 @@ class MetadataExtractor: """ try: if self.connection_manager: + # FIX: Get auth_context from global ContextVar for token-bound database configuration + # This ensures all query methods use the correct user's connection pool + auth_context = None + try: + from .security import mcp_auth_context_var + auth_context = mcp_auth_context_var.get() + except Exception: + pass + # Use the injected connection manager directly (async) - result = await self.connection_manager.execute_query(self._session_id, query, None) + result = await self.connection_manager.execute_query(self._session_id, query, None, auth_context) # Extract data from QueryResult if hasattr(result, 'data'): @@ -1644,15 +1653,14 @@ class MetadataExtractor: # FIX: Try to get auth_context from context variable (set by HTTP middleware) # This allows token-bound database configuration to work + # CRITICAL: Use the global ContextVar from security.py to ensure same instance is used everywhere auth_context = None try: - from contextvars import ContextVar - from .security import AuthContext + from .security import mcp_auth_context_var - # Try to get auth_context from context variable + # Get auth_context from the global context variable # This will be set by the HTTP request handler in main.py - auth_context_var: ContextVar = ContextVar('mcp_auth_context', default=None) - auth_context = auth_context_var.get() + auth_context = mcp_auth_context_var.get() if auth_context: logger.debug(f"Retrieved auth_context from context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}") diff --git a/doris_mcp_server/utils/security.py b/doris_mcp_server/utils/security.py index d1efc43..d7c070d 100644 --- a/doris_mcp_server/utils/security.py +++ b/doris_mcp_server/utils/security.py @@ -22,6 +22,7 @@ Implements enterprise-level authentication, authorization, SQL security validati import logging import re +from contextvars import ContextVar from dataclasses import dataclass, field from datetime import datetime from enum import Enum @@ -34,6 +35,10 @@ from sqlparse.tokens import Keyword, Name from .logger import get_logger from .config import DatabaseConfig +# Global ContextVar for auth_context - must be a single instance shared across all modules +# This allows token-bound database configuration to work correctly in concurrent requests +mcp_auth_context_var: ContextVar['AuthContext'] = ContextVar('mcp_auth_context', default=None) + class SecurityLevel(Enum): """Security level enumeration"""