[fix]fix token auth (#69)

* fix tocken auth

* Further fixes to the token overwriting issue and restoration of hot reloading of tokens.json.
This commit is contained in:
bingquanzhao
2025-12-24 20:39:16 +08:00
committed by GitHub
parent 43143f0b30
commit 81305ffbf9
6 changed files with 384 additions and 40 deletions

View File

@@ -644,10 +644,10 @@ class DorisServer:
# FIX for Issue #62 Bug 1: Set auth_context in context variable # FIX for Issue #62 Bug 1: Set auth_context in context variable
# This allows tools to access token information for token-bound database configuration # This allows tools to access token information for token-bound database configuration
# CRITICAL: Use the global ContextVar from security.py to ensure same instance is used everywhere
try: try:
from contextvars import ContextVar from .utils.security import mcp_auth_context_var
auth_context_var: ContextVar = ContextVar('mcp_auth_context', default=None) mcp_auth_context_var.set(auth_context)
auth_context_var.set(auth_context)
self.logger.debug(f"Set auth_context in context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}") self.logger.debug(f"Set auth_context in context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
except Exception as ctx_error: except Exception as ctx_error:
self.logger.warning(f"Failed to set auth_context in context variable: {ctx_error}") self.logger.warning(f"Failed to set auth_context in context variable: {ctx_error}")

View File

@@ -397,6 +397,14 @@ class SQLAnalyzer:
logger.info(f"Generating SQL explain for query ID: {query_id}") logger.info(f"Generating SQL explain for query ID: {query_id}")
# 🔧 FIX: Get auth_context for token-bound database configuration
auth_context = None
try:
from .security import mcp_auth_context_var
auth_context = mcp_auth_context_var.get()
except Exception:
pass
# Switch database if specified # Switch database if specified
# SECURITY FIX: Validate and quote db_name # SECURITY FIX: Validate and quote db_name
if db_name: if db_name:
@@ -405,7 +413,7 @@ class SQLAnalyzer:
except SQLSecurityError as e: except SQLSecurityError as e:
return {"success": False, "error": f"Invalid database name: {e}"} return {"success": False, "error": f"Invalid database name: {e}"}
safe_db = quote_identifier(db_name, "database name") safe_db = quote_identifier(db_name, "database name")
await self.connection_manager.execute_query("explain_session", f"USE {safe_db}") await self.connection_manager.execute_query("explain_session", f"USE {safe_db}", None, auth_context)
# Construct EXPLAIN query # Construct EXPLAIN query
explain_type = "EXPLAIN VERBOSE" if verbose else "EXPLAIN" explain_type = "EXPLAIN VERBOSE" if verbose else "EXPLAIN"
@@ -414,7 +422,7 @@ class SQLAnalyzer:
logger.info(f"Executing explain query: {explain_sql}") logger.info(f"Executing explain query: {explain_sql}")
# Execute explain query # Execute explain query
result = await self.connection_manager.execute_query("explain_session", explain_sql) result = await self.connection_manager.execute_query("explain_session", explain_sql, None, auth_context)
# Format explain output # Format explain output
explain_content = [] explain_content = []

View File

@@ -255,6 +255,13 @@ class DorisConnectionManager:
self.security_manager = security_manager self.security_manager = security_manager
self.token_manager = token_manager # Token manager for token-bound DB config self.token_manager = token_manager # Token manager for token-bound DB config
# 🔧 FIX for multi-tenant concurrency: Per-token connection pool isolation
# Each token gets its own connection pool to prevent configuration conflicts
self.token_pools: Dict[str, Pool] = {} # token_hash -> pool
self.token_configs: Dict[str, dict] = {} # token_hash -> db_config
self._token_pool_locks: Dict[str, asyncio.Lock] = {} # token_hash -> lock
self._token_pools_lock = asyncio.Lock() # Lock for managing token_pools dict
# FIX for Issue #58 Problem 1: Disable session caching to prevent connection sharing # FIX for Issue #58 Problem 1: Disable session caching to prevent connection sharing
# Session caching causes multiple threads to share the same MySQL connection, # Session caching causes multiple threads to share the same MySQL connection,
# leading to race conditions and deadlocks in multi-threaded environments # leading to race conditions and deadlocks in multi-threaded environments
@@ -276,6 +283,7 @@ class DorisConnectionManager:
} }
# Current active database config (may be overridden by token-bound config) # Current active database config (may be overridden by token-bound config)
# NOTE: This is kept for backward compatibility with non-token requests
self.active_db_config = self.original_db_config.copy() self.active_db_config = self.original_db_config.copy()
# Connection pool state management # Connection pool state management
@@ -359,6 +367,281 @@ class DorisConnectionManager:
self.logger.error(f"Error finding available token: {e}") self.logger.error(f"Error finding available token: {e}")
return "" return ""
def _get_token_hash(self, token: str) -> str:
"""Get hash of token for use as dictionary key"""
import hashlib
return hashlib.sha256(token.encode()).hexdigest()[:16]
def _get_current_token_db_config(self, token: str) -> dict | None:
"""Get current database config for token from TokenManager
This is used to check if config has changed for hot reload support.
"""
if not self.token_manager:
return None
token_db_config = self.token_manager.get_database_config_by_token(token)
if token_db_config:
return {
'host': token_db_config.host,
'port': token_db_config.port,
'user': token_db_config.user,
'password': token_db_config.password,
'database': token_db_config.database,
'charset': token_db_config.charset
}
return None
def _config_changed(self, old_config: dict, new_config: dict) -> bool:
"""Check if database configuration has changed"""
if old_config is None or new_config is None:
return old_config != new_config
# Compare key fields
for key in ['host', 'port', 'user', 'password', 'database']:
if old_config.get(key) != new_config.get(key):
return True
return False
async def get_pool_for_token(self, token: str) -> tuple[Pool, dict]:
"""Get or create a dedicated connection pool for a specific token
This method implements per-token connection pool isolation to prevent
concurrent requests from different tokens interfering with each other.
🔧 FIX: Supports hot reload - if tokens.json config changes,
the old pool is closed and a new one is created automatically.
Args:
token: Authentication token
Returns:
(pool, db_config): The dedicated pool and its configuration
Raises:
RuntimeError: If no valid database configuration is available
"""
token_hash = self._get_token_hash(token)
# Fast path: pool already exists
if token_hash in self.token_pools:
pool = self.token_pools[token_hash]
cached_config = self.token_configs.get(token_hash)
# 🔧 FIX: Check if config has changed (hot reload support)
current_config = self._get_current_token_db_config(token)
if current_config and cached_config and self._config_changed(cached_config, current_config):
self.logger.info(f"Token config changed (hash: {token_hash[:8]}...), recreating pool...")
# Config changed, need to recreate pool
async with self._token_pools_lock:
# Close old pool
old_pool = self.token_pools.pop(token_hash, None)
if old_pool and not old_pool.closed:
try:
old_pool.close()
await asyncio.wait_for(old_pool.wait_closed(), timeout=2.0)
except Exception as e:
self.logger.warning(f"Error closing old pool during hot reload: {e}")
self.token_configs.pop(token_hash, None)
# Continue to slow path to create new pool
elif pool and not pool.closed:
return pool, cached_config
# Slow path: need to create pool (with lock to prevent race conditions)
async with self._token_pools_lock:
# Double-check after acquiring lock
if token_hash in self.token_pools:
pool = self.token_pools[token_hash]
if pool and not pool.closed:
return pool, self.token_configs[token_hash]
# Get database config for this token
db_config = None
config_source = "unknown"
if self.token_manager:
token_db_config = self.token_manager.get_database_config_by_token(token)
if token_db_config:
db_config = {
'host': token_db_config.host,
'port': token_db_config.port,
'user': token_db_config.user,
'password': token_db_config.password,
'database': token_db_config.database,
'charset': token_db_config.charset
}
config_source = "token-bound"
# Fallback to global config if token has no specific config
if not db_config or self._is_config_empty(db_config.get('host')) or self._is_config_empty(db_config.get('user')):
if self._has_valid_global_config():
db_config = self.original_db_config.copy()
config_source = "global-env"
else:
raise RuntimeError(
f"No valid database configuration available for token. "
f"Please configure database in tokens.json or .env file."
)
# Create dedicated pool for this token
self.logger.info(f"Creating dedicated connection pool for token (hash: {token_hash[:8]}...) "
f"using {config_source} config: {db_config['user']}@{db_config['host']}:{db_config['port']}")
pool = await self._create_pool_with_config(db_config)
# Store pool and config
self.token_pools[token_hash] = pool
self.token_configs[token_hash] = db_config
# Create lock for this token if not exists
if token_hash not in self._token_pool_locks:
self._token_pool_locks[token_hash] = asyncio.Lock()
return pool, db_config
async def _create_pool_with_config(self, db_config: dict) -> Pool:
"""Create a connection pool with specified configuration
Args:
db_config: Database configuration dictionary
Returns:
Created connection pool
"""
# Convert charset to aiomysql compatible format
charset_map = {"UTF8": "utf8", "UTF8MB4": "utf8mb4"}
charset = charset_map.get(db_config['charset'].upper(), db_config['charset'].lower())
self.logger.debug(f"Creating pool for {db_config['user']}@{db_config['host']}:{db_config['port']}/{db_config['database']}")
try:
pool = await asyncio.wait_for(
aiomysql.create_pool(
host=db_config['host'],
port=db_config['port'],
user=db_config['user'],
password=db_config['password'],
db=db_config['database'],
charset=charset,
minsize=0, # Don't pre-create connections
maxsize=self.maxsize,
connect_timeout=self.connect_timeout,
autocommit=True,
pool_recycle=self.pool_recycle
),
timeout=self.connect_timeout + 5 # Give extra time for pool creation
)
self.logger.info(f"Successfully created pool for {db_config['user']}@{db_config['host']}:{db_config['port']}")
return pool
except asyncio.TimeoutError:
self.logger.error(f"Timeout creating pool for {db_config['user']}@{db_config['host']}:{db_config['port']}")
raise RuntimeError(f"Timeout creating connection pool for {db_config['user']}@{db_config['host']}:{db_config['port']}")
except Exception as e:
self.logger.error(f"Failed to create pool for {db_config['user']}@{db_config['host']}:{db_config['port']}: {type(e).__name__}: {e}")
raise
async def get_connection_for_token(self, token: str, session_id: str) -> 'DorisConnection':
"""Get a connection from the token's dedicated pool
Args:
token: Authentication token
session_id: Session identifier for logging
Returns:
DorisConnection wrapper
"""
pool, db_config = await self.get_pool_for_token(token)
try:
connection = await asyncio.wait_for(
pool.acquire(),
timeout=self.connect_timeout
)
self.logger.debug(f"Session {session_id}: Acquired connection from token pool "
f"(user: {db_config['user']}@{db_config['host']})")
return DorisConnection(connection, session_id, self.security_manager)
except Exception as e:
self.logger.error(f"Session {session_id}: Failed to acquire connection from token pool: {e}")
raise
async def release_connection_for_token(self, token: str, connection: 'DorisConnection'):
"""Release a connection back to the token's dedicated pool
Args:
token: Authentication token
connection: DorisConnection wrapper to release
"""
token_hash = self._get_token_hash(token)
if token_hash in self.token_pools:
pool = self.token_pools[token_hash]
if pool and not pool.closed:
try:
pool.release(connection.connection)
except Exception as e:
self.logger.warning(f"Failed to release connection to token pool: {e}")
async def cleanup_token_pools(self, max_idle_time: int = 3600):
"""Clean up idle token connection pools
Args:
max_idle_time: Maximum idle time in seconds before closing a pool
"""
async with self._token_pools_lock:
pools_to_remove = []
for token_hash, pool in self.token_pools.items():
if pool and not pool.closed:
# Check if pool is idle (no active connections)
if pool.size == 0 and pool.freesize == 0:
pools_to_remove.append(token_hash)
elif pool and pool.closed:
pools_to_remove.append(token_hash)
for token_hash in pools_to_remove:
try:
pool = self.token_pools.pop(token_hash, None)
if pool and not pool.closed:
pool.close()
await pool.wait_closed()
self.token_configs.pop(token_hash, None)
self._token_pool_locks.pop(token_hash, None)
self.logger.info(f"Cleaned up idle token pool (hash: {token_hash[:8]}...)")
except Exception as e:
self.logger.warning(f"Error cleaning up token pool: {e}")
async def close_all_token_pools(self):
"""Close all token connection pools (for shutdown)"""
# Use timeout to prevent blocking on lock acquisition during shutdown
try:
async with asyncio.timeout(5): # 5 second timeout for lock
async with self._token_pools_lock:
for token_hash, pool in list(self.token_pools.items()):
try:
if pool and not pool.closed:
pool.close()
# Use timeout for wait_closed to prevent hanging
try:
await asyncio.wait_for(pool.wait_closed(), timeout=2.0)
except asyncio.TimeoutError:
self.logger.warning(f"Timeout waiting for token pool to close (hash: {token_hash[:8]}...)")
self.logger.info(f"Closed token pool (hash: {token_hash[:8]}...)")
except Exception as e:
self.logger.warning(f"Error closing token pool: {e}")
self.token_pools.clear()
self.token_configs.clear()
self._token_pool_locks.clear()
except asyncio.TimeoutError:
self.logger.warning("Timeout acquiring lock for token pool cleanup, forcing clear")
# Force clear without lock
self.token_pools.clear()
self.token_configs.clear()
self._token_pool_locks.clear()
async def configure_for_token(self, token: str) -> tuple[bool, str]: async def configure_for_token(self, token: str) -> tuple[bool, str]:
"""Configure connection manager for token with new priority logic """Configure connection manager for token with new priority logic
@@ -1092,7 +1375,26 @@ class DorisConnectionManager:
Uses only semaphore to prevent too many concurrent acquisitions. Uses only semaphore to prevent too many concurrent acquisitions.
If the connection is successfully obtained, it will be added to the connection pool cache. If the connection is successfully obtained, it will be added to the connection pool cache.
🔧 FIX for token isolation: Now automatically checks for auth_context from ContextVar
and uses token-specific connection pool if available.
""" """
# 🔧 FIX: Check for auth_context from global ContextVar
# This ensures all tools using get_connection respect token-bound database configuration
auth_context = None
try:
from .security import mcp_auth_context_var
auth_context = mcp_auth_context_var.get()
except Exception as e:
self.logger.debug(f"get_connection: Could not get auth_context: {e}")
if auth_context and hasattr(auth_context, 'token') and auth_context.token:
# Use token-specific connection pool
# SECURITY: Do NOT catch exceptions here - if token pool fails, don't fallback to global pool
# This prevents privilege escalation
self.logger.debug(f"get_connection: Using token-specific pool for session {session_id}")
return await self.get_connection_for_token(auth_context.token, session_id)
cached_conn = self.session_cache.get(session_id) cached_conn = self.session_cache.get(session_id)
if cached_conn: if cached_conn:
return cached_conn return cached_conn
@@ -1239,10 +1541,16 @@ class DorisConnectionManager:
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
# Close connection pool # 🔧 FIX: Close all per-token connection pools
await self.close_all_token_pools()
# Close global connection pool with timeout
if self.pool: if self.pool:
self.pool.close() self.pool.close()
await self.pool.wait_closed() try:
await asyncio.wait_for(self.pool.wait_closed(), timeout=5.0)
except asyncio.TimeoutError:
self.logger.warning("Timeout waiting for global pool to close")
self.logger.info("Connection manager closed successfully") self.logger.info("Connection manager closed successfully")
@@ -1271,37 +1579,43 @@ class DorisConnectionManager:
async def execute_query( async def execute_query(
self, session_id: str, sql: str, params: tuple | None = None, auth_context=None self, session_id: str, sql: str, params: tuple | None = None, auth_context=None
) -> QueryResult: ) -> QueryResult:
"""Execute query - Simplified Strategy with automatic connection management """Execute query - Enhanced Strategy with per-token connection pool isolation
FIX for Issue #62 Bug 1: Configure token-bound database before query execution FIX for multi-tenant concurrency: Each token now uses its own dedicated connection pool
to prevent configuration conflicts between concurrent requests from different tokens.
""" """
connection = None connection = None
token = None
try: try:
# FIX: Configure database for token BEFORE getting connection # Check if we have a token for per-token pool isolation
# This ensures token-bound database configuration is used instead of global config
if auth_context and hasattr(auth_context, 'token') and auth_context.token: if auth_context and hasattr(auth_context, 'token') and auth_context.token:
token = auth_context.token
try: try:
success, config_source = await self.configure_for_token(auth_context.token) # 🔧 FIX: Use dedicated connection pool for this token
if success: # This prevents concurrent requests from different tokens interfering
self.logger.info(f"Session {session_id}: Using {config_source} database configuration") connection = await self.get_connection_for_token(token, session_id)
else:
self.logger.warning(f"Session {session_id}: Token configuration failed, may use global config") # Get the config for logging
except Exception as token_config_error: token_hash = self._get_token_hash(token)
# SECURITY: If token should have config but configuration fails, don't fallback if token_hash in self.token_configs:
db_config = self.token_configs[token_hash]
self.logger.info(f"Session {session_id}: Using dedicated pool for {db_config['user']}@{db_config['host']}")
except Exception as token_pool_error:
# SECURITY: If token should have pool but creation fails, don't fallback
# This prevents privilege escalation (using high-privilege default user) # This prevents privilege escalation (using high-privilege default user)
if self.token_manager: self.logger.error(f"Session {session_id}: Token pool error: {token_pool_error}")
self.logger.error(f"Session {session_id}: Token database configuration failed: {token_config_error}") raise RuntimeError(
raise RuntimeError( f"Failed to get connection for authenticated token. "
f"Failed to configure database for authenticated token. " f"This is a security measure to prevent using default high-privilege credentials. "
f"This is a security measure to prevent using default high-privilege credentials. " f"Error: {token_pool_error}"
f"Error: {token_config_error}" )
) else:
else: # No token - use global pool (backward compatibility)
# No token manager, can use global config self.logger.debug(f"Session {session_id}: No token, using global connection pool")
self.logger.warning(f"Session {session_id}: No token manager, using global config") connection = await self.get_connection(session_id)
# Always get fresh connection from pool (with configured database)
connection = await self.get_connection(session_id)
# Execute query # Execute query
result = await connection.execute(sql, params, auth_context) result = await connection.execute(sql, params, auth_context)
@@ -1312,9 +1626,12 @@ class DorisConnectionManager:
self.logger.error(f"Query execution failed for session {session_id}: {e}") self.logger.error(f"Query execution failed for session {session_id}: {e}")
raise raise
finally: finally:
# Always release connection back to pool # Always release connection back to the appropriate pool
if connection: if connection:
await self.release_connection(session_id, connection) if token:
await self.release_connection_for_token(token, connection)
else:
await self.release_connection(session_id, connection)
@asynccontextmanager @asynccontextmanager
async def get_connection_context(self, session_id: str): async def get_connection_context(self, session_id: str):

View File

@@ -639,7 +639,13 @@ class DorisQueryExecutor:
if hasattr(self.connection_manager, 'config') and hasattr(self.connection_manager.config, 'security'): if hasattr(self.connection_manager, 'config') and hasattr(self.connection_manager.config, 'security'):
if self.connection_manager.config.security.enable_security_check: if self.connection_manager.config.security.enable_security_check:
try: try:
security_manager = DorisSecurityManager(self.connection_manager.config) # 🔧 FIX: Use existing security_manager to avoid creating multiple TokenManager instances
# Creating new DorisSecurityManager each time causes multiple hot reload monitors
security_manager = getattr(self.connection_manager, 'security_manager', None)
if not security_manager:
# Fallback: create new one only if not available (should rarely happen)
self.logger.warning("No existing security_manager, creating new instance")
security_manager = DorisSecurityManager(self.connection_manager.config)
validation_result = await security_manager.validate_sql_security(sql, auth_context) validation_result = await security_manager.validate_sql_security(sql, auth_context)
if not validation_result.is_valid: if not validation_result.is_valid:

View File

@@ -1194,8 +1194,17 @@ class MetadataExtractor:
""" """
try: try:
if self.connection_manager: if self.connection_manager:
# FIX: Get auth_context from global ContextVar for token-bound database configuration
# This ensures all query methods use the correct user's connection pool
auth_context = None
try:
from .security import mcp_auth_context_var
auth_context = mcp_auth_context_var.get()
except Exception:
pass
# Use the injected connection manager directly (async) # Use the injected connection manager directly (async)
result = await self.connection_manager.execute_query(self._session_id, query, None) result = await self.connection_manager.execute_query(self._session_id, query, None, auth_context)
# Extract data from QueryResult # Extract data from QueryResult
if hasattr(result, 'data'): if hasattr(result, 'data'):
@@ -1644,15 +1653,14 @@ class MetadataExtractor:
# FIX: Try to get auth_context from context variable (set by HTTP middleware) # FIX: Try to get auth_context from context variable (set by HTTP middleware)
# This allows token-bound database configuration to work # This allows token-bound database configuration to work
# CRITICAL: Use the global ContextVar from security.py to ensure same instance is used everywhere
auth_context = None auth_context = None
try: try:
from contextvars import ContextVar from .security import mcp_auth_context_var
from .security import AuthContext
# Try to get auth_context from context variable # Get auth_context from the global context variable
# This will be set by the HTTP request handler in main.py # This will be set by the HTTP request handler in main.py
auth_context_var: ContextVar = ContextVar('mcp_auth_context', default=None) auth_context = mcp_auth_context_var.get()
auth_context = auth_context_var.get()
if auth_context: if auth_context:
logger.debug(f"Retrieved auth_context from context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}") logger.debug(f"Retrieved auth_context from context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")

View File

@@ -22,6 +22,7 @@ Implements enterprise-level authentication, authorization, SQL security validati
import logging import logging
import re import re
from contextvars import ContextVar
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
@@ -34,6 +35,10 @@ from sqlparse.tokens import Keyword, Name
from .logger import get_logger from .logger import get_logger
from .config import DatabaseConfig from .config import DatabaseConfig
# Global ContextVar for auth_context - must be a single instance shared across all modules
# This allows token-bound database configuration to work correctly in concurrent requests
mcp_auth_context_var: ContextVar['AuthContext'] = ContextVar('mcp_auth_context', default=None)
class SecurityLevel(Enum): class SecurityLevel(Enum):
"""Security level enumeration""" """Security level enumeration"""