[fix]fix token auth (#69)
* fix tocken auth * Further fixes to the token overwriting issue and restoration of hot reloading of tokens.json.
This commit is contained in:
@@ -644,10 +644,10 @@ class DorisServer:
|
|||||||
|
|
||||||
# FIX for Issue #62 Bug 1: Set auth_context in context variable
|
# FIX for Issue #62 Bug 1: Set auth_context in context variable
|
||||||
# This allows tools to access token information for token-bound database configuration
|
# This allows tools to access token information for token-bound database configuration
|
||||||
|
# CRITICAL: Use the global ContextVar from security.py to ensure same instance is used everywhere
|
||||||
try:
|
try:
|
||||||
from contextvars import ContextVar
|
from .utils.security import mcp_auth_context_var
|
||||||
auth_context_var: ContextVar = ContextVar('mcp_auth_context', default=None)
|
mcp_auth_context_var.set(auth_context)
|
||||||
auth_context_var.set(auth_context)
|
|
||||||
self.logger.debug(f"Set auth_context in context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
|
self.logger.debug(f"Set auth_context in context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
|
||||||
except Exception as ctx_error:
|
except Exception as ctx_error:
|
||||||
self.logger.warning(f"Failed to set auth_context in context variable: {ctx_error}")
|
self.logger.warning(f"Failed to set auth_context in context variable: {ctx_error}")
|
||||||
|
|||||||
@@ -397,6 +397,14 @@ class SQLAnalyzer:
|
|||||||
|
|
||||||
logger.info(f"Generating SQL explain for query ID: {query_id}")
|
logger.info(f"Generating SQL explain for query ID: {query_id}")
|
||||||
|
|
||||||
|
# 🔧 FIX: Get auth_context for token-bound database configuration
|
||||||
|
auth_context = None
|
||||||
|
try:
|
||||||
|
from .security import mcp_auth_context_var
|
||||||
|
auth_context = mcp_auth_context_var.get()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Switch database if specified
|
# Switch database if specified
|
||||||
# SECURITY FIX: Validate and quote db_name
|
# SECURITY FIX: Validate and quote db_name
|
||||||
if db_name:
|
if db_name:
|
||||||
@@ -405,7 +413,7 @@ class SQLAnalyzer:
|
|||||||
except SQLSecurityError as e:
|
except SQLSecurityError as e:
|
||||||
return {"success": False, "error": f"Invalid database name: {e}"}
|
return {"success": False, "error": f"Invalid database name: {e}"}
|
||||||
safe_db = quote_identifier(db_name, "database name")
|
safe_db = quote_identifier(db_name, "database name")
|
||||||
await self.connection_manager.execute_query("explain_session", f"USE {safe_db}")
|
await self.connection_manager.execute_query("explain_session", f"USE {safe_db}", None, auth_context)
|
||||||
|
|
||||||
# Construct EXPLAIN query
|
# Construct EXPLAIN query
|
||||||
explain_type = "EXPLAIN VERBOSE" if verbose else "EXPLAIN"
|
explain_type = "EXPLAIN VERBOSE" if verbose else "EXPLAIN"
|
||||||
@@ -414,7 +422,7 @@ class SQLAnalyzer:
|
|||||||
logger.info(f"Executing explain query: {explain_sql}")
|
logger.info(f"Executing explain query: {explain_sql}")
|
||||||
|
|
||||||
# Execute explain query
|
# Execute explain query
|
||||||
result = await self.connection_manager.execute_query("explain_session", explain_sql)
|
result = await self.connection_manager.execute_query("explain_session", explain_sql, None, auth_context)
|
||||||
|
|
||||||
# Format explain output
|
# Format explain output
|
||||||
explain_content = []
|
explain_content = []
|
||||||
|
|||||||
@@ -255,6 +255,13 @@ class DorisConnectionManager:
|
|||||||
self.security_manager = security_manager
|
self.security_manager = security_manager
|
||||||
self.token_manager = token_manager # Token manager for token-bound DB config
|
self.token_manager = token_manager # Token manager for token-bound DB config
|
||||||
|
|
||||||
|
# 🔧 FIX for multi-tenant concurrency: Per-token connection pool isolation
|
||||||
|
# Each token gets its own connection pool to prevent configuration conflicts
|
||||||
|
self.token_pools: Dict[str, Pool] = {} # token_hash -> pool
|
||||||
|
self.token_configs: Dict[str, dict] = {} # token_hash -> db_config
|
||||||
|
self._token_pool_locks: Dict[str, asyncio.Lock] = {} # token_hash -> lock
|
||||||
|
self._token_pools_lock = asyncio.Lock() # Lock for managing token_pools dict
|
||||||
|
|
||||||
# FIX for Issue #58 Problem 1: Disable session caching to prevent connection sharing
|
# FIX for Issue #58 Problem 1: Disable session caching to prevent connection sharing
|
||||||
# Session caching causes multiple threads to share the same MySQL connection,
|
# Session caching causes multiple threads to share the same MySQL connection,
|
||||||
# leading to race conditions and deadlocks in multi-threaded environments
|
# leading to race conditions and deadlocks in multi-threaded environments
|
||||||
@@ -276,6 +283,7 @@ class DorisConnectionManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Current active database config (may be overridden by token-bound config)
|
# Current active database config (may be overridden by token-bound config)
|
||||||
|
# NOTE: This is kept for backward compatibility with non-token requests
|
||||||
self.active_db_config = self.original_db_config.copy()
|
self.active_db_config = self.original_db_config.copy()
|
||||||
|
|
||||||
# Connection pool state management
|
# Connection pool state management
|
||||||
@@ -359,6 +367,281 @@ class DorisConnectionManager:
|
|||||||
self.logger.error(f"Error finding available token: {e}")
|
self.logger.error(f"Error finding available token: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
def _get_token_hash(self, token: str) -> str:
|
||||||
|
"""Get hash of token for use as dictionary key"""
|
||||||
|
import hashlib
|
||||||
|
return hashlib.sha256(token.encode()).hexdigest()[:16]
|
||||||
|
|
||||||
|
def _get_current_token_db_config(self, token: str) -> dict | None:
|
||||||
|
"""Get current database config for token from TokenManager
|
||||||
|
|
||||||
|
This is used to check if config has changed for hot reload support.
|
||||||
|
"""
|
||||||
|
if not self.token_manager:
|
||||||
|
return None
|
||||||
|
|
||||||
|
token_db_config = self.token_manager.get_database_config_by_token(token)
|
||||||
|
if token_db_config:
|
||||||
|
return {
|
||||||
|
'host': token_db_config.host,
|
||||||
|
'port': token_db_config.port,
|
||||||
|
'user': token_db_config.user,
|
||||||
|
'password': token_db_config.password,
|
||||||
|
'database': token_db_config.database,
|
||||||
|
'charset': token_db_config.charset
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _config_changed(self, old_config: dict, new_config: dict) -> bool:
|
||||||
|
"""Check if database configuration has changed"""
|
||||||
|
if old_config is None or new_config is None:
|
||||||
|
return old_config != new_config
|
||||||
|
|
||||||
|
# Compare key fields
|
||||||
|
for key in ['host', 'port', 'user', 'password', 'database']:
|
||||||
|
if old_config.get(key) != new_config.get(key):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_pool_for_token(self, token: str) -> tuple[Pool, dict]:
|
||||||
|
"""Get or create a dedicated connection pool for a specific token
|
||||||
|
|
||||||
|
This method implements per-token connection pool isolation to prevent
|
||||||
|
concurrent requests from different tokens interfering with each other.
|
||||||
|
|
||||||
|
🔧 FIX: Supports hot reload - if tokens.json config changes,
|
||||||
|
the old pool is closed and a new one is created automatically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: Authentication token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(pool, db_config): The dedicated pool and its configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If no valid database configuration is available
|
||||||
|
"""
|
||||||
|
token_hash = self._get_token_hash(token)
|
||||||
|
|
||||||
|
# Fast path: pool already exists
|
||||||
|
if token_hash in self.token_pools:
|
||||||
|
pool = self.token_pools[token_hash]
|
||||||
|
cached_config = self.token_configs.get(token_hash)
|
||||||
|
|
||||||
|
# 🔧 FIX: Check if config has changed (hot reload support)
|
||||||
|
current_config = self._get_current_token_db_config(token)
|
||||||
|
if current_config and cached_config and self._config_changed(cached_config, current_config):
|
||||||
|
self.logger.info(f"Token config changed (hash: {token_hash[:8]}...), recreating pool...")
|
||||||
|
# Config changed, need to recreate pool
|
||||||
|
async with self._token_pools_lock:
|
||||||
|
# Close old pool
|
||||||
|
old_pool = self.token_pools.pop(token_hash, None)
|
||||||
|
if old_pool and not old_pool.closed:
|
||||||
|
try:
|
||||||
|
old_pool.close()
|
||||||
|
await asyncio.wait_for(old_pool.wait_closed(), timeout=2.0)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Error closing old pool during hot reload: {e}")
|
||||||
|
self.token_configs.pop(token_hash, None)
|
||||||
|
# Continue to slow path to create new pool
|
||||||
|
elif pool and not pool.closed:
|
||||||
|
return pool, cached_config
|
||||||
|
|
||||||
|
# Slow path: need to create pool (with lock to prevent race conditions)
|
||||||
|
async with self._token_pools_lock:
|
||||||
|
# Double-check after acquiring lock
|
||||||
|
if token_hash in self.token_pools:
|
||||||
|
pool = self.token_pools[token_hash]
|
||||||
|
if pool and not pool.closed:
|
||||||
|
return pool, self.token_configs[token_hash]
|
||||||
|
|
||||||
|
# Get database config for this token
|
||||||
|
db_config = None
|
||||||
|
config_source = "unknown"
|
||||||
|
|
||||||
|
if self.token_manager:
|
||||||
|
token_db_config = self.token_manager.get_database_config_by_token(token)
|
||||||
|
if token_db_config:
|
||||||
|
db_config = {
|
||||||
|
'host': token_db_config.host,
|
||||||
|
'port': token_db_config.port,
|
||||||
|
'user': token_db_config.user,
|
||||||
|
'password': token_db_config.password,
|
||||||
|
'database': token_db_config.database,
|
||||||
|
'charset': token_db_config.charset
|
||||||
|
}
|
||||||
|
config_source = "token-bound"
|
||||||
|
|
||||||
|
# Fallback to global config if token has no specific config
|
||||||
|
if not db_config or self._is_config_empty(db_config.get('host')) or self._is_config_empty(db_config.get('user')):
|
||||||
|
if self._has_valid_global_config():
|
||||||
|
db_config = self.original_db_config.copy()
|
||||||
|
config_source = "global-env"
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No valid database configuration available for token. "
|
||||||
|
f"Please configure database in tokens.json or .env file."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create dedicated pool for this token
|
||||||
|
self.logger.info(f"Creating dedicated connection pool for token (hash: {token_hash[:8]}...) "
|
||||||
|
f"using {config_source} config: {db_config['user']}@{db_config['host']}:{db_config['port']}")
|
||||||
|
|
||||||
|
pool = await self._create_pool_with_config(db_config)
|
||||||
|
|
||||||
|
# Store pool and config
|
||||||
|
self.token_pools[token_hash] = pool
|
||||||
|
self.token_configs[token_hash] = db_config
|
||||||
|
|
||||||
|
# Create lock for this token if not exists
|
||||||
|
if token_hash not in self._token_pool_locks:
|
||||||
|
self._token_pool_locks[token_hash] = asyncio.Lock()
|
||||||
|
|
||||||
|
return pool, db_config
|
||||||
|
|
||||||
|
async def _create_pool_with_config(self, db_config: dict) -> Pool:
|
||||||
|
"""Create a connection pool with specified configuration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_config: Database configuration dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created connection pool
|
||||||
|
"""
|
||||||
|
# Convert charset to aiomysql compatible format
|
||||||
|
charset_map = {"UTF8": "utf8", "UTF8MB4": "utf8mb4"}
|
||||||
|
charset = charset_map.get(db_config['charset'].upper(), db_config['charset'].lower())
|
||||||
|
|
||||||
|
self.logger.debug(f"Creating pool for {db_config['user']}@{db_config['host']}:{db_config['port']}/{db_config['database']}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
pool = await asyncio.wait_for(
|
||||||
|
aiomysql.create_pool(
|
||||||
|
host=db_config['host'],
|
||||||
|
port=db_config['port'],
|
||||||
|
user=db_config['user'],
|
||||||
|
password=db_config['password'],
|
||||||
|
db=db_config['database'],
|
||||||
|
charset=charset,
|
||||||
|
minsize=0, # Don't pre-create connections
|
||||||
|
maxsize=self.maxsize,
|
||||||
|
connect_timeout=self.connect_timeout,
|
||||||
|
autocommit=True,
|
||||||
|
pool_recycle=self.pool_recycle
|
||||||
|
),
|
||||||
|
timeout=self.connect_timeout + 5 # Give extra time for pool creation
|
||||||
|
)
|
||||||
|
self.logger.info(f"Successfully created pool for {db_config['user']}@{db_config['host']}:{db_config['port']}")
|
||||||
|
return pool
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self.logger.error(f"Timeout creating pool for {db_config['user']}@{db_config['host']}:{db_config['port']}")
|
||||||
|
raise RuntimeError(f"Timeout creating connection pool for {db_config['user']}@{db_config['host']}:{db_config['port']}")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to create pool for {db_config['user']}@{db_config['host']}:{db_config['port']}: {type(e).__name__}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_connection_for_token(self, token: str, session_id: str) -> 'DorisConnection':
|
||||||
|
"""Get a connection from the token's dedicated pool
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: Authentication token
|
||||||
|
session_id: Session identifier for logging
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DorisConnection wrapper
|
||||||
|
"""
|
||||||
|
pool, db_config = await self.get_pool_for_token(token)
|
||||||
|
|
||||||
|
try:
|
||||||
|
connection = await asyncio.wait_for(
|
||||||
|
pool.acquire(),
|
||||||
|
timeout=self.connect_timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.debug(f"Session {session_id}: Acquired connection from token pool "
|
||||||
|
f"(user: {db_config['user']}@{db_config['host']})")
|
||||||
|
|
||||||
|
return DorisConnection(connection, session_id, self.security_manager)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Session {session_id}: Failed to acquire connection from token pool: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def release_connection_for_token(self, token: str, connection: 'DorisConnection'):
|
||||||
|
"""Release a connection back to the token's dedicated pool
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: Authentication token
|
||||||
|
connection: DorisConnection wrapper to release
|
||||||
|
"""
|
||||||
|
token_hash = self._get_token_hash(token)
|
||||||
|
|
||||||
|
if token_hash in self.token_pools:
|
||||||
|
pool = self.token_pools[token_hash]
|
||||||
|
if pool and not pool.closed:
|
||||||
|
try:
|
||||||
|
pool.release(connection.connection)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Failed to release connection to token pool: {e}")
|
||||||
|
|
||||||
|
async def cleanup_token_pools(self, max_idle_time: int = 3600):
|
||||||
|
"""Clean up idle token connection pools
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_idle_time: Maximum idle time in seconds before closing a pool
|
||||||
|
"""
|
||||||
|
async with self._token_pools_lock:
|
||||||
|
pools_to_remove = []
|
||||||
|
|
||||||
|
for token_hash, pool in self.token_pools.items():
|
||||||
|
if pool and not pool.closed:
|
||||||
|
# Check if pool is idle (no active connections)
|
||||||
|
if pool.size == 0 and pool.freesize == 0:
|
||||||
|
pools_to_remove.append(token_hash)
|
||||||
|
elif pool and pool.closed:
|
||||||
|
pools_to_remove.append(token_hash)
|
||||||
|
|
||||||
|
for token_hash in pools_to_remove:
|
||||||
|
try:
|
||||||
|
pool = self.token_pools.pop(token_hash, None)
|
||||||
|
if pool and not pool.closed:
|
||||||
|
pool.close()
|
||||||
|
await pool.wait_closed()
|
||||||
|
self.token_configs.pop(token_hash, None)
|
||||||
|
self._token_pool_locks.pop(token_hash, None)
|
||||||
|
self.logger.info(f"Cleaned up idle token pool (hash: {token_hash[:8]}...)")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Error cleaning up token pool: {e}")
|
||||||
|
|
||||||
|
async def close_all_token_pools(self):
|
||||||
|
"""Close all token connection pools (for shutdown)"""
|
||||||
|
# Use timeout to prevent blocking on lock acquisition during shutdown
|
||||||
|
try:
|
||||||
|
async with asyncio.timeout(5): # 5 second timeout for lock
|
||||||
|
async with self._token_pools_lock:
|
||||||
|
for token_hash, pool in list(self.token_pools.items()):
|
||||||
|
try:
|
||||||
|
if pool and not pool.closed:
|
||||||
|
pool.close()
|
||||||
|
# Use timeout for wait_closed to prevent hanging
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(pool.wait_closed(), timeout=2.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self.logger.warning(f"Timeout waiting for token pool to close (hash: {token_hash[:8]}...)")
|
||||||
|
self.logger.info(f"Closed token pool (hash: {token_hash[:8]}...)")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Error closing token pool: {e}")
|
||||||
|
|
||||||
|
self.token_pools.clear()
|
||||||
|
self.token_configs.clear()
|
||||||
|
self._token_pool_locks.clear()
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self.logger.warning("Timeout acquiring lock for token pool cleanup, forcing clear")
|
||||||
|
# Force clear without lock
|
||||||
|
self.token_pools.clear()
|
||||||
|
self.token_configs.clear()
|
||||||
|
self._token_pool_locks.clear()
|
||||||
|
|
||||||
async def configure_for_token(self, token: str) -> tuple[bool, str]:
|
async def configure_for_token(self, token: str) -> tuple[bool, str]:
|
||||||
"""Configure connection manager for token with new priority logic
|
"""Configure connection manager for token with new priority logic
|
||||||
|
|
||||||
@@ -1092,7 +1375,26 @@ class DorisConnectionManager:
|
|||||||
|
|
||||||
Uses only semaphore to prevent too many concurrent acquisitions.
|
Uses only semaphore to prevent too many concurrent acquisitions.
|
||||||
If the connection is successfully obtained, it will be added to the connection pool cache.
|
If the connection is successfully obtained, it will be added to the connection pool cache.
|
||||||
|
|
||||||
|
🔧 FIX for token isolation: Now automatically checks for auth_context from ContextVar
|
||||||
|
and uses token-specific connection pool if available.
|
||||||
"""
|
"""
|
||||||
|
# 🔧 FIX: Check for auth_context from global ContextVar
|
||||||
|
# This ensures all tools using get_connection respect token-bound database configuration
|
||||||
|
auth_context = None
|
||||||
|
try:
|
||||||
|
from .security import mcp_auth_context_var
|
||||||
|
auth_context = mcp_auth_context_var.get()
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.debug(f"get_connection: Could not get auth_context: {e}")
|
||||||
|
|
||||||
|
if auth_context and hasattr(auth_context, 'token') and auth_context.token:
|
||||||
|
# Use token-specific connection pool
|
||||||
|
# SECURITY: Do NOT catch exceptions here - if token pool fails, don't fallback to global pool
|
||||||
|
# This prevents privilege escalation
|
||||||
|
self.logger.debug(f"get_connection: Using token-specific pool for session {session_id}")
|
||||||
|
return await self.get_connection_for_token(auth_context.token, session_id)
|
||||||
|
|
||||||
cached_conn = self.session_cache.get(session_id)
|
cached_conn = self.session_cache.get(session_id)
|
||||||
if cached_conn:
|
if cached_conn:
|
||||||
return cached_conn
|
return cached_conn
|
||||||
@@ -1239,10 +1541,16 @@ class DorisConnectionManager:
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Close connection pool
|
# 🔧 FIX: Close all per-token connection pools
|
||||||
|
await self.close_all_token_pools()
|
||||||
|
|
||||||
|
# Close global connection pool with timeout
|
||||||
if self.pool:
|
if self.pool:
|
||||||
self.pool.close()
|
self.pool.close()
|
||||||
await self.pool.wait_closed()
|
try:
|
||||||
|
await asyncio.wait_for(self.pool.wait_closed(), timeout=5.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self.logger.warning("Timeout waiting for global pool to close")
|
||||||
|
|
||||||
self.logger.info("Connection manager closed successfully")
|
self.logger.info("Connection manager closed successfully")
|
||||||
|
|
||||||
@@ -1271,37 +1579,43 @@ class DorisConnectionManager:
|
|||||||
async def execute_query(
|
async def execute_query(
|
||||||
self, session_id: str, sql: str, params: tuple | None = None, auth_context=None
|
self, session_id: str, sql: str, params: tuple | None = None, auth_context=None
|
||||||
) -> QueryResult:
|
) -> QueryResult:
|
||||||
"""Execute query - Simplified Strategy with automatic connection management
|
"""Execute query - Enhanced Strategy with per-token connection pool isolation
|
||||||
|
|
||||||
FIX for Issue #62 Bug 1: Configure token-bound database before query execution
|
FIX for multi-tenant concurrency: Each token now uses its own dedicated connection pool
|
||||||
|
to prevent configuration conflicts between concurrent requests from different tokens.
|
||||||
"""
|
"""
|
||||||
connection = None
|
connection = None
|
||||||
try:
|
token = None
|
||||||
# FIX: Configure database for token BEFORE getting connection
|
|
||||||
# This ensures token-bound database configuration is used instead of global config
|
|
||||||
if auth_context and hasattr(auth_context, 'token') and auth_context.token:
|
|
||||||
try:
|
|
||||||
success, config_source = await self.configure_for_token(auth_context.token)
|
|
||||||
if success:
|
|
||||||
self.logger.info(f"Session {session_id}: Using {config_source} database configuration")
|
|
||||||
else:
|
|
||||||
self.logger.warning(f"Session {session_id}: Token configuration failed, may use global config")
|
|
||||||
except Exception as token_config_error:
|
|
||||||
# SECURITY: If token should have config but configuration fails, don't fallback
|
|
||||||
# This prevents privilege escalation (using high-privilege default user)
|
|
||||||
if self.token_manager:
|
|
||||||
self.logger.error(f"Session {session_id}: Token database configuration failed: {token_config_error}")
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Failed to configure database for authenticated token. "
|
|
||||||
f"This is a security measure to prevent using default high-privilege credentials. "
|
|
||||||
f"Error: {token_config_error}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# No token manager, can use global config
|
|
||||||
self.logger.warning(f"Session {session_id}: No token manager, using global config")
|
|
||||||
|
|
||||||
# Always get fresh connection from pool (with configured database)
|
try:
|
||||||
connection = await self.get_connection(session_id)
|
# Check if we have a token for per-token pool isolation
|
||||||
|
if auth_context and hasattr(auth_context, 'token') and auth_context.token:
|
||||||
|
token = auth_context.token
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 🔧 FIX: Use dedicated connection pool for this token
|
||||||
|
# This prevents concurrent requests from different tokens interfering
|
||||||
|
connection = await self.get_connection_for_token(token, session_id)
|
||||||
|
|
||||||
|
# Get the config for logging
|
||||||
|
token_hash = self._get_token_hash(token)
|
||||||
|
if token_hash in self.token_configs:
|
||||||
|
db_config = self.token_configs[token_hash]
|
||||||
|
self.logger.info(f"Session {session_id}: Using dedicated pool for {db_config['user']}@{db_config['host']}")
|
||||||
|
|
||||||
|
except Exception as token_pool_error:
|
||||||
|
# SECURITY: If token should have pool but creation fails, don't fallback
|
||||||
|
# This prevents privilege escalation (using high-privilege default user)
|
||||||
|
self.logger.error(f"Session {session_id}: Token pool error: {token_pool_error}")
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to get connection for authenticated token. "
|
||||||
|
f"This is a security measure to prevent using default high-privilege credentials. "
|
||||||
|
f"Error: {token_pool_error}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No token - use global pool (backward compatibility)
|
||||||
|
self.logger.debug(f"Session {session_id}: No token, using global connection pool")
|
||||||
|
connection = await self.get_connection(session_id)
|
||||||
|
|
||||||
# Execute query
|
# Execute query
|
||||||
result = await connection.execute(sql, params, auth_context)
|
result = await connection.execute(sql, params, auth_context)
|
||||||
@@ -1312,9 +1626,12 @@ class DorisConnectionManager:
|
|||||||
self.logger.error(f"Query execution failed for session {session_id}: {e}")
|
self.logger.error(f"Query execution failed for session {session_id}: {e}")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
# Always release connection back to pool
|
# Always release connection back to the appropriate pool
|
||||||
if connection:
|
if connection:
|
||||||
await self.release_connection(session_id, connection)
|
if token:
|
||||||
|
await self.release_connection_for_token(token, connection)
|
||||||
|
else:
|
||||||
|
await self.release_connection(session_id, connection)
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def get_connection_context(self, session_id: str):
|
async def get_connection_context(self, session_id: str):
|
||||||
|
|||||||
@@ -639,7 +639,13 @@ class DorisQueryExecutor:
|
|||||||
if hasattr(self.connection_manager, 'config') and hasattr(self.connection_manager.config, 'security'):
|
if hasattr(self.connection_manager, 'config') and hasattr(self.connection_manager.config, 'security'):
|
||||||
if self.connection_manager.config.security.enable_security_check:
|
if self.connection_manager.config.security.enable_security_check:
|
||||||
try:
|
try:
|
||||||
security_manager = DorisSecurityManager(self.connection_manager.config)
|
# 🔧 FIX: Use existing security_manager to avoid creating multiple TokenManager instances
|
||||||
|
# Creating new DorisSecurityManager each time causes multiple hot reload monitors
|
||||||
|
security_manager = getattr(self.connection_manager, 'security_manager', None)
|
||||||
|
if not security_manager:
|
||||||
|
# Fallback: create new one only if not available (should rarely happen)
|
||||||
|
self.logger.warning("No existing security_manager, creating new instance")
|
||||||
|
security_manager = DorisSecurityManager(self.connection_manager.config)
|
||||||
validation_result = await security_manager.validate_sql_security(sql, auth_context)
|
validation_result = await security_manager.validate_sql_security(sql, auth_context)
|
||||||
|
|
||||||
if not validation_result.is_valid:
|
if not validation_result.is_valid:
|
||||||
|
|||||||
@@ -1194,8 +1194,17 @@ class MetadataExtractor:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if self.connection_manager:
|
if self.connection_manager:
|
||||||
|
# FIX: Get auth_context from global ContextVar for token-bound database configuration
|
||||||
|
# This ensures all query methods use the correct user's connection pool
|
||||||
|
auth_context = None
|
||||||
|
try:
|
||||||
|
from .security import mcp_auth_context_var
|
||||||
|
auth_context = mcp_auth_context_var.get()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Use the injected connection manager directly (async)
|
# Use the injected connection manager directly (async)
|
||||||
result = await self.connection_manager.execute_query(self._session_id, query, None)
|
result = await self.connection_manager.execute_query(self._session_id, query, None, auth_context)
|
||||||
|
|
||||||
# Extract data from QueryResult
|
# Extract data from QueryResult
|
||||||
if hasattr(result, 'data'):
|
if hasattr(result, 'data'):
|
||||||
@@ -1644,15 +1653,14 @@ class MetadataExtractor:
|
|||||||
|
|
||||||
# FIX: Try to get auth_context from context variable (set by HTTP middleware)
|
# FIX: Try to get auth_context from context variable (set by HTTP middleware)
|
||||||
# This allows token-bound database configuration to work
|
# This allows token-bound database configuration to work
|
||||||
|
# CRITICAL: Use the global ContextVar from security.py to ensure same instance is used everywhere
|
||||||
auth_context = None
|
auth_context = None
|
||||||
try:
|
try:
|
||||||
from contextvars import ContextVar
|
from .security import mcp_auth_context_var
|
||||||
from .security import AuthContext
|
|
||||||
|
|
||||||
# Try to get auth_context from context variable
|
# Get auth_context from the global context variable
|
||||||
# This will be set by the HTTP request handler in main.py
|
# This will be set by the HTTP request handler in main.py
|
||||||
auth_context_var: ContextVar = ContextVar('mcp_auth_context', default=None)
|
auth_context = mcp_auth_context_var.get()
|
||||||
auth_context = auth_context_var.get()
|
|
||||||
|
|
||||||
if auth_context:
|
if auth_context:
|
||||||
logger.debug(f"Retrieved auth_context from context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
|
logger.debug(f"Retrieved auth_context from context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ Implements enterprise-level authentication, authorization, SQL security validati
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
from contextvars import ContextVar
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -34,6 +35,10 @@ from sqlparse.tokens import Keyword, Name
|
|||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
from .config import DatabaseConfig
|
from .config import DatabaseConfig
|
||||||
|
|
||||||
|
# Global ContextVar for auth_context - must be a single instance shared across all modules
|
||||||
|
# This allows token-bound database configuration to work correctly in concurrent requests
|
||||||
|
mcp_auth_context_var: ContextVar['AuthContext'] = ContextVar('mcp_auth_context', default=None)
|
||||||
|
|
||||||
|
|
||||||
class SecurityLevel(Enum):
|
class SecurityLevel(Enum):
|
||||||
"""Security level enumeration"""
|
"""Security level enumeration"""
|
||||||
|
|||||||
Reference in New Issue
Block a user