[fix]fix token auth (#69)

* fix tocken auth

* Further fixes to the token overwriting issue and restoration of hot reloading of tokens.json.
This commit is contained in:
bingquanzhao
2025-12-24 20:39:16 +08:00
committed by GitHub
parent 43143f0b30
commit 81305ffbf9
6 changed files with 384 additions and 40 deletions

View File

@@ -644,10 +644,10 @@ class DorisServer:
# FIX for Issue #62 Bug 1: Set auth_context in context variable
# This allows tools to access token information for token-bound database configuration
# CRITICAL: Use the global ContextVar from security.py to ensure same instance is used everywhere
try:
from contextvars import ContextVar
auth_context_var: ContextVar = ContextVar('mcp_auth_context', default=None)
auth_context_var.set(auth_context)
from .utils.security import mcp_auth_context_var
mcp_auth_context_var.set(auth_context)
self.logger.debug(f"Set auth_context in context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
except Exception as ctx_error:
self.logger.warning(f"Failed to set auth_context in context variable: {ctx_error}")

View File

@@ -397,6 +397,14 @@ class SQLAnalyzer:
logger.info(f"Generating SQL explain for query ID: {query_id}")
# 🔧 FIX: Get auth_context for token-bound database configuration
auth_context = None
try:
from .security import mcp_auth_context_var
auth_context = mcp_auth_context_var.get()
except Exception:
pass
# Switch database if specified
# SECURITY FIX: Validate and quote db_name
if db_name:
@@ -405,7 +413,7 @@ class SQLAnalyzer:
except SQLSecurityError as e:
return {"success": False, "error": f"Invalid database name: {e}"}
safe_db = quote_identifier(db_name, "database name")
await self.connection_manager.execute_query("explain_session", f"USE {safe_db}")
await self.connection_manager.execute_query("explain_session", f"USE {safe_db}", None, auth_context)
# Construct EXPLAIN query
explain_type = "EXPLAIN VERBOSE" if verbose else "EXPLAIN"
@@ -414,7 +422,7 @@ class SQLAnalyzer:
logger.info(f"Executing explain query: {explain_sql}")
# Execute explain query
result = await self.connection_manager.execute_query("explain_session", explain_sql)
result = await self.connection_manager.execute_query("explain_session", explain_sql, None, auth_context)
# Format explain output
explain_content = []

View File

@@ -255,6 +255,13 @@ class DorisConnectionManager:
self.security_manager = security_manager
self.token_manager = token_manager # Token manager for token-bound DB config
# 🔧 FIX for multi-tenant concurrency: Per-token connection pool isolation
# Each token gets its own connection pool to prevent configuration conflicts
self.token_pools: Dict[str, Pool] = {} # token_hash -> pool
self.token_configs: Dict[str, dict] = {} # token_hash -> db_config
self._token_pool_locks: Dict[str, asyncio.Lock] = {} # token_hash -> lock
self._token_pools_lock = asyncio.Lock() # Lock for managing token_pools dict
# FIX for Issue #58 Problem 1: Disable session caching to prevent connection sharing
# Session caching causes multiple threads to share the same MySQL connection,
# leading to race conditions and deadlocks in multi-threaded environments
@@ -276,6 +283,7 @@ class DorisConnectionManager:
}
# Current active database config (may be overridden by token-bound config)
# NOTE: This is kept for backward compatibility with non-token requests
self.active_db_config = self.original_db_config.copy()
# Connection pool state management
@@ -359,6 +367,281 @@ class DorisConnectionManager:
self.logger.error(f"Error finding available token: {e}")
return ""
def _get_token_hash(self, token: str) -> str:
"""Get hash of token for use as dictionary key"""
import hashlib
return hashlib.sha256(token.encode()).hexdigest()[:16]
def _get_current_token_db_config(self, token: str) -> dict | None:
"""Get current database config for token from TokenManager
This is used to check if config has changed for hot reload support.
"""
if not self.token_manager:
return None
token_db_config = self.token_manager.get_database_config_by_token(token)
if token_db_config:
return {
'host': token_db_config.host,
'port': token_db_config.port,
'user': token_db_config.user,
'password': token_db_config.password,
'database': token_db_config.database,
'charset': token_db_config.charset
}
return None
def _config_changed(self, old_config: dict, new_config: dict) -> bool:
"""Check if database configuration has changed"""
if old_config is None or new_config is None:
return old_config != new_config
# Compare key fields
for key in ['host', 'port', 'user', 'password', 'database']:
if old_config.get(key) != new_config.get(key):
return True
return False
async def get_pool_for_token(self, token: str) -> tuple[Pool, dict]:
"""Get or create a dedicated connection pool for a specific token
This method implements per-token connection pool isolation to prevent
concurrent requests from different tokens interfering with each other.
🔧 FIX: Supports hot reload - if tokens.json config changes,
the old pool is closed and a new one is created automatically.
Args:
token: Authentication token
Returns:
(pool, db_config): The dedicated pool and its configuration
Raises:
RuntimeError: If no valid database configuration is available
"""
token_hash = self._get_token_hash(token)
# Fast path: pool already exists
if token_hash in self.token_pools:
pool = self.token_pools[token_hash]
cached_config = self.token_configs.get(token_hash)
# 🔧 FIX: Check if config has changed (hot reload support)
current_config = self._get_current_token_db_config(token)
if current_config and cached_config and self._config_changed(cached_config, current_config):
self.logger.info(f"Token config changed (hash: {token_hash[:8]}...), recreating pool...")
# Config changed, need to recreate pool
async with self._token_pools_lock:
# Close old pool
old_pool = self.token_pools.pop(token_hash, None)
if old_pool and not old_pool.closed:
try:
old_pool.close()
await asyncio.wait_for(old_pool.wait_closed(), timeout=2.0)
except Exception as e:
self.logger.warning(f"Error closing old pool during hot reload: {e}")
self.token_configs.pop(token_hash, None)
# Continue to slow path to create new pool
elif pool and not pool.closed:
return pool, cached_config
# Slow path: need to create pool (with lock to prevent race conditions)
async with self._token_pools_lock:
# Double-check after acquiring lock
if token_hash in self.token_pools:
pool = self.token_pools[token_hash]
if pool and not pool.closed:
return pool, self.token_configs[token_hash]
# Get database config for this token
db_config = None
config_source = "unknown"
if self.token_manager:
token_db_config = self.token_manager.get_database_config_by_token(token)
if token_db_config:
db_config = {
'host': token_db_config.host,
'port': token_db_config.port,
'user': token_db_config.user,
'password': token_db_config.password,
'database': token_db_config.database,
'charset': token_db_config.charset
}
config_source = "token-bound"
# Fallback to global config if token has no specific config
if not db_config or self._is_config_empty(db_config.get('host')) or self._is_config_empty(db_config.get('user')):
if self._has_valid_global_config():
db_config = self.original_db_config.copy()
config_source = "global-env"
else:
raise RuntimeError(
f"No valid database configuration available for token. "
f"Please configure database in tokens.json or .env file."
)
# Create dedicated pool for this token
self.logger.info(f"Creating dedicated connection pool for token (hash: {token_hash[:8]}...) "
f"using {config_source} config: {db_config['user']}@{db_config['host']}:{db_config['port']}")
pool = await self._create_pool_with_config(db_config)
# Store pool and config
self.token_pools[token_hash] = pool
self.token_configs[token_hash] = db_config
# Create lock for this token if not exists
if token_hash not in self._token_pool_locks:
self._token_pool_locks[token_hash] = asyncio.Lock()
return pool, db_config
async def _create_pool_with_config(self, db_config: dict) -> Pool:
"""Create a connection pool with specified configuration
Args:
db_config: Database configuration dictionary
Returns:
Created connection pool
"""
# Convert charset to aiomysql compatible format
charset_map = {"UTF8": "utf8", "UTF8MB4": "utf8mb4"}
charset = charset_map.get(db_config['charset'].upper(), db_config['charset'].lower())
self.logger.debug(f"Creating pool for {db_config['user']}@{db_config['host']}:{db_config['port']}/{db_config['database']}")
try:
pool = await asyncio.wait_for(
aiomysql.create_pool(
host=db_config['host'],
port=db_config['port'],
user=db_config['user'],
password=db_config['password'],
db=db_config['database'],
charset=charset,
minsize=0, # Don't pre-create connections
maxsize=self.maxsize,
connect_timeout=self.connect_timeout,
autocommit=True,
pool_recycle=self.pool_recycle
),
timeout=self.connect_timeout + 5 # Give extra time for pool creation
)
self.logger.info(f"Successfully created pool for {db_config['user']}@{db_config['host']}:{db_config['port']}")
return pool
except asyncio.TimeoutError:
self.logger.error(f"Timeout creating pool for {db_config['user']}@{db_config['host']}:{db_config['port']}")
raise RuntimeError(f"Timeout creating connection pool for {db_config['user']}@{db_config['host']}:{db_config['port']}")
except Exception as e:
self.logger.error(f"Failed to create pool for {db_config['user']}@{db_config['host']}:{db_config['port']}: {type(e).__name__}: {e}")
raise
async def get_connection_for_token(self, token: str, session_id: str) -> 'DorisConnection':
"""Get a connection from the token's dedicated pool
Args:
token: Authentication token
session_id: Session identifier for logging
Returns:
DorisConnection wrapper
"""
pool, db_config = await self.get_pool_for_token(token)
try:
connection = await asyncio.wait_for(
pool.acquire(),
timeout=self.connect_timeout
)
self.logger.debug(f"Session {session_id}: Acquired connection from token pool "
f"(user: {db_config['user']}@{db_config['host']})")
return DorisConnection(connection, session_id, self.security_manager)
except Exception as e:
self.logger.error(f"Session {session_id}: Failed to acquire connection from token pool: {e}")
raise
async def release_connection_for_token(self, token: str, connection: 'DorisConnection'):
"""Release a connection back to the token's dedicated pool
Args:
token: Authentication token
connection: DorisConnection wrapper to release
"""
token_hash = self._get_token_hash(token)
if token_hash in self.token_pools:
pool = self.token_pools[token_hash]
if pool and not pool.closed:
try:
pool.release(connection.connection)
except Exception as e:
self.logger.warning(f"Failed to release connection to token pool: {e}")
async def cleanup_token_pools(self, max_idle_time: int = 3600):
"""Clean up idle token connection pools
Args:
max_idle_time: Maximum idle time in seconds before closing a pool
"""
async with self._token_pools_lock:
pools_to_remove = []
for token_hash, pool in self.token_pools.items():
if pool and not pool.closed:
# Check if pool is idle (no active connections)
if pool.size == 0 and pool.freesize == 0:
pools_to_remove.append(token_hash)
elif pool and pool.closed:
pools_to_remove.append(token_hash)
for token_hash in pools_to_remove:
try:
pool = self.token_pools.pop(token_hash, None)
if pool and not pool.closed:
pool.close()
await pool.wait_closed()
self.token_configs.pop(token_hash, None)
self._token_pool_locks.pop(token_hash, None)
self.logger.info(f"Cleaned up idle token pool (hash: {token_hash[:8]}...)")
except Exception as e:
self.logger.warning(f"Error cleaning up token pool: {e}")
async def close_all_token_pools(self):
"""Close all token connection pools (for shutdown)"""
# Use timeout to prevent blocking on lock acquisition during shutdown
try:
async with asyncio.timeout(5): # 5 second timeout for lock
async with self._token_pools_lock:
for token_hash, pool in list(self.token_pools.items()):
try:
if pool and not pool.closed:
pool.close()
# Use timeout for wait_closed to prevent hanging
try:
await asyncio.wait_for(pool.wait_closed(), timeout=2.0)
except asyncio.TimeoutError:
self.logger.warning(f"Timeout waiting for token pool to close (hash: {token_hash[:8]}...)")
self.logger.info(f"Closed token pool (hash: {token_hash[:8]}...)")
except Exception as e:
self.logger.warning(f"Error closing token pool: {e}")
self.token_pools.clear()
self.token_configs.clear()
self._token_pool_locks.clear()
except asyncio.TimeoutError:
self.logger.warning("Timeout acquiring lock for token pool cleanup, forcing clear")
# Force clear without lock
self.token_pools.clear()
self.token_configs.clear()
self._token_pool_locks.clear()
async def configure_for_token(self, token: str) -> tuple[bool, str]:
"""Configure connection manager for token with new priority logic
@@ -1092,7 +1375,26 @@ class DorisConnectionManager:
Uses only semaphore to prevent too many concurrent acquisitions.
If the connection is successfully obtained, it will be added to the connection pool cache.
🔧 FIX for token isolation: Now automatically checks for auth_context from ContextVar
and uses token-specific connection pool if available.
"""
# 🔧 FIX: Check for auth_context from global ContextVar
# This ensures all tools using get_connection respect token-bound database configuration
auth_context = None
try:
from .security import mcp_auth_context_var
auth_context = mcp_auth_context_var.get()
except Exception as e:
self.logger.debug(f"get_connection: Could not get auth_context: {e}")
if auth_context and hasattr(auth_context, 'token') and auth_context.token:
# Use token-specific connection pool
# SECURITY: Do NOT catch exceptions here - if token pool fails, don't fallback to global pool
# This prevents privilege escalation
self.logger.debug(f"get_connection: Using token-specific pool for session {session_id}")
return await self.get_connection_for_token(auth_context.token, session_id)
cached_conn = self.session_cache.get(session_id)
if cached_conn:
return cached_conn
@@ -1239,10 +1541,16 @@ class DorisConnectionManager:
except asyncio.CancelledError:
pass
# Close connection pool
# 🔧 FIX: Close all per-token connection pools
await self.close_all_token_pools()
# Close global connection pool with timeout
if self.pool:
self.pool.close()
await self.pool.wait_closed()
try:
await asyncio.wait_for(self.pool.wait_closed(), timeout=5.0)
except asyncio.TimeoutError:
self.logger.warning("Timeout waiting for global pool to close")
self.logger.info("Connection manager closed successfully")
@@ -1271,36 +1579,42 @@ class DorisConnectionManager:
async def execute_query(
self, session_id: str, sql: str, params: tuple | None = None, auth_context=None
) -> QueryResult:
"""Execute query - Simplified Strategy with automatic connection management
"""Execute query - Enhanced Strategy with per-token connection pool isolation
FIX for Issue #62 Bug 1: Configure token-bound database before query execution
FIX for multi-tenant concurrency: Each token now uses its own dedicated connection pool
to prevent configuration conflicts between concurrent requests from different tokens.
"""
connection = None
token = None
try:
# FIX: Configure database for token BEFORE getting connection
# This ensures token-bound database configuration is used instead of global config
# Check if we have a token for per-token pool isolation
if auth_context and hasattr(auth_context, 'token') and auth_context.token:
token = auth_context.token
try:
success, config_source = await self.configure_for_token(auth_context.token)
if success:
self.logger.info(f"Session {session_id}: Using {config_source} database configuration")
else:
self.logger.warning(f"Session {session_id}: Token configuration failed, may use global config")
except Exception as token_config_error:
# SECURITY: If token should have config but configuration fails, don't fallback
# 🔧 FIX: Use dedicated connection pool for this token
# This prevents concurrent requests from different tokens interfering
connection = await self.get_connection_for_token(token, session_id)
# Get the config for logging
token_hash = self._get_token_hash(token)
if token_hash in self.token_configs:
db_config = self.token_configs[token_hash]
self.logger.info(f"Session {session_id}: Using dedicated pool for {db_config['user']}@{db_config['host']}")
except Exception as token_pool_error:
# SECURITY: If token should have pool but creation fails, don't fallback
# This prevents privilege escalation (using high-privilege default user)
if self.token_manager:
self.logger.error(f"Session {session_id}: Token database configuration failed: {token_config_error}")
self.logger.error(f"Session {session_id}: Token pool error: {token_pool_error}")
raise RuntimeError(
f"Failed to configure database for authenticated token. "
f"Failed to get connection for authenticated token. "
f"This is a security measure to prevent using default high-privilege credentials. "
f"Error: {token_config_error}"
f"Error: {token_pool_error}"
)
else:
# No token manager, can use global config
self.logger.warning(f"Session {session_id}: No token manager, using global config")
# Always get fresh connection from pool (with configured database)
# No token - use global pool (backward compatibility)
self.logger.debug(f"Session {session_id}: No token, using global connection pool")
connection = await self.get_connection(session_id)
# Execute query
@@ -1312,8 +1626,11 @@ class DorisConnectionManager:
self.logger.error(f"Query execution failed for session {session_id}: {e}")
raise
finally:
# Always release connection back to pool
# Always release connection back to the appropriate pool
if connection:
if token:
await self.release_connection_for_token(token, connection)
else:
await self.release_connection(session_id, connection)
@asynccontextmanager

View File

@@ -639,6 +639,12 @@ class DorisQueryExecutor:
if hasattr(self.connection_manager, 'config') and hasattr(self.connection_manager.config, 'security'):
if self.connection_manager.config.security.enable_security_check:
try:
# 🔧 FIX: Use existing security_manager to avoid creating multiple TokenManager instances
# Creating new DorisSecurityManager each time causes multiple hot reload monitors
security_manager = getattr(self.connection_manager, 'security_manager', None)
if not security_manager:
# Fallback: create new one only if not available (should rarely happen)
self.logger.warning("No existing security_manager, creating new instance")
security_manager = DorisSecurityManager(self.connection_manager.config)
validation_result = await security_manager.validate_sql_security(sql, auth_context)

View File

@@ -1194,8 +1194,17 @@ class MetadataExtractor:
"""
try:
if self.connection_manager:
# FIX: Get auth_context from global ContextVar for token-bound database configuration
# This ensures all query methods use the correct user's connection pool
auth_context = None
try:
from .security import mcp_auth_context_var
auth_context = mcp_auth_context_var.get()
except Exception:
pass
# Use the injected connection manager directly (async)
result = await self.connection_manager.execute_query(self._session_id, query, None)
result = await self.connection_manager.execute_query(self._session_id, query, None, auth_context)
# Extract data from QueryResult
if hasattr(result, 'data'):
@@ -1644,15 +1653,14 @@ class MetadataExtractor:
# FIX: Try to get auth_context from context variable (set by HTTP middleware)
# This allows token-bound database configuration to work
# CRITICAL: Use the global ContextVar from security.py to ensure same instance is used everywhere
auth_context = None
try:
from contextvars import ContextVar
from .security import AuthContext
from .security import mcp_auth_context_var
# Try to get auth_context from context variable
# Get auth_context from the global context variable
# This will be set by the HTTP request handler in main.py
auth_context_var: ContextVar = ContextVar('mcp_auth_context', default=None)
auth_context = auth_context_var.get()
auth_context = mcp_auth_context_var.get()
if auth_context:
logger.debug(f"Retrieved auth_context from context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")

View File

@@ -22,6 +22,7 @@ Implements enterprise-level authentication, authorization, SQL security validati
import logging
import re
from contextvars import ContextVar
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
@@ -34,6 +35,10 @@ from sqlparse.tokens import Keyword, Name
from .logger import get_logger
from .config import DatabaseConfig
# Global ContextVar for auth_context - must be a single instance shared across all modules
# This allows token-bound database configuration to work correctly in concurrent requests
mcp_auth_context_var: ContextVar['AuthContext'] = ContextVar('mcp_auth_context', default=None)
class SecurityLevel(Enum):
"""Security level enumeration"""