diff --git a/.env.example b/.env.example index 167c2f7..616ccd5 100644 --- a/.env.example +++ b/.env.example @@ -36,8 +36,145 @@ DORIS_MAX_CONNECTION_AGE=3600 # Security Configuration # =================================================================== -# Authentication configuration +# Independent Authentication Switches - NEW DESIGN! +# Each authentication method can be enabled/disabled independently +# Any enabled method that succeeds will allow access +# If all methods are disabled, anonymous access is allowed + +# Legacy configuration - kept for backward compatibility +# AUTH_TYPE is now deprecated - use individual switches above AUTH_TYPE=token + +# Token Authentication (Default method - simple and effective) +ENABLE_TOKEN_AUTH=false + +# JWT Authentication (For stateless applications) +ENABLE_JWT_AUTH=false + +# OAuth 2.0/OIDC Authentication (For enterprise integration) +ENABLE_OAUTH_AUTH=false + +# =================================================================== +# Token Authentication Configuration (Enable with ENABLE_TOKEN_AUTH=true) +# =================================================================== + +# Basic token authentication settings +TOKEN_FILE_PATH=tokens.json +ENABLE_TOKEN_EXPIRY=true +DEFAULT_TOKEN_EXPIRY_HOURS=720 +TOKEN_HASH_ALGORITHM=sha256 + +# =================================================================== +# JWT Authentication Configuration (Enable with ENABLE_JWT_AUTH=true) +# =================================================================== + +# JWT token settings (when ENABLE_JWT_AUTH=true) +JWT_SECRET_KEY=your_jwt_secret_key_here_change_in_production +JWT_ALGORITHM=HS256 +JWT_EXPIRATION_HOURS=24 +JWT_ISSUER=doris-mcp-server +JWT_AUDIENCE=doris-mcp-client + +# JWT token validation settings +JWT_VERIFY_SIGNATURE=true +JWT_VERIFY_EXPIRATION=true +JWT_VERIFY_AUDIENCE=true +JWT_VERIFY_ISSUER=true + +# JWT refresh token settings +ENABLE_JWT_REFRESH=true +JWT_REFRESH_EXPIRATION_DAYS=30 +JWT_REFRESH_SECRET_KEY=your_jwt_refresh_secret_key_here + +# JWT user claims configuration +JWT_USER_ID_CLAIM=user_id +JWT_ROLES_CLAIM=roles +JWT_PERMISSIONS_CLAIM=permissions +JWT_SECURITY_LEVEL_CLAIM=security_level + +# =================================================================== +# OAuth 2.0 / OpenID Connect Configuration (Enable with ENABLE_OAUTH_AUTH=true) +# =================================================================== + +# OAuth provider settings (when ENABLE_OAUTH_AUTH=true) +OAUTH_PROVIDER_TYPE=generic +OAUTH_CLIENT_ID=your_oauth_client_id +OAUTH_CLIENT_SECRET=your_oauth_client_secret +OAUTH_REDIRECT_URI=http://localhost:3000/auth/callback + +# OAuth endpoints (for generic provider) +OAUTH_AUTHORIZATION_URL=https://your-provider.com/auth +OAUTH_TOKEN_URL=https://your-provider.com/token +OAUTH_USERINFO_URL=https://your-provider.com/userinfo +OAUTH_JWKS_URL=https://your-provider.com/.well-known/jwks.json + +# OAuth scope and claims +OAUTH_SCOPE=openid profile email +OAUTH_USER_ID_CLAIM=sub +OAUTH_USERNAME_CLAIM=preferred_username +OAUTH_EMAIL_CLAIM=email +OAUTH_ROLES_CLAIM=roles +OAUTH_GROUPS_CLAIM=groups + +# OAuth session settings +OAUTH_SESSION_SECRET=your_oauth_session_secret_here +OAUTH_SESSION_EXPIRY=3600 +OAUTH_STATE_EXPIRY=300 + +# Popular OAuth providers presets (uncomment and configure as needed) + +# Google OAuth Configuration +# OAUTH_PROVIDER_TYPE=google +# OAUTH_CLIENT_ID=your_google_client_id.apps.googleusercontent.com +# OAUTH_CLIENT_SECRET=your_google_client_secret +# OAUTH_AUTHORIZATION_URL=https://accounts.google.com/o/oauth2/auth +# OAUTH_TOKEN_URL=https://oauth2.googleapis.com/token +# OAUTH_USERINFO_URL=https://www.googleapis.com/oauth2/v1/userinfo +# OAUTH_JWKS_URL=https://www.googleapis.com/oauth2/v3/certs +# OAUTH_SCOPE=openid profile email + +# Microsoft Azure AD Configuration +# OAUTH_PROVIDER_TYPE=azure +# OAUTH_CLIENT_ID=your_azure_client_id +# OAUTH_CLIENT_SECRET=your_azure_client_secret +# OAUTH_TENANT_ID=your_tenant_id +# OAUTH_AUTHORIZATION_URL=https://login.microsoftonline.com/{tenant}/oauth2/v2.0/authorize +# OAUTH_TOKEN_URL=https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token +# OAUTH_USERINFO_URL=https://graph.microsoft.com/v1.0/me +# OAUTH_JWKS_URL=https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys +# OAUTH_SCOPE=openid profile email + +# GitHub OAuth Configuration +# OAUTH_PROVIDER_TYPE=github +# OAUTH_CLIENT_ID=your_github_client_id +# OAUTH_CLIENT_SECRET=your_github_client_secret +# OAUTH_AUTHORIZATION_URL=https://github.com/login/oauth/authorize +# OAUTH_TOKEN_URL=https://github.com/login/oauth/access_token +# OAUTH_USERINFO_URL=https://api.github.com/user +# OAUTH_SCOPE=user:email + +# GitLab OAuth Configuration +# OAUTH_PROVIDER_TYPE=gitlab +# OAUTH_CLIENT_ID=your_gitlab_client_id +# OAUTH_CLIENT_SECRET=your_gitlab_client_secret +# OAUTH_AUTHORIZATION_URL=https://gitlab.com/oauth/authorize +# OAUTH_TOKEN_URL=https://gitlab.com/oauth/token +# OAUTH_USERINFO_URL=https://gitlab.com/api/v4/user +# OAUTH_SCOPE=read_user + +# Keycloak OAuth Configuration +# OAUTH_PROVIDER_TYPE=keycloak +# OAUTH_CLIENT_ID=your_keycloak_client_id +# OAUTH_CLIENT_SECRET=your_keycloak_client_secret +# OAUTH_REALM=your_realm +# OAUTH_SERVER_URL=https://your-keycloak-server.com +# OAUTH_AUTHORIZATION_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/auth +# OAUTH_TOKEN_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/token +# OAUTH_USERINFO_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/userinfo +# OAUTH_JWKS_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/certs +# OAUTH_SCOPE=openid profile email + +# Legacy token settings (for backward compatibility) TOKEN_SECRET=your_secret_key_here TOKEN_EXPIRY=3600 @@ -172,7 +309,13 @@ TEMP_FILES_DIR=tmp # - LOG_CLEANUP_INTERVAL_HOURS: Check frequency, recommended 24 hours # 2. Security Best Practices: -# - Must change TOKEN_SECRET in production environment +# - NEW: Enable individual authentication methods using ENABLE_TOKEN_AUTH, ENABLE_JWT_AUTH, ENABLE_OAUTH_AUTH +# - When all methods are disabled, ALL requests are allowed with anonymous access +# - Authentication methods work independently - any one succeeding allows access +# - Token Auth: Change default tokens (DEFAULT_ADMIN_TOKEN, etc.) in production +# - JWT Auth: Change JWT_SECRET_KEY and JWT_REFRESH_SECRET_KEY in production +# - OAuth Auth: Configure OAuth provider settings and secure client secrets +# - Must change TOKEN_SECRET in production environment (legacy compatibility) # - Adjust BLOCKED_KEYWORDS according to business needs # - Enable ENABLE_SECURITY_CHECK and ENABLE_MASKING @@ -193,4 +336,43 @@ TEMP_FILES_DIR=tmp # - ADBC_DEFAULT_RETURN_FORMAT: Default return format (arrow/pandas/dict, recommended: arrow) # - ADBC_CONNECTION_TIMEOUT: Connection timeout for ADBC (recommended: 30) # - ADBC_ENABLED: Enable or disable ADBC tools (true/false) -# - Prerequisites: Install adbc_driver_manager, adbc_driver_flightsql, pyarrow packages \ No newline at end of file +# - Prerequisites: Install adbc_driver_manager, adbc_driver_flightsql, pyarrow packages + +# 6. Authentication Configuration Guide - UPDATED DESIGN! +# +# Independent Authentication Control (NEW): +# - ENABLE_TOKEN_AUTH=false (default): Disable token authentication +# - ENABLE_JWT_AUTH=false (default): Disable JWT authentication +# - ENABLE_OAUTH_AUTH=false (default): Disable OAuth authentication +# - When all methods are disabled, no authentication is required (anonymous access) +# - When multiple methods are enabled, any one succeeding allows access +# - Recommended for development/testing: all false, production: enable needed methods +# +# Token Authentication (ENABLE_TOKEN_AUTH=true) - Recommended for most use cases: +# - Simple and secure token-based authentication +# - Configurable default tokens via environment variables +# - Support for custom tokens via TOKEN_* environment variables +# - Token file configuration via tokens.json +# - Built-in token management HTTP endpoints +# - No user management complexity - pure API access control +# +# JWT Authentication (ENABLE_JWT_AUTH=true) - For stateless applications: +# - JSON Web Token based authentication +# - Configurable token expiration and refresh +# - Support for standard JWT claims +# - RSA/ECDSA/HS256 algorithm support +# - Suitable for microservices and distributed systems +# +# OAuth 2.0/OIDC (ENABLE_OAUTH_AUTH=true) - For enterprise integration: +# - Integration with external identity providers +# - Support for popular providers (Google, Microsoft, GitHub, GitLab, Keycloak) +# - OpenID Connect compatibility +# - Automatic user provisioning from provider +# - Secure authorization code flow +# +# Authentication Method Selection Guide: +# - No Auth (all switches false): Development, testing, trusted networks +# - Token Auth only: Small teams, simple deployment, direct API access +# - JWT Auth only: Stateless apps, microservices, mobile clients +# - OAuth Auth only: Enterprise SSO, large teams, external identity providers +# - Multiple methods: Flexible access, different client types, migration scenarios \ No newline at end of file diff --git a/doris_mcp_server/auth/__init__.py b/doris_mcp_server/auth/__init__.py new file mode 100644 index 0000000..fa82f70 --- /dev/null +++ b/doris_mcp_server/auth/__init__.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Doris MCP Server Authentication Module +Provides JWT-based, Token-based, and OAuth 2.0/OIDC authentication and authorization services +""" + +from .jwt_manager import JWTManager +from .key_manager import KeyManager +from .token_validators import TokenValidator, TokenBlacklist +from .auth_middleware import AuthMiddleware +from .token_manager import TokenManager, TokenInfo, TokenValidationResult +from .token_handlers import TokenHandlers +from .oauth_client import OAuthClient, OAuthStateManager +from .oauth_provider import OAuthAuthenticationProvider +from .oauth_types import ( + OAuthProvider, OAuthState, OAuthTokens, OAuthUserInfo, + OIDCDiscovery, OAuthError, OAuthProviderConfig +) + +__all__ = [ + "JWTManager", + "KeyManager", + "TokenValidator", + "TokenBlacklist", + "AuthMiddleware", + "TokenManager", + "TokenInfo", + "TokenValidationResult", + "TokenHandlers", + "OAuthClient", + "OAuthStateManager", + "OAuthAuthenticationProvider", + "OAuthProvider", + "OAuthState", + "OAuthTokens", + "OAuthUserInfo", + "OIDCDiscovery", + "OAuthError", + "OAuthProviderConfig" +] \ No newline at end of file diff --git a/doris_mcp_server/auth/auth_middleware.py b/doris_mcp_server/auth/auth_middleware.py new file mode 100644 index 0000000..3fd1792 --- /dev/null +++ b/doris_mcp_server/auth/auth_middleware.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Authentication Middleware Module +Provides middleware for JWT authentication in HTTP and MCP contexts +""" + +from typing import Optional, Dict, Any, Callable, Awaitable +from datetime import datetime + +from .jwt_manager import JWTManager +from ..utils.security import AuthContext, SecurityLevel +from ..utils.logger import get_logger + +logger = get_logger(__name__) + + +class AuthMiddleware: + """Authentication Middleware + + Provides JWT authentication functionality for HTTP and MCP requests + """ + + def __init__(self, jwt_manager: JWTManager): + """Initialize authentication middleware + + Args: + jwt_manager: JWT manager instance + """ + self.jwt_manager = jwt_manager + logger.info("AuthMiddleware initialized") + + def extract_token_from_header(self, authorization: str) -> Optional[str]: + """Extract JWT token from Authorization header + + Args: + authorization: Authorization header value + + Returns: + JWT token string, or None if not found + """ + if not authorization: + return None + + # Support Bearer format + if authorization.startswith('Bearer '): + return authorization[7:] # Remove "Bearer " prefix + + # Support direct token format + if not authorization.startswith('Basic '): + return authorization + + return None + + async def authenticate_request(self, auth_info: Dict[str, Any]) -> AuthContext: + """Authenticate request and return authentication context + + Args: + auth_info: Authentication information dictionary + + Returns: + AuthContext authentication context + + Raises: + ValueError: Authentication failed + """ + try: + auth_type = auth_info.get("type", "jwt") + + if auth_type == "jwt" or auth_type == "token": + return await self._authenticate_jwt(auth_info) + else: + raise ValueError(f"Unsupported authentication type: {auth_type}") + + except Exception as e: + logger.error(f"Request authentication failed: {e}") + raise + + async def _authenticate_jwt(self, auth_info: Dict[str, Any]) -> AuthContext: + """JWT authentication processing + + Args: + auth_info: Authentication information containing JWT token + + Returns: + AuthContext authentication context + """ + # Get token + token = auth_info.get("token") + if not token: + # Try to get from Authorization header + authorization = auth_info.get("authorization") + token = self.extract_token_from_header(authorization) + + if not token: + raise ValueError("Missing JWT token") + + try: + # Validate token + validation_result = await self.jwt_manager.validate_token(token, 'access') + payload = validation_result['payload'] + + # Build authentication context + auth_context = AuthContext( + user_id=payload.get('sub'), + roles=payload.get('roles', []), + permissions=payload.get('permissions', []), + session_id=payload.get('jti'), # Use JWT ID as session ID + login_time=datetime.fromtimestamp(payload.get('iat', 0)), + last_activity=datetime.utcnow(), + security_level=SecurityLevel(payload.get('security_level', 'internal')) + ) + + logger.info(f"JWT authentication successful for user: {auth_context.user_id}") + return auth_context + + except Exception as e: + logger.error(f"JWT authentication failed: {e}") + raise ValueError(f"JWT authentication failed: {str(e)}") + + async def create_auth_response_headers(self, auth_context: AuthContext) -> Dict[str, str]: + """Create authentication response headers + + Args: + auth_context: Authentication context + + Returns: + Response headers dictionary + """ + return { + 'X-Auth-User': auth_context.user_id, + 'X-Auth-Roles': ','.join(auth_context.roles), + 'X-Auth-Session': auth_context.session_id, + 'X-Auth-Security-Level': auth_context.security_level.value + } + + def create_http_middleware(self, skip_paths: Optional[list] = None): + """Create HTTP middleware function + + Args: + skip_paths: List of paths to skip authentication + + Returns: + ASGI middleware function + """ + skip_paths = skip_paths or ['/health', '/docs', '/openapi.json'] + + async def middleware(scope, receive, send): + """HTTP authentication middleware""" + if scope['type'] != 'http': + # Pass through non-HTTP requests directly + return await self.app(scope, receive, send) + + path = scope.get('path', '') + + # Check if authentication should be skipped + if any(path.startswith(skip) for skip in skip_paths): + return await self.app(scope, receive, send) + + # Extract authentication information + headers = dict(scope.get('headers', [])) + authorization = headers.get(b'authorization', b'').decode() + + try: + # Perform authentication + auth_info = { + 'type': 'jwt', + 'authorization': authorization + } + auth_context = await self.authenticate_request(auth_info) + + # Add authentication context to scope + scope['auth_context'] = auth_context + + # Create response wrapper to add authentication headers + async def send_wrapper(message): + if message['type'] == 'http.response.start': + headers = dict(message.get('headers', [])) + auth_headers = await self.create_auth_response_headers(auth_context) + + for key, value in auth_headers.items(): + headers[key.encode()] = value.encode() + + message['headers'] = list(headers.items()) + + await send(message) + + return await self.app(scope, receive, send_wrapper) + + except Exception as e: + # Authentication failed, return 401 error + response_body = f'{{"error": "Authentication failed", "message": "{str(e)}"}}' + + await send({ + 'type': 'http.response.start', + 'status': 401, + 'headers': [ + (b'content-type', b'application/json'), + (b'www-authenticate', b'Bearer') + ] + }) + await send({ + 'type': 'http.response.body', + 'body': response_body.encode() + }) + + return middleware + + async def authenticate_mcp_request(self, headers: Dict[str, str]) -> AuthContext: + """Authenticate MCP request + + Args: + headers: MCP request headers + + Returns: + AuthContext authentication context + """ + try: + # Extract authentication information from multiple possible header fields + authorization = ( + headers.get('Authorization') or + headers.get('authorization') or + headers.get('X-Auth-Token') or + headers.get('x-auth-token') + ) + + auth_info = { + 'type': 'jwt', + 'authorization': authorization + } + + return await self.authenticate_request(auth_info) + + except Exception as e: + logger.error(f"MCP request authentication failed: {e}") + raise + + +class AuthenticationError(Exception): + """Authentication error exception""" + + def __init__(self, message: str, error_code: str = "AUTH_FAILED"): + self.message = message + self.error_code = error_code + super().__init__(message) + + +class AuthorizationError(Exception): + """Authorization error exception""" + + def __init__(self, message: str, error_code: str = "ACCESS_DENIED"): + self.message = message + self.error_code = error_code + super().__init__(message) \ No newline at end of file diff --git a/doris_mcp_server/auth/jwt_manager.py b/doris_mcp_server/auth/jwt_manager.py new file mode 100644 index 0000000..fa79d1a --- /dev/null +++ b/doris_mcp_server/auth/jwt_manager.py @@ -0,0 +1,471 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +JWT Manager Module +Provides comprehensive JWT token management including generation, validation, refresh and revocation +""" + +import time +import uuid +import asyncio +from typing import Dict, Any, Optional, Tuple +from datetime import datetime, timedelta + +try: + import jwt +except ImportError: + raise ImportError("PyJWT is required for JWT functionality. Install with: pip install PyJWT[crypto]") + +from .key_manager import KeyManager +from .token_validators import TokenValidator, TokenBlacklist +from ..utils.logger import get_logger + +logger = get_logger(__name__) + + +class JWTManager: + """JWT Token Manager + + Provides comprehensive JWT token lifecycle management, including: + - Token generation and signing + - Token validation and parsing + - Token refresh mechanism + - Token revocation and blacklist + - Automatic key rotation + """ + + def __init__(self, config): + """Initialize JWT manager + + Args: + config: DorisConfig configuration object (with security attribute) + """ + self.config = config + # Access JWT settings through the security configuration + if hasattr(config, 'security'): + security_config = config.security + else: + # Fallback if config is passed directly as SecurityConfig + security_config = config + + self.algorithm = security_config.jwt_algorithm + self.issuer = security_config.jwt_issuer + self.audience = security_config.jwt_audience + self.access_token_expiry = security_config.jwt_access_token_expiry + self.refresh_token_expiry = security_config.jwt_refresh_token_expiry + self.enable_refresh = security_config.enable_token_refresh + self.enable_revocation = security_config.enable_token_revocation + + # Initialize components + self.key_manager = KeyManager(config) + self.token_blacklist = TokenBlacklist() + self.validator = TokenValidator(config, self.token_blacklist) + + # Automatic key rotation task + self._key_rotation_task = None + + logger.info(f"JWTManager initialized with algorithm: {self.algorithm}") + + async def initialize(self) -> bool: + """Initialize JWT manager""" + try: + # Initialize key manager + if not await self.key_manager.initialize(): + logger.error("Failed to initialize key manager") + return False + + # Start token validator + await self.validator.start() + + # Start automatic key rotation + if self.key_manager.key_rotation_interval > 0: + self._key_rotation_task = asyncio.create_task(self._auto_key_rotation()) + + logger.info("JWTManager initialization completed") + return True + + except Exception as e: + logger.error(f"Failed to initialize JWTManager: {e}") + return False + + async def shutdown(self): + """Shutdown JWT manager""" + try: + # Stop key rotation task + if self._key_rotation_task: + self._key_rotation_task.cancel() + try: + await self._key_rotation_task + except asyncio.CancelledError: + pass + + # Stop validator + await self.validator.stop() + + logger.info("JWTManager shutdown completed") + + except Exception as e: + logger.error(f"Error during JWTManager shutdown: {e}") + + async def generate_tokens(self, user_info: Dict[str, Any], + custom_claims: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """Generate access token and refresh token + + Args: + user_info: User information dictionary, containing user_id, roles, permissions, etc. + custom_claims: Custom claims + + Returns: + Dictionary containing access_token and refresh_token + """ + try: + current_time = int(time.time()) + jti = str(uuid.uuid4()) + + # Build base payload + base_payload = { + 'iss': self.issuer, + 'aud': self.audience, + 'iat': current_time, + 'jti': jti, + 'sub': user_info.get('user_id'), + 'roles': user_info.get('roles', []), + 'permissions': user_info.get('permissions', []), + 'security_level': user_info.get('security_level', 'internal') + } + + # Add custom claims + if custom_claims: + base_payload.update(custom_claims) + + # Generate access token + access_payload = base_payload.copy() + access_payload.update({ + 'exp': current_time + self.access_token_expiry, + 'token_type': 'access' + }) + + access_token = await self._sign_token(access_payload) + + result = { + 'access_token': access_token, + 'token_type': 'Bearer', + 'expires_in': self.access_token_expiry, + 'user_id': user_info.get('user_id'), + 'issued_at': current_time + } + + # Generate refresh token (if enabled) + if self.enable_refresh: + refresh_jti = str(uuid.uuid4()) + refresh_payload = { + 'iss': self.issuer, + 'aud': self.audience, + 'iat': current_time, + 'exp': current_time + self.refresh_token_expiry, + 'jti': refresh_jti, + 'sub': user_info.get('user_id'), + 'token_type': 'refresh', + 'access_jti': jti # Associated access token ID + } + + refresh_token = await self._sign_token(refresh_payload) + result.update({ + 'refresh_token': refresh_token, + 'refresh_expires_in': self.refresh_token_expiry + }) + + logger.info(f"Generated tokens for user: {user_info.get('user_id')}") + return result + + except Exception as e: + logger.error(f"Failed to generate tokens: {e}") + raise + + async def _sign_token(self, payload: Dict[str, Any]) -> str: + """Sign JWT token + + Args: + payload: JWT payload + + Returns: + Signed JWT token + """ + try: + signing_key = self.key_manager.get_private_key() + + if self.algorithm == "HS256": + # Symmetric key signing + token = jwt.encode(payload, signing_key, algorithm=self.algorithm) + else: + # Asymmetric key signing + token = jwt.encode(payload, signing_key, algorithm=self.algorithm) + + return token + + except Exception as e: + logger.error(f"Failed to sign token: {e}") + raise + + async def validate_token(self, token: str, token_type: str = 'access') -> Dict[str, Any]: + """Validate JWT token + + Args: + token: JWT token string + token_type: Token type ('access' or 'refresh') + + Returns: + Validation result and user information + + Raises: + ValueError: Token validation failed + """ + try: + # Decode token + verification_key = self.key_manager.get_public_key() + + # Get security configuration + if hasattr(self.config, 'security'): + security_config = self.config.security + else: + security_config = self.config + + # JWT decoding options + options = { + 'verify_signature': security_config.jwt_verify_signature, + 'verify_exp': security_config.jwt_require_exp, + 'verify_iat': security_config.jwt_require_iat, + 'verify_nbf': security_config.jwt_require_nbf, + 'verify_aud': security_config.jwt_verify_audience, + 'verify_iss': security_config.jwt_verify_issuer, + } + + # Decode JWT + payload = jwt.decode( + token, + verification_key, + algorithms=[self.algorithm], + audience=self.audience if security_config.jwt_verify_audience else None, + issuer=self.issuer if security_config.jwt_verify_issuer else None, + leeway=security_config.jwt_leeway, + options=options + ) + + # Check token type + if payload.get('token_type') != token_type: + raise ValueError(f"Invalid token type: expected {token_type}") + + # Use validator for additional checks + validation_result = await self.validator.validate_claims(payload) + + logger.info(f"Token validation successful for user: {payload.get('sub')}") + return validation_result + + except jwt.ExpiredSignatureError: + raise ValueError("Token has expired") + except jwt.InvalidTokenError as e: + raise ValueError(f"Invalid token: {str(e)}") + except Exception as e: + logger.error(f"Token validation failed: {e}") + raise ValueError(f"Token validation failed: {str(e)}") + + async def refresh_token(self, refresh_token: str) -> Dict[str, Any]: + """Refresh access token + + Args: + refresh_token: Refresh token + + Returns: + New token pair + """ + if not self.enable_refresh: + raise ValueError("Token refresh is disabled") + + try: + # Validate refresh token + refresh_result = await self.validate_token(refresh_token, 'refresh') + refresh_payload = refresh_result['payload'] + + # Revoke associated access token (if revocation is enabled) + if self.enable_revocation: + access_jti = refresh_payload.get('access_jti') + if access_jti: + # Should revoke old access token here, but since we don't know its expiration time, + # in practice might need to store more information or use different strategy + pass + + # Build new user information + user_info = { + 'user_id': refresh_payload.get('sub'), + 'roles': refresh_payload.get('roles', []), + 'permissions': refresh_payload.get('permissions', []), + 'security_level': refresh_payload.get('security_level', 'internal') + } + + # Generate new token pair + new_tokens = await self.generate_tokens(user_info) + + logger.info(f"Token refreshed for user: {user_info['user_id']}") + return new_tokens + + except Exception as e: + logger.error(f"Token refresh failed: {e}") + raise + + async def revoke_token(self, token: str) -> bool: + """Revoke token + + Args: + token: Token to revoke + + Returns: + Whether revocation was successful + """ + if not self.enable_revocation: + logger.warning("Token revocation is disabled") + return False + + try: + # Decode token to get JTI and expiration time + verification_key = self.key_manager.get_public_key() + payload = jwt.decode( + token, + verification_key, + algorithms=[self.algorithm], + options={'verify_exp': False} # Allow decoding expired tokens + ) + + jti = payload.get('jti') + exp = payload.get('exp') + + if not jti or not exp: + logger.error("Token missing required claims for revocation") + return False + + # Add to blacklist + await self.validator.revoke_token(jti, exp) + + logger.info(f"Token {jti} revoked successfully") + return True + + except Exception as e: + logger.error(f"Token revocation failed: {e}") + return False + + async def decode_token_unsafe(self, token: str) -> Dict[str, Any]: + """Decode token without verifying signature (for debugging only) + + Args: + token: JWT token + + Returns: + Token payload + """ + try: + payload = jwt.decode(token, options={'verify_signature': False}) + return payload + except Exception as e: + logger.error(f"Failed to decode token: {e}") + raise + + async def get_token_info(self, token: str) -> Dict[str, Any]: + """Get token information (without verifying signature) + + Args: + token: JWT token + + Returns: + Token information + """ + try: + payload = await self.decode_token_unsafe(token) + + return { + 'jti': payload.get('jti'), + 'sub': payload.get('sub'), + 'iss': payload.get('iss'), + 'aud': payload.get('aud'), + 'iat': payload.get('iat'), + 'exp': payload.get('exp'), + 'token_type': payload.get('token_type'), + 'roles': payload.get('roles'), + 'permissions': payload.get('permissions'), + 'security_level': payload.get('security_level'), + 'is_expired': payload.get('exp', 0) < time.time() if payload.get('exp') else None + } + + except Exception as e: + logger.error(f"Failed to get token info: {e}") + raise + + async def _auto_key_rotation(self): + """Automatic key rotation task""" + while True: + try: + # Check if key rotation is needed + if await self.key_manager.is_key_expired(): + logger.info("Key rotation needed, rotating keys...") + await self.key_manager.rotate_keys() + + # Wait until next check + await asyncio.sleep(3600) # Check every hour + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error in auto key rotation: {e}") + # Wait longer before retry after error + await asyncio.sleep(3600) + + async def get_public_key_info(self) -> Dict[str, Any]: + """Get public key information (for client verification) + + Returns: + Public key information + """ + key_info = await self.key_manager.get_key_info() + public_key_pem = await self.key_manager.export_public_key_pem() + + return { + 'algorithm': self.algorithm, + 'public_key_pem': public_key_pem, + 'key_info': key_info + } + + async def get_manager_stats(self) -> Dict[str, Any]: + """Get manager statistics + + Returns: + Statistics information + """ + key_info = await self.key_manager.get_key_info() + validation_stats = await self.validator.get_validation_stats() + + return { + 'jwt_config': { + 'algorithm': self.algorithm, + 'issuer': self.issuer, + 'audience': self.audience, + 'access_token_expiry': self.access_token_expiry, + 'refresh_token_expiry': self.refresh_token_expiry, + 'enable_refresh': self.enable_refresh, + 'enable_revocation': self.enable_revocation + }, + 'key_manager': key_info, + 'validator': validation_stats + } \ No newline at end of file diff --git a/doris_mcp_server/auth/key_manager.py b/doris_mcp_server/auth/key_manager.py new file mode 100644 index 0000000..9131aef --- /dev/null +++ b/doris_mcp_server/auth/key_manager.py @@ -0,0 +1,343 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +JWT Key Management Module +Provides secure key generation, loading, rotation and management for JWT tokens +""" + +import os +import time +import secrets +from pathlib import Path +from typing import Optional, Tuple, Union +from datetime import datetime, timedelta +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa, ec +from cryptography.hazmat.backends import default_backend + +from ..utils.logger import get_logger + +logger = get_logger(__name__) + + +class KeyManager: + """JWT Key Manager + + Responsible for generating, loading, rotating and securely storing JWT signing keys + Supports RSA and EC algorithms, provides automatic key rotation functionality + """ + + def __init__(self, config): + """Initialize key manager + + Args: + config: DorisConfig configuration object (with security attribute) + """ + self.config = config + # Access JWT settings through the security configuration + if hasattr(config, 'security'): + security_config = config.security + else: + # Fallback if config is passed directly as SecurityConfig + security_config = config + + self.algorithm = security_config.jwt_algorithm + self.key_rotation_interval = security_config.key_rotation_interval + self.private_key_path = security_config.jwt_private_key_path + self.public_key_path = security_config.jwt_public_key_path + self.secret_key = security_config.jwt_secret_key + + # Key storage + self._private_key = None + self._public_key = None + self._secret_key = None + self._key_generated_at = None + + logger.info(f"KeyManager initialized with algorithm: {self.algorithm}") + + async def initialize(self) -> bool: + """Initialize key manager, load or generate keys""" + try: + if self.algorithm == "HS256": + await self._initialize_symmetric_key() + else: + await self._initialize_asymmetric_keys() + + logger.info("KeyManager initialization completed") + return True + + except Exception as e: + logger.error(f"Failed to initialize KeyManager: {e}") + return False + + async def _initialize_symmetric_key(self): + """Initialize symmetric key (HS256)""" + if self.secret_key: + # Use configured key + self._secret_key = self.secret_key.encode() + logger.info("Loaded symmetric key from configuration") + else: + # Generate new key + self._secret_key = await self.generate_symmetric_key() + logger.info("Generated new symmetric key") + + self._key_generated_at = datetime.utcnow() + + async def _initialize_asymmetric_keys(self): + """Initialize asymmetric key pair (RS256/ES256)""" + # Try to load keys from files + if await self._load_keys_from_files(): + logger.info("Loaded asymmetric keys from files") + return + + # Try to load from environment variables + if await self._load_keys_from_env(): + logger.info("Loaded asymmetric keys from environment") + return + + # Generate new key pair + await self.generate_key_pair() + logger.info("Generated new asymmetric key pair") + + async def _load_keys_from_files(self) -> bool: + """Load keys from files""" + try: + if not self.private_key_path or not self.public_key_path: + return False + + private_path = Path(self.private_key_path) + public_path = Path(self.public_key_path) + + if not (private_path.exists() and public_path.exists()): + return False + + # Read private key + with open(private_path, 'rb') as f: + private_key_data = f.read() + self._private_key = serialization.load_pem_private_key( + private_key_data, password=None, backend=default_backend() + ) + + # Read public key + with open(public_path, 'rb') as f: + public_key_data = f.read() + self._public_key = serialization.load_pem_public_key( + public_key_data, backend=default_backend() + ) + + # Get key generation time (using file modification time) + self._key_generated_at = datetime.fromtimestamp(private_path.stat().st_mtime) + + return True + + except Exception as e: + logger.error(f"Failed to load keys from files: {e}") + return False + + async def _load_keys_from_env(self) -> bool: + """Load keys from environment variables""" + try: + private_key_env = os.getenv('JWT_PRIVATE_KEY') + public_key_env = os.getenv('JWT_PUBLIC_KEY') + + if not (private_key_env and public_key_env): + return False + + # Parse private key + self._private_key = serialization.load_pem_private_key( + private_key_env.encode(), password=None, backend=default_backend() + ) + + # Parse public key + self._public_key = serialization.load_pem_public_key( + public_key_env.encode(), backend=default_backend() + ) + + self._key_generated_at = datetime.utcnow() + return True + + except Exception as e: + logger.error(f"Failed to load keys from environment: {e}") + return False + + async def generate_symmetric_key(self, length: int = 32) -> bytes: + """Generate symmetric key + + Args: + length: Key length (bytes), default 32 bytes (256 bits) + + Returns: + Generated key + """ + return secrets.token_bytes(length) + + async def generate_key_pair(self) -> Tuple[bytes, bytes]: + """Generate asymmetric key pair + + Returns: + (private key PEM, public key PEM) tuple + """ + try: + if self.algorithm == "RS256": + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + backend=default_backend() + ) + elif self.algorithm == "ES256": + private_key = ec.generate_private_key( + ec.SECP256R1(), backend=default_backend() + ) + else: + raise ValueError(f"Unsupported algorithm for key generation: {self.algorithm}") + + # Get public key + public_key = private_key.public_key() + + # Serialize private key + private_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption() + ) + + # Serialize public key + public_pem = public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + + # Store keys + self._private_key = private_key + self._public_key = public_key + self._key_generated_at = datetime.utcnow() + + # If file paths are configured, save to files + if self.private_key_path and self.public_key_path: + await self._save_keys_to_files(private_pem, public_pem) + + logger.info(f"Generated new {self.algorithm} key pair") + return private_pem, public_pem + + except Exception as e: + logger.error(f"Failed to generate key pair: {e}") + raise + + async def _save_keys_to_files(self, private_pem: bytes, public_pem: bytes): + """Save keys to files""" + try: + # Ensure directories exist + private_path = Path(self.private_key_path) + public_path = Path(self.public_key_path) + + private_path.parent.mkdir(parents=True, exist_ok=True) + public_path.parent.mkdir(parents=True, exist_ok=True) + + # Save private key (set secure permissions) + with open(private_path, 'wb') as f: + f.write(private_pem) + os.chmod(private_path, 0o600) # Only owner can read/write + + # Save public key + with open(public_path, 'wb') as f: + f.write(public_pem) + os.chmod(public_path, 0o644) # Owner read/write, others read only + + logger.info(f"Saved keys to files: {private_path}, {public_path}") + + except Exception as e: + logger.error(f"Failed to save keys to files: {e}") + raise + + def get_private_key(self): + """Get private key for signing""" + if self.algorithm == "HS256": + return self._secret_key + else: + return self._private_key + + def get_public_key(self): + """Get public key for verification""" + if self.algorithm == "HS256": + return self._secret_key + else: + return self._public_key + + def get_algorithm(self) -> str: + """Get signing algorithm""" + return self.algorithm + + async def is_key_expired(self) -> bool: + """Check if key is expired""" + if not self._key_generated_at: + return True + + expiry_time = self._key_generated_at + timedelta(seconds=self.key_rotation_interval) + return datetime.utcnow() > expiry_time + + async def rotate_keys(self) -> bool: + """Rotate keys""" + try: + logger.info("Starting key rotation") + + if self.algorithm == "HS256": + # Generate new symmetric key + self._secret_key = await self.generate_symmetric_key() + self._key_generated_at = datetime.utcnow() + else: + # Generate new asymmetric key pair + await self.generate_key_pair() + + logger.info("Key rotation completed successfully") + return True + + except Exception as e: + logger.error(f"Key rotation failed: {e}") + return False + + async def get_key_info(self) -> dict: + """Get key information""" + return { + "algorithm": self.algorithm, + "key_generated_at": self._key_generated_at.isoformat() if self._key_generated_at else None, + "key_expires_at": ( + self._key_generated_at + timedelta(seconds=self.key_rotation_interval) + ).isoformat() if self._key_generated_at else None, + "is_expired": await self.is_key_expired(), + "has_private_key": self._private_key is not None or self._secret_key is not None, + "has_public_key": self._public_key is not None or self._secret_key is not None + } + + async def export_public_key_pem(self) -> Optional[str]: + """Export public key in PEM format""" + if self.algorithm == "HS256": + return None # Symmetric key not exported + + if not self._public_key: + return None + + try: + public_pem = self._public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + return public_pem.decode() + + except Exception as e: + logger.error(f"Failed to export public key: {e}") + return None \ No newline at end of file diff --git a/doris_mcp_server/auth/oauth_client.py b/doris_mcp_server/auth/oauth_client.py new file mode 100644 index 0000000..4f8ca58 --- /dev/null +++ b/doris_mcp_server/auth/oauth_client.py @@ -0,0 +1,536 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +OAuth 2.0/OIDC Client Manager +Provides OAuth authentication client implementation with PKCE and OIDC support +""" + +import base64 +import hashlib +import secrets +import uuid +from datetime import datetime, timedelta +from typing import Dict, Optional, Any, Tuple +from urllib.parse import urlencode, parse_qs, urlparse +import asyncio +import json + +try: + import aiohttp +except ImportError: + raise ImportError("aiohttp is required for OAuth functionality. Install with: pip install aiohttp") + +from .oauth_types import ( + OAuthProvider, OAuthState, OAuthTokens, OAuthUserInfo, + OIDCDiscovery, OAuthError, OAuthProviderConfig, OAUTH_PROVIDERS +) +from ..utils.logger import get_logger + +logger = get_logger(__name__) + + +class OAuthStateManager: + """Manages OAuth state parameters for CSRF protection""" + + def __init__(self, state_expiry: int = 600): + """Initialize state manager + + Args: + state_expiry: State expiry time in seconds + """ + self.state_expiry = state_expiry + self._states: Dict[str, OAuthState] = {} + self._cleanup_task = None + + logger.info("OAuthStateManager initialized") + + async def start(self): + """Start periodic cleanup task""" + self._cleanup_task = asyncio.create_task(self._periodic_cleanup()) + logger.info("OAuth state manager started") + + async def stop(self): + """Stop periodic cleanup task""" + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + logger.info("OAuth state manager stopped") + + def create_state(self, redirect_uri: str, pkce_enabled: bool = True, + nonce_enabled: bool = True) -> OAuthState: + """Create new OAuth state + + Args: + redirect_uri: OAuth redirect URI + pkce_enabled: Whether to enable PKCE + nonce_enabled: Whether to enable nonce (for OIDC) + + Returns: + OAuth state object + """ + state = secrets.token_urlsafe(32) + nonce = secrets.token_urlsafe(32) if nonce_enabled else None + + pkce_verifier = None + pkce_challenge = None + if pkce_enabled: + pkce_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=') + challenge_bytes = hashlib.sha256(pkce_verifier.encode()).digest() + pkce_challenge = base64.urlsafe_b64encode(challenge_bytes).decode('utf-8').rstrip('=') + + oauth_state = OAuthState( + state=state, + nonce=nonce, + pkce_verifier=pkce_verifier, + pkce_challenge=pkce_challenge, + redirect_uri=redirect_uri, + created_at=datetime.utcnow(), + expires_at=datetime.utcnow() + timedelta(seconds=self.state_expiry) + ) + + self._states[state] = oauth_state + logger.debug(f"Created OAuth state: {state}") + return oauth_state + + def get_state(self, state: str) -> Optional[OAuthState]: + """Get OAuth state by state parameter + + Args: + state: State parameter + + Returns: + OAuth state object or None if not found/expired + """ + oauth_state = self._states.get(state) + if oauth_state and oauth_state.expires_at > datetime.utcnow(): + return oauth_state + elif oauth_state: + # Remove expired state + del self._states[state] + logger.debug(f"Removed expired OAuth state: {state}") + return None + + def consume_state(self, state: str) -> Optional[OAuthState]: + """Get and remove OAuth state + + Args: + state: State parameter + + Returns: + OAuth state object or None if not found/expired + """ + oauth_state = self.get_state(state) + if oauth_state: + del self._states[state] + logger.debug(f"Consumed OAuth state: {state}") + return oauth_state + + async def _periodic_cleanup(self): + """Periodic cleanup of expired states""" + while True: + try: + await asyncio.sleep(300) # Clean up every 5 minutes + current_time = datetime.utcnow() + expired_states = [ + state for state, oauth_state in self._states.items() + if oauth_state.expires_at <= current_time + ] + + for state in expired_states: + del self._states[state] + + if expired_states: + logger.info(f"Cleaned up {len(expired_states)} expired OAuth states") + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error during OAuth state cleanup: {e}") + + +class OAuthClient: + """OAuth 2.0/OIDC Client implementation""" + + def __init__(self, config): + """Initialize OAuth client + + Args: + config: DorisConfig with OAuth configuration + """ + self.config = config + + # Access OAuth settings through security configuration + if hasattr(config, 'security'): + security_config = config.security + else: + security_config = config + + self.enabled = security_config.oauth_enabled + if not self.enabled: + logger.info("OAuth client disabled by configuration") + return + + # Build provider configuration + self.provider_config = self._build_provider_config(security_config) + self.state_manager = OAuthStateManager(security_config.oauth_state_expiry) + + # HTTP client session + self._session: Optional[aiohttp.ClientSession] = None + + # Discovery cache + self._discovery_cache: Optional[OIDCDiscovery] = None + self._discovery_cache_time: Optional[datetime] = None + + logger.info(f"OAuthClient initialized for provider: {self.provider_config.provider.value}") + + def _build_provider_config(self, security_config) -> OAuthProviderConfig: + """Build OAuth provider configuration + + Args: + security_config: Security configuration object + + Returns: + OAuth provider configuration + """ + try: + provider = OAuthProvider(security_config.oauth_provider) + except ValueError: + provider = OAuthProvider.CUSTOM + + # Get default configuration for known providers + defaults = OAUTH_PROVIDERS.get(provider, {}) + + return OAuthProviderConfig( + provider=provider, + client_id=security_config.oauth_client_id, + client_secret=security_config.oauth_client_secret, + redirect_uri=security_config.oauth_redirect_uri, + scopes=security_config.oauth_scopes or defaults.get("scopes", ["openid", "email", "profile"]), + + # Endpoints (use configured or defaults) + authorization_endpoint=security_config.oauth_authorization_endpoint or defaults.get("authorization_endpoint", ""), + token_endpoint=security_config.oauth_token_endpoint or defaults.get("token_endpoint", ""), + userinfo_endpoint=security_config.oauth_userinfo_endpoint or defaults.get("userinfo_endpoint"), + jwks_uri=security_config.oauth_jwks_uri or defaults.get("jwks_uri"), + + # Discovery + discovery_url=security_config.oidc_discovery_url or defaults.get("discovery_url"), + + # Settings + pkce_enabled=security_config.oauth_pkce_enabled, + nonce_enabled=security_config.oauth_nonce_enabled, + + # User mapping + user_id_claim=security_config.oauth_user_id_claim or defaults.get("user_id_claim", "sub"), + email_claim=security_config.oauth_email_claim or defaults.get("email_claim", "email"), + name_claim=security_config.oauth_name_claim or defaults.get("name_claim", "name"), + roles_claim=security_config.oauth_roles_claim, + default_roles=security_config.oauth_default_roles + ) + + async def initialize(self) -> bool: + """Initialize OAuth client + + Returns: + True if initialization successful + """ + if not self.enabled: + return True + + try: + # Create HTTP session + self._session = aiohttp.ClientSession() + + # Start state manager + await self.state_manager.start() + + # Perform OIDC discovery if configured + if self.provider_config.discovery_url: + await self._discover_oidc_endpoints() + + logger.info("OAuth client initialization completed") + return True + + except Exception as e: + logger.error(f"Failed to initialize OAuth client: {e}") + return False + + async def shutdown(self): + """Shutdown OAuth client""" + if not self.enabled: + return + + try: + # Stop state manager + await self.state_manager.stop() + + # Close HTTP session + if self._session: + await self._session.close() + + logger.info("OAuth client shutdown completed") + + except Exception as e: + logger.error(f"Error during OAuth client shutdown: {e}") + + async def _discover_oidc_endpoints(self): + """Discover OIDC endpoints using discovery URL""" + try: + # Check cache first + if (self._discovery_cache and self._discovery_cache_time and + datetime.utcnow() - self._discovery_cache_time < timedelta(hours=1)): + return self._discovery_cache + + logger.info(f"Discovering OIDC endpoints: {self.provider_config.discovery_url}") + + async with self._session.get(self.provider_config.discovery_url) as response: + response.raise_for_status() + data = await response.json() + + discovery = OIDCDiscovery( + issuer=data["issuer"], + authorization_endpoint=data["authorization_endpoint"], + token_endpoint=data["token_endpoint"], + userinfo_endpoint=data.get("userinfo_endpoint"), + jwks_uri=data.get("jwks_uri"), + scopes_supported=data.get("scopes_supported"), + response_types_supported=data.get("response_types_supported"), + subject_types_supported=data.get("subject_types_supported"), + id_token_signing_alg_values_supported=data.get("id_token_signing_alg_values_supported") + ) + + # Update provider configuration with discovered endpoints + if not self.provider_config.authorization_endpoint: + self.provider_config.authorization_endpoint = discovery.authorization_endpoint + if not self.provider_config.token_endpoint: + self.provider_config.token_endpoint = discovery.token_endpoint + if not self.provider_config.userinfo_endpoint: + self.provider_config.userinfo_endpoint = discovery.userinfo_endpoint + if not self.provider_config.jwks_uri: + self.provider_config.jwks_uri = discovery.jwks_uri + + # Cache discovery result + self._discovery_cache = discovery + self._discovery_cache_time = datetime.utcnow() + + logger.info("OIDC endpoint discovery completed successfully") + return discovery + + except Exception as e: + logger.error(f"OIDC endpoint discovery failed: {e}") + raise + + def build_authorization_url(self) -> Tuple[str, OAuthState]: + """Build OAuth authorization URL + + Returns: + Tuple of (authorization_url, oauth_state) + """ + if not self.enabled: + raise ValueError("OAuth client is not enabled") + + # Create state for CSRF protection + oauth_state = self.state_manager.create_state( + redirect_uri=self.provider_config.redirect_uri, + pkce_enabled=self.provider_config.pkce_enabled, + nonce_enabled=self.provider_config.nonce_enabled + ) + + # Build authorization parameters + params = { + 'response_type': 'code', + 'client_id': self.provider_config.client_id, + 'redirect_uri': self.provider_config.redirect_uri, + 'scope': ' '.join(self.provider_config.scopes), + 'state': oauth_state.state + } + + # Add PKCE challenge + if oauth_state.pkce_challenge: + params['code_challenge'] = oauth_state.pkce_challenge + params['code_challenge_method'] = 'S256' + + # Add nonce for OIDC + if oauth_state.nonce: + params['nonce'] = oauth_state.nonce + + # Build URL + authorization_url = f"{self.provider_config.authorization_endpoint}?{urlencode(params)}" + + logger.info(f"Built OAuth authorization URL for state: {oauth_state.state}") + return authorization_url, oauth_state + + async def exchange_code_for_tokens(self, code: str, state: str) -> Tuple[OAuthTokens, OAuthState]: + """Exchange authorization code for tokens + + Args: + code: Authorization code + state: State parameter + + Returns: + Tuple of (OAuth tokens, OAuth state) + + Raises: + ValueError: If state is invalid or exchange fails + """ + if not self.enabled: + raise ValueError("OAuth client is not enabled") + + # Validate and consume state + oauth_state = self.state_manager.consume_state(state) + if not oauth_state: + raise ValueError("Invalid or expired state parameter") + + try: + # Prepare token request + data = { + 'grant_type': 'authorization_code', + 'client_id': self.provider_config.client_id, + 'client_secret': self.provider_config.client_secret, + 'code': code, + 'redirect_uri': oauth_state.redirect_uri + } + + # Add PKCE verifier + if oauth_state.pkce_verifier: + data['code_verifier'] = oauth_state.pkce_verifier + + # Make token request + async with self._session.post( + self.provider_config.token_endpoint, + data=data, + headers={'Content-Type': 'application/x-www-form-urlencoded'} + ) as response: + response_data = await response.json() + + if response.status != 200: + error_msg = response_data.get('error_description', response_data.get('error', 'Token exchange failed')) + raise ValueError(f"Token exchange failed: {error_msg}") + + tokens = OAuthTokens( + access_token=response_data['access_token'], + token_type=response_data.get('token_type', 'Bearer'), + expires_in=response_data.get('expires_in'), + refresh_token=response_data.get('refresh_token'), + scope=response_data.get('scope'), + id_token=response_data.get('id_token') + ) + + logger.info("Successfully exchanged authorization code for tokens") + return tokens, oauth_state + + except Exception as e: + logger.error(f"Token exchange failed: {e}") + raise ValueError(f"Token exchange failed: {str(e)}") + + async def get_user_info(self, tokens: OAuthTokens) -> OAuthUserInfo: + """Get user information from OAuth provider + + Args: + tokens: OAuth tokens + + Returns: + OAuth user information + """ + if not self.enabled: + raise ValueError("OAuth client is not enabled") + + if not self.provider_config.userinfo_endpoint: + raise ValueError("Userinfo endpoint not configured") + + try: + # Make userinfo request + headers = {'Authorization': f'{tokens.token_type} {tokens.access_token}'} + + async with self._session.get( + self.provider_config.userinfo_endpoint, + headers=headers + ) as response: + response.raise_for_status() + user_data = await response.json() + + # Extract user information using configured claims + user_info = OAuthUserInfo( + sub=str(user_data.get(self.provider_config.user_id_claim, '')), + email=user_data.get(self.provider_config.email_claim), + name=user_data.get(self.provider_config.name_claim), + given_name=user_data.get('given_name'), + family_name=user_data.get('family_name'), + picture=user_data.get('picture'), + locale=user_data.get('locale'), + email_verified=user_data.get('email_verified'), + roles=user_data.get(self.provider_config.roles_claim, self.provider_config.default_roles.copy()), + raw_claims=user_data + ) + + logger.info(f"Retrieved user info for user: {user_info.sub}") + return user_info + + except Exception as e: + logger.error(f"Failed to get user info: {e}") + raise ValueError(f"Failed to get user info: {str(e)}") + + async def refresh_tokens(self, refresh_token: str) -> OAuthTokens: + """Refresh OAuth tokens + + Args: + refresh_token: Refresh token + + Returns: + New OAuth tokens + """ + if not self.enabled: + raise ValueError("OAuth client is not enabled") + + try: + data = { + 'grant_type': 'refresh_token', + 'client_id': self.provider_config.client_id, + 'client_secret': self.provider_config.client_secret, + 'refresh_token': refresh_token + } + + async with self._session.post( + self.provider_config.token_endpoint, + data=data, + headers={'Content-Type': 'application/x-www-form-urlencoded'} + ) as response: + response_data = await response.json() + + if response.status != 200: + error_msg = response_data.get('error_description', response_data.get('error', 'Token refresh failed')) + raise ValueError(f"Token refresh failed: {error_msg}") + + tokens = OAuthTokens( + access_token=response_data['access_token'], + token_type=response_data.get('token_type', 'Bearer'), + expires_in=response_data.get('expires_in'), + refresh_token=response_data.get('refresh_token', refresh_token), # Keep old if not provided + scope=response_data.get('scope'), + id_token=response_data.get('id_token') + ) + + logger.info("Successfully refreshed OAuth tokens") + return tokens + + except Exception as e: + logger.error(f"Token refresh failed: {e}") + raise ValueError(f"Token refresh failed: {str(e)}") \ No newline at end of file diff --git a/doris_mcp_server/auth/oauth_handlers.py b/doris_mcp_server/auth/oauth_handlers.py new file mode 100644 index 0000000..4b8a46d --- /dev/null +++ b/doris_mcp_server/auth/oauth_handlers.py @@ -0,0 +1,312 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +OAuth HTTP Handlers +Provides HTTP endpoints for OAuth authentication flow +""" + +from typing import Dict, Any +from urllib.parse import parse_qs, urlparse +import json + +from starlette.responses import JSONResponse, RedirectResponse, HTMLResponse +from starlette.requests import Request + +from ..utils.logger import get_logger + +logger = get_logger(__name__) + + +class OAuthHandlers: + """OAuth HTTP request handlers""" + + def __init__(self, security_manager): + """Initialize OAuth handlers + + Args: + security_manager: DorisSecurityManager instance + """ + self.security_manager = security_manager + logger.info("OAuth handlers initialized") + + async def handle_login(self, request: Request) -> JSONResponse: + """Handle OAuth login initiation + + Returns JSON with authorization URL and state + """ + try: + # Check if OAuth is enabled + oauth_info = self.security_manager.get_oauth_provider_info() + if not oauth_info.get("enabled"): + return JSONResponse( + {"error": "OAuth authentication is not enabled"}, + status_code=400 + ) + + # Get authorization URL + authorization_url, state = self.security_manager.get_oauth_authorization_url() + + return JSONResponse({ + "authorization_url": authorization_url, + "state": state, + "provider": oauth_info.get("provider"), + "message": "Navigate to authorization_url to complete OAuth login" + }) + + except Exception as e: + logger.error(f"OAuth login initiation failed: {e}") + return JSONResponse( + {"error": f"OAuth login failed: {str(e)}"}, + status_code=500 + ) + + async def handle_callback(self, request: Request) -> JSONResponse: + """Handle OAuth callback + + Processes the OAuth callback and returns authentication result + """ + try: + # Get query parameters + query_params = dict(request.query_params) + + # Check for error in callback + if "error" in query_params: + error_description = query_params.get("error_description", "Unknown error") + logger.warning(f"OAuth callback error: {query_params['error']} - {error_description}") + return JSONResponse( + { + "error": query_params["error"], + "error_description": error_description, + "error_uri": query_params.get("error_uri") + }, + status_code=400 + ) + + # Extract required parameters + code = query_params.get("code") + state = query_params.get("state") + + if not code or not state: + return JSONResponse( + {"error": "Missing required parameters: code and state"}, + status_code=400 + ) + + # Handle OAuth callback + auth_context = await self.security_manager.handle_oauth_callback(code, state) + + # Return successful authentication response + return JSONResponse({ + "success": True, + "user_id": auth_context.user_id, + "roles": auth_context.roles, + "permissions": auth_context.permissions, + "security_level": auth_context.security_level.value, + "session_id": auth_context.session_id, + "message": "OAuth authentication successful" + }) + + except Exception as e: + logger.error(f"OAuth callback handling failed: {e}") + return JSONResponse( + {"error": f"OAuth callback failed: {str(e)}"}, + status_code=500 + ) + + async def handle_provider_info(self, request: Request) -> JSONResponse: + """Handle OAuth provider information request + + Returns information about the configured OAuth provider + """ + try: + provider_info = self.security_manager.get_oauth_provider_info() + return JSONResponse(provider_info) + + except Exception as e: + logger.error(f"Failed to get OAuth provider info: {e}") + return JSONResponse( + {"error": f"Failed to get provider info: {str(e)}"}, + status_code=500 + ) + + async def handle_demo_page(self, request: Request) -> HTMLResponse: + """Handle OAuth demo page + + Returns a simple HTML page for testing OAuth flow + """ + oauth_info = self.security_manager.get_oauth_provider_info() + if not oauth_info.get("enabled"): + return HTMLResponse(""" + + OAuth Demo + +

OAuth Demo

+

OAuth authentication is not enabled.

+

Please configure OAuth settings in your security configuration.

+ + + """) + + html_content = f""" + + + + Doris MCP Server - OAuth Demo + + + +

Doris MCP Server - OAuth Demo

+ +
+

OAuth Configuration

+

Provider: {oauth_info.get('provider', 'N/A')}

+

Client ID: {oauth_info.get('client_id', 'N/A')}

+

Scopes: {', '.join(oauth_info.get('scopes', []))}

+

PKCE Enabled: {oauth_info.get('pkce_enabled', False)}

+
+ +
+

OAuth Authentication Test

+

Click the button below to start OAuth authentication flow:

+ +
+ +
+ +
+

API Endpoints

+ +
+ + + + + """ + + return HTMLResponse(html_content) \ No newline at end of file diff --git a/doris_mcp_server/auth/oauth_provider.py b/doris_mcp_server/auth/oauth_provider.py new file mode 100644 index 0000000..71ce095 --- /dev/null +++ b/doris_mcp_server/auth/oauth_provider.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +OAuth Authentication Provider +Integrates OAuth 2.0/OIDC authentication with the existing authentication framework +""" + +from typing import Dict, Any, Optional, Tuple +from datetime import datetime + +from .oauth_client import OAuthClient +from .oauth_types import OAuthTokens, OAuthUserInfo, OAuthState +from ..utils.security import AuthContext, SecurityLevel +from ..utils.logger import get_logger + +logger = get_logger(__name__) + + +class OAuthAuthenticationProvider: + """OAuth authentication provider for Doris MCP Server""" + + def __init__(self, config): + """Initialize OAuth authentication provider + + Args: + config: DorisConfig with OAuth configuration + """ + self.config = config + self.oauth_client = OAuthClient(config) + self.enabled = self.oauth_client.enabled + + logger.info(f"OAuthAuthenticationProvider initialized (enabled: {self.enabled})") + + async def initialize(self) -> bool: + """Initialize OAuth authentication provider + + Returns: + True if initialization successful + """ + if not self.enabled: + return True + + success = await self.oauth_client.initialize() + if success: + logger.info("OAuth authentication provider initialized successfully") + else: + logger.error("Failed to initialize OAuth authentication provider") + return success + + async def shutdown(self): + """Shutdown OAuth authentication provider""" + if self.enabled: + await self.oauth_client.shutdown() + logger.info("OAuth authentication provider shutdown completed") + + def get_authorization_url(self) -> Tuple[str, str]: + """Get OAuth authorization URL + + Returns: + Tuple of (authorization_url, state) + """ + if not self.enabled: + raise ValueError("OAuth authentication is not enabled") + + authorization_url, oauth_state = self.oauth_client.build_authorization_url() + return authorization_url, oauth_state.state + + async def handle_callback(self, code: str, state: str) -> AuthContext: + """Handle OAuth callback and create authentication context + + Args: + code: Authorization code from OAuth provider + state: State parameter for CSRF protection + + Returns: + AuthContext for the authenticated user + + Raises: + ValueError: If authentication fails + """ + if not self.enabled: + raise ValueError("OAuth authentication is not enabled") + + try: + # Exchange code for tokens + tokens, oauth_state = await self.oauth_client.exchange_code_for_tokens(code, state) + + # Get user information + user_info = await self.oauth_client.get_user_info(tokens) + + # Create authentication context + auth_context = await self._create_auth_context(user_info, tokens) + + logger.info(f"OAuth authentication successful for user: {auth_context.user_id}") + return auth_context + + except Exception as e: + logger.error(f"OAuth callback handling failed: {e}") + raise ValueError(f"OAuth authentication failed: {str(e)}") + + async def authenticate_with_token(self, access_token: str) -> AuthContext: + """Authenticate using OAuth access token + + Args: + access_token: OAuth access token + + Returns: + AuthContext for the authenticated user + """ + if not self.enabled: + raise ValueError("OAuth authentication is not enabled") + + try: + # Create token object + tokens = OAuthTokens(access_token=access_token) + + # Get user information + user_info = await self.oauth_client.get_user_info(tokens) + + # Create authentication context + auth_context = await self._create_auth_context(user_info, tokens) + + logger.info(f"OAuth token authentication successful for user: {auth_context.user_id}") + return auth_context + + except Exception as e: + logger.error(f"OAuth token authentication failed: {e}") + raise ValueError(f"OAuth token authentication failed: {str(e)}") + + async def refresh_authentication(self, refresh_token: str) -> Tuple[AuthContext, str]: + """Refresh OAuth authentication + + Args: + refresh_token: OAuth refresh token + + Returns: + Tuple of (AuthContext, new_access_token) + """ + if not self.enabled: + raise ValueError("OAuth authentication is not enabled") + + try: + # Refresh tokens + tokens = await self.oauth_client.refresh_tokens(refresh_token) + + # Get updated user information + user_info = await self.oauth_client.get_user_info(tokens) + + # Create authentication context + auth_context = await self._create_auth_context(user_info, tokens) + + logger.info(f"OAuth refresh successful for user: {auth_context.user_id}") + return auth_context, tokens.access_token + + except Exception as e: + logger.error(f"OAuth refresh failed: {e}") + raise ValueError(f"OAuth refresh failed: {str(e)}") + + async def _create_auth_context(self, user_info: OAuthUserInfo, tokens: OAuthTokens) -> AuthContext: + """Create authentication context from OAuth user info + + Args: + user_info: OAuth user information + tokens: OAuth tokens + + Returns: + AuthContext for the user + """ + # Determine security level based on roles or email domain + security_level = await self._determine_security_level(user_info) + + # Map OAuth roles to application permissions + permissions = await self._map_permissions(user_info.roles) + + # Generate session ID + session_id = f"oauth_{user_info.sub}_{datetime.utcnow().timestamp()}" + + return AuthContext( + user_id=user_info.sub, + roles=user_info.roles, + permissions=permissions, + session_id=session_id, + login_time=datetime.utcnow(), + last_activity=datetime.utcnow(), + security_level=security_level + ) + + async def _determine_security_level(self, user_info: OAuthUserInfo) -> SecurityLevel: + """Determine security level for OAuth user + + Args: + user_info: OAuth user information + + Returns: + SecurityLevel for the user + """ + # Check if user has admin roles + admin_roles = {"admin", "administrator", "data_admin", "super_admin"} + if any(role.lower() in admin_roles for role in user_info.roles): + return SecurityLevel.SECRET + + # Check email domain for internal users + if user_info.email: + # You can configure trusted domains for internal access + trusted_domains = ["yourcompany.com", "internal.org"] # Configure as needed + email_domain = user_info.email.split("@")[-1].lower() + if email_domain in trusted_domains: + return SecurityLevel.CONFIDENTIAL + + # Check for special roles + elevated_roles = {"data_analyst", "developer", "manager"} + if any(role.lower() in elevated_roles for role in user_info.roles): + return SecurityLevel.CONFIDENTIAL + + # Default to internal level for OAuth users + return SecurityLevel.INTERNAL + + async def _map_permissions(self, roles: list[str]) -> list[str]: + """Map OAuth roles to application permissions + + Args: + roles: OAuth user roles + + Returns: + List of application permissions + """ + permissions = set() + + # Role to permission mapping + role_permissions = { + "admin": ["admin", "read_data", "write_data", "manage_users"], + "administrator": ["admin", "read_data", "write_data", "manage_users"], + "data_admin": ["admin", "read_data", "write_data"], + "super_admin": ["admin", "read_data", "write_data", "manage_users", "system_admin"], + "data_analyst": ["read_data", "query_database"], + "developer": ["read_data", "query_database", "debug"], + "viewer": ["read_data"], + "user": ["read_data"], + "oauth_user": ["read_data"] # Default OAuth user permission + } + + # Map roles to permissions + for role in roles: + role_lower = role.lower() + if role_lower in role_permissions: + permissions.update(role_permissions[role_lower]) + + # Ensure OAuth users have at least basic permissions + if not permissions: + permissions.add("read_data") + + return list(permissions) + + def get_provider_info(self) -> Dict[str, Any]: + """Get OAuth provider information + + Returns: + Provider information dictionary + """ + if not self.enabled: + return {"enabled": False} + + config = self.oauth_client.provider_config + return { + "enabled": True, + "provider": config.provider.value, + "client_id": config.client_id, + "scopes": config.scopes, + "redirect_uri": config.redirect_uri, + "pkce_enabled": config.pkce_enabled, + "nonce_enabled": config.nonce_enabled + } \ No newline at end of file diff --git a/doris_mcp_server/auth/oauth_types.py b/doris_mcp_server/auth/oauth_types.py new file mode 100644 index 0000000..ce9252f --- /dev/null +++ b/doris_mcp_server/auth/oauth_types.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +OAuth 2.0/OIDC Type Definitions +Provides data types and models for OAuth authentication flow +""" + +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from typing import Dict, Any, Optional, List + + +class OAuthProvider(Enum): + """OAuth provider enumeration""" + GOOGLE = "google" + MICROSOFT = "microsoft" + GITHUB = "github" + CUSTOM = "custom" + + +class OAuthGrantType(Enum): + """OAuth grant type enumeration""" + AUTHORIZATION_CODE = "authorization_code" + REFRESH_TOKEN = "refresh_token" + + +@dataclass +class OAuthState: + """OAuth state parameter for CSRF protection""" + state: str + nonce: Optional[str] = None + pkce_verifier: Optional[str] = None + pkce_challenge: Optional[str] = None + redirect_uri: str = "" + created_at: datetime = None + expires_at: datetime = None + + def __post_init__(self): + if self.created_at is None: + self.created_at = datetime.utcnow() + + +@dataclass +class OAuthTokens: + """OAuth token response""" + access_token: str + token_type: str = "Bearer" + expires_in: Optional[int] = None + refresh_token: Optional[str] = None + scope: Optional[str] = None + id_token: Optional[str] = None # OIDC ID token + created_at: datetime = None + + def __post_init__(self): + if self.created_at is None: + self.created_at = datetime.utcnow() + + +@dataclass +class OAuthUserInfo: + """OAuth/OIDC user information""" + sub: str # Subject identifier + email: Optional[str] = None + email_verified: Optional[bool] = None + name: Optional[str] = None + given_name: Optional[str] = None + family_name: Optional[str] = None + picture: Optional[str] = None + locale: Optional[str] = None + roles: List[str] = None + raw_claims: Dict[str, Any] = None + + def __post_init__(self): + if self.roles is None: + self.roles = [] + if self.raw_claims is None: + self.raw_claims = {} + + +@dataclass +class OIDCDiscovery: + """OIDC Discovery document""" + issuer: str + authorization_endpoint: str + token_endpoint: str + userinfo_endpoint: Optional[str] = None + jwks_uri: Optional[str] = None + scopes_supported: List[str] = None + response_types_supported: List[str] = None + subject_types_supported: List[str] = None + id_token_signing_alg_values_supported: List[str] = None + + def __post_init__(self): + if self.scopes_supported is None: + self.scopes_supported = ["openid"] + if self.response_types_supported is None: + self.response_types_supported = ["code"] + if self.subject_types_supported is None: + self.subject_types_supported = ["public"] + if self.id_token_signing_alg_values_supported is None: + self.id_token_signing_alg_values_supported = ["RS256"] + + +@dataclass +class OAuthError: + """OAuth error response""" + error: str + error_description: Optional[str] = None + error_uri: Optional[str] = None + state: Optional[str] = None + + +@dataclass +class OAuthProviderConfig: + """OAuth provider configuration""" + provider: OAuthProvider + client_id: str + client_secret: str + redirect_uri: str + scopes: List[str] + + # Endpoints + authorization_endpoint: str + token_endpoint: str + userinfo_endpoint: Optional[str] = None + jwks_uri: Optional[str] = None + + # Discovery + discovery_url: Optional[str] = None + + # Settings + pkce_enabled: bool = True + nonce_enabled: bool = True + + # User mapping + user_id_claim: str = "sub" + email_claim: str = "email" + name_claim: str = "name" + roles_claim: str = "roles" + default_roles: List[str] = None + + def __post_init__(self): + if self.default_roles is None: + self.default_roles = ["oauth_user"] + + +# Pre-defined provider configurations +OAUTH_PROVIDERS = { + OAuthProvider.GOOGLE: { + "authorization_endpoint": "https://accounts.google.com/o/oauth2/auth", + "token_endpoint": "https://oauth2.googleapis.com/token", + "userinfo_endpoint": "https://openidconnect.googleapis.com/v1/userinfo", + "jwks_uri": "https://www.googleapis.com/oauth2/v3/certs", + "discovery_url": "https://accounts.google.com/.well-known/openid_configuration", + "scopes": ["openid", "email", "profile"], + "user_id_claim": "sub", + "email_claim": "email", + "name_claim": "name" + }, + OAuthProvider.MICROSOFT: { + "authorization_endpoint": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize", + "token_endpoint": "https://login.microsoftonline.com/common/oauth2/v2.0/token", + "userinfo_endpoint": "https://graph.microsoft.com/v1.0/me", + "jwks_uri": "https://login.microsoftonline.com/common/discovery/v2.0/keys", + "discovery_url": "https://login.microsoftonline.com/common/v2.0/.well-known/openid_configuration", + "scopes": ["openid", "profile", "email", "User.Read"], + "user_id_claim": "sub", + "email_claim": "email", + "name_claim": "name" + }, + OAuthProvider.GITHUB: { + "authorization_endpoint": "https://github.com/login/oauth/authorize", + "token_endpoint": "https://github.com/login/oauth/access_token", + "userinfo_endpoint": "https://api.github.com/user", + "scopes": ["user:email", "read:user"], + "user_id_claim": "id", + "email_claim": "email", + "name_claim": "name" + } +} \ No newline at end of file diff --git a/doris_mcp_server/auth/token_handlers.py b/doris_mcp_server/auth/token_handlers.py new file mode 100644 index 0000000..7001d4d --- /dev/null +++ b/doris_mcp_server/auth/token_handlers.py @@ -0,0 +1,481 @@ +#!/usr/bin/env python3 +""" +Token Authentication HTTP Handlers + +Provides HTTP endpoints for token management including creation, revocation, +listing, and statistics. Used for administrative token management in HTTP mode. +""" + +import json +from typing import Dict, Any + +from starlette.requests import Request +from starlette.responses import JSONResponse, HTMLResponse + +from ..utils.logger import get_logger +from ..utils.security import SecurityLevel + + +class TokenHandlers: + """Token Authentication HTTP Handlers""" + + def __init__(self, security_manager): + self.security_manager = security_manager + self.logger = get_logger(__name__) + + async def handle_create_token(self, request: Request) -> JSONResponse: + """Handle token creation request""" + try: + # Check if token manager is available + if not self.security_manager.auth_provider.token_manager: + return JSONResponse({ + "error": "Token authentication is not enabled" + }, status_code=503) + + # Parse request data + if request.method == "GET": + # GET request with query parameters + query_params = dict(request.query_params) + token_id = query_params.get("token_id") + user_id = query_params.get("user_id") + roles = query_params.get("roles", "").split(",") if query_params.get("roles") else [] + permissions = query_params.get("permissions", "").split(",") if query_params.get("permissions") else [] + security_level_str = query_params.get("security_level", "internal") + expires_hours_str = query_params.get("expires_hours") + description = query_params.get("description", "") + custom_token = query_params.get("custom_token") + else: + # POST request with JSON body + try: + body = await request.json() + except: + return JSONResponse({ + "error": "Invalid JSON body" + }, status_code=400) + + token_id = body.get("token_id") + user_id = body.get("user_id") + roles = body.get("roles", []) + permissions = body.get("permissions", []) + security_level_str = body.get("security_level", "internal") + expires_hours_str = body.get("expires_hours") + description = body.get("description", "") + custom_token = body.get("custom_token") + + # Validate required fields + if not token_id or not user_id: + return JSONResponse({ + "error": "token_id and user_id are required" + }, status_code=400) + + # Parse security level + try: + security_level = SecurityLevel(security_level_str.lower()) + except ValueError: + security_level = SecurityLevel.INTERNAL + + # Parse expires_hours + expires_hours = None + if expires_hours_str: + try: + expires_hours = int(expires_hours_str) + except ValueError: + return JSONResponse({ + "error": "expires_hours must be an integer" + }, status_code=400) + + # Create token + try: + token = await self.security_manager.create_token( + token_id=token_id, + user_id=user_id, + roles=roles, + permissions=permissions, + security_level=security_level, + expires_hours=expires_hours, + description=description, + custom_token=custom_token + ) + + return JSONResponse({ + "success": True, + "token_id": token_id, + "user_id": user_id, + "token": token, + "roles": roles, + "permissions": permissions, + "security_level": security_level.value, + "expires_hours": expires_hours, + "description": description, + "message": "Token created successfully" + }) + + except Exception as e: + self.logger.error(f"Token creation failed: {e}") + return JSONResponse({ + "error": f"Token creation failed: {str(e)}" + }, status_code=400) + + except Exception as e: + self.logger.error(f"Error in handle_create_token: {e}") + return JSONResponse({ + "error": f"Internal server error: {str(e)}" + }, status_code=500) + + async def handle_revoke_token(self, request: Request) -> JSONResponse: + """Handle token revocation request""" + try: + # Check if token manager is available + if not self.security_manager.auth_provider.token_manager: + return JSONResponse({ + "error": "Token authentication is not enabled" + }, status_code=503) + + # Get token_id from query parameters or path + token_id = request.query_params.get("token_id") + if not token_id and request.method == "DELETE": + # Try to get from path: /token/revoke/{token_id} + path_parts = str(request.url.path).split("/") + if len(path_parts) >= 4: + token_id = path_parts[-1] + + if not token_id: + return JSONResponse({ + "error": "token_id is required" + }, status_code=400) + + # Revoke token + success = await self.security_manager.revoke_token(token_id) + + if success: + return JSONResponse({ + "success": True, + "token_id": token_id, + "message": "Token revoked successfully" + }) + else: + return JSONResponse({ + "success": False, + "token_id": token_id, + "message": "Token not found or already revoked" + }, status_code=404) + + except Exception as e: + self.logger.error(f"Error in handle_revoke_token: {e}") + return JSONResponse({ + "error": f"Internal server error: {str(e)}" + }, status_code=500) + + async def handle_list_tokens(self, request: Request) -> JSONResponse: + """Handle token listing request""" + try: + # Check if token manager is available + if not self.security_manager.auth_provider.token_manager: + return JSONResponse({ + "error": "Token authentication is not enabled" + }, status_code=503) + + # Get tokens list + tokens = await self.security_manager.list_tokens() + + return JSONResponse({ + "success": True, + "count": len(tokens), + "tokens": tokens + }) + + except Exception as e: + self.logger.error(f"Error in handle_list_tokens: {e}") + return JSONResponse({ + "error": f"Internal server error: {str(e)}" + }, status_code=500) + + async def handle_token_stats(self, request: Request) -> JSONResponse: + """Handle token statistics request""" + try: + # Check if token manager is available + if not self.security_manager.auth_provider.token_manager: + return JSONResponse({ + "error": "Token authentication is not enabled" + }, status_code=503) + + # Get token statistics + stats = self.security_manager.get_token_stats() + + return JSONResponse({ + "success": True, + "stats": stats + }) + + except Exception as e: + self.logger.error(f"Error in handle_token_stats: {e}") + return JSONResponse({ + "error": f"Internal server error: {str(e)}" + }, status_code=500) + + async def handle_cleanup_tokens(self, request: Request) -> JSONResponse: + """Handle expired tokens cleanup request""" + try: + # Check if token manager is available + if not self.security_manager.auth_provider.token_manager: + return JSONResponse({ + "error": "Token authentication is not enabled" + }, status_code=503) + + # Cleanup expired tokens + cleaned_count = await self.security_manager.cleanup_expired_tokens() + + return JSONResponse({ + "success": True, + "cleaned_count": cleaned_count, + "message": f"Cleaned up {cleaned_count} expired tokens" + }) + + except Exception as e: + self.logger.error(f"Error in handle_cleanup_tokens: {e}") + return JSONResponse({ + "error": f"Internal server error: {str(e)}" + }, status_code=500) + + async def handle_demo_page(self, request: Request) -> HTMLResponse: + """Handle token management demo page""" + try: + # Check if token manager is available + if not self.security_manager.auth_provider.token_manager: + html_content = """ + + + + Token Management - Not Available + + + +

Token Management

+
Token authentication is not enabled on this server.
+ + + """ + return HTMLResponse(html_content) + + # Get current stats for demo + stats = self.security_manager.get_token_stats() + + html_content = f""" + + + + Doris MCP Server - Token Management + + + +
+

🔐 Doris MCP Server - Token Management

+ +
+

📊 Token Statistics

+
+
+
{stats.get('total_tokens', 0)}
+
Total Tokens
+
+
+
{stats.get('active_tokens', 0)}
+
Active Tokens
+
+
+
{stats.get('expired_tokens', 0)}
+
Expired Tokens
+
+
+

Token Expiry: {'Enabled' if stats.get('expiry_enabled') else 'Disabled'}

+

Default Expiry: {stats.get('default_expiry_hours', 0)} hours

+
+ +
+

➕ Create New Token

+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+ +
+
+
+ +
+

📋 Token Management

+ + +
+ +

Revoke Token

+
+ + +
+
+
+ +
+

🔧 API Endpoints

+

Use these endpoints for programmatic token management:

+ +
+
+ + + + + """ + + return HTMLResponse(html_content) + + except Exception as e: + self.logger.error(f"Error in handle_demo_page: {e}") + error_html = f""" + + + + Token Management Error + + + +

Token Management Error

+

Error loading token management page: {str(e)}

+ + + """ + return HTMLResponse(error_html, status_code=500) \ No newline at end of file diff --git a/doris_mcp_server/auth/token_manager.py b/doris_mcp_server/auth/token_manager.py new file mode 100644 index 0000000..ed742d9 --- /dev/null +++ b/doris_mcp_server/auth/token_manager.py @@ -0,0 +1,456 @@ +#!/usr/bin/env python3 +""" +Token Authentication Management Module + +Provides enterprise-grade token authentication system with configurable tokens, +expiration management, role-based access control and secure token storage. +""" + +import hashlib +import json +import os +import secrets +import time +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Any + +from ..utils.logger import get_logger +from ..utils.security import SecurityLevel + + +@dataclass +class TokenInfo: + """Token information structure""" + + token_id: str # Unique token identifier for audit and management + created_at: datetime = field(default_factory=datetime.utcnow) + expires_at: Optional[datetime] = None + last_used: Optional[datetime] = None + description: str = "" # Optional description for token purpose + is_active: bool = True + + +@dataclass +class TokenValidationResult: + """Token validation result""" + + is_valid: bool + token_info: Optional[TokenInfo] = None + error_message: Optional[str] = None + + +class TokenManager: + """Enterprise Token Authentication Manager + + Features: + - Configurable token storage (file-based or environment variables) + - Token expiration management + - Secure token hashing + - Role-based access control + - Token lifecycle management + """ + + def __init__(self, config): + self.config = config + self.logger = get_logger(__name__) + + # Token storage + self._tokens: Dict[str, TokenInfo] = {} # token_hash -> TokenInfo + self._token_ids: Dict[str, str] = {} # token_id -> token_hash + + # Configuration + self.token_file_path = getattr(config.security, 'token_file_path', 'tokens.json') + self.enable_token_expiry = getattr(config.security, 'enable_token_expiry', True) + self.default_token_expiry_hours = getattr(config.security, 'default_token_expiry_hours', 24 * 30) # 30 days + self.token_hash_algorithm = getattr(config.security, 'token_hash_algorithm', 'sha256') + + # Initialize with default tokens if none exist + self._initialize_default_tokens() + + # Load tokens from configuration + self._load_tokens() + + self.logger.info(f"TokenManager initialized with {len(self._tokens)} tokens") + + def _initialize_default_tokens(self): + """Initialize default tokens for basic authentication (configurable via environment)""" + # Default token configurations (can be overridden by environment variables) + default_tokens = [ + { + 'token_id': 'admin-token', + 'token': os.getenv('DEFAULT_ADMIN_TOKEN', 'doris_admin_token_123456'), + 'description': os.getenv('DEFAULT_ADMIN_DESCRIPTION', 'Default admin API access token'), + 'expires_hours': None # Never expires + }, + { + 'token_id': 'analyst-token', + 'token': os.getenv('DEFAULT_ANALYST_TOKEN', 'doris_analyst_token_123456'), + 'description': os.getenv('DEFAULT_ANALYST_DESCRIPTION', 'Default data analysis API access token'), + 'expires_hours': None # Never expires + }, + { + 'token_id': 'readonly-token', + 'token': os.getenv('DEFAULT_READONLY_TOKEN', 'doris_readonly_token_123456'), + 'description': os.getenv('DEFAULT_READONLY_DESCRIPTION', 'Default read-only API access token'), + 'expires_hours': None # Never expires + } + ] + + + # Only add default tokens if no custom tokens are defined via environment variables + # Check if any TOKEN_* environment variables exist (excluding system and legacy configs) + excluded_prefixes = ('DEFAULT_', 'TOKEN_FILE_PATH', 'TOKEN_HASH_') + excluded_vars = {'TOKEN_SECRET', 'TOKEN_EXPIRY'} + + custom_tokens_exist = any( + key.startswith('TOKEN_') and + not key.startswith(excluded_prefixes) and + not key.endswith(('_EXPIRES_HOURS', '_DESCRIPTION')) and + key not in excluded_vars + for key in os.environ.keys() + ) + + # Also check if token file exists and has content + token_file_exists = False + if os.path.exists(self.token_file_path): + try: + with open(self.token_file_path, 'r') as f: + content = f.read().strip() + if content and content != '{}': + token_file_exists = True + except: + pass + + # Add default tokens only if no custom configuration exists + if not custom_tokens_exist and not token_file_exists: + for token_config in default_tokens: + self._add_token_from_config(token_config) + + self.logger.info(f"Initialized {len(default_tokens)} default tokens (no custom config found)") + else: + self.logger.info("Skipped default tokens initialization (custom tokens detected)") + + def _add_token_from_config(self, token_config: Dict[str, Any]): + """Add token from configuration""" + try: + # Calculate expiration time + expires_at = None + if self.enable_token_expiry: + expires_hours = token_config.get('expires_hours', self.default_token_expiry_hours) + if expires_hours is not None: + expires_at = datetime.utcnow() + timedelta(hours=expires_hours) + + # Create token info + token_info = TokenInfo( + token_id=token_config['token_id'], + expires_at=expires_at, + description=token_config.get('description', ''), + is_active=token_config.get('is_active', True) + ) + + # Hash the token + raw_token = token_config['token'] + token_hash = self._hash_token(raw_token) + + # Store token + self._tokens[token_hash] = token_info + self._token_ids[token_info.token_id] = token_hash + + self.logger.debug(f"Added token '{token_info.token_id}'") + + except Exception as e: + self.logger.error(f"Failed to add token from config: {e}") + + def _load_tokens(self): + """Load tokens from configuration sources""" + # 1. Load from environment variables + self._load_tokens_from_env() + + # 2. Load from token file if exists + if os.path.exists(self.token_file_path): + self._load_tokens_from_file() + + self.logger.info(f"Token loading completed, total tokens: {len(self._tokens)}") + + def _load_tokens_from_env(self): + """Load tokens from environment variables + + Simplified format: + TOKEN_= + TOKEN__EXPIRES_HOURS= + TOKEN__DESCRIPTION= + """ + token_prefixes = set() + + # Find all TOKEN_ environment variables (exclude legacy and system variables) + excluded_token_vars = { + 'TOKEN_SECRET', # Legacy token secret + 'TOKEN_EXPIRY', # Legacy token expiry + 'TOKEN_FILE_PATH', # System config + 'TOKEN_HASH_ALGORITHM' # System config + } + + for key in os.environ: + if (key.startswith('TOKEN_') and + not key.endswith(('_EXPIRES_HOURS', '_DESCRIPTION')) and + key not in excluded_token_vars): + token_id = key[6:] # Remove 'TOKEN_' prefix + token_prefixes.add(token_id) + + # Load each token + for token_id in token_prefixes: + try: + token = os.environ.get(f'TOKEN_{token_id}') + if not token: + continue + + expires_hours_str = os.environ.get(f'TOKEN_{token_id}_EXPIRES_HOURS', str(self.default_token_expiry_hours)) + description = os.environ.get(f'TOKEN_{token_id}_DESCRIPTION', f'Environment token {token_id}') + + expires_hours = None + try: + if expires_hours_str and expires_hours_str.lower() != 'none': + expires_hours = int(expires_hours_str) + except ValueError: + expires_hours = self.default_token_expiry_hours + + # Add token + token_config = { + 'token_id': token_id.lower(), + 'token': token, + 'expires_hours': expires_hours, + 'description': description + } + + self._add_token_from_config(token_config) + + except Exception as e: + self.logger.error(f"Failed to load token {token_id} from environment: {e}") + + def _load_tokens_from_file(self): + """Load tokens from JSON file""" + try: + with open(self.token_file_path, 'r', encoding='utf-8') as f: + tokens_data = json.load(f) + + if isinstance(tokens_data, dict) and 'tokens' in tokens_data: + tokens_list = tokens_data['tokens'] + elif isinstance(tokens_data, list): + tokens_list = tokens_data + else: + self.logger.error(f"Invalid token file format: {self.token_file_path}") + return + + for token_config in tokens_list: + self._add_token_from_config(token_config) + + self.logger.info(f"Loaded {len(tokens_list)} tokens from file: {self.token_file_path}") + + except Exception as e: + self.logger.error(f"Failed to load tokens from file {self.token_file_path}: {e}") + + def _hash_token(self, token: str) -> str: + """Hash token for secure storage""" + if self.token_hash_algorithm == 'sha256': + return hashlib.sha256(token.encode('utf-8')).hexdigest() + elif self.token_hash_algorithm == 'sha512': + return hashlib.sha512(token.encode('utf-8')).hexdigest() + else: + # Fallback to sha256 + return hashlib.sha256(token.encode('utf-8')).hexdigest() + + async def validate_token(self, token: str) -> TokenValidationResult: + """Validate token and return user information""" + try: + # Hash the provided token + token_hash = self._hash_token(token) + + # Find token info + token_info = self._tokens.get(token_hash) + if not token_info: + return TokenValidationResult( + is_valid=False, + error_message="Invalid token" + ) + + # Check if token is active + if not token_info.is_active: + return TokenValidationResult( + is_valid=False, + error_message="Token is inactive" + ) + + # Check expiration + if token_info.expires_at and datetime.utcnow() > token_info.expires_at: + return TokenValidationResult( + is_valid=False, + error_message="Token has expired" + ) + + # Update last used time + token_info.last_used = datetime.utcnow() + + return TokenValidationResult( + is_valid=True, + token_info=token_info + ) + + except Exception as e: + self.logger.error(f"Token validation error: {e}") + return TokenValidationResult( + is_valid=False, + error_message=f"Token validation failed: {str(e)}" + ) + + def generate_token(self, length: int = 32) -> str: + """Generate a cryptographically secure random token""" + return secrets.token_urlsafe(length) + + async def create_token( + self, + token_id: str, + expires_hours: Optional[int] = None, + description: str = "", + custom_token: Optional[str] = None + ) -> str: + """Create a new token""" + try: + # Check if token_id already exists + if token_id in self._token_ids: + raise ValueError(f"Token ID '{token_id}' already exists") + + # Generate or use provided token + if custom_token: + raw_token = custom_token + else: + raw_token = self.generate_token() + + # Calculate expiration + expires_at = None + if expires_hours is not None: + expires_at = datetime.utcnow() + timedelta(hours=expires_hours) + elif self.enable_token_expiry: + expires_at = datetime.utcnow() + timedelta(hours=self.default_token_expiry_hours) + + # Create token info + token_info = TokenInfo( + token_id=token_id, + expires_at=expires_at, + description=description + ) + + # Hash and store token + token_hash = self._hash_token(raw_token) + self._tokens[token_hash] = token_info + self._token_ids[token_id] = token_hash + + self.logger.info(f"Created new token '{token_id}'") + + return raw_token + + except Exception as e: + self.logger.error(f"Failed to create token: {e}") + raise + + async def revoke_token(self, token_id: str) -> bool: + """Revoke a token by token ID""" + try: + if token_id not in self._token_ids: + self.logger.warning(f"Token ID '{token_id}' not found") + return False + + # Get token hash and remove from storage + token_hash = self._token_ids[token_id] + if token_hash in self._tokens: + del self._tokens[token_hash] + del self._token_ids[token_id] + + self.logger.info(f"Revoked token '{token_id}'") + return True + + except Exception as e: + self.logger.error(f"Failed to revoke token '{token_id}': {e}") + return False + + async def list_tokens(self) -> List[Dict[str, Any]]: + """List all tokens (without sensitive data)""" + tokens = [] + + for token_hash, token_info in self._tokens.items(): + tokens.append({ + 'token_id': token_info.token_id, + 'created_at': token_info.created_at.isoformat(), + 'expires_at': token_info.expires_at.isoformat() if token_info.expires_at else None, + 'last_used': token_info.last_used.isoformat() if token_info.last_used else None, + 'is_active': token_info.is_active, + 'description': token_info.description, + 'is_expired': token_info.expires_at and datetime.utcnow() > token_info.expires_at if token_info.expires_at else False + }) + + # Sort by creation time + tokens.sort(key=lambda x: x['created_at'], reverse=True) + + return tokens + + async def cleanup_expired_tokens(self) -> int: + """Remove expired tokens and return count""" + if not self.enable_token_expiry: + return 0 + + now = datetime.utcnow() + expired_tokens = [] + + # Find expired tokens + for token_hash, token_info in self._tokens.items(): + if token_info.expires_at and now > token_info.expires_at: + expired_tokens.append((token_hash, token_info.token_id)) + + # Remove expired tokens + for token_hash, token_id in expired_tokens: + del self._tokens[token_hash] + if token_id in self._token_ids: + del self._token_ids[token_id] + + if expired_tokens: + self.logger.info(f"Cleaned up {len(expired_tokens)} expired tokens") + + return len(expired_tokens) + + async def save_tokens_to_file(self, file_path: Optional[str] = None) -> bool: + """Save current tokens to JSON file""" + try: + file_path = file_path or self.token_file_path + tokens_list = await self.list_tokens() + + tokens_data = { + 'version': '1.0', + 'created_at': datetime.utcnow().isoformat(), + 'tokens': tokens_list + } + + with open(file_path, 'w', encoding='utf-8') as f: + json.dump(tokens_data, f, indent=2, ensure_ascii=False) + + self.logger.info(f"Saved {len(tokens_list)} tokens to file: {file_path}") + return True + + except Exception as e: + self.logger.error(f"Failed to save tokens to file: {e}") + return False + + def get_token_stats(self) -> Dict[str, Any]: + """Get token statistics""" + now = datetime.utcnow() + total_tokens = len(self._tokens) + active_tokens = sum(1 for info in self._tokens.values() if info.is_active) + expired_tokens = sum(1 for info in self._tokens.values() + if info.expires_at and now > info.expires_at) + + return { + 'total_tokens': total_tokens, + 'active_tokens': active_tokens, + 'expired_tokens': expired_tokens, + 'expiry_enabled': self.enable_token_expiry, + 'default_expiry_hours': self.default_token_expiry_hours + } \ No newline at end of file diff --git a/doris_mcp_server/auth/token_validators.py b/doris_mcp_server/auth/token_validators.py new file mode 100644 index 0000000..43731af --- /dev/null +++ b/doris_mcp_server/auth/token_validators.py @@ -0,0 +1,365 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +JWT Token Validation Module +Provides token validation, blacklist management and security features +""" + +import time +import asyncio +from typing import Dict, Set, Optional, Any +from datetime import datetime, timedelta +from collections import defaultdict + +from ..utils.logger import get_logger + +logger = get_logger(__name__) + + +class TokenBlacklist: + """JWT Token Blacklist Manager + + Manages revoked tokens to prevent revoked tokens from being used again + Supports both in-memory and persistent storage + """ + + def __init__(self, cleanup_interval: int = 3600): + """Initialize token blacklist + + Args: + cleanup_interval: Interval for cleaning up expired tokens (seconds) + """ + self.cleanup_interval = cleanup_interval + # Storage format: {token_jti: expiry_timestamp} + self._blacklisted_tokens: Dict[str, float] = {} + self._cleanup_task = None + + logger.info("TokenBlacklist initialized") + + async def start(self): + """Start blacklist manager""" + self._cleanup_task = asyncio.create_task(self._periodic_cleanup()) + logger.info("TokenBlacklist started with periodic cleanup") + + async def stop(self): + """Stop blacklist manager""" + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + logger.info("TokenBlacklist stopped") + + async def add_token(self, jti: str, exp: float): + """Add token to blacklist + + Args: + jti: JWT ID (unique identifier) + exp: Token expiration timestamp + """ + self._blacklisted_tokens[jti] = exp + logger.info(f"Token {jti} added to blacklist") + + async def is_blacklisted(self, jti: str) -> bool: + """Check if token is blacklisted + + Args: + jti: JWT ID + + Returns: + True if blacklisted, False otherwise + """ + return jti in self._blacklisted_tokens + + async def remove_token(self, jti: str) -> bool: + """Remove token from blacklist + + Args: + jti: JWT ID + + Returns: + True if removed, False if not found + """ + if jti in self._blacklisted_tokens: + del self._blacklisted_tokens[jti] + logger.info(f"Token {jti} removed from blacklist") + return True + return False + + async def cleanup_expired(self) -> int: + """Clean up expired blacklisted tokens + + Returns: + Number of tokens cleaned up + """ + current_time = time.time() + expired_tokens = [ + jti for jti, exp in self._blacklisted_tokens.items() + if exp <= current_time + ] + + for jti in expired_tokens: + del self._blacklisted_tokens[jti] + + if expired_tokens: + logger.info(f"Cleaned up {len(expired_tokens)} expired tokens from blacklist") + + return len(expired_tokens) + + async def get_stats(self) -> Dict[str, Any]: + """Get blacklist statistics""" + current_time = time.time() + active_tokens = sum(1 for exp in self._blacklisted_tokens.values() if exp > current_time) + + return { + "total_blacklisted": len(self._blacklisted_tokens), + "active_blacklisted": active_tokens, + "expired_blacklisted": len(self._blacklisted_tokens) - active_tokens, + "cleanup_interval": self.cleanup_interval + } + + async def _periodic_cleanup(self): + """Periodically clean up expired tokens""" + while True: + try: + await asyncio.sleep(self.cleanup_interval) + await self.cleanup_expired() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error during periodic cleanup: {e}") + + +class RateLimiter: + """Token usage rate limiter""" + + def __init__(self, max_requests: int = 100, time_window: int = 3600): + """Initialize rate limiter + + Args: + max_requests: Maximum requests within time window + time_window: Time window in seconds + """ + self.max_requests = max_requests + self.time_window = time_window + # Storage format: {user_id: [timestamp1, timestamp2, ...]} + self._request_history: Dict[str, list] = defaultdict(list) + + logger.info(f"RateLimiter initialized: {max_requests} requests per {time_window} seconds") + + async def is_allowed(self, user_id: str) -> bool: + """Check if user is allowed to make request + + Args: + user_id: User ID + + Returns: + True if allowed, False otherwise + """ + current_time = time.time() + user_requests = self._request_history[user_id] + + # Clean up expired request records + cutoff_time = current_time - self.time_window + user_requests[:] = [t for t in user_requests if t > cutoff_time] + + # Check if limit exceeded + if len(user_requests) >= self.max_requests: + logger.warning(f"Rate limit exceeded for user {user_id}") + return False + + # Record current request + user_requests.append(current_time) + return True + + async def get_usage(self, user_id: str) -> Dict[str, Any]: + """Get user usage information + + Args: + user_id: User ID + + Returns: + Usage statistics + """ + current_time = time.time() + user_requests = self._request_history[user_id] + + # Clean up expired records + cutoff_time = current_time - self.time_window + active_requests = [t for t in user_requests if t > cutoff_time] + + return { + "user_id": user_id, + "requests_in_window": len(active_requests), + "max_requests": self.max_requests, + "time_window": self.time_window, + "remaining_requests": max(0, self.max_requests - len(active_requests)) + } + + +class TokenValidator: + """JWT Token Validator + + Provides comprehensive JWT token validation functionality, including signature verification, + claim validation, blacklist checking and rate limiting + """ + + def __init__(self, config, blacklist: Optional[TokenBlacklist] = None): + """Initialize token validator + + Args: + config: DorisConfig configuration object (with security attribute) + blacklist: Token blacklist manager + """ + self.config = config + self.blacklist = blacklist or TokenBlacklist() + self.rate_limiter = RateLimiter() + + # Access JWT settings through the security configuration + if hasattr(config, 'security'): + security_config = config.security + else: + # Fallback if config is passed directly as SecurityConfig + security_config = config + + # Validation options + self.verify_signature = security_config.jwt_verify_signature + self.verify_audience = security_config.jwt_verify_audience + self.verify_issuer = security_config.jwt_verify_issuer + self.require_exp = security_config.jwt_require_exp + self.require_iat = security_config.jwt_require_iat + self.require_nbf = security_config.jwt_require_nbf + self.leeway = security_config.jwt_leeway + + # Expected values + self.expected_audience = security_config.jwt_audience + self.expected_issuer = security_config.jwt_issuer + + logger.info("TokenValidator initialized") + + async def validate_claims(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """Validate JWT claims + + Args: + payload: JWT payload + + Returns: + Validation result + + Raises: + ValueError: Validation failed + """ + current_time = time.time() + + # Validate issuer + if self.verify_issuer: + if payload.get('iss') != self.expected_issuer: + raise ValueError(f"Invalid issuer: expected {self.expected_issuer}") + + # Validate audience + if self.verify_audience: + aud = payload.get('aud') + if isinstance(aud, list): + if self.expected_audience not in aud: + raise ValueError(f"Invalid audience: {self.expected_audience} not in {aud}") + elif aud != self.expected_audience: + raise ValueError(f"Invalid audience: expected {self.expected_audience}") + + # Validate expiration time + if self.require_exp or 'exp' in payload: + exp = payload.get('exp') + if not exp: + raise ValueError("Missing 'exp' claim") + if current_time > exp + self.leeway: + raise ValueError("Token has expired") + + # Validate not before time + if self.require_nbf or 'nbf' in payload: + nbf = payload.get('nbf') + if not nbf: + raise ValueError("Missing 'nbf' claim") + if current_time < nbf - self.leeway: + raise ValueError("Token not yet valid") + + # Validate issued at time + if self.require_iat or 'iat' in payload: + iat = payload.get('iat') + if not iat: + raise ValueError("Missing 'iat' claim") + # Allow some clock skew, but cannot be future time + if iat > current_time + self.leeway: + raise ValueError("Token issued in the future") + + # Check blacklist + jti = payload.get('jti') + if jti and await self.blacklist.is_blacklisted(jti): + raise ValueError("Token has been revoked") + + # Rate limit check + user_id = payload.get('sub') + if user_id: + if not await self.rate_limiter.is_allowed(user_id): + raise ValueError("Rate limit exceeded") + + return { + "valid": True, + "user_id": user_id, + "payload": payload + } + + async def start(self): + """Start validator""" + await self.blacklist.start() + logger.info("TokenValidator started") + + async def stop(self): + """Stop validator""" + await self.blacklist.stop() + logger.info("TokenValidator stopped") + + async def revoke_token(self, jti: str, exp: float): + """Revoke token + + Args: + jti: JWT ID + exp: Token expiration time + """ + await self.blacklist.add_token(jti, exp) + logger.info(f"Token {jti} has been revoked") + + async def get_validation_stats(self) -> Dict[str, Any]: + """Get validation statistics""" + blacklist_stats = await self.blacklist.get_stats() + + return { + "blacklist": blacklist_stats, + "validation_config": { + "verify_signature": self.verify_signature, + "verify_audience": self.verify_audience, + "verify_issuer": self.verify_issuer, + "require_exp": self.require_exp, + "require_iat": self.require_iat, + "require_nbf": self.require_nbf, + "leeway": self.leeway + } + } + + async def get_user_rate_limit_info(self, user_id: str) -> Dict[str, Any]: + """Get user rate limit information""" + return await self.rate_limiter.get_usage(user_id) \ No newline at end of file diff --git a/doris_mcp_server/main.py b/doris_mcp_server/main.py index ccf0c29..5b60fde 100644 --- a/doris_mcp_server/main.py +++ b/doris_mcp_server/main.py @@ -246,6 +246,40 @@ class DorisServer: self.logger = get_logger(f"{__name__}.DorisServer") self._setup_handlers() + async def _extract_auth_info_from_scope(self, scope, headers): + """Extract authentication information from ASGI scope and headers""" + auth_info = {} + + # Extract client IP + client = scope.get("client") + if client: + auth_info["client_ip"] = client[0] + else: + auth_info["client_ip"] = "unknown" + + # Extract token from Authorization header + authorization = headers.get(b'authorization', b'').decode('utf-8') + if authorization: + if authorization.startswith('Bearer '): + auth_info["token"] = authorization[7:] + auth_info["authorization"] = authorization + elif authorization.startswith('Token '): + auth_info["token"] = authorization[6:] + auth_info["authorization"] = authorization + + # Extract token from query parameters (for compatibility) + query_string = scope.get("query_string", b"").decode('utf-8') + if query_string and "token=" in query_string: + import urllib.parse + query_params = urllib.parse.parse_qs(query_string) + if "token" in query_params: + auth_info["token"] = query_params["token"][0] + + # If no token found, this will be handled by the authentication system + # (either return anonymous context if auth disabled, or raise error if auth enabled) + + return auth_info + def _get_mcp_capabilities(self): """Get MCP capabilities with version compatibility""" try: @@ -390,6 +424,10 @@ class DorisServer: self.logger.info("Starting Doris MCP Server (stdio mode)") try: + # Initialize security manager first (includes JWT setup if enabled) + await self.security_manager.initialize() + self.logger.info("Security manager initialization completed") + # Ensure connection manager is initialized await self.connection_manager.initialize() self.logger.info("Connection manager initialization completed") @@ -456,6 +494,10 @@ class DorisServer: self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}, workers: {workers}") try: + # Initialize security manager first (includes JWT setup if enabled) + await self.security_manager.initialize() + self.logger.info("Security manager initialization completed") + # Ensure connection manager is initialized await self.connection_manager.initialize() @@ -482,6 +524,44 @@ class DorisServer: async def health_check(request): return JSONResponse({"status": "healthy", "service": "doris-mcp-server"}) + # OAuth endpoints + from .auth.oauth_handlers import OAuthHandlers + oauth_handlers = OAuthHandlers(self.security_manager) + + async def oauth_login(request): + return await oauth_handlers.handle_login(request) + + async def oauth_callback(request): + return await oauth_handlers.handle_callback(request) + + async def oauth_provider_info(request): + return await oauth_handlers.handle_provider_info(request) + + async def oauth_demo(request): + return await oauth_handlers.handle_demo_page(request) + + # Token management endpoints + from .auth.token_handlers import TokenHandlers + token_handlers = TokenHandlers(self.security_manager) + + async def token_create(request): + return await token_handlers.handle_create_token(request) + + async def token_revoke(request): + return await token_handlers.handle_revoke_token(request) + + async def token_list(request): + return await token_handlers.handle_list_tokens(request) + + async def token_stats(request): + return await token_handlers.handle_token_stats(request) + + async def token_cleanup(request): + return await token_handlers.handle_cleanup_tokens(request) + + async def token_demo(request): + return await token_handlers.handle_demo_page(request) + # Lifecycle manager - simplified since we manage session_manager externally @contextlib.asynccontextmanager async def lifespan(app: Starlette) -> AsyncIterator[None]: @@ -497,6 +577,18 @@ class DorisServer: debug=True, routes=[ Route("/health", health_check, methods=["GET"]), + # OAuth endpoints + Route("/auth/login", oauth_login, methods=["GET"]), + Route("/auth/callback", oauth_callback, methods=["GET"]), + Route("/auth/provider", oauth_provider_info, methods=["GET"]), + Route("/auth/demo", oauth_demo, methods=["GET"]), + # Token management endpoints + Route("/token/create", token_create, methods=["GET", "POST"]), + Route("/token/revoke", token_revoke, methods=["GET", "DELETE"]), + Route("/token/list", token_list, methods=["GET"]), + Route("/token/stats", token_stats, methods=["GET"]), + Route("/token/cleanup", token_cleanup, methods=["GET", "POST"]), + Route("/token/demo", token_demo, methods=["GET"]), ], lifespan=lifespan, ) @@ -514,8 +606,10 @@ class DorisServer: self.logger.info(f"Received request for path: {path}") try: - # Handle health check - if path.startswith("/health"): + # Handle health check, auth, and token management endpoints + if (path.startswith("/health") or + path.startswith("/auth/") or + path.startswith("/token/")): await starlette_app(scope, receive, send) return @@ -528,6 +622,29 @@ class DorisServer: self.logger.info(f"MCP Request - Method: {method}") self.logger.info(f"MCP Request - Headers: {headers}") + # Authentication check for MCP requests + try: + # Extract authentication information + auth_info = await self._extract_auth_info_from_scope(scope, headers) + + # Authenticate the request + auth_context = await self.security_manager.authenticate_request(auth_info) + self.logger.info(f"MCP request authenticated: token_id={auth_context.token_id}, client_ip={auth_context.client_ip}") + + # Store auth context in scope for potential use by tools/resources + scope["auth_context"] = auth_context + + except Exception as auth_error: + self.logger.error(f"MCP authentication failed: {auth_error}") + # Return 401 Unauthorized + from starlette.responses import JSONResponse + response = JSONResponse( + {"error": "Authentication required", "message": str(auth_error)}, + status_code=401 + ) + await response(scope, receive, send) + return + # Handle Dify compatibility for GET requests if method == "GET": accept_header = headers.get(b'accept', b'').decode('utf-8') @@ -619,6 +736,10 @@ class DorisServer: """Shutdown server""" self.logger.info("Shutting down Doris MCP Server") try: + # Shutdown security manager first (includes JWT cleanup) + await self.security_manager.shutdown() + self.logger.info("Security manager shutdown completed") + await self.connection_manager.close() self.logger.info("Doris MCP Server has been shut down") except Exception as e: diff --git a/doris_mcp_server/multiworker_app.py b/doris_mcp_server/multiworker_app.py index 9bf7c01..ba6b06c 100644 --- a/doris_mcp_server/multiworker_app.py +++ b/doris_mcp_server/multiworker_app.py @@ -209,6 +209,7 @@ from .utils.security import DorisSecurityManager _worker_server = None _worker_session_manager = None _worker_connection_manager = None +_worker_security_manager = None _worker_session_manager_context = None _worker_initialized = False @@ -242,7 +243,7 @@ def get_mcp_capabilities(): async def initialize_worker(): """Initialize MCP server and managers for this worker process""" - global _worker_server, _worker_session_manager, _worker_connection_manager, _worker_session_manager_context, _worker_initialized + global _worker_server, _worker_session_manager, _worker_connection_manager, _worker_security_manager, _worker_session_manager_context, _worker_initialized, _oauth_handlers, _token_handlers if _worker_initialized: return @@ -263,10 +264,14 @@ async def initialize_worker(): config_manager.setup_logging() # Create security manager - security_manager = DorisSecurityManager(config) + _worker_security_manager = DorisSecurityManager(config) + + # Initialize security manager first (includes JWT setup if enabled) + await _worker_security_manager.initialize() + logger.info(f"Worker {os.getpid()} security manager initialization completed") # Create connection manager - _worker_connection_manager = DorisConnectionManager(config, security_manager) + _worker_connection_manager = DorisConnectionManager(config, _worker_security_manager) await _worker_connection_manager.initialize() # Create MCP server @@ -382,6 +387,12 @@ async def initialize_worker(): _worker_session_manager_context = _worker_session_manager.run() await _worker_session_manager_context.__aenter__() + # Initialize OAuth and Token handlers + from .auth.oauth_handlers import OAuthHandlers + from .auth.token_handlers import TokenHandlers + _oauth_handlers = OAuthHandlers(_worker_security_manager) + _token_handlers = TokenHandlers(_worker_security_manager) + _worker_initialized = True logger.info(f"Worker {os.getpid()} MCP initialization completed successfully") @@ -405,6 +416,73 @@ async def health_check(request): "mcp_version": MCP_VERSION }) +# OAuth and Token handlers (initialize after worker setup) +_oauth_handlers = None +_token_handlers = None + +async def oauth_login(request): + """OAuth login endpoint""" + if not _oauth_handlers: + return JSONResponse({"error": "OAuth not initialized"}, status_code=503) + return await _oauth_handlers.handle_login(request) + +async def oauth_callback(request): + """OAuth callback endpoint""" + if not _oauth_handlers: + return JSONResponse({"error": "OAuth not initialized"}, status_code=503) + return await _oauth_handlers.handle_callback(request) + +async def oauth_provider_info(request): + """OAuth provider info endpoint""" + if not _oauth_handlers: + return JSONResponse({"error": "OAuth not initialized"}, status_code=503) + return await _oauth_handlers.handle_provider_info(request) + +async def oauth_demo(request): + """OAuth demo page endpoint""" + if not _oauth_handlers: + from starlette.responses import HTMLResponse + return HTMLResponse("

OAuth not initialized

") + return await _oauth_handlers.handle_demo_page(request) + +# Token management endpoints +async def token_create(request): + """Token creation endpoint""" + if not _token_handlers: + return JSONResponse({"error": "Token handlers not initialized"}, status_code=503) + return await _token_handlers.handle_create_token(request) + +async def token_revoke(request): + """Token revocation endpoint""" + if not _token_handlers: + return JSONResponse({"error": "Token handlers not initialized"}, status_code=503) + return await _token_handlers.handle_revoke_token(request) + +async def token_list(request): + """Token listing endpoint""" + if not _token_handlers: + return JSONResponse({"error": "Token handlers not initialized"}, status_code=503) + return await _token_handlers.handle_list_tokens(request) + +async def token_stats(request): + """Token statistics endpoint""" + if not _token_handlers: + return JSONResponse({"error": "Token handlers not initialized"}, status_code=503) + return await _token_handlers.handle_token_stats(request) + +async def token_cleanup(request): + """Token cleanup endpoint""" + if not _token_handlers: + return JSONResponse({"error": "Token handlers not initialized"}, status_code=503) + return await _token_handlers.handle_cleanup_tokens(request) + +async def token_demo(request): + """Token demo page endpoint""" + if not _token_handlers: + from starlette.responses import HTMLResponse + return HTMLResponse("

Token handlers not initialized

") + return await _token_handlers.handle_demo_page(request) + async def root_info(request): """Root endpoint""" return JSONResponse({ @@ -452,6 +530,13 @@ async def lifespan(app): except Exception as e: logger.error(f"Error closing worker connection manager: {e}") + if _worker_security_manager: + try: + await _worker_security_manager.shutdown() + logger.info(f"Worker {os.getpid()} security manager shutdown completed") + except Exception as e: + logger.error(f"Error shutting down worker security manager: {e}") + # Shutdown logging system try: from .utils.logger import shutdown_logging @@ -492,6 +577,18 @@ basic_app = Starlette( routes=[ Route("/", root_info, methods=["GET"]), Route("/health", health_check, methods=["GET"]), + # OAuth endpoints + Route("/auth/login", oauth_login, methods=["GET"]), + Route("/auth/callback", oauth_callback, methods=["GET"]), + Route("/auth/provider", oauth_provider_info, methods=["GET"]), + Route("/auth/demo", oauth_demo, methods=["GET"]), + # Token management endpoints + Route("/token/create", token_create, methods=["GET", "POST"]), + Route("/token/revoke", token_revoke, methods=["GET", "DELETE"]), + Route("/token/list", token_list, methods=["GET"]), + Route("/token/stats", token_stats, methods=["GET"]), + Route("/token/cleanup", token_cleanup, methods=["GET", "POST"]), + Route("/token/demo", token_demo, methods=["GET"]), ], lifespan=lifespan ) @@ -505,5 +602,5 @@ async def app(scope, receive, send): # Handle MCP requests with session manager await mcp_asgi_app(scope, receive, send) else: - # Handle other requests with basic Starlette app + # Handle other requests with basic Starlette app (includes auth endpoints) await basic_app(scope, receive, send) diff --git a/doris_mcp_server/utils/config.py b/doris_mcp_server/utils/config.py index 8ee50a9..fe52f66 100644 --- a/doris_mcp_server/utils/config.py +++ b/doris_mcp_server/utils/config.py @@ -77,10 +77,43 @@ class DatabaseConfig: class SecurityConfig: """Security configuration""" - # Authentication configuration - auth_type: str = "token" # token, basic, oauth - token_secret: str = "default_secret" + # Independent authentication switches - any one enabled allows that method + enable_token_auth: bool = False # Enable token-based authentication (default: disabled) + enable_jwt_auth: bool = False # Enable JWT authentication (default: disabled) + enable_oauth_auth: bool = False # Enable OAuth 2.0/OIDC authentication (default: disabled) + + # Legacy configuration (kept for backward compatibility) + auth_type: str = "token" # jwt, token, basic, oauth (deprecated: use individual switches) + token_secret: str = "default_secret" # Legacy token secret for backward compatibility token_expiry: int = 3600 + + # Enhanced Token Authentication Configuration + token_file_path: str = "tokens.json" # Path to token configuration file + enable_token_expiry: bool = True # Enable token expiration + default_token_expiry_hours: int = 24 * 30 # Default expiry: 30 days + token_hash_algorithm: str = "sha256" # Token hashing algorithm: sha256, sha512 + + # JWT Configuration + jwt_algorithm: str = "RS256" # RS256, ES256, HS256 + jwt_issuer: str = "doris-mcp-server" + jwt_audience: str = "doris-mcp-client" + jwt_private_key_path: str = "" + jwt_public_key_path: str = "" + jwt_secret_key: str = "" # Only used for HS256 algorithm + jwt_access_token_expiry: int = 3600 # 1 hour + jwt_refresh_token_expiry: int = 86400 # 24 hours + enable_token_refresh: bool = True + enable_token_revocation: bool = True + key_rotation_interval: int = 30 * 24 * 3600 # 30 days in seconds + + # JWT Security Features + jwt_require_iat: bool = True # Require "issued at" claim + jwt_require_exp: bool = True # Require "expires at" claim + jwt_require_nbf: bool = False # Require "not before" claim + jwt_leeway: int = 10 # Clock skew tolerance in seconds + jwt_verify_signature: bool = True # Verify JWT signature + jwt_verify_audience: bool = True # Verify audience claim + jwt_verify_issuer: bool = True # Verify issuer claim # SQL security configuration enable_security_check: bool = True # Main switch: whether to enable SQL security check @@ -115,6 +148,45 @@ class SecurityConfig: enable_masking: bool = True masking_rules: list[dict[str, Any]] = field(default_factory=list) + # OAuth 2.0/OIDC Configuration + oauth_enabled: bool = False + oauth_provider: str = "" # 'google', 'microsoft', 'github', 'custom' + oauth_client_id: str = "" + oauth_client_secret: str = "" + oauth_redirect_uri: str = "http://localhost:3000/auth/callback" + + # OIDC Discovery + oidc_discovery_url: str = "" # e.g., https://accounts.google.com/.well-known/openid_configuration + oauth_authorization_endpoint: str = "" + oauth_token_endpoint: str = "" + oauth_userinfo_endpoint: str = "" + oauth_jwks_uri: str = "" + + # OAuth Scopes and Settings + oauth_scopes: list[str] = field(default_factory=list) + oauth_state_expiry: int = 600 # State parameter expiry in seconds (10 minutes) + oauth_pkce_enabled: bool = True # Enable PKCE for better security + oauth_nonce_enabled: bool = True # Enable nonce for OIDC + + # User Mapping Configuration + oauth_user_id_claim: str = "sub" # JWT claim for user ID + oauth_email_claim: str = "email" + oauth_name_claim: str = "name" + oauth_roles_claim: str = "roles" # Custom claim for roles + oauth_default_roles: list[str] = field(default_factory=lambda: ["oauth_user"]) + + def __post_init__(self): + """Initialize default OAuth scopes based on provider""" + if not self.oauth_scopes and self.oauth_provider: + if self.oauth_provider == "google": + self.oauth_scopes = ["openid", "email", "profile"] + elif self.oauth_provider == "microsoft": + self.oauth_scopes = ["openid", "profile", "email", "User.Read"] + elif self.oauth_provider == "github": + self.oauth_scopes = ["user:email", "read:user"] + else: + self.oauth_scopes = ["openid", "email", "profile"] + @dataclass class PerformanceConfig: @@ -338,6 +410,10 @@ class DorisConfig: ) # Security configuration + # Independent authentication switches + config.security.enable_token_auth = os.getenv("ENABLE_TOKEN_AUTH", str(config.security.enable_token_auth)).lower() == "true" + config.security.enable_jwt_auth = os.getenv("ENABLE_JWT_AUTH", str(config.security.enable_jwt_auth)).lower() == "true" + config.security.enable_oauth_auth = os.getenv("ENABLE_OAUTH_AUTH", str(config.security.enable_oauth_auth)).lower() == "true" config.security.auth_type = os.getenv("AUTH_TYPE", config.security.auth_type) config.security.token_secret = os.getenv("TOKEN_SECRET", config.security.token_secret) config.security.token_expiry = int( @@ -368,6 +444,16 @@ class DorisConfig: config.security.enable_masking = ( os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true" ) + + # Enhanced Token Authentication configuration + config.security.token_file_path = os.getenv("TOKEN_FILE_PATH", config.security.token_file_path) + config.security.enable_token_expiry = ( + os.getenv("ENABLE_TOKEN_EXPIRY", str(config.security.enable_token_expiry).lower()).lower() == "true" + ) + config.security.default_token_expiry_hours = int( + os.getenv("DEFAULT_TOKEN_EXPIRY_HOURS", str(config.security.default_token_expiry_hours)) + ) + config.security.token_hash_algorithm = os.getenv("TOKEN_HASH_ALGORITHM", config.security.token_hash_algorithm) # Performance configuration config.performance.enable_query_cache = ( diff --git a/doris_mcp_server/utils/security.py b/doris_mcp_server/utils/security.py index c1c4dc8..9b319b0 100644 --- a/doris_mcp_server/utils/security.py +++ b/doris_mcp_server/utils/security.py @@ -22,10 +22,10 @@ Implements enterprise-level authentication, authorization, SQL security validati import logging import re -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Any +from typing import Any, Optional import sqlparse from sqlparse.sql import Statement @@ -45,15 +45,13 @@ class SecurityLevel(Enum): @dataclass class AuthContext: - """Authentication context""" + """Authentication context for audit and session tracking""" - user_id: str - roles: list[str] - permissions: list[str] - session_id: str - login_time: datetime | None = None + token_id: str # Token identifier for audit logging + client_ip: str = "unknown" # Client IP address + session_id: str = "" # Session identifier + login_time: datetime = field(default_factory=datetime.utcnow) last_activity: datetime | None = None - security_level: SecurityLevel = SecurityLevel.INTERNAL @dataclass @@ -100,6 +98,36 @@ class DorisSecurityManager: self.blocked_keywords = self._load_blocked_keywords() self.sensitive_tables = self._load_sensitive_tables() self.masking_rules = self._load_masking_rules() + + # Track initialization state + self._initialized = False + + async def initialize(self): + """Initialize security manager components""" + if self._initialized: + return + + try: + # Initialize authentication provider (for JWT setup) + await self.auth_provider.initialize() + + self._initialized = True + self.logger.info("DorisSecurityManager initialized successfully") + + except Exception as e: + self.logger.error(f"Failed to initialize DorisSecurityManager: {e}") + raise + + async def shutdown(self): + """Shutdown security manager components""" + try: + await self.auth_provider.shutdown() + self._initialized = False + self.logger.info("DorisSecurityManager shutdown completed") + + except Exception as e: + self.logger.error(f"Error during DorisSecurityManager shutdown: {e}") + raise def _load_blocked_keywords(self) -> set[str]: """Load blocked SQL keywords from configuration""" @@ -184,8 +212,55 @@ class DorisSecurityManager: return default_rules async def authenticate_request(self, auth_info: dict[str, Any]) -> AuthContext: - """Validate request authentication information""" - return await self.auth_provider.authenticate(auth_info) + """Validate request authentication information + + Tries authentication methods in order: Token -> JWT -> OAuth + Any one method succeeding allows access + If all methods are disabled, returns anonymous context + """ + # Check if any authentication method is enabled + if not (self.config.security.enable_token_auth or + self.config.security.enable_jwt_auth or + self.config.security.enable_oauth_auth): + self.logger.debug("All authentication methods are disabled") + # Return anonymous context when no authentication is enabled + return AuthContext( + token_id="anonymous", + client_ip=auth_info.get("client_ip", "unknown"), + session_id="anonymous_session" + ) + + # Try authentication methods in order of preference + last_error = None + + # 1. Try Token authentication first (most common) + if self.config.security.enable_token_auth: + try: + return await self.auth_provider.authenticate_token(auth_info) + except Exception as e: + self.logger.debug(f"Token authentication failed: {e}") + last_error = e + + # 2. Try JWT authentication + if self.config.security.enable_jwt_auth: + try: + return await self.auth_provider.authenticate_jwt(auth_info) + except Exception as e: + self.logger.debug(f"JWT authentication failed: {e}") + last_error = e + + # 3. Try OAuth authentication + if self.config.security.enable_oauth_auth: + try: + return await self.auth_provider.authenticate_oauth(auth_info) + except Exception as e: + self.logger.debug(f"OAuth authentication failed: {e}") + last_error = e + + # All enabled authentication methods failed + error_message = f"Authentication failed: {str(last_error)}" if last_error else "No authentication method succeeded" + self.logger.warning(f"Authentication failed for client {auth_info.get('client_ip', 'unknown')}: {error_message}") + raise ValueError(error_message) async def authorize_resource_access( self, auth_context: AuthContext, resource_uri: str @@ -207,6 +282,117 @@ class DorisSecurityManager: """Apply data masking processing""" return await self.masking_processor.process(data, auth_context) + # OAuth-specific methods + def get_oauth_authorization_url(self) -> tuple[str, str]: + """Get OAuth authorization URL + + Returns: + Tuple of (authorization_url, state) + """ + if not self.auth_provider.oauth_provider: + raise ValueError("OAuth is not enabled") + return self.auth_provider.oauth_provider.get_authorization_url() + + async def handle_oauth_callback(self, code: str, state: str) -> AuthContext: + """Handle OAuth callback + + Args: + code: Authorization code from OAuth provider + state: State parameter for CSRF protection + + Returns: + AuthContext for authenticated user + """ + if not self.auth_provider.oauth_provider: + raise ValueError("OAuth is not enabled") + return await self.auth_provider.oauth_provider.handle_callback(code, state) + + def get_oauth_provider_info(self) -> dict[str, Any]: + """Get OAuth provider information + + Returns: + OAuth provider information + """ + if not self.auth_provider.oauth_provider: + return {"enabled": False} + return self.auth_provider.oauth_provider.get_provider_info() + + # Token management methods + async def create_token( + self, + token_id: str, + expires_hours: Optional[int] = None, + description: str = "", + custom_token: Optional[str] = None + ) -> str: + """Create a new API access token + + Args: + token_id: Unique token identifier for audit and management + expires_hours: Token expiration in hours (None for no expiration) + description: Token description for management purposes + custom_token: Custom token string (if None, generates random token) + + Returns: + Generated token string + """ + if not self.auth_provider.token_manager: + raise ValueError("Token manager not initialized") + + return await self.auth_provider.token_manager.create_token( + token_id=token_id, + expires_hours=expires_hours, + description=description, + custom_token=custom_token + ) + + async def revoke_token(self, token_id: str) -> bool: + """Revoke a token by token ID + + Args: + token_id: Token ID to revoke + + Returns: + True if token was revoked successfully + """ + if not self.auth_provider.token_manager: + raise ValueError("Token manager not initialized") + + return await self.auth_provider.token_manager.revoke_token(token_id) + + async def list_tokens(self) -> list[dict[str, Any]]: + """List all tokens (without sensitive data) + + Returns: + List of token information + """ + if not self.auth_provider.token_manager: + raise ValueError("Token manager not initialized") + + return await self.auth_provider.token_manager.list_tokens() + + async def cleanup_expired_tokens(self) -> int: + """Remove expired tokens and return count + + Returns: + Number of expired tokens removed + """ + if not self.auth_provider.token_manager: + return 0 + + return await self.auth_provider.token_manager.cleanup_expired_tokens() + + def get_token_stats(self) -> dict[str, Any]: + """Get token statistics + + Returns: + Token statistics dictionary + """ + if not self.auth_provider.token_manager: + return {"error": "Token manager not initialized"} + + return self.auth_provider.token_manager.get_token_stats() + class AuthenticationProvider: """Authentication provider""" @@ -215,35 +401,199 @@ class AuthenticationProvider: self.config = config self.logger = get_logger(__name__) self.session_cache = {} - - async def authenticate(self, auth_info: dict[str, Any]) -> AuthContext: - """Perform identity authentication""" - auth_type = auth_info.get("type", "token") - - if auth_type == "token": - return await self._authenticate_token(auth_info) - elif auth_type == "basic": - return await self._authenticate_basic(auth_info) + self.jwt_manager = None + self.oauth_provider = None + self.token_manager = None + + # Initialize authentication providers based on individual switches + auth_methods_enabled = [] + + # Initialize Token manager if enabled + if config.security.enable_token_auth: + self._initialize_token_manager() + auth_methods_enabled.append("Token") + + # Initialize JWT manager if enabled + if config.security.enable_jwt_auth: + self._initialize_jwt_manager() + auth_methods_enabled.append("JWT") + + # Initialize OAuth provider if enabled + if config.security.enable_oauth_auth or (hasattr(config.security, 'oauth_enabled') and config.security.oauth_enabled): + self._initialize_oauth_provider() + auth_methods_enabled.append("OAuth") + + if auth_methods_enabled: + self.logger.info(f"Authentication enabled with methods: {', '.join(auth_methods_enabled)}") else: - raise ValueError(f"Unsupported authentication type: {auth_type}") + self.logger.info("All authentication methods are disabled - anonymous access allowed") + + def _initialize_jwt_manager(self): + """Initialize JWT manager""" + try: + from ..auth.jwt_manager import JWTManager + self.jwt_manager = JWTManager(self.config) + self.logger.info("JWT manager initialized") + except ImportError as e: + self.logger.error(f"Failed to import JWT manager: {e}") + raise + except Exception as e: + self.logger.error(f"Failed to initialize JWT manager: {e}") + raise + + def _initialize_token_manager(self): + """Initialize Token manager""" + try: + from ..auth.token_manager import TokenManager + self.token_manager = TokenManager(self.config) + self.logger.info("Token manager initialized") + except ImportError as e: + self.logger.error(f"Failed to import Token manager: {e}") + raise + except Exception as e: + self.logger.error(f"Failed to initialize Token manager: {e}") + raise + + def _initialize_oauth_provider(self): + """Initialize OAuth provider""" + try: + from ..auth.oauth_provider import OAuthAuthenticationProvider + self.oauth_provider = OAuthAuthenticationProvider(self.config) + self.logger.info("OAuth provider initialized") + except ImportError as e: + self.logger.error(f"Failed to import OAuth provider: {e}") + raise + except Exception as e: + self.logger.error(f"Failed to initialize OAuth provider: {e}") + raise + + async def initialize(self): + """Initialize authentication provider asynchronously""" + if self.jwt_manager: + success = await self.jwt_manager.initialize() + if not success: + raise RuntimeError("Failed to initialize JWT manager") + self.logger.info("JWT authentication provider initialized successfully") + + if self.token_manager: + # Token manager doesn't need async initialization, just log success + self.logger.info("Token authentication provider initialized successfully") + + if self.oauth_provider: + success = await self.oauth_provider.initialize() + if not success: + raise RuntimeError("Failed to initialize OAuth provider") + self.logger.info("OAuth authentication provider initialized successfully") + + async def shutdown(self): + """Shutdown authentication provider""" + if self.jwt_manager: + await self.jwt_manager.shutdown() + self.logger.info("JWT authentication provider shutdown completed") + + if self.token_manager: + # Token manager doesn't need async shutdown, just log + self.logger.info("Token authentication provider shutdown completed") + + if self.oauth_provider: + await self.oauth_provider.shutdown() + self.logger.info("OAuth authentication provider shutdown completed") + + async def authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext: + """Perform token authentication""" + if not self.config.security.enable_token_auth: + raise ValueError("Token authentication is not enabled") + return await self._authenticate_token(auth_info) + + async def authenticate_jwt(self, auth_info: dict[str, Any]) -> AuthContext: + """Perform JWT authentication""" + if not self.config.security.enable_jwt_auth: + raise ValueError("JWT authentication is not enabled") + return await self._authenticate_jwt(auth_info) + + async def authenticate_oauth(self, auth_info: dict[str, Any]) -> AuthContext: + """Perform OAuth authentication""" + if not self.config.security.enable_oauth_auth: + raise ValueError("OAuth authentication is not enabled") + return await self._authenticate_oauth(auth_info) + + async def _authenticate_jwt(self, auth_info: dict[str, Any]) -> AuthContext: + """JWT authentication""" + if not self.jwt_manager: + raise ValueError("JWT manager not initialized") + + token = auth_info.get("token") + if not token: + # Try to extract from Authorization header + authorization = auth_info.get("authorization") + if authorization and authorization.startswith('Bearer '): + token = authorization[7:] + + if not token: + raise ValueError("Missing JWT token") + + try: + # Use JWT middleware for authentication + from ..auth.auth_middleware import AuthMiddleware + middleware = AuthMiddleware(self.jwt_manager) + return await middleware.authenticate_request(auth_info) + + except Exception as e: + self.logger.error(f"JWT authentication failed: {e}") + raise ValueError(f"JWT authentication failed: {str(e)}") + + async def _authenticate_oauth(self, auth_info: dict[str, Any]) -> AuthContext: + """OAuth authentication""" + if not self.oauth_provider: + raise ValueError("OAuth provider not initialized") + + # Handle different OAuth authentication scenarios + if "access_token" in auth_info: + # Direct OAuth access token authentication + return await self.oauth_provider.authenticate_with_token(auth_info["access_token"]) + elif "code" in auth_info and "state" in auth_info: + # OAuth callback authentication + return await self.oauth_provider.handle_callback(auth_info["code"], auth_info["state"]) + else: + raise ValueError("OAuth authentication requires either access_token or code+state") async def _authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext: """Token authentication""" + if not self.token_manager: + raise ValueError("Token manager not initialized") + token = auth_info.get("token") + if not token: + # Try to extract from Authorization header + authorization = auth_info.get("authorization") + if authorization and authorization.startswith('Bearer '): + token = authorization[7:] + elif authorization and authorization.startswith('Token '): + token = authorization[6:] + if not token: raise ValueError("Missing authentication token") - # Validate token (simplified implementation, should validate JWT or query authentication service in practice) - user_info = await self._validate_token(token) - - return AuthContext( - user_id=user_info["user_id"], - roles=user_info["roles"], - permissions=user_info["permissions"], - session_id=auth_info.get("session_id", "default"), - login_time=datetime.utcnow(), - security_level=SecurityLevel(user_info.get("security_level", "internal")), - ) + try: + # Validate token using TokenManager + validation_result = await self.token_manager.validate_token(token) + + if not validation_result.is_valid: + raise ValueError(f"Token validation failed: {validation_result.error_message}") + + token_info = validation_result.token_info + + return AuthContext( + token_id=token_info.token_id, + client_ip=auth_info.get("client_ip", "unknown"), + session_id=auth_info.get("session_id", f"session_{token_info.token_id}"), + login_time=datetime.utcnow(), + last_activity=token_info.last_used + ) + + except Exception as e: + self.logger.error(f"Token authentication failed: {e}") + raise ValueError(f"Token authentication failed: {str(e)}") async def _authenticate_basic(self, auth_info: dict[str, Any]) -> AuthContext: """Basic authentication (username password)""" diff --git a/tokens.json b/tokens.json new file mode 100644 index 0000000..1524a49 --- /dev/null +++ b/tokens.json @@ -0,0 +1,37 @@ +{ + "version": "1.0", + "description": "Simplified Token configuration file for Doris MCP Server API access control", + "created_at": "2025-09-01T00:00:00Z", + "tokens": [ + { + "token_id": "admin-token", + "token": "doris_admin_token_123456", + "description": "Doris admin API access token", + "expires_hours": null, + "is_active": true + }, + { + "token_id": "analyst-token", + "token": "doris_analyst_token_123456", + "description": "Doris analyst API access token", + "expires_hours": 8760, + "is_active": true + }, + { + "token_id": "readonly-token", + "token": "doris_readonly_token_123456", + "description": "Doris readonly API access token", + "expires_hours": 4320, + "is_active": true + } + ], + "notes": [ + "The admin_token, analyst_token, readonly_token is default token,Please change the token before using in production!", + "The token_id is the key of the token,Please use the token_id to identify the token", + "The token is the value of the token,Please use the token to identify the token", + "The description is the description of the token,Please use the description to identify the token", + "The expires_hours is the expires hours of the token,Please use the expires_hours to identify the token", + "The is_active is the is active of the token,Please use the is_active to identify the token", + "The token_id, token, description, expires_hours, is_active is the metadata of the token,Please use the metadata to identify the token" + ] +} \ No newline at end of file