diff --git a/.env.example b/.env.example
index 167c2f7..616ccd5 100644
--- a/.env.example
+++ b/.env.example
@@ -36,8 +36,145 @@ DORIS_MAX_CONNECTION_AGE=3600
# Security Configuration
# ===================================================================
-# Authentication configuration
+# Independent Authentication Switches - NEW DESIGN!
+# Each authentication method can be enabled/disabled independently
+# Any enabled method that succeeds will allow access
+# If all methods are disabled, anonymous access is allowed
+
+# Legacy configuration - kept for backward compatibility
+# AUTH_TYPE is now deprecated - use individual switches above
AUTH_TYPE=token
+
+# Token Authentication (Default method - simple and effective)
+ENABLE_TOKEN_AUTH=false
+
+# JWT Authentication (For stateless applications)
+ENABLE_JWT_AUTH=false
+
+# OAuth 2.0/OIDC Authentication (For enterprise integration)
+ENABLE_OAUTH_AUTH=false
+
+# ===================================================================
+# Token Authentication Configuration (Enable with ENABLE_TOKEN_AUTH=true)
+# ===================================================================
+
+# Basic token authentication settings
+TOKEN_FILE_PATH=tokens.json
+ENABLE_TOKEN_EXPIRY=true
+DEFAULT_TOKEN_EXPIRY_HOURS=720
+TOKEN_HASH_ALGORITHM=sha256
+
+# ===================================================================
+# JWT Authentication Configuration (Enable with ENABLE_JWT_AUTH=true)
+# ===================================================================
+
+# JWT token settings (when ENABLE_JWT_AUTH=true)
+JWT_SECRET_KEY=your_jwt_secret_key_here_change_in_production
+JWT_ALGORITHM=HS256
+JWT_EXPIRATION_HOURS=24
+JWT_ISSUER=doris-mcp-server
+JWT_AUDIENCE=doris-mcp-client
+
+# JWT token validation settings
+JWT_VERIFY_SIGNATURE=true
+JWT_VERIFY_EXPIRATION=true
+JWT_VERIFY_AUDIENCE=true
+JWT_VERIFY_ISSUER=true
+
+# JWT refresh token settings
+ENABLE_JWT_REFRESH=true
+JWT_REFRESH_EXPIRATION_DAYS=30
+JWT_REFRESH_SECRET_KEY=your_jwt_refresh_secret_key_here
+
+# JWT user claims configuration
+JWT_USER_ID_CLAIM=user_id
+JWT_ROLES_CLAIM=roles
+JWT_PERMISSIONS_CLAIM=permissions
+JWT_SECURITY_LEVEL_CLAIM=security_level
+
+# ===================================================================
+# OAuth 2.0 / OpenID Connect Configuration (Enable with ENABLE_OAUTH_AUTH=true)
+# ===================================================================
+
+# OAuth provider settings (when ENABLE_OAUTH_AUTH=true)
+OAUTH_PROVIDER_TYPE=generic
+OAUTH_CLIENT_ID=your_oauth_client_id
+OAUTH_CLIENT_SECRET=your_oauth_client_secret
+OAUTH_REDIRECT_URI=http://localhost:3000/auth/callback
+
+# OAuth endpoints (for generic provider)
+OAUTH_AUTHORIZATION_URL=https://your-provider.com/auth
+OAUTH_TOKEN_URL=https://your-provider.com/token
+OAUTH_USERINFO_URL=https://your-provider.com/userinfo
+OAUTH_JWKS_URL=https://your-provider.com/.well-known/jwks.json
+
+# OAuth scope and claims
+OAUTH_SCOPE=openid profile email
+OAUTH_USER_ID_CLAIM=sub
+OAUTH_USERNAME_CLAIM=preferred_username
+OAUTH_EMAIL_CLAIM=email
+OAUTH_ROLES_CLAIM=roles
+OAUTH_GROUPS_CLAIM=groups
+
+# OAuth session settings
+OAUTH_SESSION_SECRET=your_oauth_session_secret_here
+OAUTH_SESSION_EXPIRY=3600
+OAUTH_STATE_EXPIRY=300
+
+# Popular OAuth providers presets (uncomment and configure as needed)
+
+# Google OAuth Configuration
+# OAUTH_PROVIDER_TYPE=google
+# OAUTH_CLIENT_ID=your_google_client_id.apps.googleusercontent.com
+# OAUTH_CLIENT_SECRET=your_google_client_secret
+# OAUTH_AUTHORIZATION_URL=https://accounts.google.com/o/oauth2/auth
+# OAUTH_TOKEN_URL=https://oauth2.googleapis.com/token
+# OAUTH_USERINFO_URL=https://www.googleapis.com/oauth2/v1/userinfo
+# OAUTH_JWKS_URL=https://www.googleapis.com/oauth2/v3/certs
+# OAUTH_SCOPE=openid profile email
+
+# Microsoft Azure AD Configuration
+# OAUTH_PROVIDER_TYPE=azure
+# OAUTH_CLIENT_ID=your_azure_client_id
+# OAUTH_CLIENT_SECRET=your_azure_client_secret
+# OAUTH_TENANT_ID=your_tenant_id
+# OAUTH_AUTHORIZATION_URL=https://login.microsoftonline.com/{tenant}/oauth2/v2.0/authorize
+# OAUTH_TOKEN_URL=https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token
+# OAUTH_USERINFO_URL=https://graph.microsoft.com/v1.0/me
+# OAUTH_JWKS_URL=https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys
+# OAUTH_SCOPE=openid profile email
+
+# GitHub OAuth Configuration
+# OAUTH_PROVIDER_TYPE=github
+# OAUTH_CLIENT_ID=your_github_client_id
+# OAUTH_CLIENT_SECRET=your_github_client_secret
+# OAUTH_AUTHORIZATION_URL=https://github.com/login/oauth/authorize
+# OAUTH_TOKEN_URL=https://github.com/login/oauth/access_token
+# OAUTH_USERINFO_URL=https://api.github.com/user
+# OAUTH_SCOPE=user:email
+
+# GitLab OAuth Configuration
+# OAUTH_PROVIDER_TYPE=gitlab
+# OAUTH_CLIENT_ID=your_gitlab_client_id
+# OAUTH_CLIENT_SECRET=your_gitlab_client_secret
+# OAUTH_AUTHORIZATION_URL=https://gitlab.com/oauth/authorize
+# OAUTH_TOKEN_URL=https://gitlab.com/oauth/token
+# OAUTH_USERINFO_URL=https://gitlab.com/api/v4/user
+# OAUTH_SCOPE=read_user
+
+# Keycloak OAuth Configuration
+# OAUTH_PROVIDER_TYPE=keycloak
+# OAUTH_CLIENT_ID=your_keycloak_client_id
+# OAUTH_CLIENT_SECRET=your_keycloak_client_secret
+# OAUTH_REALM=your_realm
+# OAUTH_SERVER_URL=https://your-keycloak-server.com
+# OAUTH_AUTHORIZATION_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/auth
+# OAUTH_TOKEN_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/token
+# OAUTH_USERINFO_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/userinfo
+# OAUTH_JWKS_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/certs
+# OAUTH_SCOPE=openid profile email
+
+# Legacy token settings (for backward compatibility)
TOKEN_SECRET=your_secret_key_here
TOKEN_EXPIRY=3600
@@ -172,7 +309,13 @@ TEMP_FILES_DIR=tmp
# - LOG_CLEANUP_INTERVAL_HOURS: Check frequency, recommended 24 hours
# 2. Security Best Practices:
-# - Must change TOKEN_SECRET in production environment
+# - NEW: Enable individual authentication methods using ENABLE_TOKEN_AUTH, ENABLE_JWT_AUTH, ENABLE_OAUTH_AUTH
+# - When all methods are disabled, ALL requests are allowed with anonymous access
+# - Authentication methods work independently - any one succeeding allows access
+# - Token Auth: Change default tokens (DEFAULT_ADMIN_TOKEN, etc.) in production
+# - JWT Auth: Change JWT_SECRET_KEY and JWT_REFRESH_SECRET_KEY in production
+# - OAuth Auth: Configure OAuth provider settings and secure client secrets
+# - Must change TOKEN_SECRET in production environment (legacy compatibility)
# - Adjust BLOCKED_KEYWORDS according to business needs
# - Enable ENABLE_SECURITY_CHECK and ENABLE_MASKING
@@ -193,4 +336,43 @@ TEMP_FILES_DIR=tmp
# - ADBC_DEFAULT_RETURN_FORMAT: Default return format (arrow/pandas/dict, recommended: arrow)
# - ADBC_CONNECTION_TIMEOUT: Connection timeout for ADBC (recommended: 30)
# - ADBC_ENABLED: Enable or disable ADBC tools (true/false)
-# - Prerequisites: Install adbc_driver_manager, adbc_driver_flightsql, pyarrow packages
\ No newline at end of file
+# - Prerequisites: Install adbc_driver_manager, adbc_driver_flightsql, pyarrow packages
+
+# 6. Authentication Configuration Guide - UPDATED DESIGN!
+#
+# Independent Authentication Control (NEW):
+# - ENABLE_TOKEN_AUTH=false (default): Disable token authentication
+# - ENABLE_JWT_AUTH=false (default): Disable JWT authentication
+# - ENABLE_OAUTH_AUTH=false (default): Disable OAuth authentication
+# - When all methods are disabled, no authentication is required (anonymous access)
+# - When multiple methods are enabled, any one succeeding allows access
+# - Recommended for development/testing: all false, production: enable needed methods
+#
+# Token Authentication (ENABLE_TOKEN_AUTH=true) - Recommended for most use cases:
+# - Simple and secure token-based authentication
+# - Configurable default tokens via environment variables
+# - Support for custom tokens via TOKEN_* environment variables
+# - Token file configuration via tokens.json
+# - Built-in token management HTTP endpoints
+# - No user management complexity - pure API access control
+#
+# JWT Authentication (ENABLE_JWT_AUTH=true) - For stateless applications:
+# - JSON Web Token based authentication
+# - Configurable token expiration and refresh
+# - Support for standard JWT claims
+# - RSA/ECDSA/HS256 algorithm support
+# - Suitable for microservices and distributed systems
+#
+# OAuth 2.0/OIDC (ENABLE_OAUTH_AUTH=true) - For enterprise integration:
+# - Integration with external identity providers
+# - Support for popular providers (Google, Microsoft, GitHub, GitLab, Keycloak)
+# - OpenID Connect compatibility
+# - Automatic user provisioning from provider
+# - Secure authorization code flow
+#
+# Authentication Method Selection Guide:
+# - No Auth (all switches false): Development, testing, trusted networks
+# - Token Auth only: Small teams, simple deployment, direct API access
+# - JWT Auth only: Stateless apps, microservices, mobile clients
+# - OAuth Auth only: Enterprise SSO, large teams, external identity providers
+# - Multiple methods: Flexible access, different client types, migration scenarios
\ No newline at end of file
diff --git a/doris_mcp_server/auth/__init__.py b/doris_mcp_server/auth/__init__.py
new file mode 100644
index 0000000..fa82f70
--- /dev/null
+++ b/doris_mcp_server/auth/__init__.py
@@ -0,0 +1,56 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Doris MCP Server Authentication Module
+Provides JWT-based, Token-based, and OAuth 2.0/OIDC authentication and authorization services
+"""
+
+from .jwt_manager import JWTManager
+from .key_manager import KeyManager
+from .token_validators import TokenValidator, TokenBlacklist
+from .auth_middleware import AuthMiddleware
+from .token_manager import TokenManager, TokenInfo, TokenValidationResult
+from .token_handlers import TokenHandlers
+from .oauth_client import OAuthClient, OAuthStateManager
+from .oauth_provider import OAuthAuthenticationProvider
+from .oauth_types import (
+ OAuthProvider, OAuthState, OAuthTokens, OAuthUserInfo,
+ OIDCDiscovery, OAuthError, OAuthProviderConfig
+)
+
+__all__ = [
+ "JWTManager",
+ "KeyManager",
+ "TokenValidator",
+ "TokenBlacklist",
+ "AuthMiddleware",
+ "TokenManager",
+ "TokenInfo",
+ "TokenValidationResult",
+ "TokenHandlers",
+ "OAuthClient",
+ "OAuthStateManager",
+ "OAuthAuthenticationProvider",
+ "OAuthProvider",
+ "OAuthState",
+ "OAuthTokens",
+ "OAuthUserInfo",
+ "OIDCDiscovery",
+ "OAuthError",
+ "OAuthProviderConfig"
+]
\ No newline at end of file
diff --git a/doris_mcp_server/auth/auth_middleware.py b/doris_mcp_server/auth/auth_middleware.py
new file mode 100644
index 0000000..3fd1792
--- /dev/null
+++ b/doris_mcp_server/auth/auth_middleware.py
@@ -0,0 +1,269 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Authentication Middleware Module
+Provides middleware for JWT authentication in HTTP and MCP contexts
+"""
+
+from typing import Optional, Dict, Any, Callable, Awaitable
+from datetime import datetime
+
+from .jwt_manager import JWTManager
+from ..utils.security import AuthContext, SecurityLevel
+from ..utils.logger import get_logger
+
+logger = get_logger(__name__)
+
+
+class AuthMiddleware:
+ """Authentication Middleware
+
+ Provides JWT authentication functionality for HTTP and MCP requests
+ """
+
+ def __init__(self, jwt_manager: JWTManager):
+ """Initialize authentication middleware
+
+ Args:
+ jwt_manager: JWT manager instance
+ """
+ self.jwt_manager = jwt_manager
+ logger.info("AuthMiddleware initialized")
+
+ def extract_token_from_header(self, authorization: str) -> Optional[str]:
+ """Extract JWT token from Authorization header
+
+ Args:
+ authorization: Authorization header value
+
+ Returns:
+ JWT token string, or None if not found
+ """
+ if not authorization:
+ return None
+
+ # Support Bearer format
+ if authorization.startswith('Bearer '):
+ return authorization[7:] # Remove "Bearer " prefix
+
+ # Support direct token format
+ if not authorization.startswith('Basic '):
+ return authorization
+
+ return None
+
+ async def authenticate_request(self, auth_info: Dict[str, Any]) -> AuthContext:
+ """Authenticate request and return authentication context
+
+ Args:
+ auth_info: Authentication information dictionary
+
+ Returns:
+ AuthContext authentication context
+
+ Raises:
+ ValueError: Authentication failed
+ """
+ try:
+ auth_type = auth_info.get("type", "jwt")
+
+ if auth_type == "jwt" or auth_type == "token":
+ return await self._authenticate_jwt(auth_info)
+ else:
+ raise ValueError(f"Unsupported authentication type: {auth_type}")
+
+ except Exception as e:
+ logger.error(f"Request authentication failed: {e}")
+ raise
+
+ async def _authenticate_jwt(self, auth_info: Dict[str, Any]) -> AuthContext:
+ """JWT authentication processing
+
+ Args:
+ auth_info: Authentication information containing JWT token
+
+ Returns:
+ AuthContext authentication context
+ """
+ # Get token
+ token = auth_info.get("token")
+ if not token:
+ # Try to get from Authorization header
+ authorization = auth_info.get("authorization")
+ token = self.extract_token_from_header(authorization)
+
+ if not token:
+ raise ValueError("Missing JWT token")
+
+ try:
+ # Validate token
+ validation_result = await self.jwt_manager.validate_token(token, 'access')
+ payload = validation_result['payload']
+
+ # Build authentication context
+ auth_context = AuthContext(
+ user_id=payload.get('sub'),
+ roles=payload.get('roles', []),
+ permissions=payload.get('permissions', []),
+ session_id=payload.get('jti'), # Use JWT ID as session ID
+ login_time=datetime.fromtimestamp(payload.get('iat', 0)),
+ last_activity=datetime.utcnow(),
+ security_level=SecurityLevel(payload.get('security_level', 'internal'))
+ )
+
+ logger.info(f"JWT authentication successful for user: {auth_context.user_id}")
+ return auth_context
+
+ except Exception as e:
+ logger.error(f"JWT authentication failed: {e}")
+ raise ValueError(f"JWT authentication failed: {str(e)}")
+
+ async def create_auth_response_headers(self, auth_context: AuthContext) -> Dict[str, str]:
+ """Create authentication response headers
+
+ Args:
+ auth_context: Authentication context
+
+ Returns:
+ Response headers dictionary
+ """
+ return {
+ 'X-Auth-User': auth_context.user_id,
+ 'X-Auth-Roles': ','.join(auth_context.roles),
+ 'X-Auth-Session': auth_context.session_id,
+ 'X-Auth-Security-Level': auth_context.security_level.value
+ }
+
+ def create_http_middleware(self, skip_paths: Optional[list] = None):
+ """Create HTTP middleware function
+
+ Args:
+ skip_paths: List of paths to skip authentication
+
+ Returns:
+ ASGI middleware function
+ """
+ skip_paths = skip_paths or ['/health', '/docs', '/openapi.json']
+
+ async def middleware(scope, receive, send):
+ """HTTP authentication middleware"""
+ if scope['type'] != 'http':
+ # Pass through non-HTTP requests directly
+ return await self.app(scope, receive, send)
+
+ path = scope.get('path', '')
+
+ # Check if authentication should be skipped
+ if any(path.startswith(skip) for skip in skip_paths):
+ return await self.app(scope, receive, send)
+
+ # Extract authentication information
+ headers = dict(scope.get('headers', []))
+ authorization = headers.get(b'authorization', b'').decode()
+
+ try:
+ # Perform authentication
+ auth_info = {
+ 'type': 'jwt',
+ 'authorization': authorization
+ }
+ auth_context = await self.authenticate_request(auth_info)
+
+ # Add authentication context to scope
+ scope['auth_context'] = auth_context
+
+ # Create response wrapper to add authentication headers
+ async def send_wrapper(message):
+ if message['type'] == 'http.response.start':
+ headers = dict(message.get('headers', []))
+ auth_headers = await self.create_auth_response_headers(auth_context)
+
+ for key, value in auth_headers.items():
+ headers[key.encode()] = value.encode()
+
+ message['headers'] = list(headers.items())
+
+ await send(message)
+
+ return await self.app(scope, receive, send_wrapper)
+
+ except Exception as e:
+ # Authentication failed, return 401 error
+ response_body = f'{{"error": "Authentication failed", "message": "{str(e)}"}}'
+
+ await send({
+ 'type': 'http.response.start',
+ 'status': 401,
+ 'headers': [
+ (b'content-type', b'application/json'),
+ (b'www-authenticate', b'Bearer')
+ ]
+ })
+ await send({
+ 'type': 'http.response.body',
+ 'body': response_body.encode()
+ })
+
+ return middleware
+
+ async def authenticate_mcp_request(self, headers: Dict[str, str]) -> AuthContext:
+ """Authenticate MCP request
+
+ Args:
+ headers: MCP request headers
+
+ Returns:
+ AuthContext authentication context
+ """
+ try:
+ # Extract authentication information from multiple possible header fields
+ authorization = (
+ headers.get('Authorization') or
+ headers.get('authorization') or
+ headers.get('X-Auth-Token') or
+ headers.get('x-auth-token')
+ )
+
+ auth_info = {
+ 'type': 'jwt',
+ 'authorization': authorization
+ }
+
+ return await self.authenticate_request(auth_info)
+
+ except Exception as e:
+ logger.error(f"MCP request authentication failed: {e}")
+ raise
+
+
+class AuthenticationError(Exception):
+ """Authentication error exception"""
+
+ def __init__(self, message: str, error_code: str = "AUTH_FAILED"):
+ self.message = message
+ self.error_code = error_code
+ super().__init__(message)
+
+
+class AuthorizationError(Exception):
+ """Authorization error exception"""
+
+ def __init__(self, message: str, error_code: str = "ACCESS_DENIED"):
+ self.message = message
+ self.error_code = error_code
+ super().__init__(message)
\ No newline at end of file
diff --git a/doris_mcp_server/auth/jwt_manager.py b/doris_mcp_server/auth/jwt_manager.py
new file mode 100644
index 0000000..fa79d1a
--- /dev/null
+++ b/doris_mcp_server/auth/jwt_manager.py
@@ -0,0 +1,471 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+JWT Manager Module
+Provides comprehensive JWT token management including generation, validation, refresh and revocation
+"""
+
+import time
+import uuid
+import asyncio
+from typing import Dict, Any, Optional, Tuple
+from datetime import datetime, timedelta
+
+try:
+ import jwt
+except ImportError:
+ raise ImportError("PyJWT is required for JWT functionality. Install with: pip install PyJWT[crypto]")
+
+from .key_manager import KeyManager
+from .token_validators import TokenValidator, TokenBlacklist
+from ..utils.logger import get_logger
+
+logger = get_logger(__name__)
+
+
+class JWTManager:
+ """JWT Token Manager
+
+ Provides comprehensive JWT token lifecycle management, including:
+ - Token generation and signing
+ - Token validation and parsing
+ - Token refresh mechanism
+ - Token revocation and blacklist
+ - Automatic key rotation
+ """
+
+ def __init__(self, config):
+ """Initialize JWT manager
+
+ Args:
+ config: DorisConfig configuration object (with security attribute)
+ """
+ self.config = config
+ # Access JWT settings through the security configuration
+ if hasattr(config, 'security'):
+ security_config = config.security
+ else:
+ # Fallback if config is passed directly as SecurityConfig
+ security_config = config
+
+ self.algorithm = security_config.jwt_algorithm
+ self.issuer = security_config.jwt_issuer
+ self.audience = security_config.jwt_audience
+ self.access_token_expiry = security_config.jwt_access_token_expiry
+ self.refresh_token_expiry = security_config.jwt_refresh_token_expiry
+ self.enable_refresh = security_config.enable_token_refresh
+ self.enable_revocation = security_config.enable_token_revocation
+
+ # Initialize components
+ self.key_manager = KeyManager(config)
+ self.token_blacklist = TokenBlacklist()
+ self.validator = TokenValidator(config, self.token_blacklist)
+
+ # Automatic key rotation task
+ self._key_rotation_task = None
+
+ logger.info(f"JWTManager initialized with algorithm: {self.algorithm}")
+
+ async def initialize(self) -> bool:
+ """Initialize JWT manager"""
+ try:
+ # Initialize key manager
+ if not await self.key_manager.initialize():
+ logger.error("Failed to initialize key manager")
+ return False
+
+ # Start token validator
+ await self.validator.start()
+
+ # Start automatic key rotation
+ if self.key_manager.key_rotation_interval > 0:
+ self._key_rotation_task = asyncio.create_task(self._auto_key_rotation())
+
+ logger.info("JWTManager initialization completed")
+ return True
+
+ except Exception as e:
+ logger.error(f"Failed to initialize JWTManager: {e}")
+ return False
+
+ async def shutdown(self):
+ """Shutdown JWT manager"""
+ try:
+ # Stop key rotation task
+ if self._key_rotation_task:
+ self._key_rotation_task.cancel()
+ try:
+ await self._key_rotation_task
+ except asyncio.CancelledError:
+ pass
+
+ # Stop validator
+ await self.validator.stop()
+
+ logger.info("JWTManager shutdown completed")
+
+ except Exception as e:
+ logger.error(f"Error during JWTManager shutdown: {e}")
+
+ async def generate_tokens(self, user_info: Dict[str, Any],
+ custom_claims: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
+ """Generate access token and refresh token
+
+ Args:
+ user_info: User information dictionary, containing user_id, roles, permissions, etc.
+ custom_claims: Custom claims
+
+ Returns:
+ Dictionary containing access_token and refresh_token
+ """
+ try:
+ current_time = int(time.time())
+ jti = str(uuid.uuid4())
+
+ # Build base payload
+ base_payload = {
+ 'iss': self.issuer,
+ 'aud': self.audience,
+ 'iat': current_time,
+ 'jti': jti,
+ 'sub': user_info.get('user_id'),
+ 'roles': user_info.get('roles', []),
+ 'permissions': user_info.get('permissions', []),
+ 'security_level': user_info.get('security_level', 'internal')
+ }
+
+ # Add custom claims
+ if custom_claims:
+ base_payload.update(custom_claims)
+
+ # Generate access token
+ access_payload = base_payload.copy()
+ access_payload.update({
+ 'exp': current_time + self.access_token_expiry,
+ 'token_type': 'access'
+ })
+
+ access_token = await self._sign_token(access_payload)
+
+ result = {
+ 'access_token': access_token,
+ 'token_type': 'Bearer',
+ 'expires_in': self.access_token_expiry,
+ 'user_id': user_info.get('user_id'),
+ 'issued_at': current_time
+ }
+
+ # Generate refresh token (if enabled)
+ if self.enable_refresh:
+ refresh_jti = str(uuid.uuid4())
+ refresh_payload = {
+ 'iss': self.issuer,
+ 'aud': self.audience,
+ 'iat': current_time,
+ 'exp': current_time + self.refresh_token_expiry,
+ 'jti': refresh_jti,
+ 'sub': user_info.get('user_id'),
+ 'token_type': 'refresh',
+ 'access_jti': jti # Associated access token ID
+ }
+
+ refresh_token = await self._sign_token(refresh_payload)
+ result.update({
+ 'refresh_token': refresh_token,
+ 'refresh_expires_in': self.refresh_token_expiry
+ })
+
+ logger.info(f"Generated tokens for user: {user_info.get('user_id')}")
+ return result
+
+ except Exception as e:
+ logger.error(f"Failed to generate tokens: {e}")
+ raise
+
+ async def _sign_token(self, payload: Dict[str, Any]) -> str:
+ """Sign JWT token
+
+ Args:
+ payload: JWT payload
+
+ Returns:
+ Signed JWT token
+ """
+ try:
+ signing_key = self.key_manager.get_private_key()
+
+ if self.algorithm == "HS256":
+ # Symmetric key signing
+ token = jwt.encode(payload, signing_key, algorithm=self.algorithm)
+ else:
+ # Asymmetric key signing
+ token = jwt.encode(payload, signing_key, algorithm=self.algorithm)
+
+ return token
+
+ except Exception as e:
+ logger.error(f"Failed to sign token: {e}")
+ raise
+
+ async def validate_token(self, token: str, token_type: str = 'access') -> Dict[str, Any]:
+ """Validate JWT token
+
+ Args:
+ token: JWT token string
+ token_type: Token type ('access' or 'refresh')
+
+ Returns:
+ Validation result and user information
+
+ Raises:
+ ValueError: Token validation failed
+ """
+ try:
+ # Decode token
+ verification_key = self.key_manager.get_public_key()
+
+ # Get security configuration
+ if hasattr(self.config, 'security'):
+ security_config = self.config.security
+ else:
+ security_config = self.config
+
+ # JWT decoding options
+ options = {
+ 'verify_signature': security_config.jwt_verify_signature,
+ 'verify_exp': security_config.jwt_require_exp,
+ 'verify_iat': security_config.jwt_require_iat,
+ 'verify_nbf': security_config.jwt_require_nbf,
+ 'verify_aud': security_config.jwt_verify_audience,
+ 'verify_iss': security_config.jwt_verify_issuer,
+ }
+
+ # Decode JWT
+ payload = jwt.decode(
+ token,
+ verification_key,
+ algorithms=[self.algorithm],
+ audience=self.audience if security_config.jwt_verify_audience else None,
+ issuer=self.issuer if security_config.jwt_verify_issuer else None,
+ leeway=security_config.jwt_leeway,
+ options=options
+ )
+
+ # Check token type
+ if payload.get('token_type') != token_type:
+ raise ValueError(f"Invalid token type: expected {token_type}")
+
+ # Use validator for additional checks
+ validation_result = await self.validator.validate_claims(payload)
+
+ logger.info(f"Token validation successful for user: {payload.get('sub')}")
+ return validation_result
+
+ except jwt.ExpiredSignatureError:
+ raise ValueError("Token has expired")
+ except jwt.InvalidTokenError as e:
+ raise ValueError(f"Invalid token: {str(e)}")
+ except Exception as e:
+ logger.error(f"Token validation failed: {e}")
+ raise ValueError(f"Token validation failed: {str(e)}")
+
+ async def refresh_token(self, refresh_token: str) -> Dict[str, Any]:
+ """Refresh access token
+
+ Args:
+ refresh_token: Refresh token
+
+ Returns:
+ New token pair
+ """
+ if not self.enable_refresh:
+ raise ValueError("Token refresh is disabled")
+
+ try:
+ # Validate refresh token
+ refresh_result = await self.validate_token(refresh_token, 'refresh')
+ refresh_payload = refresh_result['payload']
+
+ # Revoke associated access token (if revocation is enabled)
+ if self.enable_revocation:
+ access_jti = refresh_payload.get('access_jti')
+ if access_jti:
+ # Should revoke old access token here, but since we don't know its expiration time,
+ # in practice might need to store more information or use different strategy
+ pass
+
+ # Build new user information
+ user_info = {
+ 'user_id': refresh_payload.get('sub'),
+ 'roles': refresh_payload.get('roles', []),
+ 'permissions': refresh_payload.get('permissions', []),
+ 'security_level': refresh_payload.get('security_level', 'internal')
+ }
+
+ # Generate new token pair
+ new_tokens = await self.generate_tokens(user_info)
+
+ logger.info(f"Token refreshed for user: {user_info['user_id']}")
+ return new_tokens
+
+ except Exception as e:
+ logger.error(f"Token refresh failed: {e}")
+ raise
+
+ async def revoke_token(self, token: str) -> bool:
+ """Revoke token
+
+ Args:
+ token: Token to revoke
+
+ Returns:
+ Whether revocation was successful
+ """
+ if not self.enable_revocation:
+ logger.warning("Token revocation is disabled")
+ return False
+
+ try:
+ # Decode token to get JTI and expiration time
+ verification_key = self.key_manager.get_public_key()
+ payload = jwt.decode(
+ token,
+ verification_key,
+ algorithms=[self.algorithm],
+ options={'verify_exp': False} # Allow decoding expired tokens
+ )
+
+ jti = payload.get('jti')
+ exp = payload.get('exp')
+
+ if not jti or not exp:
+ logger.error("Token missing required claims for revocation")
+ return False
+
+ # Add to blacklist
+ await self.validator.revoke_token(jti, exp)
+
+ logger.info(f"Token {jti} revoked successfully")
+ return True
+
+ except Exception as e:
+ logger.error(f"Token revocation failed: {e}")
+ return False
+
+ async def decode_token_unsafe(self, token: str) -> Dict[str, Any]:
+ """Decode token without verifying signature (for debugging only)
+
+ Args:
+ token: JWT token
+
+ Returns:
+ Token payload
+ """
+ try:
+ payload = jwt.decode(token, options={'verify_signature': False})
+ return payload
+ except Exception as e:
+ logger.error(f"Failed to decode token: {e}")
+ raise
+
+ async def get_token_info(self, token: str) -> Dict[str, Any]:
+ """Get token information (without verifying signature)
+
+ Args:
+ token: JWT token
+
+ Returns:
+ Token information
+ """
+ try:
+ payload = await self.decode_token_unsafe(token)
+
+ return {
+ 'jti': payload.get('jti'),
+ 'sub': payload.get('sub'),
+ 'iss': payload.get('iss'),
+ 'aud': payload.get('aud'),
+ 'iat': payload.get('iat'),
+ 'exp': payload.get('exp'),
+ 'token_type': payload.get('token_type'),
+ 'roles': payload.get('roles'),
+ 'permissions': payload.get('permissions'),
+ 'security_level': payload.get('security_level'),
+ 'is_expired': payload.get('exp', 0) < time.time() if payload.get('exp') else None
+ }
+
+ except Exception as e:
+ logger.error(f"Failed to get token info: {e}")
+ raise
+
+ async def _auto_key_rotation(self):
+ """Automatic key rotation task"""
+ while True:
+ try:
+ # Check if key rotation is needed
+ if await self.key_manager.is_key_expired():
+ logger.info("Key rotation needed, rotating keys...")
+ await self.key_manager.rotate_keys()
+
+ # Wait until next check
+ await asyncio.sleep(3600) # Check every hour
+
+ except asyncio.CancelledError:
+ break
+ except Exception as e:
+ logger.error(f"Error in auto key rotation: {e}")
+ # Wait longer before retry after error
+ await asyncio.sleep(3600)
+
+ async def get_public_key_info(self) -> Dict[str, Any]:
+ """Get public key information (for client verification)
+
+ Returns:
+ Public key information
+ """
+ key_info = await self.key_manager.get_key_info()
+ public_key_pem = await self.key_manager.export_public_key_pem()
+
+ return {
+ 'algorithm': self.algorithm,
+ 'public_key_pem': public_key_pem,
+ 'key_info': key_info
+ }
+
+ async def get_manager_stats(self) -> Dict[str, Any]:
+ """Get manager statistics
+
+ Returns:
+ Statistics information
+ """
+ key_info = await self.key_manager.get_key_info()
+ validation_stats = await self.validator.get_validation_stats()
+
+ return {
+ 'jwt_config': {
+ 'algorithm': self.algorithm,
+ 'issuer': self.issuer,
+ 'audience': self.audience,
+ 'access_token_expiry': self.access_token_expiry,
+ 'refresh_token_expiry': self.refresh_token_expiry,
+ 'enable_refresh': self.enable_refresh,
+ 'enable_revocation': self.enable_revocation
+ },
+ 'key_manager': key_info,
+ 'validator': validation_stats
+ }
\ No newline at end of file
diff --git a/doris_mcp_server/auth/key_manager.py b/doris_mcp_server/auth/key_manager.py
new file mode 100644
index 0000000..9131aef
--- /dev/null
+++ b/doris_mcp_server/auth/key_manager.py
@@ -0,0 +1,343 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+JWT Key Management Module
+Provides secure key generation, loading, rotation and management for JWT tokens
+"""
+
+import os
+import time
+import secrets
+from pathlib import Path
+from typing import Optional, Tuple, Union
+from datetime import datetime, timedelta
+from cryptography.hazmat.primitives import serialization
+from cryptography.hazmat.primitives.asymmetric import rsa, ec
+from cryptography.hazmat.backends import default_backend
+
+from ..utils.logger import get_logger
+
+logger = get_logger(__name__)
+
+
+class KeyManager:
+ """JWT Key Manager
+
+ Responsible for generating, loading, rotating and securely storing JWT signing keys
+ Supports RSA and EC algorithms, provides automatic key rotation functionality
+ """
+
+ def __init__(self, config):
+ """Initialize key manager
+
+ Args:
+ config: DorisConfig configuration object (with security attribute)
+ """
+ self.config = config
+ # Access JWT settings through the security configuration
+ if hasattr(config, 'security'):
+ security_config = config.security
+ else:
+ # Fallback if config is passed directly as SecurityConfig
+ security_config = config
+
+ self.algorithm = security_config.jwt_algorithm
+ self.key_rotation_interval = security_config.key_rotation_interval
+ self.private_key_path = security_config.jwt_private_key_path
+ self.public_key_path = security_config.jwt_public_key_path
+ self.secret_key = security_config.jwt_secret_key
+
+ # Key storage
+ self._private_key = None
+ self._public_key = None
+ self._secret_key = None
+ self._key_generated_at = None
+
+ logger.info(f"KeyManager initialized with algorithm: {self.algorithm}")
+
+ async def initialize(self) -> bool:
+ """Initialize key manager, load or generate keys"""
+ try:
+ if self.algorithm == "HS256":
+ await self._initialize_symmetric_key()
+ else:
+ await self._initialize_asymmetric_keys()
+
+ logger.info("KeyManager initialization completed")
+ return True
+
+ except Exception as e:
+ logger.error(f"Failed to initialize KeyManager: {e}")
+ return False
+
+ async def _initialize_symmetric_key(self):
+ """Initialize symmetric key (HS256)"""
+ if self.secret_key:
+ # Use configured key
+ self._secret_key = self.secret_key.encode()
+ logger.info("Loaded symmetric key from configuration")
+ else:
+ # Generate new key
+ self._secret_key = await self.generate_symmetric_key()
+ logger.info("Generated new symmetric key")
+
+ self._key_generated_at = datetime.utcnow()
+
+ async def _initialize_asymmetric_keys(self):
+ """Initialize asymmetric key pair (RS256/ES256)"""
+ # Try to load keys from files
+ if await self._load_keys_from_files():
+ logger.info("Loaded asymmetric keys from files")
+ return
+
+ # Try to load from environment variables
+ if await self._load_keys_from_env():
+ logger.info("Loaded asymmetric keys from environment")
+ return
+
+ # Generate new key pair
+ await self.generate_key_pair()
+ logger.info("Generated new asymmetric key pair")
+
+ async def _load_keys_from_files(self) -> bool:
+ """Load keys from files"""
+ try:
+ if not self.private_key_path or not self.public_key_path:
+ return False
+
+ private_path = Path(self.private_key_path)
+ public_path = Path(self.public_key_path)
+
+ if not (private_path.exists() and public_path.exists()):
+ return False
+
+ # Read private key
+ with open(private_path, 'rb') as f:
+ private_key_data = f.read()
+ self._private_key = serialization.load_pem_private_key(
+ private_key_data, password=None, backend=default_backend()
+ )
+
+ # Read public key
+ with open(public_path, 'rb') as f:
+ public_key_data = f.read()
+ self._public_key = serialization.load_pem_public_key(
+ public_key_data, backend=default_backend()
+ )
+
+ # Get key generation time (using file modification time)
+ self._key_generated_at = datetime.fromtimestamp(private_path.stat().st_mtime)
+
+ return True
+
+ except Exception as e:
+ logger.error(f"Failed to load keys from files: {e}")
+ return False
+
+ async def _load_keys_from_env(self) -> bool:
+ """Load keys from environment variables"""
+ try:
+ private_key_env = os.getenv('JWT_PRIVATE_KEY')
+ public_key_env = os.getenv('JWT_PUBLIC_KEY')
+
+ if not (private_key_env and public_key_env):
+ return False
+
+ # Parse private key
+ self._private_key = serialization.load_pem_private_key(
+ private_key_env.encode(), password=None, backend=default_backend()
+ )
+
+ # Parse public key
+ self._public_key = serialization.load_pem_public_key(
+ public_key_env.encode(), backend=default_backend()
+ )
+
+ self._key_generated_at = datetime.utcnow()
+ return True
+
+ except Exception as e:
+ logger.error(f"Failed to load keys from environment: {e}")
+ return False
+
+ async def generate_symmetric_key(self, length: int = 32) -> bytes:
+ """Generate symmetric key
+
+ Args:
+ length: Key length (bytes), default 32 bytes (256 bits)
+
+ Returns:
+ Generated key
+ """
+ return secrets.token_bytes(length)
+
+ async def generate_key_pair(self) -> Tuple[bytes, bytes]:
+ """Generate asymmetric key pair
+
+ Returns:
+ (private key PEM, public key PEM) tuple
+ """
+ try:
+ if self.algorithm == "RS256":
+ private_key = rsa.generate_private_key(
+ public_exponent=65537,
+ key_size=2048,
+ backend=default_backend()
+ )
+ elif self.algorithm == "ES256":
+ private_key = ec.generate_private_key(
+ ec.SECP256R1(), backend=default_backend()
+ )
+ else:
+ raise ValueError(f"Unsupported algorithm for key generation: {self.algorithm}")
+
+ # Get public key
+ public_key = private_key.public_key()
+
+ # Serialize private key
+ private_pem = private_key.private_bytes(
+ encoding=serialization.Encoding.PEM,
+ format=serialization.PrivateFormat.PKCS8,
+ encryption_algorithm=serialization.NoEncryption()
+ )
+
+ # Serialize public key
+ public_pem = public_key.public_bytes(
+ encoding=serialization.Encoding.PEM,
+ format=serialization.PublicFormat.SubjectPublicKeyInfo
+ )
+
+ # Store keys
+ self._private_key = private_key
+ self._public_key = public_key
+ self._key_generated_at = datetime.utcnow()
+
+ # If file paths are configured, save to files
+ if self.private_key_path and self.public_key_path:
+ await self._save_keys_to_files(private_pem, public_pem)
+
+ logger.info(f"Generated new {self.algorithm} key pair")
+ return private_pem, public_pem
+
+ except Exception as e:
+ logger.error(f"Failed to generate key pair: {e}")
+ raise
+
+ async def _save_keys_to_files(self, private_pem: bytes, public_pem: bytes):
+ """Save keys to files"""
+ try:
+ # Ensure directories exist
+ private_path = Path(self.private_key_path)
+ public_path = Path(self.public_key_path)
+
+ private_path.parent.mkdir(parents=True, exist_ok=True)
+ public_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # Save private key (set secure permissions)
+ with open(private_path, 'wb') as f:
+ f.write(private_pem)
+ os.chmod(private_path, 0o600) # Only owner can read/write
+
+ # Save public key
+ with open(public_path, 'wb') as f:
+ f.write(public_pem)
+ os.chmod(public_path, 0o644) # Owner read/write, others read only
+
+ logger.info(f"Saved keys to files: {private_path}, {public_path}")
+
+ except Exception as e:
+ logger.error(f"Failed to save keys to files: {e}")
+ raise
+
+ def get_private_key(self):
+ """Get private key for signing"""
+ if self.algorithm == "HS256":
+ return self._secret_key
+ else:
+ return self._private_key
+
+ def get_public_key(self):
+ """Get public key for verification"""
+ if self.algorithm == "HS256":
+ return self._secret_key
+ else:
+ return self._public_key
+
+ def get_algorithm(self) -> str:
+ """Get signing algorithm"""
+ return self.algorithm
+
+ async def is_key_expired(self) -> bool:
+ """Check if key is expired"""
+ if not self._key_generated_at:
+ return True
+
+ expiry_time = self._key_generated_at + timedelta(seconds=self.key_rotation_interval)
+ return datetime.utcnow() > expiry_time
+
+ async def rotate_keys(self) -> bool:
+ """Rotate keys"""
+ try:
+ logger.info("Starting key rotation")
+
+ if self.algorithm == "HS256":
+ # Generate new symmetric key
+ self._secret_key = await self.generate_symmetric_key()
+ self._key_generated_at = datetime.utcnow()
+ else:
+ # Generate new asymmetric key pair
+ await self.generate_key_pair()
+
+ logger.info("Key rotation completed successfully")
+ return True
+
+ except Exception as e:
+ logger.error(f"Key rotation failed: {e}")
+ return False
+
+ async def get_key_info(self) -> dict:
+ """Get key information"""
+ return {
+ "algorithm": self.algorithm,
+ "key_generated_at": self._key_generated_at.isoformat() if self._key_generated_at else None,
+ "key_expires_at": (
+ self._key_generated_at + timedelta(seconds=self.key_rotation_interval)
+ ).isoformat() if self._key_generated_at else None,
+ "is_expired": await self.is_key_expired(),
+ "has_private_key": self._private_key is not None or self._secret_key is not None,
+ "has_public_key": self._public_key is not None or self._secret_key is not None
+ }
+
+ async def export_public_key_pem(self) -> Optional[str]:
+ """Export public key in PEM format"""
+ if self.algorithm == "HS256":
+ return None # Symmetric key not exported
+
+ if not self._public_key:
+ return None
+
+ try:
+ public_pem = self._public_key.public_bytes(
+ encoding=serialization.Encoding.PEM,
+ format=serialization.PublicFormat.SubjectPublicKeyInfo
+ )
+ return public_pem.decode()
+
+ except Exception as e:
+ logger.error(f"Failed to export public key: {e}")
+ return None
\ No newline at end of file
diff --git a/doris_mcp_server/auth/oauth_client.py b/doris_mcp_server/auth/oauth_client.py
new file mode 100644
index 0000000..4f8ca58
--- /dev/null
+++ b/doris_mcp_server/auth/oauth_client.py
@@ -0,0 +1,536 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+OAuth 2.0/OIDC Client Manager
+Provides OAuth authentication client implementation with PKCE and OIDC support
+"""
+
+import base64
+import hashlib
+import secrets
+import uuid
+from datetime import datetime, timedelta
+from typing import Dict, Optional, Any, Tuple
+from urllib.parse import urlencode, parse_qs, urlparse
+import asyncio
+import json
+
+try:
+ import aiohttp
+except ImportError:
+ raise ImportError("aiohttp is required for OAuth functionality. Install with: pip install aiohttp")
+
+from .oauth_types import (
+ OAuthProvider, OAuthState, OAuthTokens, OAuthUserInfo,
+ OIDCDiscovery, OAuthError, OAuthProviderConfig, OAUTH_PROVIDERS
+)
+from ..utils.logger import get_logger
+
+logger = get_logger(__name__)
+
+
+class OAuthStateManager:
+ """Manages OAuth state parameters for CSRF protection"""
+
+ def __init__(self, state_expiry: int = 600):
+ """Initialize state manager
+
+ Args:
+ state_expiry: State expiry time in seconds
+ """
+ self.state_expiry = state_expiry
+ self._states: Dict[str, OAuthState] = {}
+ self._cleanup_task = None
+
+ logger.info("OAuthStateManager initialized")
+
+ async def start(self):
+ """Start periodic cleanup task"""
+ self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
+ logger.info("OAuth state manager started")
+
+ async def stop(self):
+ """Stop periodic cleanup task"""
+ if self._cleanup_task:
+ self._cleanup_task.cancel()
+ try:
+ await self._cleanup_task
+ except asyncio.CancelledError:
+ pass
+ logger.info("OAuth state manager stopped")
+
+ def create_state(self, redirect_uri: str, pkce_enabled: bool = True,
+ nonce_enabled: bool = True) -> OAuthState:
+ """Create new OAuth state
+
+ Args:
+ redirect_uri: OAuth redirect URI
+ pkce_enabled: Whether to enable PKCE
+ nonce_enabled: Whether to enable nonce (for OIDC)
+
+ Returns:
+ OAuth state object
+ """
+ state = secrets.token_urlsafe(32)
+ nonce = secrets.token_urlsafe(32) if nonce_enabled else None
+
+ pkce_verifier = None
+ pkce_challenge = None
+ if pkce_enabled:
+ pkce_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=')
+ challenge_bytes = hashlib.sha256(pkce_verifier.encode()).digest()
+ pkce_challenge = base64.urlsafe_b64encode(challenge_bytes).decode('utf-8').rstrip('=')
+
+ oauth_state = OAuthState(
+ state=state,
+ nonce=nonce,
+ pkce_verifier=pkce_verifier,
+ pkce_challenge=pkce_challenge,
+ redirect_uri=redirect_uri,
+ created_at=datetime.utcnow(),
+ expires_at=datetime.utcnow() + timedelta(seconds=self.state_expiry)
+ )
+
+ self._states[state] = oauth_state
+ logger.debug(f"Created OAuth state: {state}")
+ return oauth_state
+
+ def get_state(self, state: str) -> Optional[OAuthState]:
+ """Get OAuth state by state parameter
+
+ Args:
+ state: State parameter
+
+ Returns:
+ OAuth state object or None if not found/expired
+ """
+ oauth_state = self._states.get(state)
+ if oauth_state and oauth_state.expires_at > datetime.utcnow():
+ return oauth_state
+ elif oauth_state:
+ # Remove expired state
+ del self._states[state]
+ logger.debug(f"Removed expired OAuth state: {state}")
+ return None
+
+ def consume_state(self, state: str) -> Optional[OAuthState]:
+ """Get and remove OAuth state
+
+ Args:
+ state: State parameter
+
+ Returns:
+ OAuth state object or None if not found/expired
+ """
+ oauth_state = self.get_state(state)
+ if oauth_state:
+ del self._states[state]
+ logger.debug(f"Consumed OAuth state: {state}")
+ return oauth_state
+
+ async def _periodic_cleanup(self):
+ """Periodic cleanup of expired states"""
+ while True:
+ try:
+ await asyncio.sleep(300) # Clean up every 5 minutes
+ current_time = datetime.utcnow()
+ expired_states = [
+ state for state, oauth_state in self._states.items()
+ if oauth_state.expires_at <= current_time
+ ]
+
+ for state in expired_states:
+ del self._states[state]
+
+ if expired_states:
+ logger.info(f"Cleaned up {len(expired_states)} expired OAuth states")
+
+ except asyncio.CancelledError:
+ break
+ except Exception as e:
+ logger.error(f"Error during OAuth state cleanup: {e}")
+
+
+class OAuthClient:
+ """OAuth 2.0/OIDC Client implementation"""
+
+ def __init__(self, config):
+ """Initialize OAuth client
+
+ Args:
+ config: DorisConfig with OAuth configuration
+ """
+ self.config = config
+
+ # Access OAuth settings through security configuration
+ if hasattr(config, 'security'):
+ security_config = config.security
+ else:
+ security_config = config
+
+ self.enabled = security_config.oauth_enabled
+ if not self.enabled:
+ logger.info("OAuth client disabled by configuration")
+ return
+
+ # Build provider configuration
+ self.provider_config = self._build_provider_config(security_config)
+ self.state_manager = OAuthStateManager(security_config.oauth_state_expiry)
+
+ # HTTP client session
+ self._session: Optional[aiohttp.ClientSession] = None
+
+ # Discovery cache
+ self._discovery_cache: Optional[OIDCDiscovery] = None
+ self._discovery_cache_time: Optional[datetime] = None
+
+ logger.info(f"OAuthClient initialized for provider: {self.provider_config.provider.value}")
+
+ def _build_provider_config(self, security_config) -> OAuthProviderConfig:
+ """Build OAuth provider configuration
+
+ Args:
+ security_config: Security configuration object
+
+ Returns:
+ OAuth provider configuration
+ """
+ try:
+ provider = OAuthProvider(security_config.oauth_provider)
+ except ValueError:
+ provider = OAuthProvider.CUSTOM
+
+ # Get default configuration for known providers
+ defaults = OAUTH_PROVIDERS.get(provider, {})
+
+ return OAuthProviderConfig(
+ provider=provider,
+ client_id=security_config.oauth_client_id,
+ client_secret=security_config.oauth_client_secret,
+ redirect_uri=security_config.oauth_redirect_uri,
+ scopes=security_config.oauth_scopes or defaults.get("scopes", ["openid", "email", "profile"]),
+
+ # Endpoints (use configured or defaults)
+ authorization_endpoint=security_config.oauth_authorization_endpoint or defaults.get("authorization_endpoint", ""),
+ token_endpoint=security_config.oauth_token_endpoint or defaults.get("token_endpoint", ""),
+ userinfo_endpoint=security_config.oauth_userinfo_endpoint or defaults.get("userinfo_endpoint"),
+ jwks_uri=security_config.oauth_jwks_uri or defaults.get("jwks_uri"),
+
+ # Discovery
+ discovery_url=security_config.oidc_discovery_url or defaults.get("discovery_url"),
+
+ # Settings
+ pkce_enabled=security_config.oauth_pkce_enabled,
+ nonce_enabled=security_config.oauth_nonce_enabled,
+
+ # User mapping
+ user_id_claim=security_config.oauth_user_id_claim or defaults.get("user_id_claim", "sub"),
+ email_claim=security_config.oauth_email_claim or defaults.get("email_claim", "email"),
+ name_claim=security_config.oauth_name_claim or defaults.get("name_claim", "name"),
+ roles_claim=security_config.oauth_roles_claim,
+ default_roles=security_config.oauth_default_roles
+ )
+
+ async def initialize(self) -> bool:
+ """Initialize OAuth client
+
+ Returns:
+ True if initialization successful
+ """
+ if not self.enabled:
+ return True
+
+ try:
+ # Create HTTP session
+ self._session = aiohttp.ClientSession()
+
+ # Start state manager
+ await self.state_manager.start()
+
+ # Perform OIDC discovery if configured
+ if self.provider_config.discovery_url:
+ await self._discover_oidc_endpoints()
+
+ logger.info("OAuth client initialization completed")
+ return True
+
+ except Exception as e:
+ logger.error(f"Failed to initialize OAuth client: {e}")
+ return False
+
+ async def shutdown(self):
+ """Shutdown OAuth client"""
+ if not self.enabled:
+ return
+
+ try:
+ # Stop state manager
+ await self.state_manager.stop()
+
+ # Close HTTP session
+ if self._session:
+ await self._session.close()
+
+ logger.info("OAuth client shutdown completed")
+
+ except Exception as e:
+ logger.error(f"Error during OAuth client shutdown: {e}")
+
+ async def _discover_oidc_endpoints(self):
+ """Discover OIDC endpoints using discovery URL"""
+ try:
+ # Check cache first
+ if (self._discovery_cache and self._discovery_cache_time and
+ datetime.utcnow() - self._discovery_cache_time < timedelta(hours=1)):
+ return self._discovery_cache
+
+ logger.info(f"Discovering OIDC endpoints: {self.provider_config.discovery_url}")
+
+ async with self._session.get(self.provider_config.discovery_url) as response:
+ response.raise_for_status()
+ data = await response.json()
+
+ discovery = OIDCDiscovery(
+ issuer=data["issuer"],
+ authorization_endpoint=data["authorization_endpoint"],
+ token_endpoint=data["token_endpoint"],
+ userinfo_endpoint=data.get("userinfo_endpoint"),
+ jwks_uri=data.get("jwks_uri"),
+ scopes_supported=data.get("scopes_supported"),
+ response_types_supported=data.get("response_types_supported"),
+ subject_types_supported=data.get("subject_types_supported"),
+ id_token_signing_alg_values_supported=data.get("id_token_signing_alg_values_supported")
+ )
+
+ # Update provider configuration with discovered endpoints
+ if not self.provider_config.authorization_endpoint:
+ self.provider_config.authorization_endpoint = discovery.authorization_endpoint
+ if not self.provider_config.token_endpoint:
+ self.provider_config.token_endpoint = discovery.token_endpoint
+ if not self.provider_config.userinfo_endpoint:
+ self.provider_config.userinfo_endpoint = discovery.userinfo_endpoint
+ if not self.provider_config.jwks_uri:
+ self.provider_config.jwks_uri = discovery.jwks_uri
+
+ # Cache discovery result
+ self._discovery_cache = discovery
+ self._discovery_cache_time = datetime.utcnow()
+
+ logger.info("OIDC endpoint discovery completed successfully")
+ return discovery
+
+ except Exception as e:
+ logger.error(f"OIDC endpoint discovery failed: {e}")
+ raise
+
+ def build_authorization_url(self) -> Tuple[str, OAuthState]:
+ """Build OAuth authorization URL
+
+ Returns:
+ Tuple of (authorization_url, oauth_state)
+ """
+ if not self.enabled:
+ raise ValueError("OAuth client is not enabled")
+
+ # Create state for CSRF protection
+ oauth_state = self.state_manager.create_state(
+ redirect_uri=self.provider_config.redirect_uri,
+ pkce_enabled=self.provider_config.pkce_enabled,
+ nonce_enabled=self.provider_config.nonce_enabled
+ )
+
+ # Build authorization parameters
+ params = {
+ 'response_type': 'code',
+ 'client_id': self.provider_config.client_id,
+ 'redirect_uri': self.provider_config.redirect_uri,
+ 'scope': ' '.join(self.provider_config.scopes),
+ 'state': oauth_state.state
+ }
+
+ # Add PKCE challenge
+ if oauth_state.pkce_challenge:
+ params['code_challenge'] = oauth_state.pkce_challenge
+ params['code_challenge_method'] = 'S256'
+
+ # Add nonce for OIDC
+ if oauth_state.nonce:
+ params['nonce'] = oauth_state.nonce
+
+ # Build URL
+ authorization_url = f"{self.provider_config.authorization_endpoint}?{urlencode(params)}"
+
+ logger.info(f"Built OAuth authorization URL for state: {oauth_state.state}")
+ return authorization_url, oauth_state
+
+ async def exchange_code_for_tokens(self, code: str, state: str) -> Tuple[OAuthTokens, OAuthState]:
+ """Exchange authorization code for tokens
+
+ Args:
+ code: Authorization code
+ state: State parameter
+
+ Returns:
+ Tuple of (OAuth tokens, OAuth state)
+
+ Raises:
+ ValueError: If state is invalid or exchange fails
+ """
+ if not self.enabled:
+ raise ValueError("OAuth client is not enabled")
+
+ # Validate and consume state
+ oauth_state = self.state_manager.consume_state(state)
+ if not oauth_state:
+ raise ValueError("Invalid or expired state parameter")
+
+ try:
+ # Prepare token request
+ data = {
+ 'grant_type': 'authorization_code',
+ 'client_id': self.provider_config.client_id,
+ 'client_secret': self.provider_config.client_secret,
+ 'code': code,
+ 'redirect_uri': oauth_state.redirect_uri
+ }
+
+ # Add PKCE verifier
+ if oauth_state.pkce_verifier:
+ data['code_verifier'] = oauth_state.pkce_verifier
+
+ # Make token request
+ async with self._session.post(
+ self.provider_config.token_endpoint,
+ data=data,
+ headers={'Content-Type': 'application/x-www-form-urlencoded'}
+ ) as response:
+ response_data = await response.json()
+
+ if response.status != 200:
+ error_msg = response_data.get('error_description', response_data.get('error', 'Token exchange failed'))
+ raise ValueError(f"Token exchange failed: {error_msg}")
+
+ tokens = OAuthTokens(
+ access_token=response_data['access_token'],
+ token_type=response_data.get('token_type', 'Bearer'),
+ expires_in=response_data.get('expires_in'),
+ refresh_token=response_data.get('refresh_token'),
+ scope=response_data.get('scope'),
+ id_token=response_data.get('id_token')
+ )
+
+ logger.info("Successfully exchanged authorization code for tokens")
+ return tokens, oauth_state
+
+ except Exception as e:
+ logger.error(f"Token exchange failed: {e}")
+ raise ValueError(f"Token exchange failed: {str(e)}")
+
+ async def get_user_info(self, tokens: OAuthTokens) -> OAuthUserInfo:
+ """Get user information from OAuth provider
+
+ Args:
+ tokens: OAuth tokens
+
+ Returns:
+ OAuth user information
+ """
+ if not self.enabled:
+ raise ValueError("OAuth client is not enabled")
+
+ if not self.provider_config.userinfo_endpoint:
+ raise ValueError("Userinfo endpoint not configured")
+
+ try:
+ # Make userinfo request
+ headers = {'Authorization': f'{tokens.token_type} {tokens.access_token}'}
+
+ async with self._session.get(
+ self.provider_config.userinfo_endpoint,
+ headers=headers
+ ) as response:
+ response.raise_for_status()
+ user_data = await response.json()
+
+ # Extract user information using configured claims
+ user_info = OAuthUserInfo(
+ sub=str(user_data.get(self.provider_config.user_id_claim, '')),
+ email=user_data.get(self.provider_config.email_claim),
+ name=user_data.get(self.provider_config.name_claim),
+ given_name=user_data.get('given_name'),
+ family_name=user_data.get('family_name'),
+ picture=user_data.get('picture'),
+ locale=user_data.get('locale'),
+ email_verified=user_data.get('email_verified'),
+ roles=user_data.get(self.provider_config.roles_claim, self.provider_config.default_roles.copy()),
+ raw_claims=user_data
+ )
+
+ logger.info(f"Retrieved user info for user: {user_info.sub}")
+ return user_info
+
+ except Exception as e:
+ logger.error(f"Failed to get user info: {e}")
+ raise ValueError(f"Failed to get user info: {str(e)}")
+
+ async def refresh_tokens(self, refresh_token: str) -> OAuthTokens:
+ """Refresh OAuth tokens
+
+ Args:
+ refresh_token: Refresh token
+
+ Returns:
+ New OAuth tokens
+ """
+ if not self.enabled:
+ raise ValueError("OAuth client is not enabled")
+
+ try:
+ data = {
+ 'grant_type': 'refresh_token',
+ 'client_id': self.provider_config.client_id,
+ 'client_secret': self.provider_config.client_secret,
+ 'refresh_token': refresh_token
+ }
+
+ async with self._session.post(
+ self.provider_config.token_endpoint,
+ data=data,
+ headers={'Content-Type': 'application/x-www-form-urlencoded'}
+ ) as response:
+ response_data = await response.json()
+
+ if response.status != 200:
+ error_msg = response_data.get('error_description', response_data.get('error', 'Token refresh failed'))
+ raise ValueError(f"Token refresh failed: {error_msg}")
+
+ tokens = OAuthTokens(
+ access_token=response_data['access_token'],
+ token_type=response_data.get('token_type', 'Bearer'),
+ expires_in=response_data.get('expires_in'),
+ refresh_token=response_data.get('refresh_token', refresh_token), # Keep old if not provided
+ scope=response_data.get('scope'),
+ id_token=response_data.get('id_token')
+ )
+
+ logger.info("Successfully refreshed OAuth tokens")
+ return tokens
+
+ except Exception as e:
+ logger.error(f"Token refresh failed: {e}")
+ raise ValueError(f"Token refresh failed: {str(e)}")
\ No newline at end of file
diff --git a/doris_mcp_server/auth/oauth_handlers.py b/doris_mcp_server/auth/oauth_handlers.py
new file mode 100644
index 0000000..4b8a46d
--- /dev/null
+++ b/doris_mcp_server/auth/oauth_handlers.py
@@ -0,0 +1,312 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+OAuth HTTP Handlers
+Provides HTTP endpoints for OAuth authentication flow
+"""
+
+from typing import Dict, Any
+from urllib.parse import parse_qs, urlparse
+import json
+
+from starlette.responses import JSONResponse, RedirectResponse, HTMLResponse
+from starlette.requests import Request
+
+from ..utils.logger import get_logger
+
+logger = get_logger(__name__)
+
+
+class OAuthHandlers:
+ """OAuth HTTP request handlers"""
+
+ def __init__(self, security_manager):
+ """Initialize OAuth handlers
+
+ Args:
+ security_manager: DorisSecurityManager instance
+ """
+ self.security_manager = security_manager
+ logger.info("OAuth handlers initialized")
+
+ async def handle_login(self, request: Request) -> JSONResponse:
+ """Handle OAuth login initiation
+
+ Returns JSON with authorization URL and state
+ """
+ try:
+ # Check if OAuth is enabled
+ oauth_info = self.security_manager.get_oauth_provider_info()
+ if not oauth_info.get("enabled"):
+ return JSONResponse(
+ {"error": "OAuth authentication is not enabled"},
+ status_code=400
+ )
+
+ # Get authorization URL
+ authorization_url, state = self.security_manager.get_oauth_authorization_url()
+
+ return JSONResponse({
+ "authorization_url": authorization_url,
+ "state": state,
+ "provider": oauth_info.get("provider"),
+ "message": "Navigate to authorization_url to complete OAuth login"
+ })
+
+ except Exception as e:
+ logger.error(f"OAuth login initiation failed: {e}")
+ return JSONResponse(
+ {"error": f"OAuth login failed: {str(e)}"},
+ status_code=500
+ )
+
+ async def handle_callback(self, request: Request) -> JSONResponse:
+ """Handle OAuth callback
+
+ Processes the OAuth callback and returns authentication result
+ """
+ try:
+ # Get query parameters
+ query_params = dict(request.query_params)
+
+ # Check for error in callback
+ if "error" in query_params:
+ error_description = query_params.get("error_description", "Unknown error")
+ logger.warning(f"OAuth callback error: {query_params['error']} - {error_description}")
+ return JSONResponse(
+ {
+ "error": query_params["error"],
+ "error_description": error_description,
+ "error_uri": query_params.get("error_uri")
+ },
+ status_code=400
+ )
+
+ # Extract required parameters
+ code = query_params.get("code")
+ state = query_params.get("state")
+
+ if not code or not state:
+ return JSONResponse(
+ {"error": "Missing required parameters: code and state"},
+ status_code=400
+ )
+
+ # Handle OAuth callback
+ auth_context = await self.security_manager.handle_oauth_callback(code, state)
+
+ # Return successful authentication response
+ return JSONResponse({
+ "success": True,
+ "user_id": auth_context.user_id,
+ "roles": auth_context.roles,
+ "permissions": auth_context.permissions,
+ "security_level": auth_context.security_level.value,
+ "session_id": auth_context.session_id,
+ "message": "OAuth authentication successful"
+ })
+
+ except Exception as e:
+ logger.error(f"OAuth callback handling failed: {e}")
+ return JSONResponse(
+ {"error": f"OAuth callback failed: {str(e)}"},
+ status_code=500
+ )
+
+ async def handle_provider_info(self, request: Request) -> JSONResponse:
+ """Handle OAuth provider information request
+
+ Returns information about the configured OAuth provider
+ """
+ try:
+ provider_info = self.security_manager.get_oauth_provider_info()
+ return JSONResponse(provider_info)
+
+ except Exception as e:
+ logger.error(f"Failed to get OAuth provider info: {e}")
+ return JSONResponse(
+ {"error": f"Failed to get provider info: {str(e)}"},
+ status_code=500
+ )
+
+ async def handle_demo_page(self, request: Request) -> HTMLResponse:
+ """Handle OAuth demo page
+
+ Returns a simple HTML page for testing OAuth flow
+ """
+ oauth_info = self.security_manager.get_oauth_provider_info()
+ if not oauth_info.get("enabled"):
+ return HTMLResponse("""
+
+
OAuth Demo
+
+ OAuth Demo
+ OAuth authentication is not enabled.
+ Please configure OAuth settings in your security configuration.
+
+
+ """)
+
+ html_content = f"""
+
+
+
+ Doris MCP Server - OAuth Demo
+
+
+
+ Doris MCP Server - OAuth Demo
+
+
+
OAuth Configuration
+
Provider: {oauth_info.get('provider', 'N/A')}
+
Client ID: {oauth_info.get('client_id', 'N/A')}
+
Scopes: {', '.join(oauth_info.get('scopes', []))}
+
PKCE Enabled: {oauth_info.get('pkce_enabled', False)}
+
+
+
+
OAuth Authentication Test
+
Click the button below to start OAuth authentication flow:
+
+
+
+
+
+
+
API Endpoints
+
+ GET /auth/login - Initiate OAuth login
+ GET /auth/callback - OAuth callback handler
+ GET /auth/provider - Provider information
+
+
+
+
+
+
+ """
+
+ return HTMLResponse(html_content)
\ No newline at end of file
diff --git a/doris_mcp_server/auth/oauth_provider.py b/doris_mcp_server/auth/oauth_provider.py
new file mode 100644
index 0000000..71ce095
--- /dev/null
+++ b/doris_mcp_server/auth/oauth_provider.py
@@ -0,0 +1,287 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+OAuth Authentication Provider
+Integrates OAuth 2.0/OIDC authentication with the existing authentication framework
+"""
+
+from typing import Dict, Any, Optional, Tuple
+from datetime import datetime
+
+from .oauth_client import OAuthClient
+from .oauth_types import OAuthTokens, OAuthUserInfo, OAuthState
+from ..utils.security import AuthContext, SecurityLevel
+from ..utils.logger import get_logger
+
+logger = get_logger(__name__)
+
+
+class OAuthAuthenticationProvider:
+ """OAuth authentication provider for Doris MCP Server"""
+
+ def __init__(self, config):
+ """Initialize OAuth authentication provider
+
+ Args:
+ config: DorisConfig with OAuth configuration
+ """
+ self.config = config
+ self.oauth_client = OAuthClient(config)
+ self.enabled = self.oauth_client.enabled
+
+ logger.info(f"OAuthAuthenticationProvider initialized (enabled: {self.enabled})")
+
+ async def initialize(self) -> bool:
+ """Initialize OAuth authentication provider
+
+ Returns:
+ True if initialization successful
+ """
+ if not self.enabled:
+ return True
+
+ success = await self.oauth_client.initialize()
+ if success:
+ logger.info("OAuth authentication provider initialized successfully")
+ else:
+ logger.error("Failed to initialize OAuth authentication provider")
+ return success
+
+ async def shutdown(self):
+ """Shutdown OAuth authentication provider"""
+ if self.enabled:
+ await self.oauth_client.shutdown()
+ logger.info("OAuth authentication provider shutdown completed")
+
+ def get_authorization_url(self) -> Tuple[str, str]:
+ """Get OAuth authorization URL
+
+ Returns:
+ Tuple of (authorization_url, state)
+ """
+ if not self.enabled:
+ raise ValueError("OAuth authentication is not enabled")
+
+ authorization_url, oauth_state = self.oauth_client.build_authorization_url()
+ return authorization_url, oauth_state.state
+
+ async def handle_callback(self, code: str, state: str) -> AuthContext:
+ """Handle OAuth callback and create authentication context
+
+ Args:
+ code: Authorization code from OAuth provider
+ state: State parameter for CSRF protection
+
+ Returns:
+ AuthContext for the authenticated user
+
+ Raises:
+ ValueError: If authentication fails
+ """
+ if not self.enabled:
+ raise ValueError("OAuth authentication is not enabled")
+
+ try:
+ # Exchange code for tokens
+ tokens, oauth_state = await self.oauth_client.exchange_code_for_tokens(code, state)
+
+ # Get user information
+ user_info = await self.oauth_client.get_user_info(tokens)
+
+ # Create authentication context
+ auth_context = await self._create_auth_context(user_info, tokens)
+
+ logger.info(f"OAuth authentication successful for user: {auth_context.user_id}")
+ return auth_context
+
+ except Exception as e:
+ logger.error(f"OAuth callback handling failed: {e}")
+ raise ValueError(f"OAuth authentication failed: {str(e)}")
+
+ async def authenticate_with_token(self, access_token: str) -> AuthContext:
+ """Authenticate using OAuth access token
+
+ Args:
+ access_token: OAuth access token
+
+ Returns:
+ AuthContext for the authenticated user
+ """
+ if not self.enabled:
+ raise ValueError("OAuth authentication is not enabled")
+
+ try:
+ # Create token object
+ tokens = OAuthTokens(access_token=access_token)
+
+ # Get user information
+ user_info = await self.oauth_client.get_user_info(tokens)
+
+ # Create authentication context
+ auth_context = await self._create_auth_context(user_info, tokens)
+
+ logger.info(f"OAuth token authentication successful for user: {auth_context.user_id}")
+ return auth_context
+
+ except Exception as e:
+ logger.error(f"OAuth token authentication failed: {e}")
+ raise ValueError(f"OAuth token authentication failed: {str(e)}")
+
+ async def refresh_authentication(self, refresh_token: str) -> Tuple[AuthContext, str]:
+ """Refresh OAuth authentication
+
+ Args:
+ refresh_token: OAuth refresh token
+
+ Returns:
+ Tuple of (AuthContext, new_access_token)
+ """
+ if not self.enabled:
+ raise ValueError("OAuth authentication is not enabled")
+
+ try:
+ # Refresh tokens
+ tokens = await self.oauth_client.refresh_tokens(refresh_token)
+
+ # Get updated user information
+ user_info = await self.oauth_client.get_user_info(tokens)
+
+ # Create authentication context
+ auth_context = await self._create_auth_context(user_info, tokens)
+
+ logger.info(f"OAuth refresh successful for user: {auth_context.user_id}")
+ return auth_context, tokens.access_token
+
+ except Exception as e:
+ logger.error(f"OAuth refresh failed: {e}")
+ raise ValueError(f"OAuth refresh failed: {str(e)}")
+
+ async def _create_auth_context(self, user_info: OAuthUserInfo, tokens: OAuthTokens) -> AuthContext:
+ """Create authentication context from OAuth user info
+
+ Args:
+ user_info: OAuth user information
+ tokens: OAuth tokens
+
+ Returns:
+ AuthContext for the user
+ """
+ # Determine security level based on roles or email domain
+ security_level = await self._determine_security_level(user_info)
+
+ # Map OAuth roles to application permissions
+ permissions = await self._map_permissions(user_info.roles)
+
+ # Generate session ID
+ session_id = f"oauth_{user_info.sub}_{datetime.utcnow().timestamp()}"
+
+ return AuthContext(
+ user_id=user_info.sub,
+ roles=user_info.roles,
+ permissions=permissions,
+ session_id=session_id,
+ login_time=datetime.utcnow(),
+ last_activity=datetime.utcnow(),
+ security_level=security_level
+ )
+
+ async def _determine_security_level(self, user_info: OAuthUserInfo) -> SecurityLevel:
+ """Determine security level for OAuth user
+
+ Args:
+ user_info: OAuth user information
+
+ Returns:
+ SecurityLevel for the user
+ """
+ # Check if user has admin roles
+ admin_roles = {"admin", "administrator", "data_admin", "super_admin"}
+ if any(role.lower() in admin_roles for role in user_info.roles):
+ return SecurityLevel.SECRET
+
+ # Check email domain for internal users
+ if user_info.email:
+ # You can configure trusted domains for internal access
+ trusted_domains = ["yourcompany.com", "internal.org"] # Configure as needed
+ email_domain = user_info.email.split("@")[-1].lower()
+ if email_domain in trusted_domains:
+ return SecurityLevel.CONFIDENTIAL
+
+ # Check for special roles
+ elevated_roles = {"data_analyst", "developer", "manager"}
+ if any(role.lower() in elevated_roles for role in user_info.roles):
+ return SecurityLevel.CONFIDENTIAL
+
+ # Default to internal level for OAuth users
+ return SecurityLevel.INTERNAL
+
+ async def _map_permissions(self, roles: list[str]) -> list[str]:
+ """Map OAuth roles to application permissions
+
+ Args:
+ roles: OAuth user roles
+
+ Returns:
+ List of application permissions
+ """
+ permissions = set()
+
+ # Role to permission mapping
+ role_permissions = {
+ "admin": ["admin", "read_data", "write_data", "manage_users"],
+ "administrator": ["admin", "read_data", "write_data", "manage_users"],
+ "data_admin": ["admin", "read_data", "write_data"],
+ "super_admin": ["admin", "read_data", "write_data", "manage_users", "system_admin"],
+ "data_analyst": ["read_data", "query_database"],
+ "developer": ["read_data", "query_database", "debug"],
+ "viewer": ["read_data"],
+ "user": ["read_data"],
+ "oauth_user": ["read_data"] # Default OAuth user permission
+ }
+
+ # Map roles to permissions
+ for role in roles:
+ role_lower = role.lower()
+ if role_lower in role_permissions:
+ permissions.update(role_permissions[role_lower])
+
+ # Ensure OAuth users have at least basic permissions
+ if not permissions:
+ permissions.add("read_data")
+
+ return list(permissions)
+
+ def get_provider_info(self) -> Dict[str, Any]:
+ """Get OAuth provider information
+
+ Returns:
+ Provider information dictionary
+ """
+ if not self.enabled:
+ return {"enabled": False}
+
+ config = self.oauth_client.provider_config
+ return {
+ "enabled": True,
+ "provider": config.provider.value,
+ "client_id": config.client_id,
+ "scopes": config.scopes,
+ "redirect_uri": config.redirect_uri,
+ "pkce_enabled": config.pkce_enabled,
+ "nonce_enabled": config.nonce_enabled
+ }
\ No newline at end of file
diff --git a/doris_mcp_server/auth/oauth_types.py b/doris_mcp_server/auth/oauth_types.py
new file mode 100644
index 0000000..ce9252f
--- /dev/null
+++ b/doris_mcp_server/auth/oauth_types.py
@@ -0,0 +1,196 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+OAuth 2.0/OIDC Type Definitions
+Provides data types and models for OAuth authentication flow
+"""
+
+from dataclasses import dataclass
+from datetime import datetime
+from enum import Enum
+from typing import Dict, Any, Optional, List
+
+
+class OAuthProvider(Enum):
+ """OAuth provider enumeration"""
+ GOOGLE = "google"
+ MICROSOFT = "microsoft"
+ GITHUB = "github"
+ CUSTOM = "custom"
+
+
+class OAuthGrantType(Enum):
+ """OAuth grant type enumeration"""
+ AUTHORIZATION_CODE = "authorization_code"
+ REFRESH_TOKEN = "refresh_token"
+
+
+@dataclass
+class OAuthState:
+ """OAuth state parameter for CSRF protection"""
+ state: str
+ nonce: Optional[str] = None
+ pkce_verifier: Optional[str] = None
+ pkce_challenge: Optional[str] = None
+ redirect_uri: str = ""
+ created_at: datetime = None
+ expires_at: datetime = None
+
+ def __post_init__(self):
+ if self.created_at is None:
+ self.created_at = datetime.utcnow()
+
+
+@dataclass
+class OAuthTokens:
+ """OAuth token response"""
+ access_token: str
+ token_type: str = "Bearer"
+ expires_in: Optional[int] = None
+ refresh_token: Optional[str] = None
+ scope: Optional[str] = None
+ id_token: Optional[str] = None # OIDC ID token
+ created_at: datetime = None
+
+ def __post_init__(self):
+ if self.created_at is None:
+ self.created_at = datetime.utcnow()
+
+
+@dataclass
+class OAuthUserInfo:
+ """OAuth/OIDC user information"""
+ sub: str # Subject identifier
+ email: Optional[str] = None
+ email_verified: Optional[bool] = None
+ name: Optional[str] = None
+ given_name: Optional[str] = None
+ family_name: Optional[str] = None
+ picture: Optional[str] = None
+ locale: Optional[str] = None
+ roles: List[str] = None
+ raw_claims: Dict[str, Any] = None
+
+ def __post_init__(self):
+ if self.roles is None:
+ self.roles = []
+ if self.raw_claims is None:
+ self.raw_claims = {}
+
+
+@dataclass
+class OIDCDiscovery:
+ """OIDC Discovery document"""
+ issuer: str
+ authorization_endpoint: str
+ token_endpoint: str
+ userinfo_endpoint: Optional[str] = None
+ jwks_uri: Optional[str] = None
+ scopes_supported: List[str] = None
+ response_types_supported: List[str] = None
+ subject_types_supported: List[str] = None
+ id_token_signing_alg_values_supported: List[str] = None
+
+ def __post_init__(self):
+ if self.scopes_supported is None:
+ self.scopes_supported = ["openid"]
+ if self.response_types_supported is None:
+ self.response_types_supported = ["code"]
+ if self.subject_types_supported is None:
+ self.subject_types_supported = ["public"]
+ if self.id_token_signing_alg_values_supported is None:
+ self.id_token_signing_alg_values_supported = ["RS256"]
+
+
+@dataclass
+class OAuthError:
+ """OAuth error response"""
+ error: str
+ error_description: Optional[str] = None
+ error_uri: Optional[str] = None
+ state: Optional[str] = None
+
+
+@dataclass
+class OAuthProviderConfig:
+ """OAuth provider configuration"""
+ provider: OAuthProvider
+ client_id: str
+ client_secret: str
+ redirect_uri: str
+ scopes: List[str]
+
+ # Endpoints
+ authorization_endpoint: str
+ token_endpoint: str
+ userinfo_endpoint: Optional[str] = None
+ jwks_uri: Optional[str] = None
+
+ # Discovery
+ discovery_url: Optional[str] = None
+
+ # Settings
+ pkce_enabled: bool = True
+ nonce_enabled: bool = True
+
+ # User mapping
+ user_id_claim: str = "sub"
+ email_claim: str = "email"
+ name_claim: str = "name"
+ roles_claim: str = "roles"
+ default_roles: List[str] = None
+
+ def __post_init__(self):
+ if self.default_roles is None:
+ self.default_roles = ["oauth_user"]
+
+
+# Pre-defined provider configurations
+OAUTH_PROVIDERS = {
+ OAuthProvider.GOOGLE: {
+ "authorization_endpoint": "https://accounts.google.com/o/oauth2/auth",
+ "token_endpoint": "https://oauth2.googleapis.com/token",
+ "userinfo_endpoint": "https://openidconnect.googleapis.com/v1/userinfo",
+ "jwks_uri": "https://www.googleapis.com/oauth2/v3/certs",
+ "discovery_url": "https://accounts.google.com/.well-known/openid_configuration",
+ "scopes": ["openid", "email", "profile"],
+ "user_id_claim": "sub",
+ "email_claim": "email",
+ "name_claim": "name"
+ },
+ OAuthProvider.MICROSOFT: {
+ "authorization_endpoint": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
+ "token_endpoint": "https://login.microsoftonline.com/common/oauth2/v2.0/token",
+ "userinfo_endpoint": "https://graph.microsoft.com/v1.0/me",
+ "jwks_uri": "https://login.microsoftonline.com/common/discovery/v2.0/keys",
+ "discovery_url": "https://login.microsoftonline.com/common/v2.0/.well-known/openid_configuration",
+ "scopes": ["openid", "profile", "email", "User.Read"],
+ "user_id_claim": "sub",
+ "email_claim": "email",
+ "name_claim": "name"
+ },
+ OAuthProvider.GITHUB: {
+ "authorization_endpoint": "https://github.com/login/oauth/authorize",
+ "token_endpoint": "https://github.com/login/oauth/access_token",
+ "userinfo_endpoint": "https://api.github.com/user",
+ "scopes": ["user:email", "read:user"],
+ "user_id_claim": "id",
+ "email_claim": "email",
+ "name_claim": "name"
+ }
+}
\ No newline at end of file
diff --git a/doris_mcp_server/auth/token_handlers.py b/doris_mcp_server/auth/token_handlers.py
new file mode 100644
index 0000000..7001d4d
--- /dev/null
+++ b/doris_mcp_server/auth/token_handlers.py
@@ -0,0 +1,481 @@
+#!/usr/bin/env python3
+"""
+Token Authentication HTTP Handlers
+
+Provides HTTP endpoints for token management including creation, revocation,
+listing, and statistics. Used for administrative token management in HTTP mode.
+"""
+
+import json
+from typing import Dict, Any
+
+from starlette.requests import Request
+from starlette.responses import JSONResponse, HTMLResponse
+
+from ..utils.logger import get_logger
+from ..utils.security import SecurityLevel
+
+
+class TokenHandlers:
+ """Token Authentication HTTP Handlers"""
+
+ def __init__(self, security_manager):
+ self.security_manager = security_manager
+ self.logger = get_logger(__name__)
+
+ async def handle_create_token(self, request: Request) -> JSONResponse:
+ """Handle token creation request"""
+ try:
+ # Check if token manager is available
+ if not self.security_manager.auth_provider.token_manager:
+ return JSONResponse({
+ "error": "Token authentication is not enabled"
+ }, status_code=503)
+
+ # Parse request data
+ if request.method == "GET":
+ # GET request with query parameters
+ query_params = dict(request.query_params)
+ token_id = query_params.get("token_id")
+ user_id = query_params.get("user_id")
+ roles = query_params.get("roles", "").split(",") if query_params.get("roles") else []
+ permissions = query_params.get("permissions", "").split(",") if query_params.get("permissions") else []
+ security_level_str = query_params.get("security_level", "internal")
+ expires_hours_str = query_params.get("expires_hours")
+ description = query_params.get("description", "")
+ custom_token = query_params.get("custom_token")
+ else:
+ # POST request with JSON body
+ try:
+ body = await request.json()
+ except:
+ return JSONResponse({
+ "error": "Invalid JSON body"
+ }, status_code=400)
+
+ token_id = body.get("token_id")
+ user_id = body.get("user_id")
+ roles = body.get("roles", [])
+ permissions = body.get("permissions", [])
+ security_level_str = body.get("security_level", "internal")
+ expires_hours_str = body.get("expires_hours")
+ description = body.get("description", "")
+ custom_token = body.get("custom_token")
+
+ # Validate required fields
+ if not token_id or not user_id:
+ return JSONResponse({
+ "error": "token_id and user_id are required"
+ }, status_code=400)
+
+ # Parse security level
+ try:
+ security_level = SecurityLevel(security_level_str.lower())
+ except ValueError:
+ security_level = SecurityLevel.INTERNAL
+
+ # Parse expires_hours
+ expires_hours = None
+ if expires_hours_str:
+ try:
+ expires_hours = int(expires_hours_str)
+ except ValueError:
+ return JSONResponse({
+ "error": "expires_hours must be an integer"
+ }, status_code=400)
+
+ # Create token
+ try:
+ token = await self.security_manager.create_token(
+ token_id=token_id,
+ user_id=user_id,
+ roles=roles,
+ permissions=permissions,
+ security_level=security_level,
+ expires_hours=expires_hours,
+ description=description,
+ custom_token=custom_token
+ )
+
+ return JSONResponse({
+ "success": True,
+ "token_id": token_id,
+ "user_id": user_id,
+ "token": token,
+ "roles": roles,
+ "permissions": permissions,
+ "security_level": security_level.value,
+ "expires_hours": expires_hours,
+ "description": description,
+ "message": "Token created successfully"
+ })
+
+ except Exception as e:
+ self.logger.error(f"Token creation failed: {e}")
+ return JSONResponse({
+ "error": f"Token creation failed: {str(e)}"
+ }, status_code=400)
+
+ except Exception as e:
+ self.logger.error(f"Error in handle_create_token: {e}")
+ return JSONResponse({
+ "error": f"Internal server error: {str(e)}"
+ }, status_code=500)
+
+ async def handle_revoke_token(self, request: Request) -> JSONResponse:
+ """Handle token revocation request"""
+ try:
+ # Check if token manager is available
+ if not self.security_manager.auth_provider.token_manager:
+ return JSONResponse({
+ "error": "Token authentication is not enabled"
+ }, status_code=503)
+
+ # Get token_id from query parameters or path
+ token_id = request.query_params.get("token_id")
+ if not token_id and request.method == "DELETE":
+ # Try to get from path: /token/revoke/{token_id}
+ path_parts = str(request.url.path).split("/")
+ if len(path_parts) >= 4:
+ token_id = path_parts[-1]
+
+ if not token_id:
+ return JSONResponse({
+ "error": "token_id is required"
+ }, status_code=400)
+
+ # Revoke token
+ success = await self.security_manager.revoke_token(token_id)
+
+ if success:
+ return JSONResponse({
+ "success": True,
+ "token_id": token_id,
+ "message": "Token revoked successfully"
+ })
+ else:
+ return JSONResponse({
+ "success": False,
+ "token_id": token_id,
+ "message": "Token not found or already revoked"
+ }, status_code=404)
+
+ except Exception as e:
+ self.logger.error(f"Error in handle_revoke_token: {e}")
+ return JSONResponse({
+ "error": f"Internal server error: {str(e)}"
+ }, status_code=500)
+
+ async def handle_list_tokens(self, request: Request) -> JSONResponse:
+ """Handle token listing request"""
+ try:
+ # Check if token manager is available
+ if not self.security_manager.auth_provider.token_manager:
+ return JSONResponse({
+ "error": "Token authentication is not enabled"
+ }, status_code=503)
+
+ # Get tokens list
+ tokens = await self.security_manager.list_tokens()
+
+ return JSONResponse({
+ "success": True,
+ "count": len(tokens),
+ "tokens": tokens
+ })
+
+ except Exception as e:
+ self.logger.error(f"Error in handle_list_tokens: {e}")
+ return JSONResponse({
+ "error": f"Internal server error: {str(e)}"
+ }, status_code=500)
+
+ async def handle_token_stats(self, request: Request) -> JSONResponse:
+ """Handle token statistics request"""
+ try:
+ # Check if token manager is available
+ if not self.security_manager.auth_provider.token_manager:
+ return JSONResponse({
+ "error": "Token authentication is not enabled"
+ }, status_code=503)
+
+ # Get token statistics
+ stats = self.security_manager.get_token_stats()
+
+ return JSONResponse({
+ "success": True,
+ "stats": stats
+ })
+
+ except Exception as e:
+ self.logger.error(f"Error in handle_token_stats: {e}")
+ return JSONResponse({
+ "error": f"Internal server error: {str(e)}"
+ }, status_code=500)
+
+ async def handle_cleanup_tokens(self, request: Request) -> JSONResponse:
+ """Handle expired tokens cleanup request"""
+ try:
+ # Check if token manager is available
+ if not self.security_manager.auth_provider.token_manager:
+ return JSONResponse({
+ "error": "Token authentication is not enabled"
+ }, status_code=503)
+
+ # Cleanup expired tokens
+ cleaned_count = await self.security_manager.cleanup_expired_tokens()
+
+ return JSONResponse({
+ "success": True,
+ "cleaned_count": cleaned_count,
+ "message": f"Cleaned up {cleaned_count} expired tokens"
+ })
+
+ except Exception as e:
+ self.logger.error(f"Error in handle_cleanup_tokens: {e}")
+ return JSONResponse({
+ "error": f"Internal server error: {str(e)}"
+ }, status_code=500)
+
+ async def handle_demo_page(self, request: Request) -> HTMLResponse:
+ """Handle token management demo page"""
+ try:
+ # Check if token manager is available
+ if not self.security_manager.auth_provider.token_manager:
+ html_content = """
+
+
+
+ Token Management - Not Available
+
+
+
+ Token Management
+ Token authentication is not enabled on this server.
+
+
+ """
+ return HTMLResponse(html_content)
+
+ # Get current stats for demo
+ stats = self.security_manager.get_token_stats()
+
+ html_content = f"""
+
+
+
+ Doris MCP Server - Token Management
+
+
+
+
+
🔐 Doris MCP Server - Token Management
+
+
+
📊 Token Statistics
+
+
+
{stats.get('total_tokens', 0)}
+
Total Tokens
+
+
+
{stats.get('active_tokens', 0)}
+
Active Tokens
+
+
+
{stats.get('expired_tokens', 0)}
+
Expired Tokens
+
+
+
Token Expiry: {'Enabled' if stats.get('expiry_enabled') else 'Disabled'}
+
Default Expiry: {stats.get('default_expiry_hours', 0)} hours
+
+
+
+
➕ Create New Token
+
+
+
+
+
+
📋 Token Management
+
+
+
+
+
Revoke Token
+
+
+
+
+
+
+
+
+
🔧 API Endpoints
+
Use these endpoints for programmatic token management:
+
+ - POST /token/create - Create new token
+ - DELETE /token/revoke?token_id=... - Revoke token
+ - GET /token/list - List all tokens
+ - GET /token/stats - Get token statistics
+ - POST /token/cleanup - Cleanup expired tokens
+
+
+
+
+
+
+
+ """
+
+ return HTMLResponse(html_content)
+
+ except Exception as e:
+ self.logger.error(f"Error in handle_demo_page: {e}")
+ error_html = f"""
+
+
+
+ Token Management Error
+
+
+
+ Token Management Error
+ Error loading token management page: {str(e)}
+
+
+ """
+ return HTMLResponse(error_html, status_code=500)
\ No newline at end of file
diff --git a/doris_mcp_server/auth/token_manager.py b/doris_mcp_server/auth/token_manager.py
new file mode 100644
index 0000000..ed742d9
--- /dev/null
+++ b/doris_mcp_server/auth/token_manager.py
@@ -0,0 +1,456 @@
+#!/usr/bin/env python3
+"""
+Token Authentication Management Module
+
+Provides enterprise-grade token authentication system with configurable tokens,
+expiration management, role-based access control and secure token storage.
+"""
+
+import hashlib
+import json
+import os
+import secrets
+import time
+from dataclasses import dataclass, field
+from datetime import datetime, timedelta
+from typing import Dict, List, Optional, Any
+
+from ..utils.logger import get_logger
+from ..utils.security import SecurityLevel
+
+
+@dataclass
+class TokenInfo:
+ """Token information structure"""
+
+ token_id: str # Unique token identifier for audit and management
+ created_at: datetime = field(default_factory=datetime.utcnow)
+ expires_at: Optional[datetime] = None
+ last_used: Optional[datetime] = None
+ description: str = "" # Optional description for token purpose
+ is_active: bool = True
+
+
+@dataclass
+class TokenValidationResult:
+ """Token validation result"""
+
+ is_valid: bool
+ token_info: Optional[TokenInfo] = None
+ error_message: Optional[str] = None
+
+
+class TokenManager:
+ """Enterprise Token Authentication Manager
+
+ Features:
+ - Configurable token storage (file-based or environment variables)
+ - Token expiration management
+ - Secure token hashing
+ - Role-based access control
+ - Token lifecycle management
+ """
+
+ def __init__(self, config):
+ self.config = config
+ self.logger = get_logger(__name__)
+
+ # Token storage
+ self._tokens: Dict[str, TokenInfo] = {} # token_hash -> TokenInfo
+ self._token_ids: Dict[str, str] = {} # token_id -> token_hash
+
+ # Configuration
+ self.token_file_path = getattr(config.security, 'token_file_path', 'tokens.json')
+ self.enable_token_expiry = getattr(config.security, 'enable_token_expiry', True)
+ self.default_token_expiry_hours = getattr(config.security, 'default_token_expiry_hours', 24 * 30) # 30 days
+ self.token_hash_algorithm = getattr(config.security, 'token_hash_algorithm', 'sha256')
+
+ # Initialize with default tokens if none exist
+ self._initialize_default_tokens()
+
+ # Load tokens from configuration
+ self._load_tokens()
+
+ self.logger.info(f"TokenManager initialized with {len(self._tokens)} tokens")
+
+ def _initialize_default_tokens(self):
+ """Initialize default tokens for basic authentication (configurable via environment)"""
+ # Default token configurations (can be overridden by environment variables)
+ default_tokens = [
+ {
+ 'token_id': 'admin-token',
+ 'token': os.getenv('DEFAULT_ADMIN_TOKEN', 'doris_admin_token_123456'),
+ 'description': os.getenv('DEFAULT_ADMIN_DESCRIPTION', 'Default admin API access token'),
+ 'expires_hours': None # Never expires
+ },
+ {
+ 'token_id': 'analyst-token',
+ 'token': os.getenv('DEFAULT_ANALYST_TOKEN', 'doris_analyst_token_123456'),
+ 'description': os.getenv('DEFAULT_ANALYST_DESCRIPTION', 'Default data analysis API access token'),
+ 'expires_hours': None # Never expires
+ },
+ {
+ 'token_id': 'readonly-token',
+ 'token': os.getenv('DEFAULT_READONLY_TOKEN', 'doris_readonly_token_123456'),
+ 'description': os.getenv('DEFAULT_READONLY_DESCRIPTION', 'Default read-only API access token'),
+ 'expires_hours': None # Never expires
+ }
+ ]
+
+
+ # Only add default tokens if no custom tokens are defined via environment variables
+ # Check if any TOKEN_* environment variables exist (excluding system and legacy configs)
+ excluded_prefixes = ('DEFAULT_', 'TOKEN_FILE_PATH', 'TOKEN_HASH_')
+ excluded_vars = {'TOKEN_SECRET', 'TOKEN_EXPIRY'}
+
+ custom_tokens_exist = any(
+ key.startswith('TOKEN_') and
+ not key.startswith(excluded_prefixes) and
+ not key.endswith(('_EXPIRES_HOURS', '_DESCRIPTION')) and
+ key not in excluded_vars
+ for key in os.environ.keys()
+ )
+
+ # Also check if token file exists and has content
+ token_file_exists = False
+ if os.path.exists(self.token_file_path):
+ try:
+ with open(self.token_file_path, 'r') as f:
+ content = f.read().strip()
+ if content and content != '{}':
+ token_file_exists = True
+ except:
+ pass
+
+ # Add default tokens only if no custom configuration exists
+ if not custom_tokens_exist and not token_file_exists:
+ for token_config in default_tokens:
+ self._add_token_from_config(token_config)
+
+ self.logger.info(f"Initialized {len(default_tokens)} default tokens (no custom config found)")
+ else:
+ self.logger.info("Skipped default tokens initialization (custom tokens detected)")
+
+ def _add_token_from_config(self, token_config: Dict[str, Any]):
+ """Add token from configuration"""
+ try:
+ # Calculate expiration time
+ expires_at = None
+ if self.enable_token_expiry:
+ expires_hours = token_config.get('expires_hours', self.default_token_expiry_hours)
+ if expires_hours is not None:
+ expires_at = datetime.utcnow() + timedelta(hours=expires_hours)
+
+ # Create token info
+ token_info = TokenInfo(
+ token_id=token_config['token_id'],
+ expires_at=expires_at,
+ description=token_config.get('description', ''),
+ is_active=token_config.get('is_active', True)
+ )
+
+ # Hash the token
+ raw_token = token_config['token']
+ token_hash = self._hash_token(raw_token)
+
+ # Store token
+ self._tokens[token_hash] = token_info
+ self._token_ids[token_info.token_id] = token_hash
+
+ self.logger.debug(f"Added token '{token_info.token_id}'")
+
+ except Exception as e:
+ self.logger.error(f"Failed to add token from config: {e}")
+
+ def _load_tokens(self):
+ """Load tokens from configuration sources"""
+ # 1. Load from environment variables
+ self._load_tokens_from_env()
+
+ # 2. Load from token file if exists
+ if os.path.exists(self.token_file_path):
+ self._load_tokens_from_file()
+
+ self.logger.info(f"Token loading completed, total tokens: {len(self._tokens)}")
+
+ def _load_tokens_from_env(self):
+ """Load tokens from environment variables
+
+ Simplified format:
+ TOKEN_=
+ TOKEN__EXPIRES_HOURS=
+ TOKEN__DESCRIPTION=
+ """
+ token_prefixes = set()
+
+ # Find all TOKEN_ environment variables (exclude legacy and system variables)
+ excluded_token_vars = {
+ 'TOKEN_SECRET', # Legacy token secret
+ 'TOKEN_EXPIRY', # Legacy token expiry
+ 'TOKEN_FILE_PATH', # System config
+ 'TOKEN_HASH_ALGORITHM' # System config
+ }
+
+ for key in os.environ:
+ if (key.startswith('TOKEN_') and
+ not key.endswith(('_EXPIRES_HOURS', '_DESCRIPTION')) and
+ key not in excluded_token_vars):
+ token_id = key[6:] # Remove 'TOKEN_' prefix
+ token_prefixes.add(token_id)
+
+ # Load each token
+ for token_id in token_prefixes:
+ try:
+ token = os.environ.get(f'TOKEN_{token_id}')
+ if not token:
+ continue
+
+ expires_hours_str = os.environ.get(f'TOKEN_{token_id}_EXPIRES_HOURS', str(self.default_token_expiry_hours))
+ description = os.environ.get(f'TOKEN_{token_id}_DESCRIPTION', f'Environment token {token_id}')
+
+ expires_hours = None
+ try:
+ if expires_hours_str and expires_hours_str.lower() != 'none':
+ expires_hours = int(expires_hours_str)
+ except ValueError:
+ expires_hours = self.default_token_expiry_hours
+
+ # Add token
+ token_config = {
+ 'token_id': token_id.lower(),
+ 'token': token,
+ 'expires_hours': expires_hours,
+ 'description': description
+ }
+
+ self._add_token_from_config(token_config)
+
+ except Exception as e:
+ self.logger.error(f"Failed to load token {token_id} from environment: {e}")
+
+ def _load_tokens_from_file(self):
+ """Load tokens from JSON file"""
+ try:
+ with open(self.token_file_path, 'r', encoding='utf-8') as f:
+ tokens_data = json.load(f)
+
+ if isinstance(tokens_data, dict) and 'tokens' in tokens_data:
+ tokens_list = tokens_data['tokens']
+ elif isinstance(tokens_data, list):
+ tokens_list = tokens_data
+ else:
+ self.logger.error(f"Invalid token file format: {self.token_file_path}")
+ return
+
+ for token_config in tokens_list:
+ self._add_token_from_config(token_config)
+
+ self.logger.info(f"Loaded {len(tokens_list)} tokens from file: {self.token_file_path}")
+
+ except Exception as e:
+ self.logger.error(f"Failed to load tokens from file {self.token_file_path}: {e}")
+
+ def _hash_token(self, token: str) -> str:
+ """Hash token for secure storage"""
+ if self.token_hash_algorithm == 'sha256':
+ return hashlib.sha256(token.encode('utf-8')).hexdigest()
+ elif self.token_hash_algorithm == 'sha512':
+ return hashlib.sha512(token.encode('utf-8')).hexdigest()
+ else:
+ # Fallback to sha256
+ return hashlib.sha256(token.encode('utf-8')).hexdigest()
+
+ async def validate_token(self, token: str) -> TokenValidationResult:
+ """Validate token and return user information"""
+ try:
+ # Hash the provided token
+ token_hash = self._hash_token(token)
+
+ # Find token info
+ token_info = self._tokens.get(token_hash)
+ if not token_info:
+ return TokenValidationResult(
+ is_valid=False,
+ error_message="Invalid token"
+ )
+
+ # Check if token is active
+ if not token_info.is_active:
+ return TokenValidationResult(
+ is_valid=False,
+ error_message="Token is inactive"
+ )
+
+ # Check expiration
+ if token_info.expires_at and datetime.utcnow() > token_info.expires_at:
+ return TokenValidationResult(
+ is_valid=False,
+ error_message="Token has expired"
+ )
+
+ # Update last used time
+ token_info.last_used = datetime.utcnow()
+
+ return TokenValidationResult(
+ is_valid=True,
+ token_info=token_info
+ )
+
+ except Exception as e:
+ self.logger.error(f"Token validation error: {e}")
+ return TokenValidationResult(
+ is_valid=False,
+ error_message=f"Token validation failed: {str(e)}"
+ )
+
+ def generate_token(self, length: int = 32) -> str:
+ """Generate a cryptographically secure random token"""
+ return secrets.token_urlsafe(length)
+
+ async def create_token(
+ self,
+ token_id: str,
+ expires_hours: Optional[int] = None,
+ description: str = "",
+ custom_token: Optional[str] = None
+ ) -> str:
+ """Create a new token"""
+ try:
+ # Check if token_id already exists
+ if token_id in self._token_ids:
+ raise ValueError(f"Token ID '{token_id}' already exists")
+
+ # Generate or use provided token
+ if custom_token:
+ raw_token = custom_token
+ else:
+ raw_token = self.generate_token()
+
+ # Calculate expiration
+ expires_at = None
+ if expires_hours is not None:
+ expires_at = datetime.utcnow() + timedelta(hours=expires_hours)
+ elif self.enable_token_expiry:
+ expires_at = datetime.utcnow() + timedelta(hours=self.default_token_expiry_hours)
+
+ # Create token info
+ token_info = TokenInfo(
+ token_id=token_id,
+ expires_at=expires_at,
+ description=description
+ )
+
+ # Hash and store token
+ token_hash = self._hash_token(raw_token)
+ self._tokens[token_hash] = token_info
+ self._token_ids[token_id] = token_hash
+
+ self.logger.info(f"Created new token '{token_id}'")
+
+ return raw_token
+
+ except Exception as e:
+ self.logger.error(f"Failed to create token: {e}")
+ raise
+
+ async def revoke_token(self, token_id: str) -> bool:
+ """Revoke a token by token ID"""
+ try:
+ if token_id not in self._token_ids:
+ self.logger.warning(f"Token ID '{token_id}' not found")
+ return False
+
+ # Get token hash and remove from storage
+ token_hash = self._token_ids[token_id]
+ if token_hash in self._tokens:
+ del self._tokens[token_hash]
+ del self._token_ids[token_id]
+
+ self.logger.info(f"Revoked token '{token_id}'")
+ return True
+
+ except Exception as e:
+ self.logger.error(f"Failed to revoke token '{token_id}': {e}")
+ return False
+
+ async def list_tokens(self) -> List[Dict[str, Any]]:
+ """List all tokens (without sensitive data)"""
+ tokens = []
+
+ for token_hash, token_info in self._tokens.items():
+ tokens.append({
+ 'token_id': token_info.token_id,
+ 'created_at': token_info.created_at.isoformat(),
+ 'expires_at': token_info.expires_at.isoformat() if token_info.expires_at else None,
+ 'last_used': token_info.last_used.isoformat() if token_info.last_used else None,
+ 'is_active': token_info.is_active,
+ 'description': token_info.description,
+ 'is_expired': token_info.expires_at and datetime.utcnow() > token_info.expires_at if token_info.expires_at else False
+ })
+
+ # Sort by creation time
+ tokens.sort(key=lambda x: x['created_at'], reverse=True)
+
+ return tokens
+
+ async def cleanup_expired_tokens(self) -> int:
+ """Remove expired tokens and return count"""
+ if not self.enable_token_expiry:
+ return 0
+
+ now = datetime.utcnow()
+ expired_tokens = []
+
+ # Find expired tokens
+ for token_hash, token_info in self._tokens.items():
+ if token_info.expires_at and now > token_info.expires_at:
+ expired_tokens.append((token_hash, token_info.token_id))
+
+ # Remove expired tokens
+ for token_hash, token_id in expired_tokens:
+ del self._tokens[token_hash]
+ if token_id in self._token_ids:
+ del self._token_ids[token_id]
+
+ if expired_tokens:
+ self.logger.info(f"Cleaned up {len(expired_tokens)} expired tokens")
+
+ return len(expired_tokens)
+
+ async def save_tokens_to_file(self, file_path: Optional[str] = None) -> bool:
+ """Save current tokens to JSON file"""
+ try:
+ file_path = file_path or self.token_file_path
+ tokens_list = await self.list_tokens()
+
+ tokens_data = {
+ 'version': '1.0',
+ 'created_at': datetime.utcnow().isoformat(),
+ 'tokens': tokens_list
+ }
+
+ with open(file_path, 'w', encoding='utf-8') as f:
+ json.dump(tokens_data, f, indent=2, ensure_ascii=False)
+
+ self.logger.info(f"Saved {len(tokens_list)} tokens to file: {file_path}")
+ return True
+
+ except Exception as e:
+ self.logger.error(f"Failed to save tokens to file: {e}")
+ return False
+
+ def get_token_stats(self) -> Dict[str, Any]:
+ """Get token statistics"""
+ now = datetime.utcnow()
+ total_tokens = len(self._tokens)
+ active_tokens = sum(1 for info in self._tokens.values() if info.is_active)
+ expired_tokens = sum(1 for info in self._tokens.values()
+ if info.expires_at and now > info.expires_at)
+
+ return {
+ 'total_tokens': total_tokens,
+ 'active_tokens': active_tokens,
+ 'expired_tokens': expired_tokens,
+ 'expiry_enabled': self.enable_token_expiry,
+ 'default_expiry_hours': self.default_token_expiry_hours
+ }
\ No newline at end of file
diff --git a/doris_mcp_server/auth/token_validators.py b/doris_mcp_server/auth/token_validators.py
new file mode 100644
index 0000000..43731af
--- /dev/null
+++ b/doris_mcp_server/auth/token_validators.py
@@ -0,0 +1,365 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+JWT Token Validation Module
+Provides token validation, blacklist management and security features
+"""
+
+import time
+import asyncio
+from typing import Dict, Set, Optional, Any
+from datetime import datetime, timedelta
+from collections import defaultdict
+
+from ..utils.logger import get_logger
+
+logger = get_logger(__name__)
+
+
+class TokenBlacklist:
+ """JWT Token Blacklist Manager
+
+ Manages revoked tokens to prevent revoked tokens from being used again
+ Supports both in-memory and persistent storage
+ """
+
+ def __init__(self, cleanup_interval: int = 3600):
+ """Initialize token blacklist
+
+ Args:
+ cleanup_interval: Interval for cleaning up expired tokens (seconds)
+ """
+ self.cleanup_interval = cleanup_interval
+ # Storage format: {token_jti: expiry_timestamp}
+ self._blacklisted_tokens: Dict[str, float] = {}
+ self._cleanup_task = None
+
+ logger.info("TokenBlacklist initialized")
+
+ async def start(self):
+ """Start blacklist manager"""
+ self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
+ logger.info("TokenBlacklist started with periodic cleanup")
+
+ async def stop(self):
+ """Stop blacklist manager"""
+ if self._cleanup_task:
+ self._cleanup_task.cancel()
+ try:
+ await self._cleanup_task
+ except asyncio.CancelledError:
+ pass
+ logger.info("TokenBlacklist stopped")
+
+ async def add_token(self, jti: str, exp: float):
+ """Add token to blacklist
+
+ Args:
+ jti: JWT ID (unique identifier)
+ exp: Token expiration timestamp
+ """
+ self._blacklisted_tokens[jti] = exp
+ logger.info(f"Token {jti} added to blacklist")
+
+ async def is_blacklisted(self, jti: str) -> bool:
+ """Check if token is blacklisted
+
+ Args:
+ jti: JWT ID
+
+ Returns:
+ True if blacklisted, False otherwise
+ """
+ return jti in self._blacklisted_tokens
+
+ async def remove_token(self, jti: str) -> bool:
+ """Remove token from blacklist
+
+ Args:
+ jti: JWT ID
+
+ Returns:
+ True if removed, False if not found
+ """
+ if jti in self._blacklisted_tokens:
+ del self._blacklisted_tokens[jti]
+ logger.info(f"Token {jti} removed from blacklist")
+ return True
+ return False
+
+ async def cleanup_expired(self) -> int:
+ """Clean up expired blacklisted tokens
+
+ Returns:
+ Number of tokens cleaned up
+ """
+ current_time = time.time()
+ expired_tokens = [
+ jti for jti, exp in self._blacklisted_tokens.items()
+ if exp <= current_time
+ ]
+
+ for jti in expired_tokens:
+ del self._blacklisted_tokens[jti]
+
+ if expired_tokens:
+ logger.info(f"Cleaned up {len(expired_tokens)} expired tokens from blacklist")
+
+ return len(expired_tokens)
+
+ async def get_stats(self) -> Dict[str, Any]:
+ """Get blacklist statistics"""
+ current_time = time.time()
+ active_tokens = sum(1 for exp in self._blacklisted_tokens.values() if exp > current_time)
+
+ return {
+ "total_blacklisted": len(self._blacklisted_tokens),
+ "active_blacklisted": active_tokens,
+ "expired_blacklisted": len(self._blacklisted_tokens) - active_tokens,
+ "cleanup_interval": self.cleanup_interval
+ }
+
+ async def _periodic_cleanup(self):
+ """Periodically clean up expired tokens"""
+ while True:
+ try:
+ await asyncio.sleep(self.cleanup_interval)
+ await self.cleanup_expired()
+ except asyncio.CancelledError:
+ break
+ except Exception as e:
+ logger.error(f"Error during periodic cleanup: {e}")
+
+
+class RateLimiter:
+ """Token usage rate limiter"""
+
+ def __init__(self, max_requests: int = 100, time_window: int = 3600):
+ """Initialize rate limiter
+
+ Args:
+ max_requests: Maximum requests within time window
+ time_window: Time window in seconds
+ """
+ self.max_requests = max_requests
+ self.time_window = time_window
+ # Storage format: {user_id: [timestamp1, timestamp2, ...]}
+ self._request_history: Dict[str, list] = defaultdict(list)
+
+ logger.info(f"RateLimiter initialized: {max_requests} requests per {time_window} seconds")
+
+ async def is_allowed(self, user_id: str) -> bool:
+ """Check if user is allowed to make request
+
+ Args:
+ user_id: User ID
+
+ Returns:
+ True if allowed, False otherwise
+ """
+ current_time = time.time()
+ user_requests = self._request_history[user_id]
+
+ # Clean up expired request records
+ cutoff_time = current_time - self.time_window
+ user_requests[:] = [t for t in user_requests if t > cutoff_time]
+
+ # Check if limit exceeded
+ if len(user_requests) >= self.max_requests:
+ logger.warning(f"Rate limit exceeded for user {user_id}")
+ return False
+
+ # Record current request
+ user_requests.append(current_time)
+ return True
+
+ async def get_usage(self, user_id: str) -> Dict[str, Any]:
+ """Get user usage information
+
+ Args:
+ user_id: User ID
+
+ Returns:
+ Usage statistics
+ """
+ current_time = time.time()
+ user_requests = self._request_history[user_id]
+
+ # Clean up expired records
+ cutoff_time = current_time - self.time_window
+ active_requests = [t for t in user_requests if t > cutoff_time]
+
+ return {
+ "user_id": user_id,
+ "requests_in_window": len(active_requests),
+ "max_requests": self.max_requests,
+ "time_window": self.time_window,
+ "remaining_requests": max(0, self.max_requests - len(active_requests))
+ }
+
+
+class TokenValidator:
+ """JWT Token Validator
+
+ Provides comprehensive JWT token validation functionality, including signature verification,
+ claim validation, blacklist checking and rate limiting
+ """
+
+ def __init__(self, config, blacklist: Optional[TokenBlacklist] = None):
+ """Initialize token validator
+
+ Args:
+ config: DorisConfig configuration object (with security attribute)
+ blacklist: Token blacklist manager
+ """
+ self.config = config
+ self.blacklist = blacklist or TokenBlacklist()
+ self.rate_limiter = RateLimiter()
+
+ # Access JWT settings through the security configuration
+ if hasattr(config, 'security'):
+ security_config = config.security
+ else:
+ # Fallback if config is passed directly as SecurityConfig
+ security_config = config
+
+ # Validation options
+ self.verify_signature = security_config.jwt_verify_signature
+ self.verify_audience = security_config.jwt_verify_audience
+ self.verify_issuer = security_config.jwt_verify_issuer
+ self.require_exp = security_config.jwt_require_exp
+ self.require_iat = security_config.jwt_require_iat
+ self.require_nbf = security_config.jwt_require_nbf
+ self.leeway = security_config.jwt_leeway
+
+ # Expected values
+ self.expected_audience = security_config.jwt_audience
+ self.expected_issuer = security_config.jwt_issuer
+
+ logger.info("TokenValidator initialized")
+
+ async def validate_claims(self, payload: Dict[str, Any]) -> Dict[str, Any]:
+ """Validate JWT claims
+
+ Args:
+ payload: JWT payload
+
+ Returns:
+ Validation result
+
+ Raises:
+ ValueError: Validation failed
+ """
+ current_time = time.time()
+
+ # Validate issuer
+ if self.verify_issuer:
+ if payload.get('iss') != self.expected_issuer:
+ raise ValueError(f"Invalid issuer: expected {self.expected_issuer}")
+
+ # Validate audience
+ if self.verify_audience:
+ aud = payload.get('aud')
+ if isinstance(aud, list):
+ if self.expected_audience not in aud:
+ raise ValueError(f"Invalid audience: {self.expected_audience} not in {aud}")
+ elif aud != self.expected_audience:
+ raise ValueError(f"Invalid audience: expected {self.expected_audience}")
+
+ # Validate expiration time
+ if self.require_exp or 'exp' in payload:
+ exp = payload.get('exp')
+ if not exp:
+ raise ValueError("Missing 'exp' claim")
+ if current_time > exp + self.leeway:
+ raise ValueError("Token has expired")
+
+ # Validate not before time
+ if self.require_nbf or 'nbf' in payload:
+ nbf = payload.get('nbf')
+ if not nbf:
+ raise ValueError("Missing 'nbf' claim")
+ if current_time < nbf - self.leeway:
+ raise ValueError("Token not yet valid")
+
+ # Validate issued at time
+ if self.require_iat or 'iat' in payload:
+ iat = payload.get('iat')
+ if not iat:
+ raise ValueError("Missing 'iat' claim")
+ # Allow some clock skew, but cannot be future time
+ if iat > current_time + self.leeway:
+ raise ValueError("Token issued in the future")
+
+ # Check blacklist
+ jti = payload.get('jti')
+ if jti and await self.blacklist.is_blacklisted(jti):
+ raise ValueError("Token has been revoked")
+
+ # Rate limit check
+ user_id = payload.get('sub')
+ if user_id:
+ if not await self.rate_limiter.is_allowed(user_id):
+ raise ValueError("Rate limit exceeded")
+
+ return {
+ "valid": True,
+ "user_id": user_id,
+ "payload": payload
+ }
+
+ async def start(self):
+ """Start validator"""
+ await self.blacklist.start()
+ logger.info("TokenValidator started")
+
+ async def stop(self):
+ """Stop validator"""
+ await self.blacklist.stop()
+ logger.info("TokenValidator stopped")
+
+ async def revoke_token(self, jti: str, exp: float):
+ """Revoke token
+
+ Args:
+ jti: JWT ID
+ exp: Token expiration time
+ """
+ await self.blacklist.add_token(jti, exp)
+ logger.info(f"Token {jti} has been revoked")
+
+ async def get_validation_stats(self) -> Dict[str, Any]:
+ """Get validation statistics"""
+ blacklist_stats = await self.blacklist.get_stats()
+
+ return {
+ "blacklist": blacklist_stats,
+ "validation_config": {
+ "verify_signature": self.verify_signature,
+ "verify_audience": self.verify_audience,
+ "verify_issuer": self.verify_issuer,
+ "require_exp": self.require_exp,
+ "require_iat": self.require_iat,
+ "require_nbf": self.require_nbf,
+ "leeway": self.leeway
+ }
+ }
+
+ async def get_user_rate_limit_info(self, user_id: str) -> Dict[str, Any]:
+ """Get user rate limit information"""
+ return await self.rate_limiter.get_usage(user_id)
\ No newline at end of file
diff --git a/doris_mcp_server/main.py b/doris_mcp_server/main.py
index ccf0c29..5b60fde 100644
--- a/doris_mcp_server/main.py
+++ b/doris_mcp_server/main.py
@@ -246,6 +246,40 @@ class DorisServer:
self.logger = get_logger(f"{__name__}.DorisServer")
self._setup_handlers()
+ async def _extract_auth_info_from_scope(self, scope, headers):
+ """Extract authentication information from ASGI scope and headers"""
+ auth_info = {}
+
+ # Extract client IP
+ client = scope.get("client")
+ if client:
+ auth_info["client_ip"] = client[0]
+ else:
+ auth_info["client_ip"] = "unknown"
+
+ # Extract token from Authorization header
+ authorization = headers.get(b'authorization', b'').decode('utf-8')
+ if authorization:
+ if authorization.startswith('Bearer '):
+ auth_info["token"] = authorization[7:]
+ auth_info["authorization"] = authorization
+ elif authorization.startswith('Token '):
+ auth_info["token"] = authorization[6:]
+ auth_info["authorization"] = authorization
+
+ # Extract token from query parameters (for compatibility)
+ query_string = scope.get("query_string", b"").decode('utf-8')
+ if query_string and "token=" in query_string:
+ import urllib.parse
+ query_params = urllib.parse.parse_qs(query_string)
+ if "token" in query_params:
+ auth_info["token"] = query_params["token"][0]
+
+ # If no token found, this will be handled by the authentication system
+ # (either return anonymous context if auth disabled, or raise error if auth enabled)
+
+ return auth_info
+
def _get_mcp_capabilities(self):
"""Get MCP capabilities with version compatibility"""
try:
@@ -390,6 +424,10 @@ class DorisServer:
self.logger.info("Starting Doris MCP Server (stdio mode)")
try:
+ # Initialize security manager first (includes JWT setup if enabled)
+ await self.security_manager.initialize()
+ self.logger.info("Security manager initialization completed")
+
# Ensure connection manager is initialized
await self.connection_manager.initialize()
self.logger.info("Connection manager initialization completed")
@@ -456,6 +494,10 @@ class DorisServer:
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}, workers: {workers}")
try:
+ # Initialize security manager first (includes JWT setup if enabled)
+ await self.security_manager.initialize()
+ self.logger.info("Security manager initialization completed")
+
# Ensure connection manager is initialized
await self.connection_manager.initialize()
@@ -482,6 +524,44 @@ class DorisServer:
async def health_check(request):
return JSONResponse({"status": "healthy", "service": "doris-mcp-server"})
+ # OAuth endpoints
+ from .auth.oauth_handlers import OAuthHandlers
+ oauth_handlers = OAuthHandlers(self.security_manager)
+
+ async def oauth_login(request):
+ return await oauth_handlers.handle_login(request)
+
+ async def oauth_callback(request):
+ return await oauth_handlers.handle_callback(request)
+
+ async def oauth_provider_info(request):
+ return await oauth_handlers.handle_provider_info(request)
+
+ async def oauth_demo(request):
+ return await oauth_handlers.handle_demo_page(request)
+
+ # Token management endpoints
+ from .auth.token_handlers import TokenHandlers
+ token_handlers = TokenHandlers(self.security_manager)
+
+ async def token_create(request):
+ return await token_handlers.handle_create_token(request)
+
+ async def token_revoke(request):
+ return await token_handlers.handle_revoke_token(request)
+
+ async def token_list(request):
+ return await token_handlers.handle_list_tokens(request)
+
+ async def token_stats(request):
+ return await token_handlers.handle_token_stats(request)
+
+ async def token_cleanup(request):
+ return await token_handlers.handle_cleanup_tokens(request)
+
+ async def token_demo(request):
+ return await token_handlers.handle_demo_page(request)
+
# Lifecycle manager - simplified since we manage session_manager externally
@contextlib.asynccontextmanager
async def lifespan(app: Starlette) -> AsyncIterator[None]:
@@ -497,6 +577,18 @@ class DorisServer:
debug=True,
routes=[
Route("/health", health_check, methods=["GET"]),
+ # OAuth endpoints
+ Route("/auth/login", oauth_login, methods=["GET"]),
+ Route("/auth/callback", oauth_callback, methods=["GET"]),
+ Route("/auth/provider", oauth_provider_info, methods=["GET"]),
+ Route("/auth/demo", oauth_demo, methods=["GET"]),
+ # Token management endpoints
+ Route("/token/create", token_create, methods=["GET", "POST"]),
+ Route("/token/revoke", token_revoke, methods=["GET", "DELETE"]),
+ Route("/token/list", token_list, methods=["GET"]),
+ Route("/token/stats", token_stats, methods=["GET"]),
+ Route("/token/cleanup", token_cleanup, methods=["GET", "POST"]),
+ Route("/token/demo", token_demo, methods=["GET"]),
],
lifespan=lifespan,
)
@@ -514,8 +606,10 @@ class DorisServer:
self.logger.info(f"Received request for path: {path}")
try:
- # Handle health check
- if path.startswith("/health"):
+ # Handle health check, auth, and token management endpoints
+ if (path.startswith("/health") or
+ path.startswith("/auth/") or
+ path.startswith("/token/")):
await starlette_app(scope, receive, send)
return
@@ -528,6 +622,29 @@ class DorisServer:
self.logger.info(f"MCP Request - Method: {method}")
self.logger.info(f"MCP Request - Headers: {headers}")
+ # Authentication check for MCP requests
+ try:
+ # Extract authentication information
+ auth_info = await self._extract_auth_info_from_scope(scope, headers)
+
+ # Authenticate the request
+ auth_context = await self.security_manager.authenticate_request(auth_info)
+ self.logger.info(f"MCP request authenticated: token_id={auth_context.token_id}, client_ip={auth_context.client_ip}")
+
+ # Store auth context in scope for potential use by tools/resources
+ scope["auth_context"] = auth_context
+
+ except Exception as auth_error:
+ self.logger.error(f"MCP authentication failed: {auth_error}")
+ # Return 401 Unauthorized
+ from starlette.responses import JSONResponse
+ response = JSONResponse(
+ {"error": "Authentication required", "message": str(auth_error)},
+ status_code=401
+ )
+ await response(scope, receive, send)
+ return
+
# Handle Dify compatibility for GET requests
if method == "GET":
accept_header = headers.get(b'accept', b'').decode('utf-8')
@@ -619,6 +736,10 @@ class DorisServer:
"""Shutdown server"""
self.logger.info("Shutting down Doris MCP Server")
try:
+ # Shutdown security manager first (includes JWT cleanup)
+ await self.security_manager.shutdown()
+ self.logger.info("Security manager shutdown completed")
+
await self.connection_manager.close()
self.logger.info("Doris MCP Server has been shut down")
except Exception as e:
diff --git a/doris_mcp_server/multiworker_app.py b/doris_mcp_server/multiworker_app.py
index 9bf7c01..ba6b06c 100644
--- a/doris_mcp_server/multiworker_app.py
+++ b/doris_mcp_server/multiworker_app.py
@@ -209,6 +209,7 @@ from .utils.security import DorisSecurityManager
_worker_server = None
_worker_session_manager = None
_worker_connection_manager = None
+_worker_security_manager = None
_worker_session_manager_context = None
_worker_initialized = False
@@ -242,7 +243,7 @@ def get_mcp_capabilities():
async def initialize_worker():
"""Initialize MCP server and managers for this worker process"""
- global _worker_server, _worker_session_manager, _worker_connection_manager, _worker_session_manager_context, _worker_initialized
+ global _worker_server, _worker_session_manager, _worker_connection_manager, _worker_security_manager, _worker_session_manager_context, _worker_initialized, _oauth_handlers, _token_handlers
if _worker_initialized:
return
@@ -263,10 +264,14 @@ async def initialize_worker():
config_manager.setup_logging()
# Create security manager
- security_manager = DorisSecurityManager(config)
+ _worker_security_manager = DorisSecurityManager(config)
+
+ # Initialize security manager first (includes JWT setup if enabled)
+ await _worker_security_manager.initialize()
+ logger.info(f"Worker {os.getpid()} security manager initialization completed")
# Create connection manager
- _worker_connection_manager = DorisConnectionManager(config, security_manager)
+ _worker_connection_manager = DorisConnectionManager(config, _worker_security_manager)
await _worker_connection_manager.initialize()
# Create MCP server
@@ -382,6 +387,12 @@ async def initialize_worker():
_worker_session_manager_context = _worker_session_manager.run()
await _worker_session_manager_context.__aenter__()
+ # Initialize OAuth and Token handlers
+ from .auth.oauth_handlers import OAuthHandlers
+ from .auth.token_handlers import TokenHandlers
+ _oauth_handlers = OAuthHandlers(_worker_security_manager)
+ _token_handlers = TokenHandlers(_worker_security_manager)
+
_worker_initialized = True
logger.info(f"Worker {os.getpid()} MCP initialization completed successfully")
@@ -405,6 +416,73 @@ async def health_check(request):
"mcp_version": MCP_VERSION
})
+# OAuth and Token handlers (initialize after worker setup)
+_oauth_handlers = None
+_token_handlers = None
+
+async def oauth_login(request):
+ """OAuth login endpoint"""
+ if not _oauth_handlers:
+ return JSONResponse({"error": "OAuth not initialized"}, status_code=503)
+ return await _oauth_handlers.handle_login(request)
+
+async def oauth_callback(request):
+ """OAuth callback endpoint"""
+ if not _oauth_handlers:
+ return JSONResponse({"error": "OAuth not initialized"}, status_code=503)
+ return await _oauth_handlers.handle_callback(request)
+
+async def oauth_provider_info(request):
+ """OAuth provider info endpoint"""
+ if not _oauth_handlers:
+ return JSONResponse({"error": "OAuth not initialized"}, status_code=503)
+ return await _oauth_handlers.handle_provider_info(request)
+
+async def oauth_demo(request):
+ """OAuth demo page endpoint"""
+ if not _oauth_handlers:
+ from starlette.responses import HTMLResponse
+ return HTMLResponse("OAuth not initialized
")
+ return await _oauth_handlers.handle_demo_page(request)
+
+# Token management endpoints
+async def token_create(request):
+ """Token creation endpoint"""
+ if not _token_handlers:
+ return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
+ return await _token_handlers.handle_create_token(request)
+
+async def token_revoke(request):
+ """Token revocation endpoint"""
+ if not _token_handlers:
+ return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
+ return await _token_handlers.handle_revoke_token(request)
+
+async def token_list(request):
+ """Token listing endpoint"""
+ if not _token_handlers:
+ return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
+ return await _token_handlers.handle_list_tokens(request)
+
+async def token_stats(request):
+ """Token statistics endpoint"""
+ if not _token_handlers:
+ return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
+ return await _token_handlers.handle_token_stats(request)
+
+async def token_cleanup(request):
+ """Token cleanup endpoint"""
+ if not _token_handlers:
+ return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
+ return await _token_handlers.handle_cleanup_tokens(request)
+
+async def token_demo(request):
+ """Token demo page endpoint"""
+ if not _token_handlers:
+ from starlette.responses import HTMLResponse
+ return HTMLResponse("Token handlers not initialized
")
+ return await _token_handlers.handle_demo_page(request)
+
async def root_info(request):
"""Root endpoint"""
return JSONResponse({
@@ -452,6 +530,13 @@ async def lifespan(app):
except Exception as e:
logger.error(f"Error closing worker connection manager: {e}")
+ if _worker_security_manager:
+ try:
+ await _worker_security_manager.shutdown()
+ logger.info(f"Worker {os.getpid()} security manager shutdown completed")
+ except Exception as e:
+ logger.error(f"Error shutting down worker security manager: {e}")
+
# Shutdown logging system
try:
from .utils.logger import shutdown_logging
@@ -492,6 +577,18 @@ basic_app = Starlette(
routes=[
Route("/", root_info, methods=["GET"]),
Route("/health", health_check, methods=["GET"]),
+ # OAuth endpoints
+ Route("/auth/login", oauth_login, methods=["GET"]),
+ Route("/auth/callback", oauth_callback, methods=["GET"]),
+ Route("/auth/provider", oauth_provider_info, methods=["GET"]),
+ Route("/auth/demo", oauth_demo, methods=["GET"]),
+ # Token management endpoints
+ Route("/token/create", token_create, methods=["GET", "POST"]),
+ Route("/token/revoke", token_revoke, methods=["GET", "DELETE"]),
+ Route("/token/list", token_list, methods=["GET"]),
+ Route("/token/stats", token_stats, methods=["GET"]),
+ Route("/token/cleanup", token_cleanup, methods=["GET", "POST"]),
+ Route("/token/demo", token_demo, methods=["GET"]),
],
lifespan=lifespan
)
@@ -505,5 +602,5 @@ async def app(scope, receive, send):
# Handle MCP requests with session manager
await mcp_asgi_app(scope, receive, send)
else:
- # Handle other requests with basic Starlette app
+ # Handle other requests with basic Starlette app (includes auth endpoints)
await basic_app(scope, receive, send)
diff --git a/doris_mcp_server/utils/config.py b/doris_mcp_server/utils/config.py
index 8ee50a9..fe52f66 100644
--- a/doris_mcp_server/utils/config.py
+++ b/doris_mcp_server/utils/config.py
@@ -77,10 +77,43 @@ class DatabaseConfig:
class SecurityConfig:
"""Security configuration"""
- # Authentication configuration
- auth_type: str = "token" # token, basic, oauth
- token_secret: str = "default_secret"
+ # Independent authentication switches - any one enabled allows that method
+ enable_token_auth: bool = False # Enable token-based authentication (default: disabled)
+ enable_jwt_auth: bool = False # Enable JWT authentication (default: disabled)
+ enable_oauth_auth: bool = False # Enable OAuth 2.0/OIDC authentication (default: disabled)
+
+ # Legacy configuration (kept for backward compatibility)
+ auth_type: str = "token" # jwt, token, basic, oauth (deprecated: use individual switches)
+ token_secret: str = "default_secret" # Legacy token secret for backward compatibility
token_expiry: int = 3600
+
+ # Enhanced Token Authentication Configuration
+ token_file_path: str = "tokens.json" # Path to token configuration file
+ enable_token_expiry: bool = True # Enable token expiration
+ default_token_expiry_hours: int = 24 * 30 # Default expiry: 30 days
+ token_hash_algorithm: str = "sha256" # Token hashing algorithm: sha256, sha512
+
+ # JWT Configuration
+ jwt_algorithm: str = "RS256" # RS256, ES256, HS256
+ jwt_issuer: str = "doris-mcp-server"
+ jwt_audience: str = "doris-mcp-client"
+ jwt_private_key_path: str = ""
+ jwt_public_key_path: str = ""
+ jwt_secret_key: str = "" # Only used for HS256 algorithm
+ jwt_access_token_expiry: int = 3600 # 1 hour
+ jwt_refresh_token_expiry: int = 86400 # 24 hours
+ enable_token_refresh: bool = True
+ enable_token_revocation: bool = True
+ key_rotation_interval: int = 30 * 24 * 3600 # 30 days in seconds
+
+ # JWT Security Features
+ jwt_require_iat: bool = True # Require "issued at" claim
+ jwt_require_exp: bool = True # Require "expires at" claim
+ jwt_require_nbf: bool = False # Require "not before" claim
+ jwt_leeway: int = 10 # Clock skew tolerance in seconds
+ jwt_verify_signature: bool = True # Verify JWT signature
+ jwt_verify_audience: bool = True # Verify audience claim
+ jwt_verify_issuer: bool = True # Verify issuer claim
# SQL security configuration
enable_security_check: bool = True # Main switch: whether to enable SQL security check
@@ -115,6 +148,45 @@ class SecurityConfig:
enable_masking: bool = True
masking_rules: list[dict[str, Any]] = field(default_factory=list)
+ # OAuth 2.0/OIDC Configuration
+ oauth_enabled: bool = False
+ oauth_provider: str = "" # 'google', 'microsoft', 'github', 'custom'
+ oauth_client_id: str = ""
+ oauth_client_secret: str = ""
+ oauth_redirect_uri: str = "http://localhost:3000/auth/callback"
+
+ # OIDC Discovery
+ oidc_discovery_url: str = "" # e.g., https://accounts.google.com/.well-known/openid_configuration
+ oauth_authorization_endpoint: str = ""
+ oauth_token_endpoint: str = ""
+ oauth_userinfo_endpoint: str = ""
+ oauth_jwks_uri: str = ""
+
+ # OAuth Scopes and Settings
+ oauth_scopes: list[str] = field(default_factory=list)
+ oauth_state_expiry: int = 600 # State parameter expiry in seconds (10 minutes)
+ oauth_pkce_enabled: bool = True # Enable PKCE for better security
+ oauth_nonce_enabled: bool = True # Enable nonce for OIDC
+
+ # User Mapping Configuration
+ oauth_user_id_claim: str = "sub" # JWT claim for user ID
+ oauth_email_claim: str = "email"
+ oauth_name_claim: str = "name"
+ oauth_roles_claim: str = "roles" # Custom claim for roles
+ oauth_default_roles: list[str] = field(default_factory=lambda: ["oauth_user"])
+
+ def __post_init__(self):
+ """Initialize default OAuth scopes based on provider"""
+ if not self.oauth_scopes and self.oauth_provider:
+ if self.oauth_provider == "google":
+ self.oauth_scopes = ["openid", "email", "profile"]
+ elif self.oauth_provider == "microsoft":
+ self.oauth_scopes = ["openid", "profile", "email", "User.Read"]
+ elif self.oauth_provider == "github":
+ self.oauth_scopes = ["user:email", "read:user"]
+ else:
+ self.oauth_scopes = ["openid", "email", "profile"]
+
@dataclass
class PerformanceConfig:
@@ -338,6 +410,10 @@ class DorisConfig:
)
# Security configuration
+ # Independent authentication switches
+ config.security.enable_token_auth = os.getenv("ENABLE_TOKEN_AUTH", str(config.security.enable_token_auth)).lower() == "true"
+ config.security.enable_jwt_auth = os.getenv("ENABLE_JWT_AUTH", str(config.security.enable_jwt_auth)).lower() == "true"
+ config.security.enable_oauth_auth = os.getenv("ENABLE_OAUTH_AUTH", str(config.security.enable_oauth_auth)).lower() == "true"
config.security.auth_type = os.getenv("AUTH_TYPE", config.security.auth_type)
config.security.token_secret = os.getenv("TOKEN_SECRET", config.security.token_secret)
config.security.token_expiry = int(
@@ -368,6 +444,16 @@ class DorisConfig:
config.security.enable_masking = (
os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true"
)
+
+ # Enhanced Token Authentication configuration
+ config.security.token_file_path = os.getenv("TOKEN_FILE_PATH", config.security.token_file_path)
+ config.security.enable_token_expiry = (
+ os.getenv("ENABLE_TOKEN_EXPIRY", str(config.security.enable_token_expiry).lower()).lower() == "true"
+ )
+ config.security.default_token_expiry_hours = int(
+ os.getenv("DEFAULT_TOKEN_EXPIRY_HOURS", str(config.security.default_token_expiry_hours))
+ )
+ config.security.token_hash_algorithm = os.getenv("TOKEN_HASH_ALGORITHM", config.security.token_hash_algorithm)
# Performance configuration
config.performance.enable_query_cache = (
diff --git a/doris_mcp_server/utils/security.py b/doris_mcp_server/utils/security.py
index c1c4dc8..9b319b0 100644
--- a/doris_mcp_server/utils/security.py
+++ b/doris_mcp_server/utils/security.py
@@ -22,10 +22,10 @@ Implements enterprise-level authentication, authorization, SQL security validati
import logging
import re
-from dataclasses import dataclass
+from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
-from typing import Any
+from typing import Any, Optional
import sqlparse
from sqlparse.sql import Statement
@@ -45,15 +45,13 @@ class SecurityLevel(Enum):
@dataclass
class AuthContext:
- """Authentication context"""
+ """Authentication context for audit and session tracking"""
- user_id: str
- roles: list[str]
- permissions: list[str]
- session_id: str
- login_time: datetime | None = None
+ token_id: str # Token identifier for audit logging
+ client_ip: str = "unknown" # Client IP address
+ session_id: str = "" # Session identifier
+ login_time: datetime = field(default_factory=datetime.utcnow)
last_activity: datetime | None = None
- security_level: SecurityLevel = SecurityLevel.INTERNAL
@dataclass
@@ -100,6 +98,36 @@ class DorisSecurityManager:
self.blocked_keywords = self._load_blocked_keywords()
self.sensitive_tables = self._load_sensitive_tables()
self.masking_rules = self._load_masking_rules()
+
+ # Track initialization state
+ self._initialized = False
+
+ async def initialize(self):
+ """Initialize security manager components"""
+ if self._initialized:
+ return
+
+ try:
+ # Initialize authentication provider (for JWT setup)
+ await self.auth_provider.initialize()
+
+ self._initialized = True
+ self.logger.info("DorisSecurityManager initialized successfully")
+
+ except Exception as e:
+ self.logger.error(f"Failed to initialize DorisSecurityManager: {e}")
+ raise
+
+ async def shutdown(self):
+ """Shutdown security manager components"""
+ try:
+ await self.auth_provider.shutdown()
+ self._initialized = False
+ self.logger.info("DorisSecurityManager shutdown completed")
+
+ except Exception as e:
+ self.logger.error(f"Error during DorisSecurityManager shutdown: {e}")
+ raise
def _load_blocked_keywords(self) -> set[str]:
"""Load blocked SQL keywords from configuration"""
@@ -184,8 +212,55 @@ class DorisSecurityManager:
return default_rules
async def authenticate_request(self, auth_info: dict[str, Any]) -> AuthContext:
- """Validate request authentication information"""
- return await self.auth_provider.authenticate(auth_info)
+ """Validate request authentication information
+
+ Tries authentication methods in order: Token -> JWT -> OAuth
+ Any one method succeeding allows access
+ If all methods are disabled, returns anonymous context
+ """
+ # Check if any authentication method is enabled
+ if not (self.config.security.enable_token_auth or
+ self.config.security.enable_jwt_auth or
+ self.config.security.enable_oauth_auth):
+ self.logger.debug("All authentication methods are disabled")
+ # Return anonymous context when no authentication is enabled
+ return AuthContext(
+ token_id="anonymous",
+ client_ip=auth_info.get("client_ip", "unknown"),
+ session_id="anonymous_session"
+ )
+
+ # Try authentication methods in order of preference
+ last_error = None
+
+ # 1. Try Token authentication first (most common)
+ if self.config.security.enable_token_auth:
+ try:
+ return await self.auth_provider.authenticate_token(auth_info)
+ except Exception as e:
+ self.logger.debug(f"Token authentication failed: {e}")
+ last_error = e
+
+ # 2. Try JWT authentication
+ if self.config.security.enable_jwt_auth:
+ try:
+ return await self.auth_provider.authenticate_jwt(auth_info)
+ except Exception as e:
+ self.logger.debug(f"JWT authentication failed: {e}")
+ last_error = e
+
+ # 3. Try OAuth authentication
+ if self.config.security.enable_oauth_auth:
+ try:
+ return await self.auth_provider.authenticate_oauth(auth_info)
+ except Exception as e:
+ self.logger.debug(f"OAuth authentication failed: {e}")
+ last_error = e
+
+ # All enabled authentication methods failed
+ error_message = f"Authentication failed: {str(last_error)}" if last_error else "No authentication method succeeded"
+ self.logger.warning(f"Authentication failed for client {auth_info.get('client_ip', 'unknown')}: {error_message}")
+ raise ValueError(error_message)
async def authorize_resource_access(
self, auth_context: AuthContext, resource_uri: str
@@ -207,6 +282,117 @@ class DorisSecurityManager:
"""Apply data masking processing"""
return await self.masking_processor.process(data, auth_context)
+ # OAuth-specific methods
+ def get_oauth_authorization_url(self) -> tuple[str, str]:
+ """Get OAuth authorization URL
+
+ Returns:
+ Tuple of (authorization_url, state)
+ """
+ if not self.auth_provider.oauth_provider:
+ raise ValueError("OAuth is not enabled")
+ return self.auth_provider.oauth_provider.get_authorization_url()
+
+ async def handle_oauth_callback(self, code: str, state: str) -> AuthContext:
+ """Handle OAuth callback
+
+ Args:
+ code: Authorization code from OAuth provider
+ state: State parameter for CSRF protection
+
+ Returns:
+ AuthContext for authenticated user
+ """
+ if not self.auth_provider.oauth_provider:
+ raise ValueError("OAuth is not enabled")
+ return await self.auth_provider.oauth_provider.handle_callback(code, state)
+
+ def get_oauth_provider_info(self) -> dict[str, Any]:
+ """Get OAuth provider information
+
+ Returns:
+ OAuth provider information
+ """
+ if not self.auth_provider.oauth_provider:
+ return {"enabled": False}
+ return self.auth_provider.oauth_provider.get_provider_info()
+
+ # Token management methods
+ async def create_token(
+ self,
+ token_id: str,
+ expires_hours: Optional[int] = None,
+ description: str = "",
+ custom_token: Optional[str] = None
+ ) -> str:
+ """Create a new API access token
+
+ Args:
+ token_id: Unique token identifier for audit and management
+ expires_hours: Token expiration in hours (None for no expiration)
+ description: Token description for management purposes
+ custom_token: Custom token string (if None, generates random token)
+
+ Returns:
+ Generated token string
+ """
+ if not self.auth_provider.token_manager:
+ raise ValueError("Token manager not initialized")
+
+ return await self.auth_provider.token_manager.create_token(
+ token_id=token_id,
+ expires_hours=expires_hours,
+ description=description,
+ custom_token=custom_token
+ )
+
+ async def revoke_token(self, token_id: str) -> bool:
+ """Revoke a token by token ID
+
+ Args:
+ token_id: Token ID to revoke
+
+ Returns:
+ True if token was revoked successfully
+ """
+ if not self.auth_provider.token_manager:
+ raise ValueError("Token manager not initialized")
+
+ return await self.auth_provider.token_manager.revoke_token(token_id)
+
+ async def list_tokens(self) -> list[dict[str, Any]]:
+ """List all tokens (without sensitive data)
+
+ Returns:
+ List of token information
+ """
+ if not self.auth_provider.token_manager:
+ raise ValueError("Token manager not initialized")
+
+ return await self.auth_provider.token_manager.list_tokens()
+
+ async def cleanup_expired_tokens(self) -> int:
+ """Remove expired tokens and return count
+
+ Returns:
+ Number of expired tokens removed
+ """
+ if not self.auth_provider.token_manager:
+ return 0
+
+ return await self.auth_provider.token_manager.cleanup_expired_tokens()
+
+ def get_token_stats(self) -> dict[str, Any]:
+ """Get token statistics
+
+ Returns:
+ Token statistics dictionary
+ """
+ if not self.auth_provider.token_manager:
+ return {"error": "Token manager not initialized"}
+
+ return self.auth_provider.token_manager.get_token_stats()
+
class AuthenticationProvider:
"""Authentication provider"""
@@ -215,35 +401,199 @@ class AuthenticationProvider:
self.config = config
self.logger = get_logger(__name__)
self.session_cache = {}
-
- async def authenticate(self, auth_info: dict[str, Any]) -> AuthContext:
- """Perform identity authentication"""
- auth_type = auth_info.get("type", "token")
-
- if auth_type == "token":
- return await self._authenticate_token(auth_info)
- elif auth_type == "basic":
- return await self._authenticate_basic(auth_info)
+ self.jwt_manager = None
+ self.oauth_provider = None
+ self.token_manager = None
+
+ # Initialize authentication providers based on individual switches
+ auth_methods_enabled = []
+
+ # Initialize Token manager if enabled
+ if config.security.enable_token_auth:
+ self._initialize_token_manager()
+ auth_methods_enabled.append("Token")
+
+ # Initialize JWT manager if enabled
+ if config.security.enable_jwt_auth:
+ self._initialize_jwt_manager()
+ auth_methods_enabled.append("JWT")
+
+ # Initialize OAuth provider if enabled
+ if config.security.enable_oauth_auth or (hasattr(config.security, 'oauth_enabled') and config.security.oauth_enabled):
+ self._initialize_oauth_provider()
+ auth_methods_enabled.append("OAuth")
+
+ if auth_methods_enabled:
+ self.logger.info(f"Authentication enabled with methods: {', '.join(auth_methods_enabled)}")
else:
- raise ValueError(f"Unsupported authentication type: {auth_type}")
+ self.logger.info("All authentication methods are disabled - anonymous access allowed")
+
+ def _initialize_jwt_manager(self):
+ """Initialize JWT manager"""
+ try:
+ from ..auth.jwt_manager import JWTManager
+ self.jwt_manager = JWTManager(self.config)
+ self.logger.info("JWT manager initialized")
+ except ImportError as e:
+ self.logger.error(f"Failed to import JWT manager: {e}")
+ raise
+ except Exception as e:
+ self.logger.error(f"Failed to initialize JWT manager: {e}")
+ raise
+
+ def _initialize_token_manager(self):
+ """Initialize Token manager"""
+ try:
+ from ..auth.token_manager import TokenManager
+ self.token_manager = TokenManager(self.config)
+ self.logger.info("Token manager initialized")
+ except ImportError as e:
+ self.logger.error(f"Failed to import Token manager: {e}")
+ raise
+ except Exception as e:
+ self.logger.error(f"Failed to initialize Token manager: {e}")
+ raise
+
+ def _initialize_oauth_provider(self):
+ """Initialize OAuth provider"""
+ try:
+ from ..auth.oauth_provider import OAuthAuthenticationProvider
+ self.oauth_provider = OAuthAuthenticationProvider(self.config)
+ self.logger.info("OAuth provider initialized")
+ except ImportError as e:
+ self.logger.error(f"Failed to import OAuth provider: {e}")
+ raise
+ except Exception as e:
+ self.logger.error(f"Failed to initialize OAuth provider: {e}")
+ raise
+
+ async def initialize(self):
+ """Initialize authentication provider asynchronously"""
+ if self.jwt_manager:
+ success = await self.jwt_manager.initialize()
+ if not success:
+ raise RuntimeError("Failed to initialize JWT manager")
+ self.logger.info("JWT authentication provider initialized successfully")
+
+ if self.token_manager:
+ # Token manager doesn't need async initialization, just log success
+ self.logger.info("Token authentication provider initialized successfully")
+
+ if self.oauth_provider:
+ success = await self.oauth_provider.initialize()
+ if not success:
+ raise RuntimeError("Failed to initialize OAuth provider")
+ self.logger.info("OAuth authentication provider initialized successfully")
+
+ async def shutdown(self):
+ """Shutdown authentication provider"""
+ if self.jwt_manager:
+ await self.jwt_manager.shutdown()
+ self.logger.info("JWT authentication provider shutdown completed")
+
+ if self.token_manager:
+ # Token manager doesn't need async shutdown, just log
+ self.logger.info("Token authentication provider shutdown completed")
+
+ if self.oauth_provider:
+ await self.oauth_provider.shutdown()
+ self.logger.info("OAuth authentication provider shutdown completed")
+
+ async def authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext:
+ """Perform token authentication"""
+ if not self.config.security.enable_token_auth:
+ raise ValueError("Token authentication is not enabled")
+ return await self._authenticate_token(auth_info)
+
+ async def authenticate_jwt(self, auth_info: dict[str, Any]) -> AuthContext:
+ """Perform JWT authentication"""
+ if not self.config.security.enable_jwt_auth:
+ raise ValueError("JWT authentication is not enabled")
+ return await self._authenticate_jwt(auth_info)
+
+ async def authenticate_oauth(self, auth_info: dict[str, Any]) -> AuthContext:
+ """Perform OAuth authentication"""
+ if not self.config.security.enable_oauth_auth:
+ raise ValueError("OAuth authentication is not enabled")
+ return await self._authenticate_oauth(auth_info)
+
+ async def _authenticate_jwt(self, auth_info: dict[str, Any]) -> AuthContext:
+ """JWT authentication"""
+ if not self.jwt_manager:
+ raise ValueError("JWT manager not initialized")
+
+ token = auth_info.get("token")
+ if not token:
+ # Try to extract from Authorization header
+ authorization = auth_info.get("authorization")
+ if authorization and authorization.startswith('Bearer '):
+ token = authorization[7:]
+
+ if not token:
+ raise ValueError("Missing JWT token")
+
+ try:
+ # Use JWT middleware for authentication
+ from ..auth.auth_middleware import AuthMiddleware
+ middleware = AuthMiddleware(self.jwt_manager)
+ return await middleware.authenticate_request(auth_info)
+
+ except Exception as e:
+ self.logger.error(f"JWT authentication failed: {e}")
+ raise ValueError(f"JWT authentication failed: {str(e)}")
+
+ async def _authenticate_oauth(self, auth_info: dict[str, Any]) -> AuthContext:
+ """OAuth authentication"""
+ if not self.oauth_provider:
+ raise ValueError("OAuth provider not initialized")
+
+ # Handle different OAuth authentication scenarios
+ if "access_token" in auth_info:
+ # Direct OAuth access token authentication
+ return await self.oauth_provider.authenticate_with_token(auth_info["access_token"])
+ elif "code" in auth_info and "state" in auth_info:
+ # OAuth callback authentication
+ return await self.oauth_provider.handle_callback(auth_info["code"], auth_info["state"])
+ else:
+ raise ValueError("OAuth authentication requires either access_token or code+state")
async def _authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext:
"""Token authentication"""
+ if not self.token_manager:
+ raise ValueError("Token manager not initialized")
+
token = auth_info.get("token")
+ if not token:
+ # Try to extract from Authorization header
+ authorization = auth_info.get("authorization")
+ if authorization and authorization.startswith('Bearer '):
+ token = authorization[7:]
+ elif authorization and authorization.startswith('Token '):
+ token = authorization[6:]
+
if not token:
raise ValueError("Missing authentication token")
- # Validate token (simplified implementation, should validate JWT or query authentication service in practice)
- user_info = await self._validate_token(token)
-
- return AuthContext(
- user_id=user_info["user_id"],
- roles=user_info["roles"],
- permissions=user_info["permissions"],
- session_id=auth_info.get("session_id", "default"),
- login_time=datetime.utcnow(),
- security_level=SecurityLevel(user_info.get("security_level", "internal")),
- )
+ try:
+ # Validate token using TokenManager
+ validation_result = await self.token_manager.validate_token(token)
+
+ if not validation_result.is_valid:
+ raise ValueError(f"Token validation failed: {validation_result.error_message}")
+
+ token_info = validation_result.token_info
+
+ return AuthContext(
+ token_id=token_info.token_id,
+ client_ip=auth_info.get("client_ip", "unknown"),
+ session_id=auth_info.get("session_id", f"session_{token_info.token_id}"),
+ login_time=datetime.utcnow(),
+ last_activity=token_info.last_used
+ )
+
+ except Exception as e:
+ self.logger.error(f"Token authentication failed: {e}")
+ raise ValueError(f"Token authentication failed: {str(e)}")
async def _authenticate_basic(self, auth_info: dict[str, Any]) -> AuthContext:
"""Basic authentication (username password)"""
diff --git a/tokens.json b/tokens.json
new file mode 100644
index 0000000..1524a49
--- /dev/null
+++ b/tokens.json
@@ -0,0 +1,37 @@
+{
+ "version": "1.0",
+ "description": "Simplified Token configuration file for Doris MCP Server API access control",
+ "created_at": "2025-09-01T00:00:00Z",
+ "tokens": [
+ {
+ "token_id": "admin-token",
+ "token": "doris_admin_token_123456",
+ "description": "Doris admin API access token",
+ "expires_hours": null,
+ "is_active": true
+ },
+ {
+ "token_id": "analyst-token",
+ "token": "doris_analyst_token_123456",
+ "description": "Doris analyst API access token",
+ "expires_hours": 8760,
+ "is_active": true
+ },
+ {
+ "token_id": "readonly-token",
+ "token": "doris_readonly_token_123456",
+ "description": "Doris readonly API access token",
+ "expires_hours": 4320,
+ "is_active": true
+ }
+ ],
+ "notes": [
+ "The admin_token, analyst_token, readonly_token is default token,Please change the token before using in production!",
+ "The token_id is the key of the token,Please use the token_id to identify the token",
+ "The token is the value of the token,Please use the token to identify the token",
+ "The description is the description of the token,Please use the description to identify the token",
+ "The expires_hours is the expires hours of the token,Please use the expires_hours to identify the token",
+ "The is_active is the is active of the token,Please use the is_active to identify the token",
+ "The token_id, token, description, expires_hours, is_active is the metadata of the token,Please use the metadata to identify the token"
+ ]
+}
\ No newline at end of file