[Performance]Add Token Management (#55)

* 0.5.1 Version

* fix 0.5.1 schema async bug

* fix security bug

* fix security bug

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add a controllable MCP Server DB Pool permission authentication system, connect it with the Doris permission system, and provide it to enterprise-level applications concurrently with the multi-Worker mode.

* Add Tokens Management
This commit is contained in:
Yijia Su
2025-09-03 11:55:38 +08:00
committed by GitHub
parent f99399c6c7
commit 9ba4cc6f45
10 changed files with 1252 additions and 127 deletions

View File

@@ -1,4 +1,20 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Token Authentication HTTP Handlers
@@ -14,17 +30,32 @@ from starlette.responses import JSONResponse, HTMLResponse
from ..utils.logger import get_logger
from ..utils.security import SecurityLevel
from ..utils.config import DatabaseConfig
from .token_security_middleware import TokenSecurityMiddleware
class TokenHandlers:
"""Token Authentication HTTP Handlers"""
def __init__(self, security_manager):
def __init__(self, security_manager, config=None):
self.security_manager = security_manager
self.logger = get_logger(__name__)
# Initialize security middleware if config is provided
if config:
self.security_middleware = TokenSecurityMiddleware(config)
else:
self.security_middleware = None
self.logger.warning("Token handlers initialized without security middleware - access control disabled")
async def handle_create_token(self, request: Request) -> JSONResponse:
"""Handle token creation request"""
# Apply security checks
if self.security_middleware:
security_response = await self.security_middleware.check_token_management_access(request)
if security_response:
return security_response
try:
# Check if token manager is available
if not self.security_manager.auth_provider.token_manager:
@@ -37,13 +68,20 @@ class TokenHandlers:
# GET request with query parameters
query_params = dict(request.query_params)
token_id = query_params.get("token_id")
user_id = query_params.get("user_id")
roles = query_params.get("roles", "").split(",") if query_params.get("roles") else []
permissions = query_params.get("permissions", "").split(",") if query_params.get("permissions") else []
security_level_str = query_params.get("security_level", "internal")
expires_hours_str = query_params.get("expires_hours")
description = query_params.get("description", "")
custom_token = query_params.get("custom_token")
# Database configuration from query params
db_config = None
if query_params.get("db_host"):
db_config = DatabaseConfig(
host=query_params.get("db_host", "localhost"),
port=int(query_params.get("db_port", "9030")),
user=query_params.get("db_user", "root"),
password=query_params.get("db_password", ""),
database=query_params.get("db_database", "information_schema"),
fe_http_port=int(query_params.get("db_fe_http_port", "8030"))
)
else:
# POST request with JSON body
try:
@@ -54,26 +92,33 @@ class TokenHandlers:
}, status_code=400)
token_id = body.get("token_id")
user_id = body.get("user_id")
roles = body.get("roles", [])
permissions = body.get("permissions", [])
security_level_str = body.get("security_level", "internal")
expires_hours_str = body.get("expires_hours")
description = body.get("description", "")
custom_token = body.get("custom_token")
# Database configuration from JSON body
db_config = None
if body.get("database_config"):
db_data = body["database_config"]
try:
db_config = DatabaseConfig(
host=db_data.get("host", "localhost"),
port=int(db_data.get("port", 9030)),
user=db_data.get("user", "root"),
password=db_data.get("password", ""),
database=db_data.get("database", "information_schema"),
fe_http_port=int(db_data.get("fe_http_port", 8030))
)
except (ValueError, TypeError) as e:
return JSONResponse({
"error": f"Invalid database configuration: {str(e)}"
}, status_code=400)
# Validate required fields
if not token_id or not user_id:
if not token_id:
return JSONResponse({
"error": "token_id and user_id are required"
"error": "token_id is required"
}, status_code=400)
# Parse security level
try:
security_level = SecurityLevel(security_level_str.lower())
except ValueError:
security_level = SecurityLevel.INTERNAL
# Parse expires_hours
expires_hours = None
if expires_hours_str:
@@ -84,27 +129,20 @@ class TokenHandlers:
"error": "expires_hours must be an integer"
}, status_code=400)
# Create token
# Create token using the actual API
try:
token = await self.security_manager.create_token(
token_id=token_id,
user_id=user_id,
roles=roles,
permissions=permissions,
security_level=security_level,
expires_hours=expires_hours,
description=description,
custom_token=custom_token
custom_token=custom_token,
database_config=db_config
)
return JSONResponse({
"success": True,
"token_id": token_id,
"user_id": user_id,
"token": token,
"roles": roles,
"permissions": permissions,
"security_level": security_level.value,
"expires_hours": expires_hours,
"description": description,
"message": "Token created successfully"
@@ -124,6 +162,12 @@ class TokenHandlers:
async def handle_revoke_token(self, request: Request) -> JSONResponse:
"""Handle token revocation request"""
# Apply security checks
if self.security_middleware:
security_response = await self.security_middleware.check_token_management_access(request)
if security_response:
return security_response
try:
# Check if token manager is available
if not self.security_manager.auth_provider.token_manager:
@@ -168,6 +212,12 @@ class TokenHandlers:
async def handle_list_tokens(self, request: Request) -> JSONResponse:
"""Handle token listing request"""
# Apply security checks
if self.security_middleware:
security_response = await self.security_middleware.check_token_management_access(request)
if security_response:
return security_response
try:
# Check if token manager is available
if not self.security_manager.auth_provider.token_manager:
@@ -192,6 +242,12 @@ class TokenHandlers:
async def handle_token_stats(self, request: Request) -> JSONResponse:
"""Handle token statistics request"""
# Apply security checks
if self.security_middleware:
security_response = await self.security_middleware.check_token_management_access(request)
if security_response:
return security_response
try:
# Check if token manager is available
if not self.security_manager.auth_provider.token_manager:
@@ -215,6 +271,12 @@ class TokenHandlers:
async def handle_cleanup_tokens(self, request: Request) -> JSONResponse:
"""Handle expired tokens cleanup request"""
# Apply security checks
if self.security_middleware:
security_response = await self.security_middleware.check_token_management_access(request)
if security_response:
return security_response
try:
# Check if token manager is available
if not self.security_manager.auth_provider.token_manager:
@@ -237,8 +299,62 @@ class TokenHandlers:
"error": f"Internal server error: {str(e)}"
}, status_code=500)
async def handle_demo_page(self, request: Request) -> HTMLResponse:
async def handle_management_page(self, request: Request) -> HTMLResponse:
"""Handle token management demo page"""
# Apply security checks
if self.security_middleware:
security_response = await self.security_middleware.check_token_management_access(request)
if security_response:
# Convert JSON response to HTML for demo page
error_data = security_response.body.decode('utf-8') if hasattr(security_response, 'body') else '{"error": "Access denied"}'
try:
error_info = json.loads(error_data)
except:
error_info = {"error": "Access denied"}
error_html = f"""
<!DOCTYPE html>
<html>
<head>
<title>Access Denied - Token Management</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 50px; background: #f5f5f5; }}
.container {{ max-width: 600px; margin: 0 auto; background: white; padding: 30px; border-radius: 8px; }}
.error {{ color: #dc3545; background: #f8d7da; border: 1px solid #f5c6cb; padding: 15px; border-radius: 5px; }}
.security-info {{ background: #d1ecf1; border: 1px solid #bee5eb; padding: 15px; border-radius: 5px; margin-top: 20px; }}
</style>
</head>
<body>
<div class="container">
<h1>🔐 Token Management - Access Denied</h1>
<div class="error">
<h3>Access Denied</h3>
<p><strong>Error:</strong> {error_info.get('error', 'Access denied')}</p>
<p><strong>Message:</strong> {error_info.get('message', 'Token management access is restricted')}</p>
{'<p><strong>Your IP:</strong> ' + str(error_info.get('client_ip', 'Unknown')) + '</p>' if 'client_ip' in error_info else ''}
</div>
<div class="security-info">
<h3>🛡️ Security Information</h3>
<p>Token management endpoints are protected by the following security measures:</p>
<ul>
<li><strong>IP Restrictions:</strong> Only localhost/127.0.0.1 access allowed</li>
<li><strong>Admin Authentication:</strong> Valid admin token required</li>
<li><strong>Configuration Control:</strong> Must be explicitly enabled</li>
</ul>
<p>If you need access, please:</p>
<ol>
<li>Access from the server host (127.0.0.1)</li>
<li>Ensure HTTP token management is enabled in configuration</li>
<li>Provide valid admin authentication</li>
</ol>
</div>
</div>
</body>
</html>
"""
return HTMLResponse(error_html, status_code=security_response.status_code)
try:
# Check if token manager is available
if not self.security_manager.auth_provider.token_manager:
@@ -323,34 +439,51 @@ class TokenHandlers:
<input type="text" id="token_id" name="token_id" placeholder="e.g., my-app-token" required>
</div>
<div class="form-group">
<label for="user_id">User ID (required):</label>
<input type="text" id="user_id" name="user_id" placeholder="e.g., john_doe" required>
<label for="expires_hours">Expires Hours (optional):</label>
<input type="number" id="expires_hours" name="expires_hours" placeholder="e.g., 720 (30 days), leave empty for default">
</div>
<div class="form-group">
<label for="roles">Roles (comma-separated):</label>
<input type="text" id="roles" name="roles" placeholder="e.g., data_analyst,viewer">
</div>
<div class="form-group">
<label for="permissions">Permissions (comma-separated):</label>
<input type="text" id="permissions" name="permissions" placeholder="e.g., read_data,query_database">
</div>
<div class="form-group">
<label for="security_level">Security Level:</label>
<select id="security_level" name="security_level">
<option value="public">Public</option>
<option value="internal" selected>Internal</option>
<option value="confidential">Confidential</option>
<option value="secret">Secret</option>
</select>
</div>
<div class="form-group">
<label for="expires_hours">Expires Hours (leave empty for default):</label>
<input type="number" id="expires_hours" name="expires_hours" placeholder="e.g., 720 (30 days)">
</div>
<div class="form-group">
<label for="description">Description:</label>
<label for="description">Description (optional):</label>
<textarea id="description" name="description" placeholder="Token description"></textarea>
</div>
<div class="form-group">
<label for="custom_token">Custom Token (optional):</label>
<input type="text" id="custom_token" name="custom_token" placeholder="Leave empty to auto-generate">
<small style="color: #666; display: block; margin-top: 5px;">If not provided, a secure token will be generated automatically</small>
</div>
<div class="section" style="margin: 20px 0; padding: 15px; background: #f8f9fa; border-radius: 5px;">
<h3>🗄️ Database Configuration (Optional)</h3>
<p style="color: #666; font-size: 14px; margin-bottom: 15px;">Configure database connection for this token. Leave empty to use system defaults.</p>
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 10px;">
<div class="form-group">
<label for="db_host">Host:</label>
<input type="text" id="db_host" name="db_host" placeholder="localhost">
</div>
<div class="form-group">
<label for="db_port">Port:</label>
<input type="number" id="db_port" name="db_port" placeholder="9030">
</div>
<div class="form-group">
<label for="db_user">User:</label>
<input type="text" id="db_user" name="db_user" placeholder="root">
</div>
<div class="form-group">
<label for="db_password">Password:</label>
<input type="password" id="db_password" name="db_password" placeholder="(optional)">
</div>
<div class="form-group">
<label for="db_database">Database:</label>
<input type="text" id="db_database" name="db_database" placeholder="information_schema">
</div>
<div class="form-group">
<label for="db_fe_http_port">FE HTTP Port:</label>
<input type="number" id="db_fe_http_port" name="db_fe_http_port" placeholder="8030">
</div>
</div>
</div>
<button type="submit" class="btn-primary">Create Token</button>
</form>
<div id="createTokenResponse"></div>
@@ -384,30 +517,81 @@ class TokenHandlers:
</div>
<script>
// Get admin token from URL parameters
const urlParams = new URLSearchParams(window.location.search);
const adminToken = urlParams.get('admin_token');
// Create request headers with admin token
function getAuthHeaders() {{
if (adminToken) {{
return {{
'Content-Type': 'application/json',
'Authorization': `Bearer ${{adminToken}}`
}};
}} else {{
return {{'Content-Type': 'application/json'}};
}}
}}
// Create URL with admin token parameter
function getAuthURL(baseUrl) {{
if (adminToken) {{
const separator = baseUrl.includes('?') ? '&' : '?';
return `${{baseUrl}}${{separator}}admin_token=${{encodeURIComponent(adminToken)}}`;
}}
return baseUrl;
}}
function showResponse(elementId, data, isSuccess = true) {{
const element = document.getElementById(elementId);
element.innerHTML = '<pre>' + JSON.stringify(data, null, 2) + '</pre>';
element.className = 'response ' + (isSuccess ? 'success' : 'error');
}}
// Create token form
// Create token form - updated to match actual API
document.getElementById('createTokenForm').addEventListener('submit', async (e) => {{
e.preventDefault();
const formData = new FormData(e.target);
const data = Object.fromEntries(formData.entries());
// Convert comma-separated values to arrays
data.roles = data.roles ? data.roles.split(',').map(r => r.trim()).filter(r => r) : [];
data.permissions = data.permissions ? data.permissions.split(',').map(p => p.trim()).filter(p => p) : [];
// Remove empty fields for optional parameters
if (!data.expires_hours) delete data.expires_hours;
if (!data.description) delete data.description;
if (!data.custom_token) delete data.custom_token;
// Handle database configuration
if (data.db_host) {{
data.database_config = {{
host: data.db_host,
port: data.db_port ? parseInt(data.db_port) : 9030,
user: data.db_user || 'root',
password: data.db_password || '',
database: data.db_database || 'information_schema',
fe_http_port: data.db_fe_http_port ? parseInt(data.db_fe_http_port) : 8030
}};
}}
// Remove individual database fields from data
delete data.db_host;
delete data.db_port;
delete data.db_user;
delete data.db_password;
delete data.db_database;
delete data.db_fe_http_port;
try {{
const response = await fetch('/token/create', {{
const response = await fetch(getAuthURL('/token/create'), {{
method: 'POST',
headers: {{'Content-Type': 'application/json'}},
headers: getAuthHeaders(),
body: JSON.stringify(data)
}});
const result = await response.json();
showResponse('createTokenResponse', result, response.ok);
// Refresh token list if creation was successful
if (response.ok) {{
document.getElementById('listTokensBtn').click();
}}
}} catch (error) {{
showResponse('createTokenResponse', {{error: error.message}}, false);
}}
@@ -416,7 +600,9 @@ class TokenHandlers:
// List tokens
document.getElementById('listTokensBtn').addEventListener('click', async () => {{
try {{
const response = await fetch('/token/list');
const response = await fetch(getAuthURL('/token/list'), {{
headers: getAuthHeaders()
}});
const result = await response.json();
showResponse('tokenListResponse', result, response.ok);
}} catch (error) {{
@@ -427,7 +613,10 @@ class TokenHandlers:
// Cleanup tokens
document.getElementById('cleanupTokensBtn').addEventListener('click', async () => {{
try {{
const response = await fetch('/token/cleanup', {{method: 'POST'}});
const response = await fetch(getAuthURL('/token/cleanup'), {{
method: 'POST',
headers: getAuthHeaders()
}});
const result = await response.json();
showResponse('tokenListResponse', result, response.ok);
}} catch (error) {{
@@ -444,11 +633,18 @@ class TokenHandlers:
}}
try {{
const response = await fetch(`/token/revoke?token_id=${{encodeURIComponent(tokenId)}}`, {{
method: 'DELETE'
const response = await fetch(getAuthURL(`/token/revoke?token_id=${{encodeURIComponent(tokenId)}}`), {{
method: 'DELETE',
headers: getAuthHeaders()
}});
const result = await response.json();
showResponse('revokeTokenResponse', result, response.ok);
// Refresh token list if revocation was successful
if (response.ok) {{
document.getElementById('listTokensBtn').click();
document.getElementById('revokeTokenId').value = '';
}}
}} catch (error) {{
showResponse('revokeTokenResponse', {{error: error.message}}, false);
}}

View File

@@ -1,4 +1,20 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Token Authentication Management Module
@@ -355,7 +371,8 @@ class TokenManager:
token_id: str,
expires_hours: Optional[int] = None,
description: str = "",
custom_token: Optional[str] = None
custom_token: Optional[str] = None,
database_config: Optional[DatabaseConfig] = None
) -> str:
"""Create a new token"""
try:
@@ -380,7 +397,8 @@ class TokenManager:
token_info = TokenInfo(
token_id=token_id,
expires_at=expires_at,
description=description
description=description,
database_config=database_config
)
# Hash and store token
@@ -390,6 +408,9 @@ class TokenManager:
self.logger.info(f"Created new token '{token_id}'")
# Save token to file
self._save_token_to_file(token_id, raw_token, token_info)
return raw_token
except Exception as e:
@@ -410,12 +431,201 @@ class TokenManager:
del self._token_ids[token_id]
self.logger.info(f"Revoked token '{token_id}'")
# Save updated tokens to file
self._remove_token_from_file(token_id)
return True
except Exception as e:
self.logger.error(f"Failed to revoke token '{token_id}': {e}")
return False
def _save_tokens_to_file(self):
"""Save current tokens to JSON file"""
try:
# Convert current tokens to file format
tokens_list = []
for token_hash, token_info in self._tokens.items():
# Find the raw token for this token_info
raw_token = None
for tid, thash in self._token_ids.items():
if thash == token_hash and tid == token_info.token_id:
# We can't recover the original token from hash,
# so we'll create a placeholder for existing tokens
raw_token = f"<existing_token_hash_{token_hash[:8]}>"
break
if raw_token is None:
continue
token_config = {
"token_id": token_info.token_id,
"token": raw_token,
"description": token_info.description,
"expires_hours": None,
"is_active": token_info.is_active
}
# Add expiration info
if token_info.expires_at:
# Calculate remaining hours from now
remaining = token_info.expires_at - datetime.utcnow()
if remaining.total_seconds() > 0:
token_config["expires_hours"] = int(remaining.total_seconds() / 3600)
else:
token_config["expires_hours"] = 0
# Add database config if present
if token_info.database_config:
token_config["database_config"] = {
"host": token_info.database_config.host,
"port": token_info.database_config.port,
"user": token_info.database_config.user,
"password": token_info.database_config.password,
"database": token_info.database_config.database,
"charset": token_info.database_config.charset,
"fe_http_port": token_info.database_config.fe_http_port
}
tokens_list.append(token_config)
# Create file content
file_content = {
"version": "1.0",
"description": "Doris MCP Server Token configuration file",
"created_at": datetime.utcnow().isoformat() + "Z",
"tokens": tokens_list,
"notes": [
"This file is automatically updated when tokens are created or revoked",
"Please backup this file before making manual changes",
"Tokens with hash placeholders were loaded from previous configuration"
]
}
# Save to file
with open(self.token_file_path, 'w', encoding='utf-8') as f:
json.dump(file_content, f, indent=2, ensure_ascii=False)
self.logger.info(f"Saved {len(tokens_list)} tokens to file: {self.token_file_path}")
except Exception as e:
self.logger.error(f"Failed to save tokens to file {self.token_file_path}: {e}")
def _save_token_to_file(self, token_id: str, raw_token: str, token_info: TokenInfo):
"""Save a single new token to file (for newly created tokens only)"""
try:
# Load existing file
existing_data = {"tokens": []}
if os.path.exists(self.token_file_path):
try:
with open(self.token_file_path, 'r', encoding='utf-8') as f:
existing_data = json.load(f)
except Exception as e:
self.logger.warning(f"Could not load existing token file: {e}")
# Ensure tokens list exists
if 'tokens' not in existing_data or not isinstance(existing_data['tokens'], list):
existing_data['tokens'] = []
# Check if token already exists in file
token_exists = False
for i, token_config in enumerate(existing_data['tokens']):
if token_config.get('token_id') == token_id:
# Update existing token
existing_data['tokens'][i] = self._token_info_to_config(token_id, raw_token, token_info)
token_exists = True
break
# Add new token if it doesn't exist
if not token_exists:
new_token_config = self._token_info_to_config(token_id, raw_token, token_info)
existing_data['tokens'].append(new_token_config)
# Update metadata
existing_data.update({
"version": "1.0",
"description": "Doris MCP Server Token configuration file",
"updated_at": datetime.utcnow().isoformat() + "Z"
})
# Save to file
with open(self.token_file_path, 'w', encoding='utf-8') as f:
json.dump(existing_data, f, indent=2, ensure_ascii=False)
self.logger.info(f"Saved token '{token_id}' to file: {self.token_file_path}")
except Exception as e:
self.logger.error(f"Failed to save token '{token_id}' to file: {e}")
def _token_info_to_config(self, token_id: str, raw_token: str, token_info: TokenInfo) -> dict:
"""Convert TokenInfo to file configuration format"""
token_config = {
"token_id": token_id,
"token": raw_token,
"description": token_info.description,
"expires_hours": None,
"is_active": token_info.is_active
}
# Add expiration info
if token_info.expires_at:
# Calculate remaining hours from creation time
remaining = token_info.expires_at - token_info.created_at
token_config["expires_hours"] = int(remaining.total_seconds() / 3600) if remaining.total_seconds() > 0 else None
# Add database config if present
if token_info.database_config:
token_config["database_config"] = {
"host": token_info.database_config.host,
"port": token_info.database_config.port,
"user": token_info.database_config.user,
"password": token_info.database_config.password,
"database": token_info.database_config.database,
"charset": token_info.database_config.charset,
"fe_http_port": token_info.database_config.fe_http_port
}
return token_config
def _remove_token_from_file(self, token_id: str):
"""Remove a token from the JSON file"""
try:
if not os.path.exists(self.token_file_path):
return
# Load existing file
with open(self.token_file_path, 'r', encoding='utf-8') as f:
existing_data = json.load(f)
if 'tokens' not in existing_data or not isinstance(existing_data['tokens'], list):
return
# Remove the token
original_count = len(existing_data['tokens'])
existing_data['tokens'] = [
token for token in existing_data['tokens']
if token.get('token_id') != token_id
]
if len(existing_data['tokens']) < original_count:
# Update metadata
existing_data.update({
"version": "1.0",
"description": "Doris MCP Server Token configuration file",
"updated_at": datetime.utcnow().isoformat() + "Z"
})
# Save to file
with open(self.token_file_path, 'w', encoding='utf-8') as f:
json.dump(existing_data, f, indent=2, ensure_ascii=False)
self.logger.info(f"Removed token '{token_id}' from file: {self.token_file_path}")
except Exception as e:
self.logger.error(f"Failed to remove token '{token_id}' from file: {e}")
async def list_tokens(self) -> List[Dict[str, Any]]:
"""List all tokens (without sensitive data)"""
tokens = []

View File

@@ -0,0 +1,227 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Token Management Security Middleware
Provides comprehensive security controls for token management endpoints including
IP restrictions, admin authentication, and configuration-based access control.
"""
import hashlib
import hmac
import ipaddress
import secrets
import time
from typing import Optional, List, Dict, Any
from starlette.requests import Request
from starlette.responses import JSONResponse
from ..utils.logger import get_logger
from ..utils.config import DorisConfig
class TokenSecurityMiddleware:
"""Security middleware for token management endpoints"""
def __init__(self, config: DorisConfig):
self.config = config
self.logger = get_logger(__name__)
# Initialize admin token hash if provided
self._admin_token_hash = None
if config.security.token_management_admin_token:
self._admin_token_hash = self._hash_token(config.security.token_management_admin_token)
# Normalize allowed IPs
self._allowed_networks = self._parse_allowed_networks(
config.security.token_management_allowed_ips
)
self.logger.info(f"Token management security initialized: "
f"HTTP endpoints {'enabled' if config.security.enable_http_token_management else 'disabled'}, "
f"Admin auth {'required' if config.security.require_admin_auth else 'optional'}, "
f"Allowed networks: {len(self._allowed_networks)}")
def _hash_token(self, token: str) -> str:
"""Hash token using SHA-256"""
return hashlib.sha256(token.encode('utf-8')).hexdigest()
def _parse_allowed_networks(self, allowed_ips: List[str]) -> List[ipaddress.IPv4Network | ipaddress.IPv6Network]:
"""Parse allowed IP addresses and networks"""
networks = []
for ip_str in allowed_ips:
try:
# Try to parse as network (CIDR notation)
if '/' in ip_str:
networks.append(ipaddress.ip_network(ip_str, strict=False))
else:
# Parse as single IP and convert to /32 network
ip = ipaddress.ip_address(ip_str)
if isinstance(ip, ipaddress.IPv4Address):
networks.append(ipaddress.IPv4Network(f"{ip}/32"))
else:
networks.append(ipaddress.IPv6Network(f"{ip}/128"))
except ValueError as e:
self.logger.warning(f"Invalid IP/network '{ip_str}': {e}")
return networks
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP from request, considering proxies"""
# Check X-Forwarded-For header first (for proxy setups)
forwarded_for = request.headers.get('X-Forwarded-For')
if forwarded_for:
# Take the first IP (original client)
client_ip = forwarded_for.split(',')[0].strip()
elif request.headers.get('X-Real-IP'):
client_ip = request.headers.get('X-Real-IP')
else:
# Direct connection
client_ip = request.client.host if request.client else "unknown"
return client_ip
def _is_ip_allowed(self, client_ip: str) -> bool:
"""Check if client IP is in allowed networks"""
try:
client_addr = ipaddress.ip_address(client_ip)
for network in self._allowed_networks:
if client_addr in network:
return True
return False
except ValueError:
self.logger.warning(f"Invalid client IP format: {client_ip}")
return False
def _extract_admin_token(self, request: Request) -> Optional[str]:
"""Extract admin token from request headers"""
# Try Authorization header first
auth_header = request.headers.get('Authorization', '')
if auth_header.startswith('Bearer '):
return auth_header[7:]
elif auth_header.startswith('Token '):
return auth_header[6:]
# Try X-Admin-Token header
admin_token = request.headers.get('X-Admin-Token', '')
if admin_token:
return admin_token
# Try query parameter as fallback (not recommended for production)
admin_token = request.query_params.get('admin_token', '')
if admin_token:
self.logger.warning("Admin token passed via query parameter - this is insecure for production")
return admin_token
return None
def _verify_admin_token(self, provided_token: str) -> bool:
"""Verify provided admin token against configured token"""
if not self._admin_token_hash:
self.logger.warning("No admin token configured for token management")
return False
provided_hash = self._hash_token(provided_token)
# Use constant-time comparison to prevent timing attacks
return hmac.compare_digest(self._admin_token_hash, provided_hash)
async def check_token_management_access(self, request: Request) -> Optional[JSONResponse]:
"""
Check if request is authorized for token management operations
Returns:
None if access is granted
JSONResponse with error if access is denied
"""
# Check if HTTP token management is enabled
if not self.config.security.enable_http_token_management:
self.logger.warning(f"Token management endpoint access denied - HTTP management disabled: {request.url.path}")
return JSONResponse({
"error": "Token management endpoints are disabled for security",
"message": "HTTP token management is disabled. Use file-based token management instead.",
"suggestion": "Edit tokens.json file directly or enable HTTP management with proper security configuration"
}, status_code=403)
# Extract client IP
client_ip = self._get_client_ip(request)
# Check IP restrictions
if not self._is_ip_allowed(client_ip):
self.logger.warning(f"Token management access denied for IP {client_ip}: not in allowed list")
return JSONResponse({
"error": "Access denied - IP not allowed",
"client_ip": client_ip,
"message": "Token management is restricted to specific IP addresses",
"allowed_networks": [str(net) for net in self._allowed_networks]
}, status_code=403)
# Check admin authentication if required
if self.config.security.require_admin_auth:
admin_token = self._extract_admin_token(request)
if not admin_token:
self.logger.warning(f"Token management access denied for IP {client_ip}: missing admin token")
return JSONResponse({
"error": "Admin authentication required",
"message": "Token management requires admin authentication",
"hint": "Provide admin token in Authorization header: 'Bearer <admin_token>' or 'X-Admin-Token: <admin_token>'"
}, status_code=401)
if not self._verify_admin_token(admin_token):
self.logger.warning(f"Token management access denied for IP {client_ip}: invalid admin token")
return JSONResponse({
"error": "Invalid admin token",
"message": "The provided admin token is invalid"
}, status_code=401)
# Log successful access
self.logger.info(f"Token management access granted for IP {client_ip} to {request.url.path}")
# Access granted
return None
def get_security_info(self) -> Dict[str, Any]:
"""Get current security configuration info (for demo/status pages)"""
return {
"http_token_management_enabled": self.config.security.enable_http_token_management,
"admin_auth_required": self.config.security.require_admin_auth,
"admin_token_configured": bool(self._admin_token_hash),
"allowed_networks_count": len(self._allowed_networks),
"allowed_networks": [str(net) for net in self._allowed_networks],
"security_features": [
"IP address restrictions",
"Admin token authentication" if self.config.security.require_admin_auth else "Optional admin authentication",
"Secure token hashing",
"Request logging and auditing"
]
}
def generate_admin_token(self) -> str:
"""Generate a secure admin token"""
return secrets.token_urlsafe(32)
# Convenience function for middleware creation
def create_token_security_middleware(config: DorisConfig) -> TokenSecurityMiddleware:
"""Create token security middleware with configuration"""
return TokenSecurityMiddleware(config)

View File

@@ -546,7 +546,7 @@ class DorisServer:
# Token management endpoints
from .auth.token_handlers import TokenHandlers
token_handlers = TokenHandlers(self.security_manager)
token_handlers = TokenHandlers(self.security_manager, self.config)
async def token_create(request):
return await token_handlers.handle_create_token(request)
@@ -563,8 +563,8 @@ class DorisServer:
async def token_cleanup(request):
return await token_handlers.handle_cleanup_tokens(request)
async def token_demo(request):
return await token_handlers.handle_demo_page(request)
async def token_management(request):
return await token_handlers.handle_management_page(request)
# Lifecycle manager - simplified since we manage session_manager externally
@contextlib.asynccontextmanager
@@ -592,7 +592,7 @@ class DorisServer:
Route("/token/list", token_list, methods=["GET"]),
Route("/token/stats", token_stats, methods=["GET"]),
Route("/token/cleanup", token_cleanup, methods=["GET", "POST"]),
Route("/token/demo", token_demo, methods=["GET"]),
Route("/token/management", token_management, methods=["GET"]),
],
lifespan=lifespan,
)

View File

@@ -1,4 +1,20 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Multi-worker application module for doris-mcp-server
@@ -396,7 +412,7 @@ async def initialize_worker():
from .auth.oauth_handlers import OAuthHandlers
from .auth.token_handlers import TokenHandlers
_oauth_handlers = OAuthHandlers(_worker_security_manager)
_token_handlers = TokenHandlers(_worker_security_manager)
_token_handlers = TokenHandlers(_worker_security_manager, config)
_worker_initialized = True
logger.info(f"Worker {os.getpid()} MCP initialization completed successfully")
@@ -481,12 +497,12 @@ async def token_cleanup(request):
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
return await _token_handlers.handle_cleanup_tokens(request)
async def token_demo(request):
"""Token demo page endpoint"""
async def token_management(request):
"""Token management page endpoint"""
if not _token_handlers:
from starlette.responses import HTMLResponse
return HTMLResponse("<h1>Token handlers not initialized</h1>")
return await _token_handlers.handle_demo_page(request)
return await _token_handlers.handle_management_page(request)
async def root_info(request):
"""Root endpoint"""
@@ -593,7 +609,7 @@ basic_app = Starlette(
Route("/token/list", token_list, methods=["GET"]),
Route("/token/stats", token_stats, methods=["GET"]),
Route("/token/cleanup", token_cleanup, methods=["GET", "POST"]),
Route("/token/demo", token_demo, methods=["GET"]),
Route("/token/management", token_management, methods=["GET"]),
],
lifespan=lifespan
)

View File

@@ -93,6 +93,12 @@ class SecurityConfig:
default_token_expiry_hours: int = 24 * 30 # Default expiry: 30 days
token_hash_algorithm: str = "sha256" # Token hashing algorithm: sha256, sha512
# Token Management Security (New in v0.6.0)
enable_http_token_management: bool = False # Enable HTTP token management endpoints (default: disabled for security)
token_management_admin_token: str = "" # Admin token for token management endpoints (required if HTTP management enabled)
token_management_allowed_ips: list[str] = field(default_factory=lambda: ["127.0.0.1", "::1", "localhost"]) # Allowed IPs for token management
require_admin_auth: bool = True # Require admin authentication for token management (default: true)
# JWT Configuration
jwt_algorithm: str = "RS256" # RS256, ES256, HS256
jwt_issuer: str = "doris-mcp-server"
@@ -469,6 +475,21 @@ class DorisConfig:
os.getenv("DEFAULT_TOKEN_EXPIRY_HOURS", str(config.security.default_token_expiry_hours))
)
config.security.token_hash_algorithm = os.getenv("TOKEN_HASH_ALGORITHM", config.security.token_hash_algorithm)
# Token Management Security Configuration (New in v0.6.0)
config.security.enable_http_token_management = (
os.getenv("ENABLE_HTTP_TOKEN_MANAGEMENT", str(config.security.enable_http_token_management).lower()).lower() == "true"
)
config.security.token_management_admin_token = os.getenv("TOKEN_MANAGEMENT_ADMIN_TOKEN", config.security.token_management_admin_token)
# Parse allowed IPs from comma-separated string
allowed_ips_str = os.getenv("TOKEN_MANAGEMENT_ALLOWED_IPS", "")
if allowed_ips_str:
config.security.token_management_allowed_ips = [ip.strip() for ip in allowed_ips_str.split(",") if ip.strip()]
config.security.require_admin_auth = (
os.getenv("REQUIRE_ADMIN_AUTH", str(config.security.require_admin_auth).lower()).lower() == "true"
)
# Performance configuration
config.performance.enable_query_cache = (

View File

@@ -32,6 +32,7 @@ from sqlparse.sql import Statement
from sqlparse.tokens import Keyword, Name
from .logger import get_logger
from .config import DatabaseConfig
class SecurityLevel(Enum):
@@ -333,7 +334,8 @@ class DorisSecurityManager:
token_id: str,
expires_hours: Optional[int] = None,
description: str = "",
custom_token: Optional[str] = None
custom_token: Optional[str] = None,
database_config: Optional[DatabaseConfig] = None
) -> str:
"""Create a new API access token
@@ -342,6 +344,7 @@ class DorisSecurityManager:
expires_hours: Token expiration in hours (None for no expiration)
description: Token description for management purposes
custom_token: Custom token string (if None, generates random token)
database_config: Optional database configuration for this token
Returns:
Generated token string
@@ -353,7 +356,8 @@ class DorisSecurityManager:
token_id=token_id,
expires_hours=expires_hours,
description=description,
custom_token=custom_token
custom_token=custom_token,
database_config=database_config
)
async def revoke_token(self, token_id: str) -> bool: