[Performance]Add a controllable MCP Server DB Pool permission authentication system (#53)

* 0.5.1 Version

* fix 0.5.1 schema async bug

* fix security bug

* fix security bug

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add a controllable MCP Server DB Pool permission authentication system, connect it with the Doris permission system, and provide it to enterprise-level applications concurrently with the multi-Worker mode.
This commit is contained in:
Yijia Su
2025-09-02 18:40:48 +08:00
committed by GitHub
parent c3d487ccdd
commit f99399c6c7
10 changed files with 636 additions and 47 deletions

View File

@@ -117,13 +117,15 @@ class AuthMiddleware:
# Build authentication context
auth_context = AuthContext(
token_id=payload.get('jti', ''),
user_id=payload.get('sub'),
roles=payload.get('roles', []),
permissions=payload.get('permissions', []),
security_level=SecurityLevel(payload.get('security_level', 'internal')),
session_id=payload.get('jti'), # Use JWT ID as session ID
login_time=datetime.fromtimestamp(payload.get('iat', 0)),
last_activity=datetime.utcnow(),
security_level=SecurityLevel(payload.get('security_level', 'internal'))
token=token # Store raw token for token-bound database configuration
)
logger.info(f"JWT authentication successful for user: {auth_context.user_id}")

View File

@@ -191,13 +191,15 @@ class OAuthAuthenticationProvider:
session_id = f"oauth_{user_info.sub}_{datetime.utcnow().timestamp()}"
return AuthContext(
token_id=f"oauth_{user_info.sub}",
user_id=user_info.sub,
roles=user_info.roles,
permissions=permissions,
security_level=security_level,
session_id=session_id,
login_time=datetime.utcnow(),
last_activity=datetime.utcnow(),
security_level=security_level
token="" # OAuth doesn't have raw token, use empty string
)
async def _determine_security_level(self, user_info: OAuthUserInfo) -> SecurityLevel:

View File

@@ -11,17 +11,32 @@ import json
import os
import secrets
import time
import asyncio
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from pathlib import Path
from ..utils.logger import get_logger
from ..utils.security import SecurityLevel
@dataclass
class DatabaseConfig:
"""Database connection configuration for token binding"""
host: str
port: int = 9030
user: str = ""
password: str = ""
database: str = "information_schema"
charset: str = "UTF8"
fe_http_port: int = 8030
@dataclass
class TokenInfo:
"""Token information structure"""
"""Token information structure with optional database binding"""
token_id: str # Unique token identifier for audit and management
created_at: datetime = field(default_factory=datetime.utcnow)
@@ -29,6 +44,7 @@ class TokenInfo:
last_used: Optional[datetime] = None
description: str = "" # Optional description for token purpose
is_active: bool = True
database_config: Optional[DatabaseConfig] = None # Optional database binding
@dataclass
@@ -65,13 +81,23 @@ class TokenManager:
self.default_token_expiry_hours = getattr(config.security, 'default_token_expiry_hours', 24 * 30) # 30 days
self.token_hash_algorithm = getattr(config.security, 'token_hash_algorithm', 'sha256')
# Hot reload configuration
self.enable_hot_reload = True
self.hot_reload_interval = 10 # Check every 10 seconds
self._file_last_modified = 0
self._hot_reload_task = None
# Initialize with default tokens if none exist
self._initialize_default_tokens()
# Load tokens from configuration
self._load_tokens()
self.logger.info(f"TokenManager initialized with {len(self._tokens)} tokens")
# Start hot reload monitoring
if self.enable_hot_reload:
self._start_hot_reload()
self.logger.info(f"TokenManager initialized with {len(self._tokens)} tokens, hot reload: {self.enable_hot_reload}")
def _initialize_default_tokens(self):
"""Initialize default tokens for basic authentication (configurable via environment)"""
@@ -132,7 +158,7 @@ class TokenManager:
self.logger.info("Skipped default tokens initialization (custom tokens detected)")
def _add_token_from_config(self, token_config: Dict[str, Any]):
"""Add token from configuration"""
"""Add token from configuration with optional database binding"""
try:
# Calculate expiration time
expires_at = None
@@ -141,12 +167,27 @@ class TokenManager:
if expires_hours is not None:
expires_at = datetime.utcnow() + timedelta(hours=expires_hours)
# Parse database configuration if provided
database_config = None
if 'database_config' in token_config:
db_config = token_config['database_config']
database_config = DatabaseConfig(
host=db_config.get('host', 'localhost'),
port=db_config.get('port', 9030),
user=db_config.get('user', 'root'),
password=db_config.get('password', ''),
database=db_config.get('database', 'information_schema'),
charset=db_config.get('charset', 'UTF8'),
fe_http_port=db_config.get('fe_http_port', 8030)
)
# Create token info
token_info = TokenInfo(
token_id=token_config['token_id'],
expires_at=expires_at,
description=token_config.get('description', ''),
is_active=token_config.get('is_active', True)
is_active=token_config.get('is_active', True),
database_config=database_config
)
# Hash the token
@@ -157,10 +198,12 @@ class TokenManager:
self._tokens[token_hash] = token_info
self._token_ids[token_info.token_id] = token_hash
self.logger.debug(f"Added token '{token_info.token_id}'")
db_info = f" with DB binding ({database_config.host})" if database_config else ""
self.logger.debug(f"Added token '{token_info.token_id}'{db_info}")
except Exception as e:
self.logger.error(f"Failed to add token from config: {e}")
raise
def _load_tokens(self):
"""Load tokens from configuration sources"""
@@ -378,7 +421,7 @@ class TokenManager:
tokens = []
for token_hash, token_info in self._tokens.items():
tokens.append({
token_data = {
'token_id': token_info.token_id,
'created_at': token_info.created_at.isoformat(),
'expires_at': token_info.expires_at.isoformat() if token_info.expires_at else None,
@@ -386,7 +429,21 @@ class TokenManager:
'is_active': token_info.is_active,
'description': token_info.description,
'is_expired': token_info.expires_at and datetime.utcnow() > token_info.expires_at if token_info.expires_at else False
})
}
# Add database binding info (without sensitive data)
if token_info.database_config:
token_data['database_binding'] = {
'host': token_info.database_config.host,
'port': token_info.database_config.port,
'user': token_info.database_config.user,
'database': token_info.database_config.database,
'has_password': bool(token_info.database_config.password)
}
else:
token_data['database_binding'] = None
tokens.append(token_data)
# Sort by creation time
tokens.sort(key=lambda x: x['created_at'], reverse=True)
@@ -439,6 +496,32 @@ class TokenManager:
self.logger.error(f"Failed to save tokens to file: {e}")
return False
def get_database_config_by_token(self, token: str) -> Optional[DatabaseConfig]:
"""Get database configuration bound to a token
Args:
token: The raw token string
Returns:
DatabaseConfig if token exists and has database binding, None otherwise
"""
try:
token_hash = self._hash_token(token)
token_info = self._tokens.get(token_hash)
if not token_info or not token_info.is_active:
return None
# Check expiration
if token_info.expires_at and datetime.utcnow() > token_info.expires_at:
return None
return token_info.database_config
except Exception as e:
self.logger.error(f"Failed to get database config for token: {e}")
return None
def get_token_stats(self) -> Dict[str, Any]:
"""Get token statistics"""
now = datetime.utcnow()
@@ -446,11 +529,89 @@ class TokenManager:
active_tokens = sum(1 for info in self._tokens.values() if info.is_active)
expired_tokens = sum(1 for info in self._tokens.values()
if info.expires_at and now > info.expires_at)
tokens_with_db = sum(1 for info in self._tokens.values()
if info.database_config is not None)
return {
'total_tokens': total_tokens,
'active_tokens': active_tokens,
'expired_tokens': expired_tokens,
'tokens_with_database_binding': tokens_with_db,
'expiry_enabled': self.enable_token_expiry,
'default_expiry_hours': self.default_token_expiry_hours
}
'default_expiry_hours': self.default_token_expiry_hours,
'hot_reload_enabled': self.enable_hot_reload,
'last_file_check': datetime.fromtimestamp(self._file_last_modified).isoformat() if self._file_last_modified else None
}
def _start_hot_reload(self):
"""Start hot reload monitoring task"""
if self._hot_reload_task:
return # Already running
# Update initial file modification time
self._update_file_modified_time()
# Start monitoring task
self._hot_reload_task = asyncio.create_task(self._hot_reload_monitor())
self.logger.info(f"Started hot reload monitoring for {self.token_file_path}")
def stop_hot_reload(self):
"""Stop hot reload monitoring"""
if self._hot_reload_task:
self._hot_reload_task.cancel()
self._hot_reload_task = None
self.logger.info("Stopped hot reload monitoring")
def _update_file_modified_time(self):
"""Update the last modified time of tokens file"""
try:
if os.path.exists(self.token_file_path):
self._file_last_modified = os.path.getmtime(self.token_file_path)
except Exception as e:
self.logger.debug(f"Failed to get file modification time: {e}")
async def _hot_reload_monitor(self):
"""Background task to monitor tokens.json file changes"""
while True:
try:
await asyncio.sleep(self.hot_reload_interval)
if not os.path.exists(self.token_file_path):
continue
# Check if file was modified
current_mtime = os.path.getmtime(self.token_file_path)
if current_mtime > self._file_last_modified:
self.logger.info(f"Detected changes in {self.token_file_path}, reloading tokens...")
try:
# Backup current tokens
old_tokens = self._tokens.copy()
old_token_ids = self._token_ids.copy()
# Clear and reload
self._tokens.clear()
self._token_ids.clear()
# Reinitialize default tokens
self._initialize_default_tokens()
# Load from file
self._load_tokens_from_file()
# Update modification time
self._file_last_modified = current_mtime
self.logger.info(f"Hot reload completed, {len(self._tokens)} tokens loaded")
except Exception as reload_error:
# Restore backup on failure
self.logger.error(f"Hot reload failed, restoring previous tokens: {reload_error}")
self._tokens = old_tokens
self._token_ids = old_token_ids
except asyncio.CancelledError:
self.logger.info("Hot reload monitor stopped")
break
except Exception as e:
self.logger.error(f"Error in hot reload monitor: {e}")

