[Performance]Add complete Token, JWT, OAuth authentication system (#52)
* 0.5.1 Version * fix 0.5.1 schema async bug * fix security bug * fix security bug * Add complete Token, JWT, OAuth authentication system * Add complete Token, JWT, OAuth authentication system * Add complete Token, JWT, OAuth authentication system * Add complete Token, JWT, OAuth authentication system
This commit is contained in:
536
doris_mcp_server/auth/oauth_client.py
Normal file
536
doris_mcp_server/auth/oauth_client.py
Normal file
@@ -0,0 +1,536 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
OAuth 2.0/OIDC Client Manager
|
||||
Provides OAuth authentication client implementation with PKCE and OIDC support
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional, Any, Tuple
|
||||
from urllib.parse import urlencode, parse_qs, urlparse
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError:
|
||||
raise ImportError("aiohttp is required for OAuth functionality. Install with: pip install aiohttp")
|
||||
|
||||
from .oauth_types import (
|
||||
OAuthProvider, OAuthState, OAuthTokens, OAuthUserInfo,
|
||||
OIDCDiscovery, OAuthError, OAuthProviderConfig, OAUTH_PROVIDERS
|
||||
)
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OAuthStateManager:
|
||||
"""Manages OAuth state parameters for CSRF protection"""
|
||||
|
||||
def __init__(self, state_expiry: int = 600):
|
||||
"""Initialize state manager
|
||||
|
||||
Args:
|
||||
state_expiry: State expiry time in seconds
|
||||
"""
|
||||
self.state_expiry = state_expiry
|
||||
self._states: Dict[str, OAuthState] = {}
|
||||
self._cleanup_task = None
|
||||
|
||||
logger.info("OAuthStateManager initialized")
|
||||
|
||||
async def start(self):
|
||||
"""Start periodic cleanup task"""
|
||||
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||
logger.info("OAuth state manager started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop periodic cleanup task"""
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("OAuth state manager stopped")
|
||||
|
||||
def create_state(self, redirect_uri: str, pkce_enabled: bool = True,
|
||||
nonce_enabled: bool = True) -> OAuthState:
|
||||
"""Create new OAuth state
|
||||
|
||||
Args:
|
||||
redirect_uri: OAuth redirect URI
|
||||
pkce_enabled: Whether to enable PKCE
|
||||
nonce_enabled: Whether to enable nonce (for OIDC)
|
||||
|
||||
Returns:
|
||||
OAuth state object
|
||||
"""
|
||||
state = secrets.token_urlsafe(32)
|
||||
nonce = secrets.token_urlsafe(32) if nonce_enabled else None
|
||||
|
||||
pkce_verifier = None
|
||||
pkce_challenge = None
|
||||
if pkce_enabled:
|
||||
pkce_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=')
|
||||
challenge_bytes = hashlib.sha256(pkce_verifier.encode()).digest()
|
||||
pkce_challenge = base64.urlsafe_b64encode(challenge_bytes).decode('utf-8').rstrip('=')
|
||||
|
||||
oauth_state = OAuthState(
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
pkce_verifier=pkce_verifier,
|
||||
pkce_challenge=pkce_challenge,
|
||||
redirect_uri=redirect_uri,
|
||||
created_at=datetime.utcnow(),
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=self.state_expiry)
|
||||
)
|
||||
|
||||
self._states[state] = oauth_state
|
||||
logger.debug(f"Created OAuth state: {state}")
|
||||
return oauth_state
|
||||
|
||||
def get_state(self, state: str) -> Optional[OAuthState]:
|
||||
"""Get OAuth state by state parameter
|
||||
|
||||
Args:
|
||||
state: State parameter
|
||||
|
||||
Returns:
|
||||
OAuth state object or None if not found/expired
|
||||
"""
|
||||
oauth_state = self._states.get(state)
|
||||
if oauth_state and oauth_state.expires_at > datetime.utcnow():
|
||||
return oauth_state
|
||||
elif oauth_state:
|
||||
# Remove expired state
|
||||
del self._states[state]
|
||||
logger.debug(f"Removed expired OAuth state: {state}")
|
||||
return None
|
||||
|
||||
def consume_state(self, state: str) -> Optional[OAuthState]:
|
||||
"""Get and remove OAuth state
|
||||
|
||||
Args:
|
||||
state: State parameter
|
||||
|
||||
Returns:
|
||||
OAuth state object or None if not found/expired
|
||||
"""
|
||||
oauth_state = self.get_state(state)
|
||||
if oauth_state:
|
||||
del self._states[state]
|
||||
logger.debug(f"Consumed OAuth state: {state}")
|
||||
return oauth_state
|
||||
|
||||
async def _periodic_cleanup(self):
|
||||
"""Periodic cleanup of expired states"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(300) # Clean up every 5 minutes
|
||||
current_time = datetime.utcnow()
|
||||
expired_states = [
|
||||
state for state, oauth_state in self._states.items()
|
||||
if oauth_state.expires_at <= current_time
|
||||
]
|
||||
|
||||
for state in expired_states:
|
||||
del self._states[state]
|
||||
|
||||
if expired_states:
|
||||
logger.info(f"Cleaned up {len(expired_states)} expired OAuth states")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error during OAuth state cleanup: {e}")
|
||||
|
||||
|
||||
class OAuthClient:
|
||||
"""OAuth 2.0/OIDC Client implementation"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize OAuth client
|
||||
|
||||
Args:
|
||||
config: DorisConfig with OAuth configuration
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
# Access OAuth settings through security configuration
|
||||
if hasattr(config, 'security'):
|
||||
security_config = config.security
|
||||
else:
|
||||
security_config = config
|
||||
|
||||
self.enabled = security_config.oauth_enabled
|
||||
if not self.enabled:
|
||||
logger.info("OAuth client disabled by configuration")
|
||||
return
|
||||
|
||||
# Build provider configuration
|
||||
self.provider_config = self._build_provider_config(security_config)
|
||||
self.state_manager = OAuthStateManager(security_config.oauth_state_expiry)
|
||||
|
||||
# HTTP client session
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
# Discovery cache
|
||||
self._discovery_cache: Optional[OIDCDiscovery] = None
|
||||
self._discovery_cache_time: Optional[datetime] = None
|
||||
|
||||
logger.info(f"OAuthClient initialized for provider: {self.provider_config.provider.value}")
|
||||
|
||||
def _build_provider_config(self, security_config) -> OAuthProviderConfig:
|
||||
"""Build OAuth provider configuration
|
||||
|
||||
Args:
|
||||
security_config: Security configuration object
|
||||
|
||||
Returns:
|
||||
OAuth provider configuration
|
||||
"""
|
||||
try:
|
||||
provider = OAuthProvider(security_config.oauth_provider)
|
||||
except ValueError:
|
||||
provider = OAuthProvider.CUSTOM
|
||||
|
||||
# Get default configuration for known providers
|
||||
defaults = OAUTH_PROVIDERS.get(provider, {})
|
||||
|
||||
return OAuthProviderConfig(
|
||||
provider=provider,
|
||||
client_id=security_config.oauth_client_id,
|
||||
client_secret=security_config.oauth_client_secret,
|
||||
redirect_uri=security_config.oauth_redirect_uri,
|
||||
scopes=security_config.oauth_scopes or defaults.get("scopes", ["openid", "email", "profile"]),
|
||||
|
||||
# Endpoints (use configured or defaults)
|
||||
authorization_endpoint=security_config.oauth_authorization_endpoint or defaults.get("authorization_endpoint", ""),
|
||||
token_endpoint=security_config.oauth_token_endpoint or defaults.get("token_endpoint", ""),
|
||||
userinfo_endpoint=security_config.oauth_userinfo_endpoint or defaults.get("userinfo_endpoint"),
|
||||
jwks_uri=security_config.oauth_jwks_uri or defaults.get("jwks_uri"),
|
||||
|
||||
# Discovery
|
||||
discovery_url=security_config.oidc_discovery_url or defaults.get("discovery_url"),
|
||||
|
||||
# Settings
|
||||
pkce_enabled=security_config.oauth_pkce_enabled,
|
||||
nonce_enabled=security_config.oauth_nonce_enabled,
|
||||
|
||||
# User mapping
|
||||
user_id_claim=security_config.oauth_user_id_claim or defaults.get("user_id_claim", "sub"),
|
||||
email_claim=security_config.oauth_email_claim or defaults.get("email_claim", "email"),
|
||||
name_claim=security_config.oauth_name_claim or defaults.get("name_claim", "name"),
|
||||
roles_claim=security_config.oauth_roles_claim,
|
||||
default_roles=security_config.oauth_default_roles
|
||||
)
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize OAuth client
|
||||
|
||||
Returns:
|
||||
True if initialization successful
|
||||
"""
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
try:
|
||||
# Create HTTP session
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# Start state manager
|
||||
await self.state_manager.start()
|
||||
|
||||
# Perform OIDC discovery if configured
|
||||
if self.provider_config.discovery_url:
|
||||
await self._discover_oidc_endpoints()
|
||||
|
||||
logger.info("OAuth client initialization completed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize OAuth client: {e}")
|
||||
return False
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown OAuth client"""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
try:
|
||||
# Stop state manager
|
||||
await self.state_manager.stop()
|
||||
|
||||
# Close HTTP session
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
logger.info("OAuth client shutdown completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during OAuth client shutdown: {e}")
|
||||
|
||||
async def _discover_oidc_endpoints(self):
|
||||
"""Discover OIDC endpoints using discovery URL"""
|
||||
try:
|
||||
# Check cache first
|
||||
if (self._discovery_cache and self._discovery_cache_time and
|
||||
datetime.utcnow() - self._discovery_cache_time < timedelta(hours=1)):
|
||||
return self._discovery_cache
|
||||
|
||||
logger.info(f"Discovering OIDC endpoints: {self.provider_config.discovery_url}")
|
||||
|
||||
async with self._session.get(self.provider_config.discovery_url) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
|
||||
discovery = OIDCDiscovery(
|
||||
issuer=data["issuer"],
|
||||
authorization_endpoint=data["authorization_endpoint"],
|
||||
token_endpoint=data["token_endpoint"],
|
||||
userinfo_endpoint=data.get("userinfo_endpoint"),
|
||||
jwks_uri=data.get("jwks_uri"),
|
||||
scopes_supported=data.get("scopes_supported"),
|
||||
response_types_supported=data.get("response_types_supported"),
|
||||
subject_types_supported=data.get("subject_types_supported"),
|
||||
id_token_signing_alg_values_supported=data.get("id_token_signing_alg_values_supported")
|
||||
)
|
||||
|
||||
# Update provider configuration with discovered endpoints
|
||||
if not self.provider_config.authorization_endpoint:
|
||||
self.provider_config.authorization_endpoint = discovery.authorization_endpoint
|
||||
if not self.provider_config.token_endpoint:
|
||||
self.provider_config.token_endpoint = discovery.token_endpoint
|
||||
if not self.provider_config.userinfo_endpoint:
|
||||
self.provider_config.userinfo_endpoint = discovery.userinfo_endpoint
|
||||
if not self.provider_config.jwks_uri:
|
||||
self.provider_config.jwks_uri = discovery.jwks_uri
|
||||
|
||||
# Cache discovery result
|
||||
self._discovery_cache = discovery
|
||||
self._discovery_cache_time = datetime.utcnow()
|
||||
|
||||
logger.info("OIDC endpoint discovery completed successfully")
|
||||
return discovery
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OIDC endpoint discovery failed: {e}")
|
||||
raise
|
||||
|
||||
def build_authorization_url(self) -> Tuple[str, OAuthState]:
|
||||
"""Build OAuth authorization URL
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, oauth_state)
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth client is not enabled")
|
||||
|
||||
# Create state for CSRF protection
|
||||
oauth_state = self.state_manager.create_state(
|
||||
redirect_uri=self.provider_config.redirect_uri,
|
||||
pkce_enabled=self.provider_config.pkce_enabled,
|
||||
nonce_enabled=self.provider_config.nonce_enabled
|
||||
)
|
||||
|
||||
# Build authorization parameters
|
||||
params = {
|
||||
'response_type': 'code',
|
||||
'client_id': self.provider_config.client_id,
|
||||
'redirect_uri': self.provider_config.redirect_uri,
|
||||
'scope': ' '.join(self.provider_config.scopes),
|
||||
'state': oauth_state.state
|
||||
}
|
||||
|
||||
# Add PKCE challenge
|
||||
if oauth_state.pkce_challenge:
|
||||
params['code_challenge'] = oauth_state.pkce_challenge
|
||||
params['code_challenge_method'] = 'S256'
|
||||
|
||||
# Add nonce for OIDC
|
||||
if oauth_state.nonce:
|
||||
params['nonce'] = oauth_state.nonce
|
||||
|
||||
# Build URL
|
||||
authorization_url = f"{self.provider_config.authorization_endpoint}?{urlencode(params)}"
|
||||
|
||||
logger.info(f"Built OAuth authorization URL for state: {oauth_state.state}")
|
||||
return authorization_url, oauth_state
|
||||
|
||||
async def exchange_code_for_tokens(self, code: str, state: str) -> Tuple[OAuthTokens, OAuthState]:
|
||||
"""Exchange authorization code for tokens
|
||||
|
||||
Args:
|
||||
code: Authorization code
|
||||
state: State parameter
|
||||
|
||||
Returns:
|
||||
Tuple of (OAuth tokens, OAuth state)
|
||||
|
||||
Raises:
|
||||
ValueError: If state is invalid or exchange fails
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth client is not enabled")
|
||||
|
||||
# Validate and consume state
|
||||
oauth_state = self.state_manager.consume_state(state)
|
||||
if not oauth_state:
|
||||
raise ValueError("Invalid or expired state parameter")
|
||||
|
||||
try:
|
||||
# Prepare token request
|
||||
data = {
|
||||
'grant_type': 'authorization_code',
|
||||
'client_id': self.provider_config.client_id,
|
||||
'client_secret': self.provider_config.client_secret,
|
||||
'code': code,
|
||||
'redirect_uri': oauth_state.redirect_uri
|
||||
}
|
||||
|
||||
# Add PKCE verifier
|
||||
if oauth_state.pkce_verifier:
|
||||
data['code_verifier'] = oauth_state.pkce_verifier
|
||||
|
||||
# Make token request
|
||||
async with self._session.post(
|
||||
self.provider_config.token_endpoint,
|
||||
data=data,
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded'}
|
||||
) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
if response.status != 200:
|
||||
error_msg = response_data.get('error_description', response_data.get('error', 'Token exchange failed'))
|
||||
raise ValueError(f"Token exchange failed: {error_msg}")
|
||||
|
||||
tokens = OAuthTokens(
|
||||
access_token=response_data['access_token'],
|
||||
token_type=response_data.get('token_type', 'Bearer'),
|
||||
expires_in=response_data.get('expires_in'),
|
||||
refresh_token=response_data.get('refresh_token'),
|
||||
scope=response_data.get('scope'),
|
||||
id_token=response_data.get('id_token')
|
||||
)
|
||||
|
||||
logger.info("Successfully exchanged authorization code for tokens")
|
||||
return tokens, oauth_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token exchange failed: {e}")
|
||||
raise ValueError(f"Token exchange failed: {str(e)}")
|
||||
|
||||
async def get_user_info(self, tokens: OAuthTokens) -> OAuthUserInfo:
|
||||
"""Get user information from OAuth provider
|
||||
|
||||
Args:
|
||||
tokens: OAuth tokens
|
||||
|
||||
Returns:
|
||||
OAuth user information
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth client is not enabled")
|
||||
|
||||
if not self.provider_config.userinfo_endpoint:
|
||||
raise ValueError("Userinfo endpoint not configured")
|
||||
|
||||
try:
|
||||
# Make userinfo request
|
||||
headers = {'Authorization': f'{tokens.token_type} {tokens.access_token}'}
|
||||
|
||||
async with self._session.get(
|
||||
self.provider_config.userinfo_endpoint,
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
user_data = await response.json()
|
||||
|
||||
# Extract user information using configured claims
|
||||
user_info = OAuthUserInfo(
|
||||
sub=str(user_data.get(self.provider_config.user_id_claim, '')),
|
||||
email=user_data.get(self.provider_config.email_claim),
|
||||
name=user_data.get(self.provider_config.name_claim),
|
||||
given_name=user_data.get('given_name'),
|
||||
family_name=user_data.get('family_name'),
|
||||
picture=user_data.get('picture'),
|
||||
locale=user_data.get('locale'),
|
||||
email_verified=user_data.get('email_verified'),
|
||||
roles=user_data.get(self.provider_config.roles_claim, self.provider_config.default_roles.copy()),
|
||||
raw_claims=user_data
|
||||
)
|
||||
|
||||
logger.info(f"Retrieved user info for user: {user_info.sub}")
|
||||
return user_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get user info: {e}")
|
||||
raise ValueError(f"Failed to get user info: {str(e)}")
|
||||
|
||||
async def refresh_tokens(self, refresh_token: str) -> OAuthTokens:
|
||||
"""Refresh OAuth tokens
|
||||
|
||||
Args:
|
||||
refresh_token: Refresh token
|
||||
|
||||
Returns:
|
||||
New OAuth tokens
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth client is not enabled")
|
||||
|
||||
try:
|
||||
data = {
|
||||
'grant_type': 'refresh_token',
|
||||
'client_id': self.provider_config.client_id,
|
||||
'client_secret': self.provider_config.client_secret,
|
||||
'refresh_token': refresh_token
|
||||
}
|
||||
|
||||
async with self._session.post(
|
||||
self.provider_config.token_endpoint,
|
||||
data=data,
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded'}
|
||||
) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
if response.status != 200:
|
||||
error_msg = response_data.get('error_description', response_data.get('error', 'Token refresh failed'))
|
||||
raise ValueError(f"Token refresh failed: {error_msg}")
|
||||
|
||||
tokens = OAuthTokens(
|
||||
access_token=response_data['access_token'],
|
||||
token_type=response_data.get('token_type', 'Bearer'),
|
||||
expires_in=response_data.get('expires_in'),
|
||||
refresh_token=response_data.get('refresh_token', refresh_token), # Keep old if not provided
|
||||
scope=response_data.get('scope'),
|
||||
id_token=response_data.get('id_token')
|
||||
)
|
||||
|
||||
logger.info("Successfully refreshed OAuth tokens")
|
||||
return tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token refresh failed: {e}")
|
||||
raise ValueError(f"Token refresh failed: {str(e)}")
|
||||
Reference in New Issue
Block a user