[fix]fix token auth (#69)
* fix tocken auth * Further fixes to the token overwriting issue and restoration of hot reloading of tokens.json.
This commit is contained in:
@@ -255,6 +255,13 @@ class DorisConnectionManager:
|
||||
self.security_manager = security_manager
|
||||
self.token_manager = token_manager # Token manager for token-bound DB config
|
||||
|
||||
# 🔧 FIX for multi-tenant concurrency: Per-token connection pool isolation
|
||||
# Each token gets its own connection pool to prevent configuration conflicts
|
||||
self.token_pools: Dict[str, Pool] = {} # token_hash -> pool
|
||||
self.token_configs: Dict[str, dict] = {} # token_hash -> db_config
|
||||
self._token_pool_locks: Dict[str, asyncio.Lock] = {} # token_hash -> lock
|
||||
self._token_pools_lock = asyncio.Lock() # Lock for managing token_pools dict
|
||||
|
||||
# FIX for Issue #58 Problem 1: Disable session caching to prevent connection sharing
|
||||
# Session caching causes multiple threads to share the same MySQL connection,
|
||||
# leading to race conditions and deadlocks in multi-threaded environments
|
||||
@@ -276,6 +283,7 @@ class DorisConnectionManager:
|
||||
}
|
||||
|
||||
# Current active database config (may be overridden by token-bound config)
|
||||
# NOTE: This is kept for backward compatibility with non-token requests
|
||||
self.active_db_config = self.original_db_config.copy()
|
||||
|
||||
# Connection pool state management
|
||||
@@ -359,6 +367,281 @@ class DorisConnectionManager:
|
||||
self.logger.error(f"Error finding available token: {e}")
|
||||
return ""
|
||||
|
||||
def _get_token_hash(self, token: str) -> str:
|
||||
"""Get hash of token for use as dictionary key"""
|
||||
import hashlib
|
||||
return hashlib.sha256(token.encode()).hexdigest()[:16]
|
||||
|
||||
def _get_current_token_db_config(self, token: str) -> dict | None:
|
||||
"""Get current database config for token from TokenManager
|
||||
|
||||
This is used to check if config has changed for hot reload support.
|
||||
"""
|
||||
if not self.token_manager:
|
||||
return None
|
||||
|
||||
token_db_config = self.token_manager.get_database_config_by_token(token)
|
||||
if token_db_config:
|
||||
return {
|
||||
'host': token_db_config.host,
|
||||
'port': token_db_config.port,
|
||||
'user': token_db_config.user,
|
||||
'password': token_db_config.password,
|
||||
'database': token_db_config.database,
|
||||
'charset': token_db_config.charset
|
||||
}
|
||||
return None
|
||||
|
||||
def _config_changed(self, old_config: dict, new_config: dict) -> bool:
|
||||
"""Check if database configuration has changed"""
|
||||
if old_config is None or new_config is None:
|
||||
return old_config != new_config
|
||||
|
||||
# Compare key fields
|
||||
for key in ['host', 'port', 'user', 'password', 'database']:
|
||||
if old_config.get(key) != new_config.get(key):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_pool_for_token(self, token: str) -> tuple[Pool, dict]:
|
||||
"""Get or create a dedicated connection pool for a specific token
|
||||
|
||||
This method implements per-token connection pool isolation to prevent
|
||||
concurrent requests from different tokens interfering with each other.
|
||||
|
||||
🔧 FIX: Supports hot reload - if tokens.json config changes,
|
||||
the old pool is closed and a new one is created automatically.
|
||||
|
||||
Args:
|
||||
token: Authentication token
|
||||
|
||||
Returns:
|
||||
(pool, db_config): The dedicated pool and its configuration
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no valid database configuration is available
|
||||
"""
|
||||
token_hash = self._get_token_hash(token)
|
||||
|
||||
# Fast path: pool already exists
|
||||
if token_hash in self.token_pools:
|
||||
pool = self.token_pools[token_hash]
|
||||
cached_config = self.token_configs.get(token_hash)
|
||||
|
||||
# 🔧 FIX: Check if config has changed (hot reload support)
|
||||
current_config = self._get_current_token_db_config(token)
|
||||
if current_config and cached_config and self._config_changed(cached_config, current_config):
|
||||
self.logger.info(f"Token config changed (hash: {token_hash[:8]}...), recreating pool...")
|
||||
# Config changed, need to recreate pool
|
||||
async with self._token_pools_lock:
|
||||
# Close old pool
|
||||
old_pool = self.token_pools.pop(token_hash, None)
|
||||
if old_pool and not old_pool.closed:
|
||||
try:
|
||||
old_pool.close()
|
||||
await asyncio.wait_for(old_pool.wait_closed(), timeout=2.0)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error closing old pool during hot reload: {e}")
|
||||
self.token_configs.pop(token_hash, None)
|
||||
# Continue to slow path to create new pool
|
||||
elif pool and not pool.closed:
|
||||
return pool, cached_config
|
||||
|
||||
# Slow path: need to create pool (with lock to prevent race conditions)
|
||||
async with self._token_pools_lock:
|
||||
# Double-check after acquiring lock
|
||||
if token_hash in self.token_pools:
|
||||
pool = self.token_pools[token_hash]
|
||||
if pool and not pool.closed:
|
||||
return pool, self.token_configs[token_hash]
|
||||
|
||||
# Get database config for this token
|
||||
db_config = None
|
||||
config_source = "unknown"
|
||||
|
||||
if self.token_manager:
|
||||
token_db_config = self.token_manager.get_database_config_by_token(token)
|
||||
if token_db_config:
|
||||
db_config = {
|
||||
'host': token_db_config.host,
|
||||
'port': token_db_config.port,
|
||||
'user': token_db_config.user,
|
||||
'password': token_db_config.password,
|
||||
'database': token_db_config.database,
|
||||
'charset': token_db_config.charset
|
||||
}
|
||||
config_source = "token-bound"
|
||||
|
||||
# Fallback to global config if token has no specific config
|
||||
if not db_config or self._is_config_empty(db_config.get('host')) or self._is_config_empty(db_config.get('user')):
|
||||
if self._has_valid_global_config():
|
||||
db_config = self.original_db_config.copy()
|
||||
config_source = "global-env"
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"No valid database configuration available for token. "
|
||||
f"Please configure database in tokens.json or .env file."
|
||||
)
|
||||
|
||||
# Create dedicated pool for this token
|
||||
self.logger.info(f"Creating dedicated connection pool for token (hash: {token_hash[:8]}...) "
|
||||
f"using {config_source} config: {db_config['user']}@{db_config['host']}:{db_config['port']}")
|
||||
|
||||
pool = await self._create_pool_with_config(db_config)
|
||||
|
||||
# Store pool and config
|
||||
self.token_pools[token_hash] = pool
|
||||
self.token_configs[token_hash] = db_config
|
||||
|
||||
# Create lock for this token if not exists
|
||||
if token_hash not in self._token_pool_locks:
|
||||
self._token_pool_locks[token_hash] = asyncio.Lock()
|
||||
|
||||
return pool, db_config
|
||||
|
||||
async def _create_pool_with_config(self, db_config: dict) -> Pool:
|
||||
"""Create a connection pool with specified configuration
|
||||
|
||||
Args:
|
||||
db_config: Database configuration dictionary
|
||||
|
||||
Returns:
|
||||
Created connection pool
|
||||
"""
|
||||
# Convert charset to aiomysql compatible format
|
||||
charset_map = {"UTF8": "utf8", "UTF8MB4": "utf8mb4"}
|
||||
charset = charset_map.get(db_config['charset'].upper(), db_config['charset'].lower())
|
||||
|
||||
self.logger.debug(f"Creating pool for {db_config['user']}@{db_config['host']}:{db_config['port']}/{db_config['database']}")
|
||||
|
||||
try:
|
||||
pool = await asyncio.wait_for(
|
||||
aiomysql.create_pool(
|
||||
host=db_config['host'],
|
||||
port=db_config['port'],
|
||||
user=db_config['user'],
|
||||
password=db_config['password'],
|
||||
db=db_config['database'],
|
||||
charset=charset,
|
||||
minsize=0, # Don't pre-create connections
|
||||
maxsize=self.maxsize,
|
||||
connect_timeout=self.connect_timeout,
|
||||
autocommit=True,
|
||||
pool_recycle=self.pool_recycle
|
||||
),
|
||||
timeout=self.connect_timeout + 5 # Give extra time for pool creation
|
||||
)
|
||||
self.logger.info(f"Successfully created pool for {db_config['user']}@{db_config['host']}:{db_config['port']}")
|
||||
return pool
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.error(f"Timeout creating pool for {db_config['user']}@{db_config['host']}:{db_config['port']}")
|
||||
raise RuntimeError(f"Timeout creating connection pool for {db_config['user']}@{db_config['host']}:{db_config['port']}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to create pool for {db_config['user']}@{db_config['host']}:{db_config['port']}: {type(e).__name__}: {e}")
|
||||
raise
|
||||
|
||||
async def get_connection_for_token(self, token: str, session_id: str) -> 'DorisConnection':
|
||||
"""Get a connection from the token's dedicated pool
|
||||
|
||||
Args:
|
||||
token: Authentication token
|
||||
session_id: Session identifier for logging
|
||||
|
||||
Returns:
|
||||
DorisConnection wrapper
|
||||
"""
|
||||
pool, db_config = await self.get_pool_for_token(token)
|
||||
|
||||
try:
|
||||
connection = await asyncio.wait_for(
|
||||
pool.acquire(),
|
||||
timeout=self.connect_timeout
|
||||
)
|
||||
|
||||
self.logger.debug(f"Session {session_id}: Acquired connection from token pool "
|
||||
f"(user: {db_config['user']}@{db_config['host']})")
|
||||
|
||||
return DorisConnection(connection, session_id, self.security_manager)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Session {session_id}: Failed to acquire connection from token pool: {e}")
|
||||
raise
|
||||
|
||||
async def release_connection_for_token(self, token: str, connection: 'DorisConnection'):
|
||||
"""Release a connection back to the token's dedicated pool
|
||||
|
||||
Args:
|
||||
token: Authentication token
|
||||
connection: DorisConnection wrapper to release
|
||||
"""
|
||||
token_hash = self._get_token_hash(token)
|
||||
|
||||
if token_hash in self.token_pools:
|
||||
pool = self.token_pools[token_hash]
|
||||
if pool and not pool.closed:
|
||||
try:
|
||||
pool.release(connection.connection)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to release connection to token pool: {e}")
|
||||
|
||||
async def cleanup_token_pools(self, max_idle_time: int = 3600):
|
||||
"""Clean up idle token connection pools
|
||||
|
||||
Args:
|
||||
max_idle_time: Maximum idle time in seconds before closing a pool
|
||||
"""
|
||||
async with self._token_pools_lock:
|
||||
pools_to_remove = []
|
||||
|
||||
for token_hash, pool in self.token_pools.items():
|
||||
if pool and not pool.closed:
|
||||
# Check if pool is idle (no active connections)
|
||||
if pool.size == 0 and pool.freesize == 0:
|
||||
pools_to_remove.append(token_hash)
|
||||
elif pool and pool.closed:
|
||||
pools_to_remove.append(token_hash)
|
||||
|
||||
for token_hash in pools_to_remove:
|
||||
try:
|
||||
pool = self.token_pools.pop(token_hash, None)
|
||||
if pool and not pool.closed:
|
||||
pool.close()
|
||||
await pool.wait_closed()
|
||||
self.token_configs.pop(token_hash, None)
|
||||
self._token_pool_locks.pop(token_hash, None)
|
||||
self.logger.info(f"Cleaned up idle token pool (hash: {token_hash[:8]}...)")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error cleaning up token pool: {e}")
|
||||
|
||||
async def close_all_token_pools(self):
|
||||
"""Close all token connection pools (for shutdown)"""
|
||||
# Use timeout to prevent blocking on lock acquisition during shutdown
|
||||
try:
|
||||
async with asyncio.timeout(5): # 5 second timeout for lock
|
||||
async with self._token_pools_lock:
|
||||
for token_hash, pool in list(self.token_pools.items()):
|
||||
try:
|
||||
if pool and not pool.closed:
|
||||
pool.close()
|
||||
# Use timeout for wait_closed to prevent hanging
|
||||
try:
|
||||
await asyncio.wait_for(pool.wait_closed(), timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.warning(f"Timeout waiting for token pool to close (hash: {token_hash[:8]}...)")
|
||||
self.logger.info(f"Closed token pool (hash: {token_hash[:8]}...)")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error closing token pool: {e}")
|
||||
|
||||
self.token_pools.clear()
|
||||
self.token_configs.clear()
|
||||
self._token_pool_locks.clear()
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.warning("Timeout acquiring lock for token pool cleanup, forcing clear")
|
||||
# Force clear without lock
|
||||
self.token_pools.clear()
|
||||
self.token_configs.clear()
|
||||
self._token_pool_locks.clear()
|
||||
|
||||
async def configure_for_token(self, token: str) -> tuple[bool, str]:
|
||||
"""Configure connection manager for token with new priority logic
|
||||
|
||||
@@ -1092,7 +1375,26 @@ class DorisConnectionManager:
|
||||
|
||||
Uses only semaphore to prevent too many concurrent acquisitions.
|
||||
If the connection is successfully obtained, it will be added to the connection pool cache.
|
||||
|
||||
🔧 FIX for token isolation: Now automatically checks for auth_context from ContextVar
|
||||
and uses token-specific connection pool if available.
|
||||
"""
|
||||
# 🔧 FIX: Check for auth_context from global ContextVar
|
||||
# This ensures all tools using get_connection respect token-bound database configuration
|
||||
auth_context = None
|
||||
try:
|
||||
from .security import mcp_auth_context_var
|
||||
auth_context = mcp_auth_context_var.get()
|
||||
except Exception as e:
|
||||
self.logger.debug(f"get_connection: Could not get auth_context: {e}")
|
||||
|
||||
if auth_context and hasattr(auth_context, 'token') and auth_context.token:
|
||||
# Use token-specific connection pool
|
||||
# SECURITY: Do NOT catch exceptions here - if token pool fails, don't fallback to global pool
|
||||
# This prevents privilege escalation
|
||||
self.logger.debug(f"get_connection: Using token-specific pool for session {session_id}")
|
||||
return await self.get_connection_for_token(auth_context.token, session_id)
|
||||
|
||||
cached_conn = self.session_cache.get(session_id)
|
||||
if cached_conn:
|
||||
return cached_conn
|
||||
@@ -1239,10 +1541,16 @@ class DorisConnectionManager:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Close connection pool
|
||||
# 🔧 FIX: Close all per-token connection pools
|
||||
await self.close_all_token_pools()
|
||||
|
||||
# Close global connection pool with timeout
|
||||
if self.pool:
|
||||
self.pool.close()
|
||||
await self.pool.wait_closed()
|
||||
try:
|
||||
await asyncio.wait_for(self.pool.wait_closed(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.warning("Timeout waiting for global pool to close")
|
||||
|
||||
self.logger.info("Connection manager closed successfully")
|
||||
|
||||
@@ -1271,37 +1579,43 @@ class DorisConnectionManager:
|
||||
async def execute_query(
|
||||
self, session_id: str, sql: str, params: tuple | None = None, auth_context=None
|
||||
) -> QueryResult:
|
||||
"""Execute query - Simplified Strategy with automatic connection management
|
||||
"""Execute query - Enhanced Strategy with per-token connection pool isolation
|
||||
|
||||
FIX for Issue #62 Bug 1: Configure token-bound database before query execution
|
||||
FIX for multi-tenant concurrency: Each token now uses its own dedicated connection pool
|
||||
to prevent configuration conflicts between concurrent requests from different tokens.
|
||||
"""
|
||||
connection = None
|
||||
token = None
|
||||
|
||||
try:
|
||||
# FIX: Configure database for token BEFORE getting connection
|
||||
# This ensures token-bound database configuration is used instead of global config
|
||||
# Check if we have a token for per-token pool isolation
|
||||
if auth_context and hasattr(auth_context, 'token') and auth_context.token:
|
||||
token = auth_context.token
|
||||
|
||||
try:
|
||||
success, config_source = await self.configure_for_token(auth_context.token)
|
||||
if success:
|
||||
self.logger.info(f"Session {session_id}: Using {config_source} database configuration")
|
||||
else:
|
||||
self.logger.warning(f"Session {session_id}: Token configuration failed, may use global config")
|
||||
except Exception as token_config_error:
|
||||
# SECURITY: If token should have config but configuration fails, don't fallback
|
||||
# 🔧 FIX: Use dedicated connection pool for this token
|
||||
# This prevents concurrent requests from different tokens interfering
|
||||
connection = await self.get_connection_for_token(token, session_id)
|
||||
|
||||
# Get the config for logging
|
||||
token_hash = self._get_token_hash(token)
|
||||
if token_hash in self.token_configs:
|
||||
db_config = self.token_configs[token_hash]
|
||||
self.logger.info(f"Session {session_id}: Using dedicated pool for {db_config['user']}@{db_config['host']}")
|
||||
|
||||
except Exception as token_pool_error:
|
||||
# SECURITY: If token should have pool but creation fails, don't fallback
|
||||
# This prevents privilege escalation (using high-privilege default user)
|
||||
if self.token_manager:
|
||||
self.logger.error(f"Session {session_id}: Token database configuration failed: {token_config_error}")
|
||||
raise RuntimeError(
|
||||
f"Failed to configure database for authenticated token. "
|
||||
f"This is a security measure to prevent using default high-privilege credentials. "
|
||||
f"Error: {token_config_error}"
|
||||
)
|
||||
else:
|
||||
# No token manager, can use global config
|
||||
self.logger.warning(f"Session {session_id}: No token manager, using global config")
|
||||
|
||||
# Always get fresh connection from pool (with configured database)
|
||||
connection = await self.get_connection(session_id)
|
||||
self.logger.error(f"Session {session_id}: Token pool error: {token_pool_error}")
|
||||
raise RuntimeError(
|
||||
f"Failed to get connection for authenticated token. "
|
||||
f"This is a security measure to prevent using default high-privilege credentials. "
|
||||
f"Error: {token_pool_error}"
|
||||
)
|
||||
else:
|
||||
# No token - use global pool (backward compatibility)
|
||||
self.logger.debug(f"Session {session_id}: No token, using global connection pool")
|
||||
connection = await self.get_connection(session_id)
|
||||
|
||||
# Execute query
|
||||
result = await connection.execute(sql, params, auth_context)
|
||||
@@ -1312,9 +1626,12 @@ class DorisConnectionManager:
|
||||
self.logger.error(f"Query execution failed for session {session_id}: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Always release connection back to pool
|
||||
# Always release connection back to the appropriate pool
|
||||
if connection:
|
||||
await self.release_connection(session_id, connection)
|
||||
if token:
|
||||
await self.release_connection_for_token(token, connection)
|
||||
else:
|
||||
await self.release_connection(session_id, connection)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_connection_context(self, session_id: str):
|
||||
|
||||
Reference in New Issue
Block a user