View File

@@ -230,11 +230,15 @@ class DorisServer:
self.config = config
self.server = Server("doris-mcp-server")
# Initialize security manager
# Initialize security manager (without connection_manager initially)
self.security_manager = DorisSecurityManager(config)
# Initialize connection manager, pass in security manager
self.connection_manager = DorisConnectionManager(config, self.security_manager)
# Initialize connection manager, pass in security manager and token manager for token-bound DB config
token_manager = self.security_manager.auth_provider.token_manager if hasattr(self.security_manager, 'auth_provider') and hasattr(self.security_manager.auth_provider, 'token_manager') else None
self.connection_manager = DorisConnectionManager(config, self.security_manager, token_manager)
# Set connection manager reference in security manager for database validation
self.security_manager.connection_manager = self.connection_manager
# Initialize independent managers
self.resources_manager = DorisResourcesManager(self.connection_manager)
@@ -785,14 +789,14 @@ Examples:
parser.add_argument(
"--port",
type=int,
default=os.getenv("SERVER_PORT", _default_config.server_port),
help=f"Port number for HTTP mode (default: {_default_config.server_port})"
default=3000,
help="Port number for HTTP mode (default: 3000)"
)
parser.add_argument(
"--workers",
type=int,
default=int(os.getenv("WORKERS", "1")),
default=1,
help="Number of worker processes for HTTP mode (default: 1, use 0 for auto-detect CPU cores)"
)
@@ -804,7 +808,7 @@ Examples:
)
parser.add_argument(
"--doris-port", "--db-port", type=int, default=os.getenv("DORIS_PORT", _default_config.database.port), help=f"Doris database port number (default: {_default_config.database.port})"
"--doris-port", "--db-port", type=int, default=9030, help="Doris database port number (default: 9030)"
)
parser.add_argument(

View File

@@ -270,8 +270,13 @@ async def initialize_worker():
await _worker_security_manager.initialize()
logger.info(f"Worker {os.getpid()} security manager initialization completed")
# Create connection manager
_worker_connection_manager = DorisConnectionManager(config, _worker_security_manager)
# Create connection manager with token manager for token-bound DB config
token_manager = _worker_security_manager.auth_provider.token_manager if hasattr(_worker_security_manager, 'auth_provider') and hasattr(_worker_security_manager.auth_provider, 'token_manager') else None
_worker_connection_manager = DorisConnectionManager(config, _worker_security_manager, token_manager)
# Set connection manager reference in security manager for database validation
_worker_security_manager.connection_manager = _worker_connection_manager
await _worker_connection_manager.initialize()
# Create MCP server

View File

@@ -372,19 +372,34 @@ class DorisConfig:
config = cls()
# Database configuration
config.database.host = os.getenv("DORIS_HOST", config.database.host)
config.database.port = int(os.getenv("DORIS_PORT", str(config.database.port)))
config.database.user = os.getenv("DORIS_USER", config.database.user)
config.database.password = os.getenv("DORIS_PASSWORD", config.database.password)
config.database.database = os.getenv("DORIS_DATABASE", config.database.database)
config.database.fe_http_port = int(os.getenv("DORIS_FE_HTTP_PORT", str(config.database.fe_http_port)))
# Database configuration - handle empty strings properly
doris_host = os.getenv("DORIS_HOST", "").strip()
config.database.host = doris_host if doris_host else config.database.host
doris_port = os.getenv("DORIS_PORT", "").strip()
if doris_port and doris_port.isdigit():
config.database.port = int(doris_port)
doris_user = os.getenv("DORIS_USER", "").strip()
config.database.user = doris_user if doris_user else config.database.user
doris_password = os.getenv("DORIS_PASSWORD", "")
config.database.password = doris_password if doris_password else config.database.password
doris_database = os.getenv("DORIS_DATABASE", "").strip()
config.database.database = doris_database if doris_database else config.database.database
doris_fe_http_port = os.getenv("DORIS_FE_HTTP_PORT", "").strip()
if doris_fe_http_port and doris_fe_http_port.isdigit():
config.database.fe_http_port = int(doris_fe_http_port)
# BE nodes configuration
be_hosts_env = os.getenv("DORIS_BE_HOSTS", "")
if be_hosts_env:
config.database.be_hosts = [host.strip() for host in be_hosts_env.split(",") if host.strip()]
config.database.be_webserver_port = int(os.getenv("DORIS_BE_WEBSERVER_PORT", str(config.database.be_webserver_port)))
be_webserver_port = os.getenv("DORIS_BE_WEBSERVER_PORT", "").strip()
if be_webserver_port and be_webserver_port.isdigit():
config.database.be_webserver_port = int(be_webserver_port)
# Arrow Flight SQL Configuration
fe_arrow_port_env = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
@@ -557,7 +572,9 @@ class DorisConfig:
# Server configuration
config.server_name = os.getenv("SERVER_NAME", config.server_name)
config.server_version = os.getenv("SERVER_VERSION", config.server_version)
config.server_port = int(os.getenv("SERVER_PORT", str(config.server_port)))
server_port = os.getenv("SERVER_PORT", "").strip()
if server_port and server_port.isdigit():
config.server_port = int(server_port)
config.temp_files_dir = os.getenv("TEMP_FILES_DIR", config.temp_files_dir)
return config

View File

@@ -240,15 +240,30 @@ class DorisConnectionManager:
Uses direct connection pool management with proper synchronization
Implements connection pool health monitoring and proactive cleanup
Supports token-bound database configurations for multi-tenant access
"""
def __init__(self, config, security_manager=None):
def __init__(self, config, security_manager=None, token_manager=None):
self.config = config
self.pool: Pool | None = None
self.logger = get_logger(__name__)
self.security_manager = security_manager
self.token_manager = token_manager # Token manager for token-bound DB config
self.session_cache = DorisSessionCache(self)
# Store original database config for fallback
self.original_db_config = {
'host': config.database.host,
'port': config.database.port,
'user': config.database.user,
'password': config.database.password,
'database': config.database.database,
'charset': config.database.charset
}
# Current active database config (may be overridden by token-bound config)
self.active_db_config = self.original_db_config.copy()
# Connection pool state management
self.pool_recovering = False
@@ -267,14 +282,7 @@ class DorisConnectionManager:
# Database connection parameters from config.database
self.pool_recovery_lock = self._recovery_lock # Compatibility alias
self.host = config.database.host
self.port = config.database.port
self.user = config.database.user
self.password = config.database.password
self.database = config.database.database
# Convert charset to aiomysql compatible format
charset_map = {"UTF8": "utf8", "UTF8MB4": "utf8mb4"}
self.charset = charset_map.get(config.database.charset.upper(), config.database.charset.lower())
self._update_db_params_from_config(self.active_db_config)
self.connect_timeout = config.database.connection_timeout
# Connection pool parameters - more conservative settings
@@ -285,12 +293,307 @@ class DorisConnectionManager:
# 🔧 FIX: Add missing monitoring parameters that were removed during refactoring
self.health_check_interval = 30 # seconds
self.pool_warmup_size = 3 # connections to maintain
def _update_db_params_from_config(self, db_config: dict):
"""Update database connection parameters from config dictionary"""
self.host = db_config['host']
self.port = db_config['port']
self.user = db_config['user']
self.password = db_config['password']
self.database = db_config['database']
# Convert charset to aiomysql compatible format
charset_map = {"UTF8": "utf8", "UTF8MB4": "utf8mb4"}
self.charset = charset_map.get(db_config['charset'].upper(), db_config['charset'].lower())
def _is_config_empty(self, config_value) -> bool:
"""Check if a config value is empty (None, empty string, or 'null')"""
return config_value is None or config_value == '' or str(config_value).lower() == 'null'
def _has_valid_global_config(self) -> bool:
"""Check if global database configuration is valid and non-empty"""
return (not self._is_config_empty(self.original_db_config['host']) and
not self._is_config_empty(self.original_db_config['user']))
def _find_available_token_with_db_config(self) -> str:
"""Find the first available token with database configuration
Returns:
Raw token string if found, empty string if not found
"""
if not self.token_manager:
return ""
try:
for token_hash, token_info in self.token_manager._tokens.items():
if (token_info.database_config and
token_info.is_active and
not self._is_config_empty(token_info.database_config.host) and
not self._is_config_empty(token_info.database_config.user)):
# We need to find the raw token from the hash
# This is a bit tricky since we only store hashes
# We'll need to use the admin token from tokens.json if it has db config
if token_info.token_id == 'admin-token':
# Try the known admin token
return 'doris_admin_token_123456'
elif 'tenant' in token_info.token_id:
# For tenant tokens, we'll need a different approach
# For now, skip these as we don't know the raw token
continue
return ""
except Exception as e:
self.logger.error(f"Error finding available token: {e}")
return ""
async def configure_for_token(self, token: str) -> tuple[bool, str]:
"""Configure connection manager for token with new priority logic
Priority: Token-bound DB config > .env config > error
Args:
token: Authentication token to get database config for
Returns:
(success: bool, config_source: str): Result and which config was used
Raises:
RuntimeError: If no valid database configuration is available
"""
try:
# Priority 1: Try token-bound database config first
if self.token_manager:
db_config = self.token_manager.get_database_config_by_token(token)
if db_config:
# Convert DatabaseConfig to dictionary
token_db_config = {
'host': db_config.host,
'port': db_config.port,
'user': db_config.user,
'password': db_config.password,
'database': db_config.database,
'charset': db_config.charset
}
# Check if token-bound config is valid
if (not self._is_config_empty(token_db_config['host']) and
not self._is_config_empty(token_db_config['user'])):
self.logger.info(f"Using token-bound database configuration for host: {token_db_config['host']}")
self.active_db_config = token_db_config
self._update_db_params_from_config(self.active_db_config)
# Create/recreate connection pool with token-bound config
await self._ensure_pool_with_current_config()
return True, "token-bound"
# Priority 2: Use global .env config if available
if self._has_valid_global_config():
self.logger.info("Using global .env database configuration")
self.active_db_config = self.original_db_config.copy()
self._update_db_params_from_config(self.active_db_config)
# Create/recreate connection pool with global config
await self._ensure_pool_with_current_config()
return True, "global-env"
# Priority 3: No valid configuration available
error_msg = (
"No valid database configuration available for this token. "
"Please contact administrator to:\n"
"1. Add database configuration to tokens.json for this token, OR\n"
"2. Configure valid global database settings in .env file\n"
"Required fields: DB_HOST, DB_USER"
)
self.logger.error(error_msg)
raise RuntimeError(error_msg)
except Exception as e:
self.logger.error(f"Failed to configure database for token: {e}")
raise
async def _ensure_pool_with_current_config(self):
"""Ensure connection pool exists with current configuration"""
try:
# If pool exists with different config, need to recreate it
# If no pool exists, create one with current config
if self.pool and not self.pool.closed:
# Since we can't reliably check pool config attributes,
# we'll recreate the pool if we detect a potential config change
# by checking if current config differs from what we stored
pool_needs_recreation = False
# Compare current config with what we might have used before
if hasattr(self, '_last_pool_config'):
current_config = {
'host': self.host,
'port': self.port,
'user': self.user,
'database': self.database
}
if current_config != self._last_pool_config:
pool_needs_recreation = True
if pool_needs_recreation:
self.logger.info("Database configuration changed, recreating connection pool")
await self._recreate_pool()
elif not self.pool:
self.logger.info("Creating connection pool with current configuration")
await self._create_pool_with_current_config()
# Test the connection immediately
if not await self._test_pool_health():
raise RuntimeError(f"Database connection test failed for {self.host}:{self.port}")
except Exception as e:
self.logger.error(f"Failed to ensure connection pool: {e}")
raise
async def _create_pool_with_current_config(self):
"""Create connection pool with current database configuration"""
try:
self.pool = await aiomysql.create_pool(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
db=self.database,
charset=self.charset,
minsize=self.minsize,
maxsize=self.maxsize,
pool_recycle=self.pool_recycle,
connect_timeout=self.connect_timeout,
autocommit=True
)
# Store the current config for comparison later
self._last_pool_config = {
'host': self.host,
'port': self.port,
'user': self.user,
'database': self.database
}
# Test initial connection
if not await self._test_pool_health():
raise RuntimeError("Connection pool health check failed")
# Start background monitoring tasks if not already running
if not self.pool_health_check_task or self.pool_health_check_task.done():
self.pool_health_check_task = asyncio.create_task(self._pool_health_monitor())
if not self.pool_cleanup_task or self.pool_cleanup_task.done():
self.pool_cleanup_task = asyncio.create_task(self._pool_cleanup_monitor())
# Perform initial pool warmup
await self._warmup_pool()
self.logger.info(f"Connection pool created successfully with {self.host}:{self.port}")
except Exception as e:
self.logger.error(f"Failed to create connection pool: {e}")
raise
async def _recreate_pool(self):
"""Recreate connection pool with current database configuration"""
try:
# Close existing pool
if self.pool and not self.pool.closed:
self.pool.close()
await self.pool.wait_closed()
self.pool = None
# Create new pool with current config
await self._create_pool_with_current_config()
except Exception as e:
self.logger.error(f"Failed to recreate connection pool: {e}")
raise
def validate_database_configuration(self) -> tuple[bool, str]:
"""Validate database configuration completeness
Returns:
(is_valid, error_message): Configuration validation result
"""
# Check if Token authentication is enabled
token_auth_enabled = getattr(self.config.security, 'enable_token_auth', False)
# Check if tokens.json exists and has valid tokens with database configs
tokens_file_available = False
token_bound_configs_available = False
if self.token_manager:
try:
# Check if tokens.json file exists
import os
tokens_file_path = getattr(self.token_manager, 'token_file_path', 'tokens.json')
tokens_file_available = os.path.exists(tokens_file_path)
# Check if any tokens have database configurations
if tokens_file_available or self.token_manager._tokens:
for token_hash, token_info in self.token_manager._tokens.items():
if token_info.database_config:
token_bound_configs_available = True
break
except Exception:
pass
# Validate .env database configuration
env_config_valid = self._has_valid_global_config()
# Decision logic
if token_auth_enabled:
if tokens_file_available:
# tokens.json exists - either .env OR token-bound config must be valid
if env_config_valid or token_bound_configs_available:
return True, "Configuration valid"
else:
return False, (
"Token authentication is enabled and tokens.json exists, but no valid database "
"configuration found. Please provide either:\n"
"1. Valid database configuration in .env file (DB_HOST, DB_USER, etc.)\n"
"2. Database configuration in tokens.json for at least one token"
)
else:
# tokens.json does not exist - must have valid .env config
if env_config_valid:
return True, "Configuration valid"
else:
return False, (
"Token authentication is enabled but tokens.json file not found. "
"Either:\n"
"1. Create tokens.json file with token configurations\n"
"2. Provide valid database configuration in .env file (DB_HOST, DB_USER, etc.)"
)
else:
# Token auth is disabled, must have valid .env config
if env_config_valid:
return True, "Configuration valid"
else:
return False, (
"Token authentication is disabled. Valid database configuration is required "
"in .env file (DB_HOST, DB_USER, etc.)"
)
async def initialize(self):
"""Initialize connection pool with health monitoring"""
try:
# First validate configuration
is_valid, error_message = self.validate_database_configuration()
if not is_valid:
self.logger.error(f"Database configuration validation failed: {error_message}")
raise RuntimeError(f"Database configuration validation failed:\n{error_message}")
self.logger.info(f"Database configuration validated successfully")
self.logger.info(f"Initializing connection pool to {self.host}:{self.port}")
# Only create connection pool if we have valid global config
# Token-bound configs will be handled dynamically during requests
if not self._has_valid_global_config():
self.logger.info("No valid global database config, pool will be created dynamically for token-bound configs")
return
# Create connection pool
self.pool = await aiomysql.create_pool(
host=self.host,
@@ -592,7 +895,20 @@ class DorisConnectionManager:
# Check if pool is available
if not self.pool:
self.logger.warning("Connection pool is not available, attempting recovery...")
await self._recover_pool_with_lock()
# Try to use token-bound configuration if available
if self.token_manager and not self._has_valid_global_config():
available_token = self._find_available_token_with_db_config()
if available_token:
self.logger.info(f"Using token-bound configuration for pool creation: {available_token}")
try:
await self.configure_for_token(available_token)
except Exception as e:
self.logger.error(f"Failed to configure with token-bound config: {e}")
# Fallback to recovery
if not self.pool:
await self._recover_pool_with_lock()
if not self.pool:
raise RuntimeError("Connection pool is not available and recovery failed")

View File

@@ -426,6 +426,10 @@ class DorisQueryExecutor:
self, query_request: QueryRequest, auth_context
) -> QueryResult:
"""Internal query execution"""
# Database configuration should already be handled during authentication
# No need to configure again during query execution
# Optimize query
optimized_sql = await self.query_optimizer.optimize_query(
query_request.sql, {"user_roles": getattr(auth_context, 'roles', [])}

View File

@@ -47,11 +47,16 @@ class SecurityLevel(Enum):
class AuthContext:
"""Authentication context for audit and session tracking"""
token_id: str # Token identifier for audit logging
token_id: str = "" # Token identifier for audit logging
user_id: str = "" # User identifier
roles: list[str] = field(default_factory=list) # User roles
permissions: list[str] = field(default_factory=list) # User permissions
security_level: 'SecurityLevel' = field(default_factory=lambda: SecurityLevel.INTERNAL) # Security level
client_ip: str = "unknown" # Client IP address
session_id: str = "" # Session identifier
login_time: datetime = field(default_factory=datetime.utcnow)
last_activity: datetime | None = None
token: str = "" # Raw token for token-bound database configuration
@dataclass
@@ -84,12 +89,13 @@ class DorisSecurityManager:
Provides complete security control functionality, including authentication, authorization, SQL security validation and data masking
"""
def __init__(self, config):
def __init__(self, config, connection_manager=None):
self.config = config
self.logger = get_logger(__name__)
self.connection_manager = connection_manager
# Initialize security components
self.auth_provider = AuthenticationProvider(config)
self.auth_provider = AuthenticationProvider(config, self)
self.authz_provider = AuthorizationProvider(config)
self.sql_validator = SQLSecurityValidator(config)
self.masking_processor = DataMaskingProcessor(config)
@@ -226,6 +232,10 @@ class DorisSecurityManager:
# Return anonymous context when no authentication is enabled
return AuthContext(
token_id="anonymous",
user_id="anonymous",
roles=["anonymous"],
permissions=["read"],
security_level=SecurityLevel.PUBLIC,
client_ip=auth_info.get("client_ip", "unknown"),
session_id="anonymous_session"
)
@@ -392,18 +402,50 @@ class DorisSecurityManager:
return {"error": "Token manager not initialized"}
return self.auth_provider.token_manager.get_token_stats()
async def _validate_token_database_config(self, token: str, token_info) -> None:
"""Validate database configuration for token immediately during authentication
This ensures database connectivity issues are caught at authentication time,
not during query execution, providing better user experience.
Args:
token: Raw authentication token
token_info: TokenInfo object from token validation
Raises:
ValueError: If database configuration is invalid or connection fails
"""
try:
if not self.connection_manager:
self.logger.warning("Connection manager not available for immediate database validation")
return
# Configure and test database connection for this token
success, config_source = await self.connection_manager.configure_for_token(token)
if success:
self.logger.info(f"Database configuration validated successfully for token {token_info.token_id} (source: {config_source})")
else:
raise ValueError("Database configuration validation failed")
except Exception as e:
error_msg = f"Database configuration validation failed for token {token_info.token_id}: {str(e)}"
self.logger.error(error_msg)
raise ValueError(error_msg)
class AuthenticationProvider:
"""Authentication provider"""
def __init__(self, config):
def __init__(self, config, security_manager=None):
self.config = config
self.logger = get_logger(__name__)
self.session_cache = {}
self.jwt_manager = None
self.oauth_provider = None
self.token_manager = None
self.security_manager = security_manager
# Initialize authentication providers based on individual switches
auth_methods_enabled = []
@@ -583,12 +625,21 @@ class AuthenticationProvider:
token_info = validation_result.token_info
# Immediately validate database configuration for this token
if self.security_manager:
await self.security_manager._validate_token_database_config(token, token_info)
return AuthContext(
token_id=token_info.token_id,
user_id=token_info.token_id, # Use token_id as user_id for token auth
roles=["token_user"], # Default role for token users
permissions=["read", "write"], # Default permissions for token users
security_level=SecurityLevel.INTERNAL,
client_ip=auth_info.get("client_ip", "unknown"),
session_id=auth_info.get("session_id", f"session_{token_info.token_id}"),
login_time=datetime.utcnow(),
last_activity=token_info.last_used
last_activity=token_info.last_used,
token=token # Store raw token for token-bound database configuration
)
except Exception as e: