[Performance]Add complete Token, JWT, OAuth authentication system (#52)
* 0.5.1 Version * fix 0.5.1 schema async bug * fix security bug * fix security bug * Add complete Token, JWT, OAuth authentication system * Add complete Token, JWT, OAuth authentication system * Add complete Token, JWT, OAuth authentication system * Add complete Token, JWT, OAuth authentication system
This commit is contained in:
56
doris_mcp_server/auth/__init__.py
Normal file
56
doris_mcp_server/auth/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Doris MCP Server Authentication Module
|
||||
Provides JWT-based, Token-based, and OAuth 2.0/OIDC authentication and authorization services
|
||||
"""
|
||||
|
||||
from .jwt_manager import JWTManager
|
||||
from .key_manager import KeyManager
|
||||
from .token_validators import TokenValidator, TokenBlacklist
|
||||
from .auth_middleware import AuthMiddleware
|
||||
from .token_manager import TokenManager, TokenInfo, TokenValidationResult
|
||||
from .token_handlers import TokenHandlers
|
||||
from .oauth_client import OAuthClient, OAuthStateManager
|
||||
from .oauth_provider import OAuthAuthenticationProvider
|
||||
from .oauth_types import (
|
||||
OAuthProvider, OAuthState, OAuthTokens, OAuthUserInfo,
|
||||
OIDCDiscovery, OAuthError, OAuthProviderConfig
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"JWTManager",
|
||||
"KeyManager",
|
||||
"TokenValidator",
|
||||
"TokenBlacklist",
|
||||
"AuthMiddleware",
|
||||
"TokenManager",
|
||||
"TokenInfo",
|
||||
"TokenValidationResult",
|
||||
"TokenHandlers",
|
||||
"OAuthClient",
|
||||
"OAuthStateManager",
|
||||
"OAuthAuthenticationProvider",
|
||||
"OAuthProvider",
|
||||
"OAuthState",
|
||||
"OAuthTokens",
|
||||
"OAuthUserInfo",
|
||||
"OIDCDiscovery",
|
||||
"OAuthError",
|
||||
"OAuthProviderConfig"
|
||||
]
|
||||
269
doris_mcp_server/auth/auth_middleware.py
Normal file
269
doris_mcp_server/auth/auth_middleware.py
Normal file
@@ -0,0 +1,269 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Authentication Middleware Module
|
||||
Provides middleware for JWT authentication in HTTP and MCP contexts
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any, Callable, Awaitable
|
||||
from datetime import datetime
|
||||
|
||||
from .jwt_manager import JWTManager
|
||||
from ..utils.security import AuthContext, SecurityLevel
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AuthMiddleware:
|
||||
"""Authentication Middleware
|
||||
|
||||
Provides JWT authentication functionality for HTTP and MCP requests
|
||||
"""
|
||||
|
||||
def __init__(self, jwt_manager: JWTManager):
|
||||
"""Initialize authentication middleware
|
||||
|
||||
Args:
|
||||
jwt_manager: JWT manager instance
|
||||
"""
|
||||
self.jwt_manager = jwt_manager
|
||||
logger.info("AuthMiddleware initialized")
|
||||
|
||||
def extract_token_from_header(self, authorization: str) -> Optional[str]:
|
||||
"""Extract JWT token from Authorization header
|
||||
|
||||
Args:
|
||||
authorization: Authorization header value
|
||||
|
||||
Returns:
|
||||
JWT token string, or None if not found
|
||||
"""
|
||||
if not authorization:
|
||||
return None
|
||||
|
||||
# Support Bearer format
|
||||
if authorization.startswith('Bearer '):
|
||||
return authorization[7:] # Remove "Bearer " prefix
|
||||
|
||||
# Support direct token format
|
||||
if not authorization.startswith('Basic '):
|
||||
return authorization
|
||||
|
||||
return None
|
||||
|
||||
async def authenticate_request(self, auth_info: Dict[str, Any]) -> AuthContext:
|
||||
"""Authenticate request and return authentication context
|
||||
|
||||
Args:
|
||||
auth_info: Authentication information dictionary
|
||||
|
||||
Returns:
|
||||
AuthContext authentication context
|
||||
|
||||
Raises:
|
||||
ValueError: Authentication failed
|
||||
"""
|
||||
try:
|
||||
auth_type = auth_info.get("type", "jwt")
|
||||
|
||||
if auth_type == "jwt" or auth_type == "token":
|
||||
return await self._authenticate_jwt(auth_info)
|
||||
else:
|
||||
raise ValueError(f"Unsupported authentication type: {auth_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Request authentication failed: {e}")
|
||||
raise
|
||||
|
||||
async def _authenticate_jwt(self, auth_info: Dict[str, Any]) -> AuthContext:
|
||||
"""JWT authentication processing
|
||||
|
||||
Args:
|
||||
auth_info: Authentication information containing JWT token
|
||||
|
||||
Returns:
|
||||
AuthContext authentication context
|
||||
"""
|
||||
# Get token
|
||||
token = auth_info.get("token")
|
||||
if not token:
|
||||
# Try to get from Authorization header
|
||||
authorization = auth_info.get("authorization")
|
||||
token = self.extract_token_from_header(authorization)
|
||||
|
||||
if not token:
|
||||
raise ValueError("Missing JWT token")
|
||||
|
||||
try:
|
||||
# Validate token
|
||||
validation_result = await self.jwt_manager.validate_token(token, 'access')
|
||||
payload = validation_result['payload']
|
||||
|
||||
# Build authentication context
|
||||
auth_context = AuthContext(
|
||||
user_id=payload.get('sub'),
|
||||
roles=payload.get('roles', []),
|
||||
permissions=payload.get('permissions', []),
|
||||
session_id=payload.get('jti'), # Use JWT ID as session ID
|
||||
login_time=datetime.fromtimestamp(payload.get('iat', 0)),
|
||||
last_activity=datetime.utcnow(),
|
||||
security_level=SecurityLevel(payload.get('security_level', 'internal'))
|
||||
)
|
||||
|
||||
logger.info(f"JWT authentication successful for user: {auth_context.user_id}")
|
||||
return auth_context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"JWT authentication failed: {e}")
|
||||
raise ValueError(f"JWT authentication failed: {str(e)}")
|
||||
|
||||
async def create_auth_response_headers(self, auth_context: AuthContext) -> Dict[str, str]:
|
||||
"""Create authentication response headers
|
||||
|
||||
Args:
|
||||
auth_context: Authentication context
|
||||
|
||||
Returns:
|
||||
Response headers dictionary
|
||||
"""
|
||||
return {
|
||||
'X-Auth-User': auth_context.user_id,
|
||||
'X-Auth-Roles': ','.join(auth_context.roles),
|
||||
'X-Auth-Session': auth_context.session_id,
|
||||
'X-Auth-Security-Level': auth_context.security_level.value
|
||||
}
|
||||
|
||||
def create_http_middleware(self, skip_paths: Optional[list] = None):
|
||||
"""Create HTTP middleware function
|
||||
|
||||
Args:
|
||||
skip_paths: List of paths to skip authentication
|
||||
|
||||
Returns:
|
||||
ASGI middleware function
|
||||
"""
|
||||
skip_paths = skip_paths or ['/health', '/docs', '/openapi.json']
|
||||
|
||||
async def middleware(scope, receive, send):
|
||||
"""HTTP authentication middleware"""
|
||||
if scope['type'] != 'http':
|
||||
# Pass through non-HTTP requests directly
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get('path', '')
|
||||
|
||||
# Check if authentication should be skipped
|
||||
if any(path.startswith(skip) for skip in skip_paths):
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
# Extract authentication information
|
||||
headers = dict(scope.get('headers', []))
|
||||
authorization = headers.get(b'authorization', b'').decode()
|
||||
|
||||
try:
|
||||
# Perform authentication
|
||||
auth_info = {
|
||||
'type': 'jwt',
|
||||
'authorization': authorization
|
||||
}
|
||||
auth_context = await self.authenticate_request(auth_info)
|
||||
|
||||
# Add authentication context to scope
|
||||
scope['auth_context'] = auth_context
|
||||
|
||||
# Create response wrapper to add authentication headers
|
||||
async def send_wrapper(message):
|
||||
if message['type'] == 'http.response.start':
|
||||
headers = dict(message.get('headers', []))
|
||||
auth_headers = await self.create_auth_response_headers(auth_context)
|
||||
|
||||
for key, value in auth_headers.items():
|
||||
headers[key.encode()] = value.encode()
|
||||
|
||||
message['headers'] = list(headers.items())
|
||||
|
||||
await send(message)
|
||||
|
||||
return await self.app(scope, receive, send_wrapper)
|
||||
|
||||
except Exception as e:
|
||||
# Authentication failed, return 401 error
|
||||
response_body = f'{{"error": "Authentication failed", "message": "{str(e)}"}}'
|
||||
|
||||
await send({
|
||||
'type': 'http.response.start',
|
||||
'status': 401,
|
||||
'headers': [
|
||||
(b'content-type', b'application/json'),
|
||||
(b'www-authenticate', b'Bearer')
|
||||
]
|
||||
})
|
||||
await send({
|
||||
'type': 'http.response.body',
|
||||
'body': response_body.encode()
|
||||
})
|
||||
|
||||
return middleware
|
||||
|
||||
async def authenticate_mcp_request(self, headers: Dict[str, str]) -> AuthContext:
|
||||
"""Authenticate MCP request
|
||||
|
||||
Args:
|
||||
headers: MCP request headers
|
||||
|
||||
Returns:
|
||||
AuthContext authentication context
|
||||
"""
|
||||
try:
|
||||
# Extract authentication information from multiple possible header fields
|
||||
authorization = (
|
||||
headers.get('Authorization') or
|
||||
headers.get('authorization') or
|
||||
headers.get('X-Auth-Token') or
|
||||
headers.get('x-auth-token')
|
||||
)
|
||||
|
||||
auth_info = {
|
||||
'type': 'jwt',
|
||||
'authorization': authorization
|
||||
}
|
||||
|
||||
return await self.authenticate_request(auth_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP request authentication failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
"""Authentication error exception"""
|
||||
|
||||
def __init__(self, message: str, error_code: str = "AUTH_FAILED"):
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AuthorizationError(Exception):
|
||||
"""Authorization error exception"""
|
||||
|
||||
def __init__(self, message: str, error_code: str = "ACCESS_DENIED"):
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
super().__init__(message)
|
||||
471
doris_mcp_server/auth/jwt_manager.py
Normal file
471
doris_mcp_server/auth/jwt_manager.py
Normal file
@@ -0,0 +1,471 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
JWT Manager Module
|
||||
Provides comprehensive JWT token management including generation, validation, refresh and revocation
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
try:
|
||||
import jwt
|
||||
except ImportError:
|
||||
raise ImportError("PyJWT is required for JWT functionality. Install with: pip install PyJWT[crypto]")
|
||||
|
||||
from .key_manager import KeyManager
|
||||
from .token_validators import TokenValidator, TokenBlacklist
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class JWTManager:
|
||||
"""JWT Token Manager
|
||||
|
||||
Provides comprehensive JWT token lifecycle management, including:
|
||||
- Token generation and signing
|
||||
- Token validation and parsing
|
||||
- Token refresh mechanism
|
||||
- Token revocation and blacklist
|
||||
- Automatic key rotation
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize JWT manager
|
||||
|
||||
Args:
|
||||
config: DorisConfig configuration object (with security attribute)
|
||||
"""
|
||||
self.config = config
|
||||
# Access JWT settings through the security configuration
|
||||
if hasattr(config, 'security'):
|
||||
security_config = config.security
|
||||
else:
|
||||
# Fallback if config is passed directly as SecurityConfig
|
||||
security_config = config
|
||||
|
||||
self.algorithm = security_config.jwt_algorithm
|
||||
self.issuer = security_config.jwt_issuer
|
||||
self.audience = security_config.jwt_audience
|
||||
self.access_token_expiry = security_config.jwt_access_token_expiry
|
||||
self.refresh_token_expiry = security_config.jwt_refresh_token_expiry
|
||||
self.enable_refresh = security_config.enable_token_refresh
|
||||
self.enable_revocation = security_config.enable_token_revocation
|
||||
|
||||
# Initialize components
|
||||
self.key_manager = KeyManager(config)
|
||||
self.token_blacklist = TokenBlacklist()
|
||||
self.validator = TokenValidator(config, self.token_blacklist)
|
||||
|
||||
# Automatic key rotation task
|
||||
self._key_rotation_task = None
|
||||
|
||||
logger.info(f"JWTManager initialized with algorithm: {self.algorithm}")
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize JWT manager"""
|
||||
try:
|
||||
# Initialize key manager
|
||||
if not await self.key_manager.initialize():
|
||||
logger.error("Failed to initialize key manager")
|
||||
return False
|
||||
|
||||
# Start token validator
|
||||
await self.validator.start()
|
||||
|
||||
# Start automatic key rotation
|
||||
if self.key_manager.key_rotation_interval > 0:
|
||||
self._key_rotation_task = asyncio.create_task(self._auto_key_rotation())
|
||||
|
||||
logger.info("JWTManager initialization completed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize JWTManager: {e}")
|
||||
return False
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown JWT manager"""
|
||||
try:
|
||||
# Stop key rotation task
|
||||
if self._key_rotation_task:
|
||||
self._key_rotation_task.cancel()
|
||||
try:
|
||||
await self._key_rotation_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Stop validator
|
||||
await self.validator.stop()
|
||||
|
||||
logger.info("JWTManager shutdown completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during JWTManager shutdown: {e}")
|
||||
|
||||
async def generate_tokens(self, user_info: Dict[str, Any],
|
||||
custom_claims: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""Generate access token and refresh token
|
||||
|
||||
Args:
|
||||
user_info: User information dictionary, containing user_id, roles, permissions, etc.
|
||||
custom_claims: Custom claims
|
||||
|
||||
Returns:
|
||||
Dictionary containing access_token and refresh_token
|
||||
"""
|
||||
try:
|
||||
current_time = int(time.time())
|
||||
jti = str(uuid.uuid4())
|
||||
|
||||
# Build base payload
|
||||
base_payload = {
|
||||
'iss': self.issuer,
|
||||
'aud': self.audience,
|
||||
'iat': current_time,
|
||||
'jti': jti,
|
||||
'sub': user_info.get('user_id'),
|
||||
'roles': user_info.get('roles', []),
|
||||
'permissions': user_info.get('permissions', []),
|
||||
'security_level': user_info.get('security_level', 'internal')
|
||||
}
|
||||
|
||||
# Add custom claims
|
||||
if custom_claims:
|
||||
base_payload.update(custom_claims)
|
||||
|
||||
# Generate access token
|
||||
access_payload = base_payload.copy()
|
||||
access_payload.update({
|
||||
'exp': current_time + self.access_token_expiry,
|
||||
'token_type': 'access'
|
||||
})
|
||||
|
||||
access_token = await self._sign_token(access_payload)
|
||||
|
||||
result = {
|
||||
'access_token': access_token,
|
||||
'token_type': 'Bearer',
|
||||
'expires_in': self.access_token_expiry,
|
||||
'user_id': user_info.get('user_id'),
|
||||
'issued_at': current_time
|
||||
}
|
||||
|
||||
# Generate refresh token (if enabled)
|
||||
if self.enable_refresh:
|
||||
refresh_jti = str(uuid.uuid4())
|
||||
refresh_payload = {
|
||||
'iss': self.issuer,
|
||||
'aud': self.audience,
|
||||
'iat': current_time,
|
||||
'exp': current_time + self.refresh_token_expiry,
|
||||
'jti': refresh_jti,
|
||||
'sub': user_info.get('user_id'),
|
||||
'token_type': 'refresh',
|
||||
'access_jti': jti # Associated access token ID
|
||||
}
|
||||
|
||||
refresh_token = await self._sign_token(refresh_payload)
|
||||
result.update({
|
||||
'refresh_token': refresh_token,
|
||||
'refresh_expires_in': self.refresh_token_expiry
|
||||
})
|
||||
|
||||
logger.info(f"Generated tokens for user: {user_info.get('user_id')}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate tokens: {e}")
|
||||
raise
|
||||
|
||||
async def _sign_token(self, payload: Dict[str, Any]) -> str:
|
||||
"""Sign JWT token
|
||||
|
||||
Args:
|
||||
payload: JWT payload
|
||||
|
||||
Returns:
|
||||
Signed JWT token
|
||||
"""
|
||||
try:
|
||||
signing_key = self.key_manager.get_private_key()
|
||||
|
||||
if self.algorithm == "HS256":
|
||||
# Symmetric key signing
|
||||
token = jwt.encode(payload, signing_key, algorithm=self.algorithm)
|
||||
else:
|
||||
# Asymmetric key signing
|
||||
token = jwt.encode(payload, signing_key, algorithm=self.algorithm)
|
||||
|
||||
return token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sign token: {e}")
|
||||
raise
|
||||
|
||||
async def validate_token(self, token: str, token_type: str = 'access') -> Dict[str, Any]:
|
||||
"""Validate JWT token
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
token_type: Token type ('access' or 'refresh')
|
||||
|
||||
Returns:
|
||||
Validation result and user information
|
||||
|
||||
Raises:
|
||||
ValueError: Token validation failed
|
||||
"""
|
||||
try:
|
||||
# Decode token
|
||||
verification_key = self.key_manager.get_public_key()
|
||||
|
||||
# Get security configuration
|
||||
if hasattr(self.config, 'security'):
|
||||
security_config = self.config.security
|
||||
else:
|
||||
security_config = self.config
|
||||
|
||||
# JWT decoding options
|
||||
options = {
|
||||
'verify_signature': security_config.jwt_verify_signature,
|
||||
'verify_exp': security_config.jwt_require_exp,
|
||||
'verify_iat': security_config.jwt_require_iat,
|
||||
'verify_nbf': security_config.jwt_require_nbf,
|
||||
'verify_aud': security_config.jwt_verify_audience,
|
||||
'verify_iss': security_config.jwt_verify_issuer,
|
||||
}
|
||||
|
||||
# Decode JWT
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
verification_key,
|
||||
algorithms=[self.algorithm],
|
||||
audience=self.audience if security_config.jwt_verify_audience else None,
|
||||
issuer=self.issuer if security_config.jwt_verify_issuer else None,
|
||||
leeway=security_config.jwt_leeway,
|
||||
options=options
|
||||
)
|
||||
|
||||
# Check token type
|
||||
if payload.get('token_type') != token_type:
|
||||
raise ValueError(f"Invalid token type: expected {token_type}")
|
||||
|
||||
# Use validator for additional checks
|
||||
validation_result = await self.validator.validate_claims(payload)
|
||||
|
||||
logger.info(f"Token validation successful for user: {payload.get('sub')}")
|
||||
return validation_result
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise ValueError("Token has expired")
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise ValueError(f"Invalid token: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Token validation failed: {e}")
|
||||
raise ValueError(f"Token validation failed: {str(e)}")
|
||||
|
||||
async def refresh_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||
"""Refresh access token
|
||||
|
||||
Args:
|
||||
refresh_token: Refresh token
|
||||
|
||||
Returns:
|
||||
New token pair
|
||||
"""
|
||||
if not self.enable_refresh:
|
||||
raise ValueError("Token refresh is disabled")
|
||||
|
||||
try:
|
||||
# Validate refresh token
|
||||
refresh_result = await self.validate_token(refresh_token, 'refresh')
|
||||
refresh_payload = refresh_result['payload']
|
||||
|
||||
# Revoke associated access token (if revocation is enabled)
|
||||
if self.enable_revocation:
|
||||
access_jti = refresh_payload.get('access_jti')
|
||||
if access_jti:
|
||||
# Should revoke old access token here, but since we don't know its expiration time,
|
||||
# in practice might need to store more information or use different strategy
|
||||
pass
|
||||
|
||||
# Build new user information
|
||||
user_info = {
|
||||
'user_id': refresh_payload.get('sub'),
|
||||
'roles': refresh_payload.get('roles', []),
|
||||
'permissions': refresh_payload.get('permissions', []),
|
||||
'security_level': refresh_payload.get('security_level', 'internal')
|
||||
}
|
||||
|
||||
# Generate new token pair
|
||||
new_tokens = await self.generate_tokens(user_info)
|
||||
|
||||
logger.info(f"Token refreshed for user: {user_info['user_id']}")
|
||||
return new_tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token refresh failed: {e}")
|
||||
raise
|
||||
|
||||
async def revoke_token(self, token: str) -> bool:
|
||||
"""Revoke token
|
||||
|
||||
Args:
|
||||
token: Token to revoke
|
||||
|
||||
Returns:
|
||||
Whether revocation was successful
|
||||
"""
|
||||
if not self.enable_revocation:
|
||||
logger.warning("Token revocation is disabled")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Decode token to get JTI and expiration time
|
||||
verification_key = self.key_manager.get_public_key()
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
verification_key,
|
||||
algorithms=[self.algorithm],
|
||||
options={'verify_exp': False} # Allow decoding expired tokens
|
||||
)
|
||||
|
||||
jti = payload.get('jti')
|
||||
exp = payload.get('exp')
|
||||
|
||||
if not jti or not exp:
|
||||
logger.error("Token missing required claims for revocation")
|
||||
return False
|
||||
|
||||
# Add to blacklist
|
||||
await self.validator.revoke_token(jti, exp)
|
||||
|
||||
logger.info(f"Token {jti} revoked successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token revocation failed: {e}")
|
||||
return False
|
||||
|
||||
async def decode_token_unsafe(self, token: str) -> Dict[str, Any]:
|
||||
"""Decode token without verifying signature (for debugging only)
|
||||
|
||||
Args:
|
||||
token: JWT token
|
||||
|
||||
Returns:
|
||||
Token payload
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, options={'verify_signature': False})
|
||||
return payload
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decode token: {e}")
|
||||
raise
|
||||
|
||||
async def get_token_info(self, token: str) -> Dict[str, Any]:
|
||||
"""Get token information (without verifying signature)
|
||||
|
||||
Args:
|
||||
token: JWT token
|
||||
|
||||
Returns:
|
||||
Token information
|
||||
"""
|
||||
try:
|
||||
payload = await self.decode_token_unsafe(token)
|
||||
|
||||
return {
|
||||
'jti': payload.get('jti'),
|
||||
'sub': payload.get('sub'),
|
||||
'iss': payload.get('iss'),
|
||||
'aud': payload.get('aud'),
|
||||
'iat': payload.get('iat'),
|
||||
'exp': payload.get('exp'),
|
||||
'token_type': payload.get('token_type'),
|
||||
'roles': payload.get('roles'),
|
||||
'permissions': payload.get('permissions'),
|
||||
'security_level': payload.get('security_level'),
|
||||
'is_expired': payload.get('exp', 0) < time.time() if payload.get('exp') else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get token info: {e}")
|
||||
raise
|
||||
|
||||
async def _auto_key_rotation(self):
|
||||
"""Automatic key rotation task"""
|
||||
while True:
|
||||
try:
|
||||
# Check if key rotation is needed
|
||||
if await self.key_manager.is_key_expired():
|
||||
logger.info("Key rotation needed, rotating keys...")
|
||||
await self.key_manager.rotate_keys()
|
||||
|
||||
# Wait until next check
|
||||
await asyncio.sleep(3600) # Check every hour
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in auto key rotation: {e}")
|
||||
# Wait longer before retry after error
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
async def get_public_key_info(self) -> Dict[str, Any]:
|
||||
"""Get public key information (for client verification)
|
||||
|
||||
Returns:
|
||||
Public key information
|
||||
"""
|
||||
key_info = await self.key_manager.get_key_info()
|
||||
public_key_pem = await self.key_manager.export_public_key_pem()
|
||||
|
||||
return {
|
||||
'algorithm': self.algorithm,
|
||||
'public_key_pem': public_key_pem,
|
||||
'key_info': key_info
|
||||
}
|
||||
|
||||
async def get_manager_stats(self) -> Dict[str, Any]:
|
||||
"""Get manager statistics
|
||||
|
||||
Returns:
|
||||
Statistics information
|
||||
"""
|
||||
key_info = await self.key_manager.get_key_info()
|
||||
validation_stats = await self.validator.get_validation_stats()
|
||||
|
||||
return {
|
||||
'jwt_config': {
|
||||
'algorithm': self.algorithm,
|
||||
'issuer': self.issuer,
|
||||
'audience': self.audience,
|
||||
'access_token_expiry': self.access_token_expiry,
|
||||
'refresh_token_expiry': self.refresh_token_expiry,
|
||||
'enable_refresh': self.enable_refresh,
|
||||
'enable_revocation': self.enable_revocation
|
||||
},
|
||||
'key_manager': key_info,
|
||||
'validator': validation_stats
|
||||
}
|
||||
343
doris_mcp_server/auth/key_manager.py
Normal file
343
doris_mcp_server/auth/key_manager.py
Normal file
@@ -0,0 +1,343 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
JWT Key Management Module
|
||||
Provides secure key generation, loading, rotation and management for JWT tokens
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
from datetime import datetime, timedelta
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa, ec
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class KeyManager:
|
||||
"""JWT Key Manager
|
||||
|
||||
Responsible for generating, loading, rotating and securely storing JWT signing keys
|
||||
Supports RSA and EC algorithms, provides automatic key rotation functionality
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize key manager
|
||||
|
||||
Args:
|
||||
config: DorisConfig configuration object (with security attribute)
|
||||
"""
|
||||
self.config = config
|
||||
# Access JWT settings through the security configuration
|
||||
if hasattr(config, 'security'):
|
||||
security_config = config.security
|
||||
else:
|
||||
# Fallback if config is passed directly as SecurityConfig
|
||||
security_config = config
|
||||
|
||||
self.algorithm = security_config.jwt_algorithm
|
||||
self.key_rotation_interval = security_config.key_rotation_interval
|
||||
self.private_key_path = security_config.jwt_private_key_path
|
||||
self.public_key_path = security_config.jwt_public_key_path
|
||||
self.secret_key = security_config.jwt_secret_key
|
||||
|
||||
# Key storage
|
||||
self._private_key = None
|
||||
self._public_key = None
|
||||
self._secret_key = None
|
||||
self._key_generated_at = None
|
||||
|
||||
logger.info(f"KeyManager initialized with algorithm: {self.algorithm}")
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize key manager, load or generate keys"""
|
||||
try:
|
||||
if self.algorithm == "HS256":
|
||||
await self._initialize_symmetric_key()
|
||||
else:
|
||||
await self._initialize_asymmetric_keys()
|
||||
|
||||
logger.info("KeyManager initialization completed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize KeyManager: {e}")
|
||||
return False
|
||||
|
||||
async def _initialize_symmetric_key(self):
|
||||
"""Initialize symmetric key (HS256)"""
|
||||
if self.secret_key:
|
||||
# Use configured key
|
||||
self._secret_key = self.secret_key.encode()
|
||||
logger.info("Loaded symmetric key from configuration")
|
||||
else:
|
||||
# Generate new key
|
||||
self._secret_key = await self.generate_symmetric_key()
|
||||
logger.info("Generated new symmetric key")
|
||||
|
||||
self._key_generated_at = datetime.utcnow()
|
||||
|
||||
async def _initialize_asymmetric_keys(self):
|
||||
"""Initialize asymmetric key pair (RS256/ES256)"""
|
||||
# Try to load keys from files
|
||||
if await self._load_keys_from_files():
|
||||
logger.info("Loaded asymmetric keys from files")
|
||||
return
|
||||
|
||||
# Try to load from environment variables
|
||||
if await self._load_keys_from_env():
|
||||
logger.info("Loaded asymmetric keys from environment")
|
||||
return
|
||||
|
||||
# Generate new key pair
|
||||
await self.generate_key_pair()
|
||||
logger.info("Generated new asymmetric key pair")
|
||||
|
||||
async def _load_keys_from_files(self) -> bool:
|
||||
"""Load keys from files"""
|
||||
try:
|
||||
if not self.private_key_path or not self.public_key_path:
|
||||
return False
|
||||
|
||||
private_path = Path(self.private_key_path)
|
||||
public_path = Path(self.public_key_path)
|
||||
|
||||
if not (private_path.exists() and public_path.exists()):
|
||||
return False
|
||||
|
||||
# Read private key
|
||||
with open(private_path, 'rb') as f:
|
||||
private_key_data = f.read()
|
||||
self._private_key = serialization.load_pem_private_key(
|
||||
private_key_data, password=None, backend=default_backend()
|
||||
)
|
||||
|
||||
# Read public key
|
||||
with open(public_path, 'rb') as f:
|
||||
public_key_data = f.read()
|
||||
self._public_key = serialization.load_pem_public_key(
|
||||
public_key_data, backend=default_backend()
|
||||
)
|
||||
|
||||
# Get key generation time (using file modification time)
|
||||
self._key_generated_at = datetime.fromtimestamp(private_path.stat().st_mtime)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load keys from files: {e}")
|
||||
return False
|
||||
|
||||
async def _load_keys_from_env(self) -> bool:
|
||||
"""Load keys from environment variables"""
|
||||
try:
|
||||
private_key_env = os.getenv('JWT_PRIVATE_KEY')
|
||||
public_key_env = os.getenv('JWT_PUBLIC_KEY')
|
||||
|
||||
if not (private_key_env and public_key_env):
|
||||
return False
|
||||
|
||||
# Parse private key
|
||||
self._private_key = serialization.load_pem_private_key(
|
||||
private_key_env.encode(), password=None, backend=default_backend()
|
||||
)
|
||||
|
||||
# Parse public key
|
||||
self._public_key = serialization.load_pem_public_key(
|
||||
public_key_env.encode(), backend=default_backend()
|
||||
)
|
||||
|
||||
self._key_generated_at = datetime.utcnow()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load keys from environment: {e}")
|
||||
return False
|
||||
|
||||
async def generate_symmetric_key(self, length: int = 32) -> bytes:
|
||||
"""Generate symmetric key
|
||||
|
||||
Args:
|
||||
length: Key length (bytes), default 32 bytes (256 bits)
|
||||
|
||||
Returns:
|
||||
Generated key
|
||||
"""
|
||||
return secrets.token_bytes(length)
|
||||
|
||||
async def generate_key_pair(self) -> Tuple[bytes, bytes]:
|
||||
"""Generate asymmetric key pair
|
||||
|
||||
Returns:
|
||||
(private key PEM, public key PEM) tuple
|
||||
"""
|
||||
try:
|
||||
if self.algorithm == "RS256":
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048,
|
||||
backend=default_backend()
|
||||
)
|
||||
elif self.algorithm == "ES256":
|
||||
private_key = ec.generate_private_key(
|
||||
ec.SECP256R1(), backend=default_backend()
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm for key generation: {self.algorithm}")
|
||||
|
||||
# Get public key
|
||||
public_key = private_key.public_key()
|
||||
|
||||
# Serialize private key
|
||||
private_pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption()
|
||||
)
|
||||
|
||||
# Serialize public key
|
||||
public_pem = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
)
|
||||
|
||||
# Store keys
|
||||
self._private_key = private_key
|
||||
self._public_key = public_key
|
||||
self._key_generated_at = datetime.utcnow()
|
||||
|
||||
# If file paths are configured, save to files
|
||||
if self.private_key_path and self.public_key_path:
|
||||
await self._save_keys_to_files(private_pem, public_pem)
|
||||
|
||||
logger.info(f"Generated new {self.algorithm} key pair")
|
||||
return private_pem, public_pem
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate key pair: {e}")
|
||||
raise
|
||||
|
||||
async def _save_keys_to_files(self, private_pem: bytes, public_pem: bytes):
|
||||
"""Save keys to files"""
|
||||
try:
|
||||
# Ensure directories exist
|
||||
private_path = Path(self.private_key_path)
|
||||
public_path = Path(self.public_key_path)
|
||||
|
||||
private_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
public_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save private key (set secure permissions)
|
||||
with open(private_path, 'wb') as f:
|
||||
f.write(private_pem)
|
||||
os.chmod(private_path, 0o600) # Only owner can read/write
|
||||
|
||||
# Save public key
|
||||
with open(public_path, 'wb') as f:
|
||||
f.write(public_pem)
|
||||
os.chmod(public_path, 0o644) # Owner read/write, others read only
|
||||
|
||||
logger.info(f"Saved keys to files: {private_path}, {public_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save keys to files: {e}")
|
||||
raise
|
||||
|
||||
def get_private_key(self):
|
||||
"""Get private key for signing"""
|
||||
if self.algorithm == "HS256":
|
||||
return self._secret_key
|
||||
else:
|
||||
return self._private_key
|
||||
|
||||
def get_public_key(self):
|
||||
"""Get public key for verification"""
|
||||
if self.algorithm == "HS256":
|
||||
return self._secret_key
|
||||
else:
|
||||
return self._public_key
|
||||
|
||||
def get_algorithm(self) -> str:
|
||||
"""Get signing algorithm"""
|
||||
return self.algorithm
|
||||
|
||||
async def is_key_expired(self) -> bool:
|
||||
"""Check if key is expired"""
|
||||
if not self._key_generated_at:
|
||||
return True
|
||||
|
||||
expiry_time = self._key_generated_at + timedelta(seconds=self.key_rotation_interval)
|
||||
return datetime.utcnow() > expiry_time
|
||||
|
||||
async def rotate_keys(self) -> bool:
|
||||
"""Rotate keys"""
|
||||
try:
|
||||
logger.info("Starting key rotation")
|
||||
|
||||
if self.algorithm == "HS256":
|
||||
# Generate new symmetric key
|
||||
self._secret_key = await self.generate_symmetric_key()
|
||||
self._key_generated_at = datetime.utcnow()
|
||||
else:
|
||||
# Generate new asymmetric key pair
|
||||
await self.generate_key_pair()
|
||||
|
||||
logger.info("Key rotation completed successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Key rotation failed: {e}")
|
||||
return False
|
||||
|
||||
async def get_key_info(self) -> dict:
|
||||
"""Get key information"""
|
||||
return {
|
||||
"algorithm": self.algorithm,
|
||||
"key_generated_at": self._key_generated_at.isoformat() if self._key_generated_at else None,
|
||||
"key_expires_at": (
|
||||
self._key_generated_at + timedelta(seconds=self.key_rotation_interval)
|
||||
).isoformat() if self._key_generated_at else None,
|
||||
"is_expired": await self.is_key_expired(),
|
||||
"has_private_key": self._private_key is not None or self._secret_key is not None,
|
||||
"has_public_key": self._public_key is not None or self._secret_key is not None
|
||||
}
|
||||
|
||||
async def export_public_key_pem(self) -> Optional[str]:
|
||||
"""Export public key in PEM format"""
|
||||
if self.algorithm == "HS256":
|
||||
return None # Symmetric key not exported
|
||||
|
||||
if not self._public_key:
|
||||
return None
|
||||
|
||||
try:
|
||||
public_pem = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
)
|
||||
return public_pem.decode()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to export public key: {e}")
|
||||
return None
|
||||
536
doris_mcp_server/auth/oauth_client.py
Normal file
536
doris_mcp_server/auth/oauth_client.py
Normal file
@@ -0,0 +1,536 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
OAuth 2.0/OIDC Client Manager
|
||||
Provides OAuth authentication client implementation with PKCE and OIDC support
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional, Any, Tuple
|
||||
from urllib.parse import urlencode, parse_qs, urlparse
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError:
|
||||
raise ImportError("aiohttp is required for OAuth functionality. Install with: pip install aiohttp")
|
||||
|
||||
from .oauth_types import (
|
||||
OAuthProvider, OAuthState, OAuthTokens, OAuthUserInfo,
|
||||
OIDCDiscovery, OAuthError, OAuthProviderConfig, OAUTH_PROVIDERS
|
||||
)
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OAuthStateManager:
|
||||
"""Manages OAuth state parameters for CSRF protection"""
|
||||
|
||||
def __init__(self, state_expiry: int = 600):
|
||||
"""Initialize state manager
|
||||
|
||||
Args:
|
||||
state_expiry: State expiry time in seconds
|
||||
"""
|
||||
self.state_expiry = state_expiry
|
||||
self._states: Dict[str, OAuthState] = {}
|
||||
self._cleanup_task = None
|
||||
|
||||
logger.info("OAuthStateManager initialized")
|
||||
|
||||
async def start(self):
|
||||
"""Start periodic cleanup task"""
|
||||
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||
logger.info("OAuth state manager started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop periodic cleanup task"""
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("OAuth state manager stopped")
|
||||
|
||||
def create_state(self, redirect_uri: str, pkce_enabled: bool = True,
|
||||
nonce_enabled: bool = True) -> OAuthState:
|
||||
"""Create new OAuth state
|
||||
|
||||
Args:
|
||||
redirect_uri: OAuth redirect URI
|
||||
pkce_enabled: Whether to enable PKCE
|
||||
nonce_enabled: Whether to enable nonce (for OIDC)
|
||||
|
||||
Returns:
|
||||
OAuth state object
|
||||
"""
|
||||
state = secrets.token_urlsafe(32)
|
||||
nonce = secrets.token_urlsafe(32) if nonce_enabled else None
|
||||
|
||||
pkce_verifier = None
|
||||
pkce_challenge = None
|
||||
if pkce_enabled:
|
||||
pkce_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=')
|
||||
challenge_bytes = hashlib.sha256(pkce_verifier.encode()).digest()
|
||||
pkce_challenge = base64.urlsafe_b64encode(challenge_bytes).decode('utf-8').rstrip('=')
|
||||
|
||||
oauth_state = OAuthState(
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
pkce_verifier=pkce_verifier,
|
||||
pkce_challenge=pkce_challenge,
|
||||
redirect_uri=redirect_uri,
|
||||
created_at=datetime.utcnow(),
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=self.state_expiry)
|
||||
)
|
||||
|
||||
self._states[state] = oauth_state
|
||||
logger.debug(f"Created OAuth state: {state}")
|
||||
return oauth_state
|
||||
|
||||
def get_state(self, state: str) -> Optional[OAuthState]:
|
||||
"""Get OAuth state by state parameter
|
||||
|
||||
Args:
|
||||
state: State parameter
|
||||
|
||||
Returns:
|
||||
OAuth state object or None if not found/expired
|
||||
"""
|
||||
oauth_state = self._states.get(state)
|
||||
if oauth_state and oauth_state.expires_at > datetime.utcnow():
|
||||
return oauth_state
|
||||
elif oauth_state:
|
||||
# Remove expired state
|
||||
del self._states[state]
|
||||
logger.debug(f"Removed expired OAuth state: {state}")
|
||||
return None
|
||||
|
||||
def consume_state(self, state: str) -> Optional[OAuthState]:
|
||||
"""Get and remove OAuth state
|
||||
|
||||
Args:
|
||||
state: State parameter
|
||||
|
||||
Returns:
|
||||
OAuth state object or None if not found/expired
|
||||
"""
|
||||
oauth_state = self.get_state(state)
|
||||
if oauth_state:
|
||||
del self._states[state]
|
||||
logger.debug(f"Consumed OAuth state: {state}")
|
||||
return oauth_state
|
||||
|
||||
async def _periodic_cleanup(self):
|
||||
"""Periodic cleanup of expired states"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(300) # Clean up every 5 minutes
|
||||
current_time = datetime.utcnow()
|
||||
expired_states = [
|
||||
state for state, oauth_state in self._states.items()
|
||||
if oauth_state.expires_at <= current_time
|
||||
]
|
||||
|
||||
for state in expired_states:
|
||||
del self._states[state]
|
||||
|
||||
if expired_states:
|
||||
logger.info(f"Cleaned up {len(expired_states)} expired OAuth states")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error during OAuth state cleanup: {e}")
|
||||
|
||||
|
||||
class OAuthClient:
|
||||
"""OAuth 2.0/OIDC Client implementation"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize OAuth client
|
||||
|
||||
Args:
|
||||
config: DorisConfig with OAuth configuration
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
# Access OAuth settings through security configuration
|
||||
if hasattr(config, 'security'):
|
||||
security_config = config.security
|
||||
else:
|
||||
security_config = config
|
||||
|
||||
self.enabled = security_config.oauth_enabled
|
||||
if not self.enabled:
|
||||
logger.info("OAuth client disabled by configuration")
|
||||
return
|
||||
|
||||
# Build provider configuration
|
||||
self.provider_config = self._build_provider_config(security_config)
|
||||
self.state_manager = OAuthStateManager(security_config.oauth_state_expiry)
|
||||
|
||||
# HTTP client session
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
# Discovery cache
|
||||
self._discovery_cache: Optional[OIDCDiscovery] = None
|
||||
self._discovery_cache_time: Optional[datetime] = None
|
||||
|
||||
logger.info(f"OAuthClient initialized for provider: {self.provider_config.provider.value}")
|
||||
|
||||
def _build_provider_config(self, security_config) -> OAuthProviderConfig:
|
||||
"""Build OAuth provider configuration
|
||||
|
||||
Args:
|
||||
security_config: Security configuration object
|
||||
|
||||
Returns:
|
||||
OAuth provider configuration
|
||||
"""
|
||||
try:
|
||||
provider = OAuthProvider(security_config.oauth_provider)
|
||||
except ValueError:
|
||||
provider = OAuthProvider.CUSTOM
|
||||
|
||||
# Get default configuration for known providers
|
||||
defaults = OAUTH_PROVIDERS.get(provider, {})
|
||||
|
||||
return OAuthProviderConfig(
|
||||
provider=provider,
|
||||
client_id=security_config.oauth_client_id,
|
||||
client_secret=security_config.oauth_client_secret,
|
||||
redirect_uri=security_config.oauth_redirect_uri,
|
||||
scopes=security_config.oauth_scopes or defaults.get("scopes", ["openid", "email", "profile"]),
|
||||
|
||||
# Endpoints (use configured or defaults)
|
||||
authorization_endpoint=security_config.oauth_authorization_endpoint or defaults.get("authorization_endpoint", ""),
|
||||
token_endpoint=security_config.oauth_token_endpoint or defaults.get("token_endpoint", ""),
|
||||
userinfo_endpoint=security_config.oauth_userinfo_endpoint or defaults.get("userinfo_endpoint"),
|
||||
jwks_uri=security_config.oauth_jwks_uri or defaults.get("jwks_uri"),
|
||||
|
||||
# Discovery
|
||||
discovery_url=security_config.oidc_discovery_url or defaults.get("discovery_url"),
|
||||
|
||||
# Settings
|
||||
pkce_enabled=security_config.oauth_pkce_enabled,
|
||||
nonce_enabled=security_config.oauth_nonce_enabled,
|
||||
|
||||
# User mapping
|
||||
user_id_claim=security_config.oauth_user_id_claim or defaults.get("user_id_claim", "sub"),
|
||||
email_claim=security_config.oauth_email_claim or defaults.get("email_claim", "email"),
|
||||
name_claim=security_config.oauth_name_claim or defaults.get("name_claim", "name"),
|
||||
roles_claim=security_config.oauth_roles_claim,
|
||||
default_roles=security_config.oauth_default_roles
|
||||
)
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize OAuth client
|
||||
|
||||
Returns:
|
||||
True if initialization successful
|
||||
"""
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
try:
|
||||
# Create HTTP session
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# Start state manager
|
||||
await self.state_manager.start()
|
||||
|
||||
# Perform OIDC discovery if configured
|
||||
if self.provider_config.discovery_url:
|
||||
await self._discover_oidc_endpoints()
|
||||
|
||||
logger.info("OAuth client initialization completed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize OAuth client: {e}")
|
||||
return False
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown OAuth client"""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
try:
|
||||
# Stop state manager
|
||||
await self.state_manager.stop()
|
||||
|
||||
# Close HTTP session
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
logger.info("OAuth client shutdown completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during OAuth client shutdown: {e}")
|
||||
|
||||
async def _discover_oidc_endpoints(self):
|
||||
"""Discover OIDC endpoints using discovery URL"""
|
||||
try:
|
||||
# Check cache first
|
||||
if (self._discovery_cache and self._discovery_cache_time and
|
||||
datetime.utcnow() - self._discovery_cache_time < timedelta(hours=1)):
|
||||
return self._discovery_cache
|
||||
|
||||
logger.info(f"Discovering OIDC endpoints: {self.provider_config.discovery_url}")
|
||||
|
||||
async with self._session.get(self.provider_config.discovery_url) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
|
||||
discovery = OIDCDiscovery(
|
||||
issuer=data["issuer"],
|
||||
authorization_endpoint=data["authorization_endpoint"],
|
||||
token_endpoint=data["token_endpoint"],
|
||||
userinfo_endpoint=data.get("userinfo_endpoint"),
|
||||
jwks_uri=data.get("jwks_uri"),
|
||||
scopes_supported=data.get("scopes_supported"),
|
||||
response_types_supported=data.get("response_types_supported"),
|
||||
subject_types_supported=data.get("subject_types_supported"),
|
||||
id_token_signing_alg_values_supported=data.get("id_token_signing_alg_values_supported")
|
||||
)
|
||||
|
||||
# Update provider configuration with discovered endpoints
|
||||
if not self.provider_config.authorization_endpoint:
|
||||
self.provider_config.authorization_endpoint = discovery.authorization_endpoint
|
||||
if not self.provider_config.token_endpoint:
|
||||
self.provider_config.token_endpoint = discovery.token_endpoint
|
||||
if not self.provider_config.userinfo_endpoint:
|
||||
self.provider_config.userinfo_endpoint = discovery.userinfo_endpoint
|
||||
if not self.provider_config.jwks_uri:
|
||||
self.provider_config.jwks_uri = discovery.jwks_uri
|
||||
|
||||
# Cache discovery result
|
||||
self._discovery_cache = discovery
|
||||
self._discovery_cache_time = datetime.utcnow()
|
||||
|
||||
logger.info("OIDC endpoint discovery completed successfully")
|
||||
return discovery
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OIDC endpoint discovery failed: {e}")
|
||||
raise
|
||||
|
||||
def build_authorization_url(self) -> Tuple[str, OAuthState]:
|
||||
"""Build OAuth authorization URL
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, oauth_state)
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth client is not enabled")
|
||||
|
||||
# Create state for CSRF protection
|
||||
oauth_state = self.state_manager.create_state(
|
||||
redirect_uri=self.provider_config.redirect_uri,
|
||||
pkce_enabled=self.provider_config.pkce_enabled,
|
||||
nonce_enabled=self.provider_config.nonce_enabled
|
||||
)
|
||||
|
||||
# Build authorization parameters
|
||||
params = {
|
||||
'response_type': 'code',
|
||||
'client_id': self.provider_config.client_id,
|
||||
'redirect_uri': self.provider_config.redirect_uri,
|
||||
'scope': ' '.join(self.provider_config.scopes),
|
||||
'state': oauth_state.state
|
||||
}
|
||||
|
||||
# Add PKCE challenge
|
||||
if oauth_state.pkce_challenge:
|
||||
params['code_challenge'] = oauth_state.pkce_challenge
|
||||
params['code_challenge_method'] = 'S256'
|
||||
|
||||
# Add nonce for OIDC
|
||||
if oauth_state.nonce:
|
||||
params['nonce'] = oauth_state.nonce
|
||||
|
||||
# Build URL
|
||||
authorization_url = f"{self.provider_config.authorization_endpoint}?{urlencode(params)}"
|
||||
|
||||
logger.info(f"Built OAuth authorization URL for state: {oauth_state.state}")
|
||||
return authorization_url, oauth_state
|
||||
|
||||
async def exchange_code_for_tokens(self, code: str, state: str) -> Tuple[OAuthTokens, OAuthState]:
|
||||
"""Exchange authorization code for tokens
|
||||
|
||||
Args:
|
||||
code: Authorization code
|
||||
state: State parameter
|
||||
|
||||
Returns:
|
||||
Tuple of (OAuth tokens, OAuth state)
|
||||
|
||||
Raises:
|
||||
ValueError: If state is invalid or exchange fails
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth client is not enabled")
|
||||
|
||||
# Validate and consume state
|
||||
oauth_state = self.state_manager.consume_state(state)
|
||||
if not oauth_state:
|
||||
raise ValueError("Invalid or expired state parameter")
|
||||
|
||||
try:
|
||||
# Prepare token request
|
||||
data = {
|
||||
'grant_type': 'authorization_code',
|
||||
'client_id': self.provider_config.client_id,
|
||||
'client_secret': self.provider_config.client_secret,
|
||||
'code': code,
|
||||
'redirect_uri': oauth_state.redirect_uri
|
||||
}
|
||||
|
||||
# Add PKCE verifier
|
||||
if oauth_state.pkce_verifier:
|
||||
data['code_verifier'] = oauth_state.pkce_verifier
|
||||
|
||||
# Make token request
|
||||
async with self._session.post(
|
||||
self.provider_config.token_endpoint,
|
||||
data=data,
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded'}
|
||||
) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
if response.status != 200:
|
||||
error_msg = response_data.get('error_description', response_data.get('error', 'Token exchange failed'))
|
||||
raise ValueError(f"Token exchange failed: {error_msg}")
|
||||
|
||||
tokens = OAuthTokens(
|
||||
access_token=response_data['access_token'],
|
||||
token_type=response_data.get('token_type', 'Bearer'),
|
||||
expires_in=response_data.get('expires_in'),
|
||||
refresh_token=response_data.get('refresh_token'),
|
||||
scope=response_data.get('scope'),
|
||||
id_token=response_data.get('id_token')
|
||||
)
|
||||
|
||||
logger.info("Successfully exchanged authorization code for tokens")
|
||||
return tokens, oauth_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token exchange failed: {e}")
|
||||
raise ValueError(f"Token exchange failed: {str(e)}")
|
||||
|
||||
async def get_user_info(self, tokens: OAuthTokens) -> OAuthUserInfo:
|
||||
"""Get user information from OAuth provider
|
||||
|
||||
Args:
|
||||
tokens: OAuth tokens
|
||||
|
||||
Returns:
|
||||
OAuth user information
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth client is not enabled")
|
||||
|
||||
if not self.provider_config.userinfo_endpoint:
|
||||
raise ValueError("Userinfo endpoint not configured")
|
||||
|
||||
try:
|
||||
# Make userinfo request
|
||||
headers = {'Authorization': f'{tokens.token_type} {tokens.access_token}'}
|
||||
|
||||
async with self._session.get(
|
||||
self.provider_config.userinfo_endpoint,
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
user_data = await response.json()
|
||||
|
||||
# Extract user information using configured claims
|
||||
user_info = OAuthUserInfo(
|
||||
sub=str(user_data.get(self.provider_config.user_id_claim, '')),
|
||||
email=user_data.get(self.provider_config.email_claim),
|
||||
name=user_data.get(self.provider_config.name_claim),
|
||||
given_name=user_data.get('given_name'),
|
||||
family_name=user_data.get('family_name'),
|
||||
picture=user_data.get('picture'),
|
||||
locale=user_data.get('locale'),
|
||||
email_verified=user_data.get('email_verified'),
|
||||
roles=user_data.get(self.provider_config.roles_claim, self.provider_config.default_roles.copy()),
|
||||
raw_claims=user_data
|
||||
)
|
||||
|
||||
logger.info(f"Retrieved user info for user: {user_info.sub}")
|
||||
return user_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get user info: {e}")
|
||||
raise ValueError(f"Failed to get user info: {str(e)}")
|
||||
|
||||
async def refresh_tokens(self, refresh_token: str) -> OAuthTokens:
|
||||
"""Refresh OAuth tokens
|
||||
|
||||
Args:
|
||||
refresh_token: Refresh token
|
||||
|
||||
Returns:
|
||||
New OAuth tokens
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth client is not enabled")
|
||||
|
||||
try:
|
||||
data = {
|
||||
'grant_type': 'refresh_token',
|
||||
'client_id': self.provider_config.client_id,
|
||||
'client_secret': self.provider_config.client_secret,
|
||||
'refresh_token': refresh_token
|
||||
}
|
||||
|
||||
async with self._session.post(
|
||||
self.provider_config.token_endpoint,
|
||||
data=data,
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded'}
|
||||
) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
if response.status != 200:
|
||||
error_msg = response_data.get('error_description', response_data.get('error', 'Token refresh failed'))
|
||||
raise ValueError(f"Token refresh failed: {error_msg}")
|
||||
|
||||
tokens = OAuthTokens(
|
||||
access_token=response_data['access_token'],
|
||||
token_type=response_data.get('token_type', 'Bearer'),
|
||||
expires_in=response_data.get('expires_in'),
|
||||
refresh_token=response_data.get('refresh_token', refresh_token), # Keep old if not provided
|
||||
scope=response_data.get('scope'),
|
||||
id_token=response_data.get('id_token')
|
||||
)
|
||||
|
||||
logger.info("Successfully refreshed OAuth tokens")
|
||||
return tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token refresh failed: {e}")
|
||||
raise ValueError(f"Token refresh failed: {str(e)}")
|
||||
312
doris_mcp_server/auth/oauth_handlers.py
Normal file
312
doris_mcp_server/auth/oauth_handlers.py
Normal file
@@ -0,0 +1,312 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
OAuth HTTP Handlers
|
||||
Provides HTTP endpoints for OAuth authentication flow
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
import json
|
||||
|
||||
from starlette.responses import JSONResponse, RedirectResponse, HTMLResponse
|
||||
from starlette.requests import Request
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OAuthHandlers:
|
||||
"""OAuth HTTP request handlers"""
|
||||
|
||||
def __init__(self, security_manager):
|
||||
"""Initialize OAuth handlers
|
||||
|
||||
Args:
|
||||
security_manager: DorisSecurityManager instance
|
||||
"""
|
||||
self.security_manager = security_manager
|
||||
logger.info("OAuth handlers initialized")
|
||||
|
||||
async def handle_login(self, request: Request) -> JSONResponse:
|
||||
"""Handle OAuth login initiation
|
||||
|
||||
Returns JSON with authorization URL and state
|
||||
"""
|
||||
try:
|
||||
# Check if OAuth is enabled
|
||||
oauth_info = self.security_manager.get_oauth_provider_info()
|
||||
if not oauth_info.get("enabled"):
|
||||
return JSONResponse(
|
||||
{"error": "OAuth authentication is not enabled"},
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# Get authorization URL
|
||||
authorization_url, state = self.security_manager.get_oauth_authorization_url()
|
||||
|
||||
return JSONResponse({
|
||||
"authorization_url": authorization_url,
|
||||
"state": state,
|
||||
"provider": oauth_info.get("provider"),
|
||||
"message": "Navigate to authorization_url to complete OAuth login"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth login initiation failed: {e}")
|
||||
return JSONResponse(
|
||||
{"error": f"OAuth login failed: {str(e)}"},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
async def handle_callback(self, request: Request) -> JSONResponse:
|
||||
"""Handle OAuth callback
|
||||
|
||||
Processes the OAuth callback and returns authentication result
|
||||
"""
|
||||
try:
|
||||
# Get query parameters
|
||||
query_params = dict(request.query_params)
|
||||
|
||||
# Check for error in callback
|
||||
if "error" in query_params:
|
||||
error_description = query_params.get("error_description", "Unknown error")
|
||||
logger.warning(f"OAuth callback error: {query_params['error']} - {error_description}")
|
||||
return JSONResponse(
|
||||
{
|
||||
"error": query_params["error"],
|
||||
"error_description": error_description,
|
||||
"error_uri": query_params.get("error_uri")
|
||||
},
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# Extract required parameters
|
||||
code = query_params.get("code")
|
||||
state = query_params.get("state")
|
||||
|
||||
if not code or not state:
|
||||
return JSONResponse(
|
||||
{"error": "Missing required parameters: code and state"},
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# Handle OAuth callback
|
||||
auth_context = await self.security_manager.handle_oauth_callback(code, state)
|
||||
|
||||
# Return successful authentication response
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"user_id": auth_context.user_id,
|
||||
"roles": auth_context.roles,
|
||||
"permissions": auth_context.permissions,
|
||||
"security_level": auth_context.security_level.value,
|
||||
"session_id": auth_context.session_id,
|
||||
"message": "OAuth authentication successful"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth callback handling failed: {e}")
|
||||
return JSONResponse(
|
||||
{"error": f"OAuth callback failed: {str(e)}"},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
async def handle_provider_info(self, request: Request) -> JSONResponse:
|
||||
"""Handle OAuth provider information request
|
||||
|
||||
Returns information about the configured OAuth provider
|
||||
"""
|
||||
try:
|
||||
provider_info = self.security_manager.get_oauth_provider_info()
|
||||
return JSONResponse(provider_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get OAuth provider info: {e}")
|
||||
return JSONResponse(
|
||||
{"error": f"Failed to get provider info: {str(e)}"},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
async def handle_demo_page(self, request: Request) -> HTMLResponse:
|
||||
"""Handle OAuth demo page
|
||||
|
||||
Returns a simple HTML page for testing OAuth flow
|
||||
"""
|
||||
oauth_info = self.security_manager.get_oauth_provider_info()
|
||||
if not oauth_info.get("enabled"):
|
||||
return HTMLResponse("""
|
||||
<html>
|
||||
<head><title>OAuth Demo</title></head>
|
||||
<body>
|
||||
<h1>OAuth Demo</h1>
|
||||
<p style="color: red;">OAuth authentication is not enabled.</p>
|
||||
<p>Please configure OAuth settings in your security configuration.</p>
|
||||
</body>
|
||||
</html>
|
||||
""")
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Doris MCP Server - OAuth Demo</title>
|
||||
<style>
|
||||
body {{
|
||||
font-family: Arial, sans-serif;
|
||||
max-width: 800px;
|
||||
margin: 0 auto;
|
||||
padding: 20px;
|
||||
}}
|
||||
.info {{
|
||||
background-color: #f0f8ff;
|
||||
padding: 15px;
|
||||
border-left: 4px solid #0066cc;
|
||||
margin: 20px 0;
|
||||
}}
|
||||
.error {{
|
||||
background-color: #ffe6e6;
|
||||
padding: 15px;
|
||||
border-left: 4px solid #cc0000;
|
||||
margin: 20px 0;
|
||||
}}
|
||||
.success {{
|
||||
background-color: #e6ffe6;
|
||||
padding: 15px;
|
||||
border-left: 4px solid #00cc00;
|
||||
margin: 20px 0;
|
||||
}}
|
||||
button {{
|
||||
background-color: #0066cc;
|
||||
color: white;
|
||||
padding: 10px 20px;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 16px;
|
||||
}}
|
||||
button:hover {{
|
||||
background-color: #0052a3;
|
||||
}}
|
||||
pre {{
|
||||
background-color: #f5f5f5;
|
||||
padding: 10px;
|
||||
border-radius: 4px;
|
||||
overflow-x: auto;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Doris MCP Server - OAuth Demo</h1>
|
||||
|
||||
<div class="info">
|
||||
<h3>OAuth Configuration</h3>
|
||||
<p><strong>Provider:</strong> {oauth_info.get('provider', 'N/A')}</p>
|
||||
<p><strong>Client ID:</strong> {oauth_info.get('client_id', 'N/A')}</p>
|
||||
<p><strong>Scopes:</strong> {', '.join(oauth_info.get('scopes', []))}</p>
|
||||
<p><strong>PKCE Enabled:</strong> {oauth_info.get('pkce_enabled', False)}</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<h3>OAuth Authentication Test</h3>
|
||||
<p>Click the button below to start OAuth authentication flow:</p>
|
||||
<button onclick="startOAuthFlow()">Start OAuth Login</button>
|
||||
</div>
|
||||
|
||||
<div id="result" style="margin-top: 20px;"></div>
|
||||
|
||||
<div>
|
||||
<h3>API Endpoints</h3>
|
||||
<ul>
|
||||
<li><code>GET /auth/login</code> - Initiate OAuth login</li>
|
||||
<li><code>GET /auth/callback</code> - OAuth callback handler</li>
|
||||
<li><code>GET /auth/provider</code> - Provider information</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
async function startOAuthFlow() {{
|
||||
const resultDiv = document.getElementById('result');
|
||||
resultDiv.innerHTML = '<div class="info">Initiating OAuth flow...</div>';
|
||||
|
||||
try {{
|
||||
const response = await fetch('/auth/login');
|
||||
const data = await response.json();
|
||||
|
||||
if (response.ok) {{
|
||||
resultDiv.innerHTML = `
|
||||
<div class="success">
|
||||
<h4>OAuth URL Generated Successfully</h4>
|
||||
<p><strong>State:</strong> ${{data.state}}</p>
|
||||
<p><strong>Provider:</strong> ${{data.provider}}</p>
|
||||
<p><a href="${{data.authorization_url}}" target="_blank">Click here to authenticate</a></p>
|
||||
<p><em>Note: After authentication, you will be redirected to the callback URL.</em></p>
|
||||
</div>
|
||||
`;
|
||||
|
||||
// Automatically redirect to OAuth provider
|
||||
// window.open(data.authorization_url, '_blank');
|
||||
}} else {{
|
||||
resultDiv.innerHTML = `
|
||||
<div class="error">
|
||||
<h4>Error</h4>
|
||||
<p>${{data.error}}</p>
|
||||
</div>
|
||||
`;
|
||||
}}
|
||||
}} catch (error) {{
|
||||
resultDiv.innerHTML = `
|
||||
<div class="error">
|
||||
<h4>Network Error</h4>
|
||||
<p>${{error.message}}</p>
|
||||
</div>
|
||||
`;
|
||||
}}
|
||||
}}
|
||||
|
||||
// Handle OAuth callback result if present in URL
|
||||
window.addEventListener('load', function() {{
|
||||
const urlParams = new URLSearchParams(window.location.search);
|
||||
if (urlParams.has('code') && urlParams.has('state')) {{
|
||||
const resultDiv = document.getElementById('result');
|
||||
resultDiv.innerHTML = `
|
||||
<div class="success">
|
||||
<h4>OAuth Callback Received</h4>
|
||||
<p>Code: ${{urlParams.get('code')}}</p>
|
||||
<p>State: ${{urlParams.get('state')}}</p>
|
||||
<p>The authentication was successful!</p>
|
||||
</div>
|
||||
`;
|
||||
}} else if (urlParams.has('error')) {{
|
||||
const resultDiv = document.getElementById('result');
|
||||
resultDiv.innerHTML = `
|
||||
<div class="error">
|
||||
<h4>OAuth Error</h4>
|
||||
<p>Error: ${{urlParams.get('error')}}</p>
|
||||
<p>Description: ${{urlParams.get('error_description') || 'No description'}}</p>
|
||||
</div>
|
||||
`;
|
||||
}}
|
||||
}});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return HTMLResponse(html_content)
|
||||
287
doris_mcp_server/auth/oauth_provider.py
Normal file
287
doris_mcp_server/auth/oauth_provider.py
Normal file
@@ -0,0 +1,287 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
OAuth Authentication Provider
|
||||
Integrates OAuth 2.0/OIDC authentication with the existing authentication framework
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from .oauth_client import OAuthClient
|
||||
from .oauth_types import OAuthTokens, OAuthUserInfo, OAuthState
|
||||
from ..utils.security import AuthContext, SecurityLevel
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OAuthAuthenticationProvider:
|
||||
"""OAuth authentication provider for Doris MCP Server"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize OAuth authentication provider
|
||||
|
||||
Args:
|
||||
config: DorisConfig with OAuth configuration
|
||||
"""
|
||||
self.config = config
|
||||
self.oauth_client = OAuthClient(config)
|
||||
self.enabled = self.oauth_client.enabled
|
||||
|
||||
logger.info(f"OAuthAuthenticationProvider initialized (enabled: {self.enabled})")
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize OAuth authentication provider
|
||||
|
||||
Returns:
|
||||
True if initialization successful
|
||||
"""
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
success = await self.oauth_client.initialize()
|
||||
if success:
|
||||
logger.info("OAuth authentication provider initialized successfully")
|
||||
else:
|
||||
logger.error("Failed to initialize OAuth authentication provider")
|
||||
return success
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown OAuth authentication provider"""
|
||||
if self.enabled:
|
||||
await self.oauth_client.shutdown()
|
||||
logger.info("OAuth authentication provider shutdown completed")
|
||||
|
||||
def get_authorization_url(self) -> Tuple[str, str]:
|
||||
"""Get OAuth authorization URL
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, state)
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth authentication is not enabled")
|
||||
|
||||
authorization_url, oauth_state = self.oauth_client.build_authorization_url()
|
||||
return authorization_url, oauth_state.state
|
||||
|
||||
async def handle_callback(self, code: str, state: str) -> AuthContext:
|
||||
"""Handle OAuth callback and create authentication context
|
||||
|
||||
Args:
|
||||
code: Authorization code from OAuth provider
|
||||
state: State parameter for CSRF protection
|
||||
|
||||
Returns:
|
||||
AuthContext for the authenticated user
|
||||
|
||||
Raises:
|
||||
ValueError: If authentication fails
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth authentication is not enabled")
|
||||
|
||||
try:
|
||||
# Exchange code for tokens
|
||||
tokens, oauth_state = await self.oauth_client.exchange_code_for_tokens(code, state)
|
||||
|
||||
# Get user information
|
||||
user_info = await self.oauth_client.get_user_info(tokens)
|
||||
|
||||
# Create authentication context
|
||||
auth_context = await self._create_auth_context(user_info, tokens)
|
||||
|
||||
logger.info(f"OAuth authentication successful for user: {auth_context.user_id}")
|
||||
return auth_context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth callback handling failed: {e}")
|
||||
raise ValueError(f"OAuth authentication failed: {str(e)}")
|
||||
|
||||
async def authenticate_with_token(self, access_token: str) -> AuthContext:
|
||||
"""Authenticate using OAuth access token
|
||||
|
||||
Args:
|
||||
access_token: OAuth access token
|
||||
|
||||
Returns:
|
||||
AuthContext for the authenticated user
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth authentication is not enabled")
|
||||
|
||||
try:
|
||||
# Create token object
|
||||
tokens = OAuthTokens(access_token=access_token)
|
||||
|
||||
# Get user information
|
||||
user_info = await self.oauth_client.get_user_info(tokens)
|
||||
|
||||
# Create authentication context
|
||||
auth_context = await self._create_auth_context(user_info, tokens)
|
||||
|
||||
logger.info(f"OAuth token authentication successful for user: {auth_context.user_id}")
|
||||
return auth_context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth token authentication failed: {e}")
|
||||
raise ValueError(f"OAuth token authentication failed: {str(e)}")
|
||||
|
||||
async def refresh_authentication(self, refresh_token: str) -> Tuple[AuthContext, str]:
|
||||
"""Refresh OAuth authentication
|
||||
|
||||
Args:
|
||||
refresh_token: OAuth refresh token
|
||||
|
||||
Returns:
|
||||
Tuple of (AuthContext, new_access_token)
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth authentication is not enabled")
|
||||
|
||||
try:
|
||||
# Refresh tokens
|
||||
tokens = await self.oauth_client.refresh_tokens(refresh_token)
|
||||
|
||||
# Get updated user information
|
||||
user_info = await self.oauth_client.get_user_info(tokens)
|
||||
|
||||
# Create authentication context
|
||||
auth_context = await self._create_auth_context(user_info, tokens)
|
||||
|
||||
logger.info(f"OAuth refresh successful for user: {auth_context.user_id}")
|
||||
return auth_context, tokens.access_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth refresh failed: {e}")
|
||||
raise ValueError(f"OAuth refresh failed: {str(e)}")
|
||||
|
||||
async def _create_auth_context(self, user_info: OAuthUserInfo, tokens: OAuthTokens) -> AuthContext:
|
||||
"""Create authentication context from OAuth user info
|
||||
|
||||
Args:
|
||||
user_info: OAuth user information
|
||||
tokens: OAuth tokens
|
||||
|
||||
Returns:
|
||||
AuthContext for the user
|
||||
"""
|
||||
# Determine security level based on roles or email domain
|
||||
security_level = await self._determine_security_level(user_info)
|
||||
|
||||
# Map OAuth roles to application permissions
|
||||
permissions = await self._map_permissions(user_info.roles)
|
||||
|
||||
# Generate session ID
|
||||
session_id = f"oauth_{user_info.sub}_{datetime.utcnow().timestamp()}"
|
||||
|
||||
return AuthContext(
|
||||
user_id=user_info.sub,
|
||||
roles=user_info.roles,
|
||||
permissions=permissions,
|
||||
session_id=session_id,
|
||||
login_time=datetime.utcnow(),
|
||||
last_activity=datetime.utcnow(),
|
||||
security_level=security_level
|
||||
)
|
||||
|
||||
async def _determine_security_level(self, user_info: OAuthUserInfo) -> SecurityLevel:
|
||||
"""Determine security level for OAuth user
|
||||
|
||||
Args:
|
||||
user_info: OAuth user information
|
||||
|
||||
Returns:
|
||||
SecurityLevel for the user
|
||||
"""
|
||||
# Check if user has admin roles
|
||||
admin_roles = {"admin", "administrator", "data_admin", "super_admin"}
|
||||
if any(role.lower() in admin_roles for role in user_info.roles):
|
||||
return SecurityLevel.SECRET
|
||||
|
||||
# Check email domain for internal users
|
||||
if user_info.email:
|
||||
# You can configure trusted domains for internal access
|
||||
trusted_domains = ["yourcompany.com", "internal.org"] # Configure as needed
|
||||
email_domain = user_info.email.split("@")[-1].lower()
|
||||
if email_domain in trusted_domains:
|
||||
return SecurityLevel.CONFIDENTIAL
|
||||
|
||||
# Check for special roles
|
||||
elevated_roles = {"data_analyst", "developer", "manager"}
|
||||
if any(role.lower() in elevated_roles for role in user_info.roles):
|
||||
return SecurityLevel.CONFIDENTIAL
|
||||
|
||||
# Default to internal level for OAuth users
|
||||
return SecurityLevel.INTERNAL
|
||||
|
||||
async def _map_permissions(self, roles: list[str]) -> list[str]:
|
||||
"""Map OAuth roles to application permissions
|
||||
|
||||
Args:
|
||||
roles: OAuth user roles
|
||||
|
||||
Returns:
|
||||
List of application permissions
|
||||
"""
|
||||
permissions = set()
|
||||
|
||||
# Role to permission mapping
|
||||
role_permissions = {
|
||||
"admin": ["admin", "read_data", "write_data", "manage_users"],
|
||||
"administrator": ["admin", "read_data", "write_data", "manage_users"],
|
||||
"data_admin": ["admin", "read_data", "write_data"],
|
||||
"super_admin": ["admin", "read_data", "write_data", "manage_users", "system_admin"],
|
||||
"data_analyst": ["read_data", "query_database"],
|
||||
"developer": ["read_data", "query_database", "debug"],
|
||||
"viewer": ["read_data"],
|
||||
"user": ["read_data"],
|
||||
"oauth_user": ["read_data"] # Default OAuth user permission
|
||||
}
|
||||
|
||||
# Map roles to permissions
|
||||
for role in roles:
|
||||
role_lower = role.lower()
|
||||
if role_lower in role_permissions:
|
||||
permissions.update(role_permissions[role_lower])
|
||||
|
||||
# Ensure OAuth users have at least basic permissions
|
||||
if not permissions:
|
||||
permissions.add("read_data")
|
||||
|
||||
return list(permissions)
|
||||
|
||||
def get_provider_info(self) -> Dict[str, Any]:
|
||||
"""Get OAuth provider information
|
||||
|
||||
Returns:
|
||||
Provider information dictionary
|
||||
"""
|
||||
if not self.enabled:
|
||||
return {"enabled": False}
|
||||
|
||||
config = self.oauth_client.provider_config
|
||||
return {
|
||||
"enabled": True,
|
||||
"provider": config.provider.value,
|
||||
"client_id": config.client_id,
|
||||
"scopes": config.scopes,
|
||||
"redirect_uri": config.redirect_uri,
|
||||
"pkce_enabled": config.pkce_enabled,
|
||||
"nonce_enabled": config.nonce_enabled
|
||||
}
|
||||
196
doris_mcp_server/auth/oauth_types.py
Normal file
196
doris_mcp_server/auth/oauth_types.py
Normal file
@@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
OAuth 2.0/OIDC Type Definitions
|
||||
Provides data types and models for OAuth authentication flow
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
|
||||
class OAuthProvider(Enum):
|
||||
"""OAuth provider enumeration"""
|
||||
GOOGLE = "google"
|
||||
MICROSOFT = "microsoft"
|
||||
GITHUB = "github"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class OAuthGrantType(Enum):
|
||||
"""OAuth grant type enumeration"""
|
||||
AUTHORIZATION_CODE = "authorization_code"
|
||||
REFRESH_TOKEN = "refresh_token"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthState:
|
||||
"""OAuth state parameter for CSRF protection"""
|
||||
state: str
|
||||
nonce: Optional[str] = None
|
||||
pkce_verifier: Optional[str] = None
|
||||
pkce_challenge: Optional[str] = None
|
||||
redirect_uri: str = ""
|
||||
created_at: datetime = None
|
||||
expires_at: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthTokens:
|
||||
"""OAuth token response"""
|
||||
access_token: str
|
||||
token_type: str = "Bearer"
|
||||
expires_in: Optional[int] = None
|
||||
refresh_token: Optional[str] = None
|
||||
scope: Optional[str] = None
|
||||
id_token: Optional[str] = None # OIDC ID token
|
||||
created_at: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthUserInfo:
|
||||
"""OAuth/OIDC user information"""
|
||||
sub: str # Subject identifier
|
||||
email: Optional[str] = None
|
||||
email_verified: Optional[bool] = None
|
||||
name: Optional[str] = None
|
||||
given_name: Optional[str] = None
|
||||
family_name: Optional[str] = None
|
||||
picture: Optional[str] = None
|
||||
locale: Optional[str] = None
|
||||
roles: List[str] = None
|
||||
raw_claims: Dict[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.roles is None:
|
||||
self.roles = []
|
||||
if self.raw_claims is None:
|
||||
self.raw_claims = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCDiscovery:
|
||||
"""OIDC Discovery document"""
|
||||
issuer: str
|
||||
authorization_endpoint: str
|
||||
token_endpoint: str
|
||||
userinfo_endpoint: Optional[str] = None
|
||||
jwks_uri: Optional[str] = None
|
||||
scopes_supported: List[str] = None
|
||||
response_types_supported: List[str] = None
|
||||
subject_types_supported: List[str] = None
|
||||
id_token_signing_alg_values_supported: List[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.scopes_supported is None:
|
||||
self.scopes_supported = ["openid"]
|
||||
if self.response_types_supported is None:
|
||||
self.response_types_supported = ["code"]
|
||||
if self.subject_types_supported is None:
|
||||
self.subject_types_supported = ["public"]
|
||||
if self.id_token_signing_alg_values_supported is None:
|
||||
self.id_token_signing_alg_values_supported = ["RS256"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthError:
|
||||
"""OAuth error response"""
|
||||
error: str
|
||||
error_description: Optional[str] = None
|
||||
error_uri: Optional[str] = None
|
||||
state: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthProviderConfig:
|
||||
"""OAuth provider configuration"""
|
||||
provider: OAuthProvider
|
||||
client_id: str
|
||||
client_secret: str
|
||||
redirect_uri: str
|
||||
scopes: List[str]
|
||||
|
||||
# Endpoints
|
||||
authorization_endpoint: str
|
||||
token_endpoint: str
|
||||
userinfo_endpoint: Optional[str] = None
|
||||
jwks_uri: Optional[str] = None
|
||||
|
||||
# Discovery
|
||||
discovery_url: Optional[str] = None
|
||||
|
||||
# Settings
|
||||
pkce_enabled: bool = True
|
||||
nonce_enabled: bool = True
|
||||
|
||||
# User mapping
|
||||
user_id_claim: str = "sub"
|
||||
email_claim: str = "email"
|
||||
name_claim: str = "name"
|
||||
roles_claim: str = "roles"
|
||||
default_roles: List[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.default_roles is None:
|
||||
self.default_roles = ["oauth_user"]
|
||||
|
||||
|
||||
# Pre-defined provider configurations
|
||||
OAUTH_PROVIDERS = {
|
||||
OAuthProvider.GOOGLE: {
|
||||
"authorization_endpoint": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_endpoint": "https://oauth2.googleapis.com/token",
|
||||
"userinfo_endpoint": "https://openidconnect.googleapis.com/v1/userinfo",
|
||||
"jwks_uri": "https://www.googleapis.com/oauth2/v3/certs",
|
||||
"discovery_url": "https://accounts.google.com/.well-known/openid_configuration",
|
||||
"scopes": ["openid", "email", "profile"],
|
||||
"user_id_claim": "sub",
|
||||
"email_claim": "email",
|
||||
"name_claim": "name"
|
||||
},
|
||||
OAuthProvider.MICROSOFT: {
|
||||
"authorization_endpoint": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
|
||||
"token_endpoint": "https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||
"userinfo_endpoint": "https://graph.microsoft.com/v1.0/me",
|
||||
"jwks_uri": "https://login.microsoftonline.com/common/discovery/v2.0/keys",
|
||||
"discovery_url": "https://login.microsoftonline.com/common/v2.0/.well-known/openid_configuration",
|
||||
"scopes": ["openid", "profile", "email", "User.Read"],
|
||||
"user_id_claim": "sub",
|
||||
"email_claim": "email",
|
||||
"name_claim": "name"
|
||||
},
|
||||
OAuthProvider.GITHUB: {
|
||||
"authorization_endpoint": "https://github.com/login/oauth/authorize",
|
||||
"token_endpoint": "https://github.com/login/oauth/access_token",
|
||||
"userinfo_endpoint": "https://api.github.com/user",
|
||||
"scopes": ["user:email", "read:user"],
|
||||
"user_id_claim": "id",
|
||||
"email_claim": "email",
|
||||
"name_claim": "name"
|
||||
}
|
||||
}
|
||||
481
doris_mcp_server/auth/token_handlers.py
Normal file
481
doris_mcp_server/auth/token_handlers.py
Normal file
@@ -0,0 +1,481 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Token Authentication HTTP Handlers
|
||||
|
||||
Provides HTTP endpoints for token management including creation, revocation,
|
||||
listing, and statistics. Used for administrative token management in HTTP mode.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, HTMLResponse
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.security import SecurityLevel
|
||||
|
||||
|
||||
class TokenHandlers:
|
||||
"""Token Authentication HTTP Handlers"""
|
||||
|
||||
def __init__(self, security_manager):
|
||||
self.security_manager = security_manager
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
async def handle_create_token(self, request: Request) -> JSONResponse:
|
||||
"""Handle token creation request"""
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
return JSONResponse({
|
||||
"error": "Token authentication is not enabled"
|
||||
}, status_code=503)
|
||||
|
||||
# Parse request data
|
||||
if request.method == "GET":
|
||||
# GET request with query parameters
|
||||
query_params = dict(request.query_params)
|
||||
token_id = query_params.get("token_id")
|
||||
user_id = query_params.get("user_id")
|
||||
roles = query_params.get("roles", "").split(",") if query_params.get("roles") else []
|
||||
permissions = query_params.get("permissions", "").split(",") if query_params.get("permissions") else []
|
||||
security_level_str = query_params.get("security_level", "internal")
|
||||
expires_hours_str = query_params.get("expires_hours")
|
||||
description = query_params.get("description", "")
|
||||
custom_token = query_params.get("custom_token")
|
||||
else:
|
||||
# POST request with JSON body
|
||||
try:
|
||||
body = await request.json()
|
||||
except:
|
||||
return JSONResponse({
|
||||
"error": "Invalid JSON body"
|
||||
}, status_code=400)
|
||||
|
||||
token_id = body.get("token_id")
|
||||
user_id = body.get("user_id")
|
||||
roles = body.get("roles", [])
|
||||
permissions = body.get("permissions", [])
|
||||
security_level_str = body.get("security_level", "internal")
|
||||
expires_hours_str = body.get("expires_hours")
|
||||
description = body.get("description", "")
|
||||
custom_token = body.get("custom_token")
|
||||
|
||||
# Validate required fields
|
||||
if not token_id or not user_id:
|
||||
return JSONResponse({
|
||||
"error": "token_id and user_id are required"
|
||||
}, status_code=400)
|
||||
|
||||
# Parse security level
|
||||
try:
|
||||
security_level = SecurityLevel(security_level_str.lower())
|
||||
except ValueError:
|
||||
security_level = SecurityLevel.INTERNAL
|
||||
|
||||
# Parse expires_hours
|
||||
expires_hours = None
|
||||
if expires_hours_str:
|
||||
try:
|
||||
expires_hours = int(expires_hours_str)
|
||||
except ValueError:
|
||||
return JSONResponse({
|
||||
"error": "expires_hours must be an integer"
|
||||
}, status_code=400)
|
||||
|
||||
# Create token
|
||||
try:
|
||||
token = await self.security_manager.create_token(
|
||||
token_id=token_id,
|
||||
user_id=user_id,
|
||||
roles=roles,
|
||||
permissions=permissions,
|
||||
security_level=security_level,
|
||||
expires_hours=expires_hours,
|
||||
description=description,
|
||||
custom_token=custom_token
|
||||
)
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"token_id": token_id,
|
||||
"user_id": user_id,
|
||||
"token": token,
|
||||
"roles": roles,
|
||||
"permissions": permissions,
|
||||
"security_level": security_level.value,
|
||||
"expires_hours": expires_hours,
|
||||
"description": description,
|
||||
"message": "Token created successfully"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Token creation failed: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Token creation failed: {str(e)}"
|
||||
}, status_code=400)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_create_token: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
async def handle_revoke_token(self, request: Request) -> JSONResponse:
|
||||
"""Handle token revocation request"""
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
return JSONResponse({
|
||||
"error": "Token authentication is not enabled"
|
||||
}, status_code=503)
|
||||
|
||||
# Get token_id from query parameters or path
|
||||
token_id = request.query_params.get("token_id")
|
||||
if not token_id and request.method == "DELETE":
|
||||
# Try to get from path: /token/revoke/{token_id}
|
||||
path_parts = str(request.url.path).split("/")
|
||||
if len(path_parts) >= 4:
|
||||
token_id = path_parts[-1]
|
||||
|
||||
if not token_id:
|
||||
return JSONResponse({
|
||||
"error": "token_id is required"
|
||||
}, status_code=400)
|
||||
|
||||
# Revoke token
|
||||
success = await self.security_manager.revoke_token(token_id)
|
||||
|
||||
if success:
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"token_id": token_id,
|
||||
"message": "Token revoked successfully"
|
||||
})
|
||||
else:
|
||||
return JSONResponse({
|
||||
"success": False,
|
||||
"token_id": token_id,
|
||||
"message": "Token not found or already revoked"
|
||||
}, status_code=404)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_revoke_token: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
async def handle_list_tokens(self, request: Request) -> JSONResponse:
|
||||
"""Handle token listing request"""
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
return JSONResponse({
|
||||
"error": "Token authentication is not enabled"
|
||||
}, status_code=503)
|
||||
|
||||
# Get tokens list
|
||||
tokens = await self.security_manager.list_tokens()
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"count": len(tokens),
|
||||
"tokens": tokens
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_list_tokens: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
async def handle_token_stats(self, request: Request) -> JSONResponse:
|
||||
"""Handle token statistics request"""
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
return JSONResponse({
|
||||
"error": "Token authentication is not enabled"
|
||||
}, status_code=503)
|
||||
|
||||
# Get token statistics
|
||||
stats = self.security_manager.get_token_stats()
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"stats": stats
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_token_stats: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
async def handle_cleanup_tokens(self, request: Request) -> JSONResponse:
|
||||
"""Handle expired tokens cleanup request"""
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
return JSONResponse({
|
||||
"error": "Token authentication is not enabled"
|
||||
}, status_code=503)
|
||||
|
||||
# Cleanup expired tokens
|
||||
cleaned_count = await self.security_manager.cleanup_expired_tokens()
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"cleaned_count": cleaned_count,
|
||||
"message": f"Cleaned up {cleaned_count} expired tokens"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_cleanup_tokens: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
async def handle_demo_page(self, request: Request) -> HTMLResponse:
|
||||
"""Handle token management demo page"""
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
html_content = """
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Token Management - Not Available</title>
|
||||
<style>
|
||||
body { font-family: Arial, sans-serif; margin: 50px; }
|
||||
.error { color: red; font-size: 18px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Token Management</h1>
|
||||
<div class="error">Token authentication is not enabled on this server.</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return HTMLResponse(html_content)
|
||||
|
||||
# Get current stats for demo
|
||||
stats = self.security_manager.get_token_stats()
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Doris MCP Server - Token Management</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; margin: 50px; background: #f5f5f5; }}
|
||||
.container {{ max-width: 1200px; margin: 0 auto; background: white; padding: 30px; border-radius: 8px; }}
|
||||
h1 {{ color: #333; }}
|
||||
.section {{ margin: 30px 0; padding: 20px; border: 1px solid #ddd; border-radius: 5px; }}
|
||||
.stats {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 15px; }}
|
||||
.stat-item {{ padding: 15px; background: #f8f9fa; border-radius: 5px; text-align: center; }}
|
||||
.stat-value {{ font-size: 24px; font-weight: bold; color: #007bff; }}
|
||||
.form-group {{ margin: 15px 0; }}
|
||||
.form-group label {{ display: block; margin-bottom: 5px; font-weight: bold; }}
|
||||
.form-group input, .form-group textarea {{ width: 100%; padding: 8px; border: 1px solid #ddd; border-radius: 4px; }}
|
||||
button {{ padding: 10px 20px; margin: 5px; border: none; border-radius: 4px; cursor: pointer; }}
|
||||
.btn-primary {{ background: #007bff; color: white; }}
|
||||
.btn-danger {{ background: #dc3545; color: white; }}
|
||||
.btn-success {{ background: #28a745; color: white; }}
|
||||
.response {{ margin: 15px 0; padding: 15px; border-radius: 5px; }}
|
||||
.response.success {{ background: #d4edda; border: 1px solid #c3e6cb; }}
|
||||
.response.error {{ background: #f8d7da; border: 1px solid #f5c6cb; }}
|
||||
.token-list {{ margin: 15px 0; }}
|
||||
.token-item {{ padding: 10px; margin: 5px 0; background: #f8f9fa; border-radius: 4px; }}
|
||||
pre {{ background: #f8f9fa; padding: 10px; border-radius: 4px; overflow-x: auto; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔐 Doris MCP Server - Token Management</h1>
|
||||
|
||||
<div class="section">
|
||||
<h2>📊 Token Statistics</h2>
|
||||
<div class="stats">
|
||||
<div class="stat-item">
|
||||
<div class="stat-value">{stats.get('total_tokens', 0)}</div>
|
||||
<div>Total Tokens</div>
|
||||
</div>
|
||||
<div class="stat-item">
|
||||
<div class="stat-value">{stats.get('active_tokens', 0)}</div>
|
||||
<div>Active Tokens</div>
|
||||
</div>
|
||||
<div class="stat-item">
|
||||
<div class="stat-value">{stats.get('expired_tokens', 0)}</div>
|
||||
<div>Expired Tokens</div>
|
||||
</div>
|
||||
</div>
|
||||
<p><strong>Token Expiry:</strong> {'Enabled' if stats.get('expiry_enabled') else 'Disabled'}</p>
|
||||
<p><strong>Default Expiry:</strong> {stats.get('default_expiry_hours', 0)} hours</p>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>➕ Create New Token</h2>
|
||||
<form id="createTokenForm">
|
||||
<div class="form-group">
|
||||
<label for="token_id">Token ID (required):</label>
|
||||
<input type="text" id="token_id" name="token_id" placeholder="e.g., my-app-token" required>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="user_id">User ID (required):</label>
|
||||
<input type="text" id="user_id" name="user_id" placeholder="e.g., john_doe" required>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="roles">Roles (comma-separated):</label>
|
||||
<input type="text" id="roles" name="roles" placeholder="e.g., data_analyst,viewer">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="permissions">Permissions (comma-separated):</label>
|
||||
<input type="text" id="permissions" name="permissions" placeholder="e.g., read_data,query_database">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="security_level">Security Level:</label>
|
||||
<select id="security_level" name="security_level">
|
||||
<option value="public">Public</option>
|
||||
<option value="internal" selected>Internal</option>
|
||||
<option value="confidential">Confidential</option>
|
||||
<option value="secret">Secret</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="expires_hours">Expires Hours (leave empty for default):</label>
|
||||
<input type="number" id="expires_hours" name="expires_hours" placeholder="e.g., 720 (30 days)">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="description">Description:</label>
|
||||
<textarea id="description" name="description" placeholder="Token description"></textarea>
|
||||
</div>
|
||||
<button type="submit" class="btn-primary">Create Token</button>
|
||||
</form>
|
||||
<div id="createTokenResponse"></div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>📋 Token Management</h2>
|
||||
<button id="listTokensBtn" class="btn-success">Refresh Token List</button>
|
||||
<button id="cleanupTokensBtn" class="btn-primary">Cleanup Expired Tokens</button>
|
||||
<div id="tokenListResponse"></div>
|
||||
|
||||
<h3>Revoke Token</h3>
|
||||
<div class="form-group">
|
||||
<input type="text" id="revokeTokenId" placeholder="Enter token ID to revoke">
|
||||
<button id="revokeTokenBtn" class="btn-danger">Revoke Token</button>
|
||||
</div>
|
||||
<div id="revokeTokenResponse"></div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>🔧 API Endpoints</h2>
|
||||
<p>Use these endpoints for programmatic token management:</p>
|
||||
<ul>
|
||||
<li><strong>POST /token/create</strong> - Create new token</li>
|
||||
<li><strong>DELETE /token/revoke?token_id=...</strong> - Revoke token</li>
|
||||
<li><strong>GET /token/list</strong> - List all tokens</li>
|
||||
<li><strong>GET /token/stats</strong> - Get token statistics</li>
|
||||
<li><strong>POST /token/cleanup</strong> - Cleanup expired tokens</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
function showResponse(elementId, data, isSuccess = true) {{
|
||||
const element = document.getElementById(elementId);
|
||||
element.innerHTML = '<pre>' + JSON.stringify(data, null, 2) + '</pre>';
|
||||
element.className = 'response ' + (isSuccess ? 'success' : 'error');
|
||||
}}
|
||||
|
||||
// Create token form
|
||||
document.getElementById('createTokenForm').addEventListener('submit', async (e) => {{
|
||||
e.preventDefault();
|
||||
const formData = new FormData(e.target);
|
||||
const data = Object.fromEntries(formData.entries());
|
||||
|
||||
// Convert comma-separated values to arrays
|
||||
data.roles = data.roles ? data.roles.split(',').map(r => r.trim()).filter(r => r) : [];
|
||||
data.permissions = data.permissions ? data.permissions.split(',').map(p => p.trim()).filter(p => p) : [];
|
||||
|
||||
try {{
|
||||
const response = await fetch('/token/create', {{
|
||||
method: 'POST',
|
||||
headers: {{'Content-Type': 'application/json'}},
|
||||
body: JSON.stringify(data)
|
||||
}});
|
||||
const result = await response.json();
|
||||
showResponse('createTokenResponse', result, response.ok);
|
||||
}} catch (error) {{
|
||||
showResponse('createTokenResponse', {{error: error.message}}, false);
|
||||
}}
|
||||
}});
|
||||
|
||||
// List tokens
|
||||
document.getElementById('listTokensBtn').addEventListener('click', async () => {{
|
||||
try {{
|
||||
const response = await fetch('/token/list');
|
||||
const result = await response.json();
|
||||
showResponse('tokenListResponse', result, response.ok);
|
||||
}} catch (error) {{
|
||||
showResponse('tokenListResponse', {{error: error.message}}, false);
|
||||
}}
|
||||
}});
|
||||
|
||||
// Cleanup tokens
|
||||
document.getElementById('cleanupTokensBtn').addEventListener('click', async () => {{
|
||||
try {{
|
||||
const response = await fetch('/token/cleanup', {{method: 'POST'}});
|
||||
const result = await response.json();
|
||||
showResponse('tokenListResponse', result, response.ok);
|
||||
}} catch (error) {{
|
||||
showResponse('tokenListResponse', {{error: error.message}}, false);
|
||||
}}
|
||||
}});
|
||||
|
||||
// Revoke token
|
||||
document.getElementById('revokeTokenBtn').addEventListener('click', async () => {{
|
||||
const tokenId = document.getElementById('revokeTokenId').value;
|
||||
if (!tokenId) {{
|
||||
showResponse('revokeTokenResponse', {{error: 'Token ID is required'}}, false);
|
||||
return;
|
||||
}}
|
||||
|
||||
try {{
|
||||
const response = await fetch(`/token/revoke?token_id=${{encodeURIComponent(tokenId)}}`, {{
|
||||
method: 'DELETE'
|
||||
}});
|
||||
const result = await response.json();
|
||||
showResponse('revokeTokenResponse', result, response.ok);
|
||||
}} catch (error) {{
|
||||
showResponse('revokeTokenResponse', {{error: error.message}}, false);
|
||||
}}
|
||||
}});
|
||||
|
||||
// Load token list on page load
|
||||
document.getElementById('listTokensBtn').click();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return HTMLResponse(html_content)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_demo_page: {e}")
|
||||
error_html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Token Management Error</title>
|
||||
<style>body {{ font-family: Arial, sans-serif; margin: 50px; }}</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Token Management Error</h1>
|
||||
<p>Error loading token management page: {str(e)}</p>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return HTMLResponse(error_html, status_code=500)
|
||||
456
doris_mcp_server/auth/token_manager.py
Normal file
456
doris_mcp_server/auth/token_manager.py
Normal file
@@ -0,0 +1,456 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Token Authentication Management Module
|
||||
|
||||
Provides enterprise-grade token authentication system with configurable tokens,
|
||||
expiration management, role-based access control and secure token storage.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.security import SecurityLevel
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenInfo:
|
||||
"""Token information structure"""
|
||||
|
||||
token_id: str # Unique token identifier for audit and management
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
expires_at: Optional[datetime] = None
|
||||
last_used: Optional[datetime] = None
|
||||
description: str = "" # Optional description for token purpose
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenValidationResult:
|
||||
"""Token validation result"""
|
||||
|
||||
is_valid: bool
|
||||
token_info: Optional[TokenInfo] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class TokenManager:
|
||||
"""Enterprise Token Authentication Manager
|
||||
|
||||
Features:
|
||||
- Configurable token storage (file-based or environment variables)
|
||||
- Token expiration management
|
||||
- Secure token hashing
|
||||
- Role-based access control
|
||||
- Token lifecycle management
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Token storage
|
||||
self._tokens: Dict[str, TokenInfo] = {} # token_hash -> TokenInfo
|
||||
self._token_ids: Dict[str, str] = {} # token_id -> token_hash
|
||||
|
||||
# Configuration
|
||||
self.token_file_path = getattr(config.security, 'token_file_path', 'tokens.json')
|
||||
self.enable_token_expiry = getattr(config.security, 'enable_token_expiry', True)
|
||||
self.default_token_expiry_hours = getattr(config.security, 'default_token_expiry_hours', 24 * 30) # 30 days
|
||||
self.token_hash_algorithm = getattr(config.security, 'token_hash_algorithm', 'sha256')
|
||||
|
||||
# Initialize with default tokens if none exist
|
||||
self._initialize_default_tokens()
|
||||
|
||||
# Load tokens from configuration
|
||||
self._load_tokens()
|
||||
|
||||
self.logger.info(f"TokenManager initialized with {len(self._tokens)} tokens")
|
||||
|
||||
def _initialize_default_tokens(self):
|
||||
"""Initialize default tokens for basic authentication (configurable via environment)"""
|
||||
# Default token configurations (can be overridden by environment variables)
|
||||
default_tokens = [
|
||||
{
|
||||
'token_id': 'admin-token',
|
||||
'token': os.getenv('DEFAULT_ADMIN_TOKEN', 'doris_admin_token_123456'),
|
||||
'description': os.getenv('DEFAULT_ADMIN_DESCRIPTION', 'Default admin API access token'),
|
||||
'expires_hours': None # Never expires
|
||||
},
|
||||
{
|
||||
'token_id': 'analyst-token',
|
||||
'token': os.getenv('DEFAULT_ANALYST_TOKEN', 'doris_analyst_token_123456'),
|
||||
'description': os.getenv('DEFAULT_ANALYST_DESCRIPTION', 'Default data analysis API access token'),
|
||||
'expires_hours': None # Never expires
|
||||
},
|
||||
{
|
||||
'token_id': 'readonly-token',
|
||||
'token': os.getenv('DEFAULT_READONLY_TOKEN', 'doris_readonly_token_123456'),
|
||||
'description': os.getenv('DEFAULT_READONLY_DESCRIPTION', 'Default read-only API access token'),
|
||||
'expires_hours': None # Never expires
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# Only add default tokens if no custom tokens are defined via environment variables
|
||||
# Check if any TOKEN_* environment variables exist (excluding system and legacy configs)
|
||||
excluded_prefixes = ('DEFAULT_', 'TOKEN_FILE_PATH', 'TOKEN_HASH_')
|
||||
excluded_vars = {'TOKEN_SECRET', 'TOKEN_EXPIRY'}
|
||||
|
||||
custom_tokens_exist = any(
|
||||
key.startswith('TOKEN_') and
|
||||
not key.startswith(excluded_prefixes) and
|
||||
not key.endswith(('_EXPIRES_HOURS', '_DESCRIPTION')) and
|
||||
key not in excluded_vars
|
||||
for key in os.environ.keys()
|
||||
)
|
||||
|
||||
# Also check if token file exists and has content
|
||||
token_file_exists = False
|
||||
if os.path.exists(self.token_file_path):
|
||||
try:
|
||||
with open(self.token_file_path, 'r') as f:
|
||||
content = f.read().strip()
|
||||
if content and content != '{}':
|
||||
token_file_exists = True
|
||||
except:
|
||||
pass
|
||||
|
||||
# Add default tokens only if no custom configuration exists
|
||||
if not custom_tokens_exist and not token_file_exists:
|
||||
for token_config in default_tokens:
|
||||
self._add_token_from_config(token_config)
|
||||
|
||||
self.logger.info(f"Initialized {len(default_tokens)} default tokens (no custom config found)")
|
||||
else:
|
||||
self.logger.info("Skipped default tokens initialization (custom tokens detected)")
|
||||
|
||||
def _add_token_from_config(self, token_config: Dict[str, Any]):
|
||||
"""Add token from configuration"""
|
||||
try:
|
||||
# Calculate expiration time
|
||||
expires_at = None
|
||||
if self.enable_token_expiry:
|
||||
expires_hours = token_config.get('expires_hours', self.default_token_expiry_hours)
|
||||
if expires_hours is not None:
|
||||
expires_at = datetime.utcnow() + timedelta(hours=expires_hours)
|
||||
|
||||
# Create token info
|
||||
token_info = TokenInfo(
|
||||
token_id=token_config['token_id'],
|
||||
expires_at=expires_at,
|
||||
description=token_config.get('description', ''),
|
||||
is_active=token_config.get('is_active', True)
|
||||
)
|
||||
|
||||
# Hash the token
|
||||
raw_token = token_config['token']
|
||||
token_hash = self._hash_token(raw_token)
|
||||
|
||||
# Store token
|
||||
self._tokens[token_hash] = token_info
|
||||
self._token_ids[token_info.token_id] = token_hash
|
||||
|
||||
self.logger.debug(f"Added token '{token_info.token_id}'")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to add token from config: {e}")
|
||||
|
||||
def _load_tokens(self):
|
||||
"""Load tokens from configuration sources"""
|
||||
# 1. Load from environment variables
|
||||
self._load_tokens_from_env()
|
||||
|
||||
# 2. Load from token file if exists
|
||||
if os.path.exists(self.token_file_path):
|
||||
self._load_tokens_from_file()
|
||||
|
||||
self.logger.info(f"Token loading completed, total tokens: {len(self._tokens)}")
|
||||
|
||||
def _load_tokens_from_env(self):
|
||||
"""Load tokens from environment variables
|
||||
|
||||
Simplified format:
|
||||
TOKEN_<ID>=<token>
|
||||
TOKEN_<ID>_EXPIRES_HOURS=<hours>
|
||||
TOKEN_<ID>_DESCRIPTION=<description>
|
||||
"""
|
||||
token_prefixes = set()
|
||||
|
||||
# Find all TOKEN_ environment variables (exclude legacy and system variables)
|
||||
excluded_token_vars = {
|
||||
'TOKEN_SECRET', # Legacy token secret
|
||||
'TOKEN_EXPIRY', # Legacy token expiry
|
||||
'TOKEN_FILE_PATH', # System config
|
||||
'TOKEN_HASH_ALGORITHM' # System config
|
||||
}
|
||||
|
||||
for key in os.environ:
|
||||
if (key.startswith('TOKEN_') and
|
||||
not key.endswith(('_EXPIRES_HOURS', '_DESCRIPTION')) and
|
||||
key not in excluded_token_vars):
|
||||
token_id = key[6:] # Remove 'TOKEN_' prefix
|
||||
token_prefixes.add(token_id)
|
||||
|
||||
# Load each token
|
||||
for token_id in token_prefixes:
|
||||
try:
|
||||
token = os.environ.get(f'TOKEN_{token_id}')
|
||||
if not token:
|
||||
continue
|
||||
|
||||
expires_hours_str = os.environ.get(f'TOKEN_{token_id}_EXPIRES_HOURS', str(self.default_token_expiry_hours))
|
||||
description = os.environ.get(f'TOKEN_{token_id}_DESCRIPTION', f'Environment token {token_id}')
|
||||
|
||||
expires_hours = None
|
||||
try:
|
||||
if expires_hours_str and expires_hours_str.lower() != 'none':
|
||||
expires_hours = int(expires_hours_str)
|
||||
except ValueError:
|
||||
expires_hours = self.default_token_expiry_hours
|
||||
|
||||
# Add token
|
||||
token_config = {
|
||||
'token_id': token_id.lower(),
|
||||
'token': token,
|
||||
'expires_hours': expires_hours,
|
||||
'description': description
|
||||
}
|
||||
|
||||
self._add_token_from_config(token_config)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load token {token_id} from environment: {e}")
|
||||
|
||||
def _load_tokens_from_file(self):
|
||||
"""Load tokens from JSON file"""
|
||||
try:
|
||||
with open(self.token_file_path, 'r', encoding='utf-8') as f:
|
||||
tokens_data = json.load(f)
|
||||
|
||||
if isinstance(tokens_data, dict) and 'tokens' in tokens_data:
|
||||
tokens_list = tokens_data['tokens']
|
||||
elif isinstance(tokens_data, list):
|
||||
tokens_list = tokens_data
|
||||
else:
|
||||
self.logger.error(f"Invalid token file format: {self.token_file_path}")
|
||||
return
|
||||
|
||||
for token_config in tokens_list:
|
||||
self._add_token_from_config(token_config)
|
||||
|
||||
self.logger.info(f"Loaded {len(tokens_list)} tokens from file: {self.token_file_path}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load tokens from file {self.token_file_path}: {e}")
|
||||
|
||||
def _hash_token(self, token: str) -> str:
|
||||
"""Hash token for secure storage"""
|
||||
if self.token_hash_algorithm == 'sha256':
|
||||
return hashlib.sha256(token.encode('utf-8')).hexdigest()
|
||||
elif self.token_hash_algorithm == 'sha512':
|
||||
return hashlib.sha512(token.encode('utf-8')).hexdigest()
|
||||
else:
|
||||
# Fallback to sha256
|
||||
return hashlib.sha256(token.encode('utf-8')).hexdigest()
|
||||
|
||||
async def validate_token(self, token: str) -> TokenValidationResult:
|
||||
"""Validate token and return user information"""
|
||||
try:
|
||||
# Hash the provided token
|
||||
token_hash = self._hash_token(token)
|
||||
|
||||
# Find token info
|
||||
token_info = self._tokens.get(token_hash)
|
||||
if not token_info:
|
||||
return TokenValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Invalid token"
|
||||
)
|
||||
|
||||
# Check if token is active
|
||||
if not token_info.is_active:
|
||||
return TokenValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Token is inactive"
|
||||
)
|
||||
|
||||
# Check expiration
|
||||
if token_info.expires_at and datetime.utcnow() > token_info.expires_at:
|
||||
return TokenValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Token has expired"
|
||||
)
|
||||
|
||||
# Update last used time
|
||||
token_info.last_used = datetime.utcnow()
|
||||
|
||||
return TokenValidationResult(
|
||||
is_valid=True,
|
||||
token_info=token_info
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Token validation error: {e}")
|
||||
return TokenValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"Token validation failed: {str(e)}"
|
||||
)
|
||||
|
||||
def generate_token(self, length: int = 32) -> str:
|
||||
"""Generate a cryptographically secure random token"""
|
||||
return secrets.token_urlsafe(length)
|
||||
|
||||
async def create_token(
|
||||
self,
|
||||
token_id: str,
|
||||
expires_hours: Optional[int] = None,
|
||||
description: str = "",
|
||||
custom_token: Optional[str] = None
|
||||
) -> str:
|
||||
"""Create a new token"""
|
||||
try:
|
||||
# Check if token_id already exists
|
||||
if token_id in self._token_ids:
|
||||
raise ValueError(f"Token ID '{token_id}' already exists")
|
||||
|
||||
# Generate or use provided token
|
||||
if custom_token:
|
||||
raw_token = custom_token
|
||||
else:
|
||||
raw_token = self.generate_token()
|
||||
|
||||
# Calculate expiration
|
||||
expires_at = None
|
||||
if expires_hours is not None:
|
||||
expires_at = datetime.utcnow() + timedelta(hours=expires_hours)
|
||||
elif self.enable_token_expiry:
|
||||
expires_at = datetime.utcnow() + timedelta(hours=self.default_token_expiry_hours)
|
||||
|
||||
# Create token info
|
||||
token_info = TokenInfo(
|
||||
token_id=token_id,
|
||||
expires_at=expires_at,
|
||||
description=description
|
||||
)
|
||||
|
||||
# Hash and store token
|
||||
token_hash = self._hash_token(raw_token)
|
||||
self._tokens[token_hash] = token_info
|
||||
self._token_ids[token_id] = token_hash
|
||||
|
||||
self.logger.info(f"Created new token '{token_id}'")
|
||||
|
||||
return raw_token
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to create token: {e}")
|
||||
raise
|
||||
|
||||
async def revoke_token(self, token_id: str) -> bool:
|
||||
"""Revoke a token by token ID"""
|
||||
try:
|
||||
if token_id not in self._token_ids:
|
||||
self.logger.warning(f"Token ID '{token_id}' not found")
|
||||
return False
|
||||
|
||||
# Get token hash and remove from storage
|
||||
token_hash = self._token_ids[token_id]
|
||||
if token_hash in self._tokens:
|
||||
del self._tokens[token_hash]
|
||||
del self._token_ids[token_id]
|
||||
|
||||
self.logger.info(f"Revoked token '{token_id}'")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to revoke token '{token_id}': {e}")
|
||||
return False
|
||||
|
||||
async def list_tokens(self) -> List[Dict[str, Any]]:
|
||||
"""List all tokens (without sensitive data)"""
|
||||
tokens = []
|
||||
|
||||
for token_hash, token_info in self._tokens.items():
|
||||
tokens.append({
|
||||
'token_id': token_info.token_id,
|
||||
'created_at': token_info.created_at.isoformat(),
|
||||
'expires_at': token_info.expires_at.isoformat() if token_info.expires_at else None,
|
||||
'last_used': token_info.last_used.isoformat() if token_info.last_used else None,
|
||||
'is_active': token_info.is_active,
|
||||
'description': token_info.description,
|
||||
'is_expired': token_info.expires_at and datetime.utcnow() > token_info.expires_at if token_info.expires_at else False
|
||||
})
|
||||
|
||||
# Sort by creation time
|
||||
tokens.sort(key=lambda x: x['created_at'], reverse=True)
|
||||
|
||||
return tokens
|
||||
|
||||
async def cleanup_expired_tokens(self) -> int:
|
||||
"""Remove expired tokens and return count"""
|
||||
if not self.enable_token_expiry:
|
||||
return 0
|
||||
|
||||
now = datetime.utcnow()
|
||||
expired_tokens = []
|
||||
|
||||
# Find expired tokens
|
||||
for token_hash, token_info in self._tokens.items():
|
||||
if token_info.expires_at and now > token_info.expires_at:
|
||||
expired_tokens.append((token_hash, token_info.token_id))
|
||||
|
||||
# Remove expired tokens
|
||||
for token_hash, token_id in expired_tokens:
|
||||
del self._tokens[token_hash]
|
||||
if token_id in self._token_ids:
|
||||
del self._token_ids[token_id]
|
||||
|
||||
if expired_tokens:
|
||||
self.logger.info(f"Cleaned up {len(expired_tokens)} expired tokens")
|
||||
|
||||
return len(expired_tokens)
|
||||
|
||||
async def save_tokens_to_file(self, file_path: Optional[str] = None) -> bool:
|
||||
"""Save current tokens to JSON file"""
|
||||
try:
|
||||
file_path = file_path or self.token_file_path
|
||||
tokens_list = await self.list_tokens()
|
||||
|
||||
tokens_data = {
|
||||
'version': '1.0',
|
||||
'created_at': datetime.utcnow().isoformat(),
|
||||
'tokens': tokens_list
|
||||
}
|
||||
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(tokens_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
self.logger.info(f"Saved {len(tokens_list)} tokens to file: {file_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to save tokens to file: {e}")
|
||||
return False
|
||||
|
||||
def get_token_stats(self) -> Dict[str, Any]:
|
||||
"""Get token statistics"""
|
||||
now = datetime.utcnow()
|
||||
total_tokens = len(self._tokens)
|
||||
active_tokens = sum(1 for info in self._tokens.values() if info.is_active)
|
||||
expired_tokens = sum(1 for info in self._tokens.values()
|
||||
if info.expires_at and now > info.expires_at)
|
||||
|
||||
return {
|
||||
'total_tokens': total_tokens,
|
||||
'active_tokens': active_tokens,
|
||||
'expired_tokens': expired_tokens,
|
||||
'expiry_enabled': self.enable_token_expiry,
|
||||
'default_expiry_hours': self.default_token_expiry_hours
|
||||
}
|
||||
365
doris_mcp_server/auth/token_validators.py
Normal file
365
doris_mcp_server/auth/token_validators.py
Normal file
@@ -0,0 +1,365 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
JWT Token Validation Module
|
||||
Provides token validation, blacklist management and security features
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, Set, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TokenBlacklist:
|
||||
"""JWT Token Blacklist Manager
|
||||
|
||||
Manages revoked tokens to prevent revoked tokens from being used again
|
||||
Supports both in-memory and persistent storage
|
||||
"""
|
||||
|
||||
def __init__(self, cleanup_interval: int = 3600):
|
||||
"""Initialize token blacklist
|
||||
|
||||
Args:
|
||||
cleanup_interval: Interval for cleaning up expired tokens (seconds)
|
||||
"""
|
||||
self.cleanup_interval = cleanup_interval
|
||||
# Storage format: {token_jti: expiry_timestamp}
|
||||
self._blacklisted_tokens: Dict[str, float] = {}
|
||||
self._cleanup_task = None
|
||||
|
||||
logger.info("TokenBlacklist initialized")
|
||||
|
||||
async def start(self):
|
||||
"""Start blacklist manager"""
|
||||
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||
logger.info("TokenBlacklist started with periodic cleanup")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop blacklist manager"""
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("TokenBlacklist stopped")
|
||||
|
||||
async def add_token(self, jti: str, exp: float):
|
||||
"""Add token to blacklist
|
||||
|
||||
Args:
|
||||
jti: JWT ID (unique identifier)
|
||||
exp: Token expiration timestamp
|
||||
"""
|
||||
self._blacklisted_tokens[jti] = exp
|
||||
logger.info(f"Token {jti} added to blacklist")
|
||||
|
||||
async def is_blacklisted(self, jti: str) -> bool:
|
||||
"""Check if token is blacklisted
|
||||
|
||||
Args:
|
||||
jti: JWT ID
|
||||
|
||||
Returns:
|
||||
True if blacklisted, False otherwise
|
||||
"""
|
||||
return jti in self._blacklisted_tokens
|
||||
|
||||
async def remove_token(self, jti: str) -> bool:
|
||||
"""Remove token from blacklist
|
||||
|
||||
Args:
|
||||
jti: JWT ID
|
||||
|
||||
Returns:
|
||||
True if removed, False if not found
|
||||
"""
|
||||
if jti in self._blacklisted_tokens:
|
||||
del self._blacklisted_tokens[jti]
|
||||
logger.info(f"Token {jti} removed from blacklist")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""Clean up expired blacklisted tokens
|
||||
|
||||
Returns:
|
||||
Number of tokens cleaned up
|
||||
"""
|
||||
current_time = time.time()
|
||||
expired_tokens = [
|
||||
jti for jti, exp in self._blacklisted_tokens.items()
|
||||
if exp <= current_time
|
||||
]
|
||||
|
||||
for jti in expired_tokens:
|
||||
del self._blacklisted_tokens[jti]
|
||||
|
||||
if expired_tokens:
|
||||
logger.info(f"Cleaned up {len(expired_tokens)} expired tokens from blacklist")
|
||||
|
||||
return len(expired_tokens)
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get blacklist statistics"""
|
||||
current_time = time.time()
|
||||
active_tokens = sum(1 for exp in self._blacklisted_tokens.values() if exp > current_time)
|
||||
|
||||
return {
|
||||
"total_blacklisted": len(self._blacklisted_tokens),
|
||||
"active_blacklisted": active_tokens,
|
||||
"expired_blacklisted": len(self._blacklisted_tokens) - active_tokens,
|
||||
"cleanup_interval": self.cleanup_interval
|
||||
}
|
||||
|
||||
async def _periodic_cleanup(self):
|
||||
"""Periodically clean up expired tokens"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.cleanup_interval)
|
||||
await self.cleanup_expired()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error during periodic cleanup: {e}")
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Token usage rate limiter"""
|
||||
|
||||
def __init__(self, max_requests: int = 100, time_window: int = 3600):
|
||||
"""Initialize rate limiter
|
||||
|
||||
Args:
|
||||
max_requests: Maximum requests within time window
|
||||
time_window: Time window in seconds
|
||||
"""
|
||||
self.max_requests = max_requests
|
||||
self.time_window = time_window
|
||||
# Storage format: {user_id: [timestamp1, timestamp2, ...]}
|
||||
self._request_history: Dict[str, list] = defaultdict(list)
|
||||
|
||||
logger.info(f"RateLimiter initialized: {max_requests} requests per {time_window} seconds")
|
||||
|
||||
async def is_allowed(self, user_id: str) -> bool:
|
||||
"""Check if user is allowed to make request
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
True if allowed, False otherwise
|
||||
"""
|
||||
current_time = time.time()
|
||||
user_requests = self._request_history[user_id]
|
||||
|
||||
# Clean up expired request records
|
||||
cutoff_time = current_time - self.time_window
|
||||
user_requests[:] = [t for t in user_requests if t > cutoff_time]
|
||||
|
||||
# Check if limit exceeded
|
||||
if len(user_requests) >= self.max_requests:
|
||||
logger.warning(f"Rate limit exceeded for user {user_id}")
|
||||
return False
|
||||
|
||||
# Record current request
|
||||
user_requests.append(current_time)
|
||||
return True
|
||||
|
||||
async def get_usage(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get user usage information
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Usage statistics
|
||||
"""
|
||||
current_time = time.time()
|
||||
user_requests = self._request_history[user_id]
|
||||
|
||||
# Clean up expired records
|
||||
cutoff_time = current_time - self.time_window
|
||||
active_requests = [t for t in user_requests if t > cutoff_time]
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"requests_in_window": len(active_requests),
|
||||
"max_requests": self.max_requests,
|
||||
"time_window": self.time_window,
|
||||
"remaining_requests": max(0, self.max_requests - len(active_requests))
|
||||
}
|
||||
|
||||
|
||||
class TokenValidator:
|
||||
"""JWT Token Validator
|
||||
|
||||
Provides comprehensive JWT token validation functionality, including signature verification,
|
||||
claim validation, blacklist checking and rate limiting
|
||||
"""
|
||||
|
||||
def __init__(self, config, blacklist: Optional[TokenBlacklist] = None):
|
||||
"""Initialize token validator
|
||||
|
||||
Args:
|
||||
config: DorisConfig configuration object (with security attribute)
|
||||
blacklist: Token blacklist manager
|
||||
"""
|
||||
self.config = config
|
||||
self.blacklist = blacklist or TokenBlacklist()
|
||||
self.rate_limiter = RateLimiter()
|
||||
|
||||
# Access JWT settings through the security configuration
|
||||
if hasattr(config, 'security'):
|
||||
security_config = config.security
|
||||
else:
|
||||
# Fallback if config is passed directly as SecurityConfig
|
||||
security_config = config
|
||||
|
||||
# Validation options
|
||||
self.verify_signature = security_config.jwt_verify_signature
|
||||
self.verify_audience = security_config.jwt_verify_audience
|
||||
self.verify_issuer = security_config.jwt_verify_issuer
|
||||
self.require_exp = security_config.jwt_require_exp
|
||||
self.require_iat = security_config.jwt_require_iat
|
||||
self.require_nbf = security_config.jwt_require_nbf
|
||||
self.leeway = security_config.jwt_leeway
|
||||
|
||||
# Expected values
|
||||
self.expected_audience = security_config.jwt_audience
|
||||
self.expected_issuer = security_config.jwt_issuer
|
||||
|
||||
logger.info("TokenValidator initialized")
|
||||
|
||||
async def validate_claims(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate JWT claims
|
||||
|
||||
Args:
|
||||
payload: JWT payload
|
||||
|
||||
Returns:
|
||||
Validation result
|
||||
|
||||
Raises:
|
||||
ValueError: Validation failed
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# Validate issuer
|
||||
if self.verify_issuer:
|
||||
if payload.get('iss') != self.expected_issuer:
|
||||
raise ValueError(f"Invalid issuer: expected {self.expected_issuer}")
|
||||
|
||||
# Validate audience
|
||||
if self.verify_audience:
|
||||
aud = payload.get('aud')
|
||||
if isinstance(aud, list):
|
||||
if self.expected_audience not in aud:
|
||||
raise ValueError(f"Invalid audience: {self.expected_audience} not in {aud}")
|
||||
elif aud != self.expected_audience:
|
||||
raise ValueError(f"Invalid audience: expected {self.expected_audience}")
|
||||
|
||||
# Validate expiration time
|
||||
if self.require_exp or 'exp' in payload:
|
||||
exp = payload.get('exp')
|
||||
if not exp:
|
||||
raise ValueError("Missing 'exp' claim")
|
||||
if current_time > exp + self.leeway:
|
||||
raise ValueError("Token has expired")
|
||||
|
||||
# Validate not before time
|
||||
if self.require_nbf or 'nbf' in payload:
|
||||
nbf = payload.get('nbf')
|
||||
if not nbf:
|
||||
raise ValueError("Missing 'nbf' claim")
|
||||
if current_time < nbf - self.leeway:
|
||||
raise ValueError("Token not yet valid")
|
||||
|
||||
# Validate issued at time
|
||||
if self.require_iat or 'iat' in payload:
|
||||
iat = payload.get('iat')
|
||||
if not iat:
|
||||
raise ValueError("Missing 'iat' claim")
|
||||
# Allow some clock skew, but cannot be future time
|
||||
if iat > current_time + self.leeway:
|
||||
raise ValueError("Token issued in the future")
|
||||
|
||||
# Check blacklist
|
||||
jti = payload.get('jti')
|
||||
if jti and await self.blacklist.is_blacklisted(jti):
|
||||
raise ValueError("Token has been revoked")
|
||||
|
||||
# Rate limit check
|
||||
user_id = payload.get('sub')
|
||||
if user_id:
|
||||
if not await self.rate_limiter.is_allowed(user_id):
|
||||
raise ValueError("Rate limit exceeded")
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"user_id": user_id,
|
||||
"payload": payload
|
||||
}
|
||||
|
||||
async def start(self):
|
||||
"""Start validator"""
|
||||
await self.blacklist.start()
|
||||
logger.info("TokenValidator started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop validator"""
|
||||
await self.blacklist.stop()
|
||||
logger.info("TokenValidator stopped")
|
||||
|
||||
async def revoke_token(self, jti: str, exp: float):
|
||||
"""Revoke token
|
||||
|
||||
Args:
|
||||
jti: JWT ID
|
||||
exp: Token expiration time
|
||||
"""
|
||||
await self.blacklist.add_token(jti, exp)
|
||||
logger.info(f"Token {jti} has been revoked")
|
||||
|
||||
async def get_validation_stats(self) -> Dict[str, Any]:
|
||||
"""Get validation statistics"""
|
||||
blacklist_stats = await self.blacklist.get_stats()
|
||||
|
||||
return {
|
||||
"blacklist": blacklist_stats,
|
||||
"validation_config": {
|
||||
"verify_signature": self.verify_signature,
|
||||
"verify_audience": self.verify_audience,
|
||||
"verify_issuer": self.verify_issuer,
|
||||
"require_exp": self.require_exp,
|
||||
"require_iat": self.require_iat,
|
||||
"require_nbf": self.require_nbf,
|
||||
"leeway": self.leeway
|
||||
}
|
||||
}
|
||||
|
||||
async def get_user_rate_limit_info(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get user rate limit information"""
|
||||
return await self.rate_limiter.get_usage(user_id)
|
||||
Reference in New Issue
Block a user