343 lines
12 KiB
Python
343 lines
12 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||
|
|
# or more contributor license agreements. See the NOTICE file
|
||
|
|
# distributed with this work for additional information
|
||
|
|
# regarding copyright ownership. The ASF licenses this file
|
||
|
|
# to you under the Apache License, Version 2.0 (the
|
||
|
|
# "License"); you may not use this file except in compliance
|
||
|
|
# with the License. You may obtain a copy of the License at
|
||
|
|
#
|
||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
#
|
||
|
|
# Unless required by applicable law or agreed to in writing,
|
||
|
|
# software distributed under the License is distributed on an
|
||
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||
|
|
# KIND, either express or implied. See the License for the
|
||
|
|
# specific language governing permissions and limitations
|
||
|
|
# under the License.
|
||
|
|
"""
|
||
|
|
JWT Key Management Module
|
||
|
|
Provides secure key generation, loading, rotation and management for JWT tokens
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import time
|
||
|
|
import secrets
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Optional, Tuple, Union
|
||
|
|
from datetime import datetime, timedelta
|
||
|
|
from cryptography.hazmat.primitives import serialization
|
||
|
|
from cryptography.hazmat.primitives.asymmetric import rsa, ec
|
||
|
|
from cryptography.hazmat.backends import default_backend
|
||
|
|
|
||
|
|
from ..utils.logger import get_logger
|
||
|
|
|
||
|
|
logger = get_logger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class KeyManager:
|
||
|
|
"""JWT Key Manager
|
||
|
|
|
||
|
|
Responsible for generating, loading, rotating and securely storing JWT signing keys
|
||
|
|
Supports RSA and EC algorithms, provides automatic key rotation functionality
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, config):
|
||
|
|
"""Initialize key manager
|
||
|
|
|
||
|
|
Args:
|
||
|
|
config: DorisConfig configuration object (with security attribute)
|
||
|
|
"""
|
||
|
|
self.config = config
|
||
|
|
# Access JWT settings through the security configuration
|
||
|
|
if hasattr(config, 'security'):
|
||
|
|
security_config = config.security
|
||
|
|
else:
|
||
|
|
# Fallback if config is passed directly as SecurityConfig
|
||
|
|
security_config = config
|
||
|
|
|
||
|
|
self.algorithm = security_config.jwt_algorithm
|
||
|
|
self.key_rotation_interval = security_config.key_rotation_interval
|
||
|
|
self.private_key_path = security_config.jwt_private_key_path
|
||
|
|
self.public_key_path = security_config.jwt_public_key_path
|
||
|
|
self.secret_key = security_config.jwt_secret_key
|
||
|
|
|
||
|
|
# Key storage
|
||
|
|
self._private_key = None
|
||
|
|
self._public_key = None
|
||
|
|
self._secret_key = None
|
||
|
|
self._key_generated_at = None
|
||
|
|
|
||
|
|
logger.info(f"KeyManager initialized with algorithm: {self.algorithm}")
|
||
|
|
|
||
|
|
async def initialize(self) -> bool:
|
||
|
|
"""Initialize key manager, load or generate keys"""
|
||
|
|
try:
|
||
|
|
if self.algorithm == "HS256":
|
||
|
|
await self._initialize_symmetric_key()
|
||
|
|
else:
|
||
|
|
await self._initialize_asymmetric_keys()
|
||
|
|
|
||
|
|
logger.info("KeyManager initialization completed")
|
||
|
|
return True
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to initialize KeyManager: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def _initialize_symmetric_key(self):
|
||
|
|
"""Initialize symmetric key (HS256)"""
|
||
|
|
if self.secret_key:
|
||
|
|
# Use configured key
|
||
|
|
self._secret_key = self.secret_key.encode()
|
||
|
|
logger.info("Loaded symmetric key from configuration")
|
||
|
|
else:
|
||
|
|
# Generate new key
|
||
|
|
self._secret_key = await self.generate_symmetric_key()
|
||
|
|
logger.info("Generated new symmetric key")
|
||
|
|
|
||
|
|
self._key_generated_at = datetime.utcnow()
|
||
|
|
|
||
|
|
async def _initialize_asymmetric_keys(self):
|
||
|
|
"""Initialize asymmetric key pair (RS256/ES256)"""
|
||
|
|
# Try to load keys from files
|
||
|
|
if await self._load_keys_from_files():
|
||
|
|
logger.info("Loaded asymmetric keys from files")
|
||
|
|
return
|
||
|
|
|
||
|
|
# Try to load from environment variables
|
||
|
|
if await self._load_keys_from_env():
|
||
|
|
logger.info("Loaded asymmetric keys from environment")
|
||
|
|
return
|
||
|
|
|
||
|
|
# Generate new key pair
|
||
|
|
await self.generate_key_pair()
|
||
|
|
logger.info("Generated new asymmetric key pair")
|
||
|
|
|
||
|
|
async def _load_keys_from_files(self) -> bool:
|
||
|
|
"""Load keys from files"""
|
||
|
|
try:
|
||
|
|
if not self.private_key_path or not self.public_key_path:
|
||
|
|
return False
|
||
|
|
|
||
|
|
private_path = Path(self.private_key_path)
|
||
|
|
public_path = Path(self.public_key_path)
|
||
|
|
|
||
|
|
if not (private_path.exists() and public_path.exists()):
|
||
|
|
return False
|
||
|
|
|
||
|
|
# Read private key
|
||
|
|
with open(private_path, 'rb') as f:
|
||
|
|
private_key_data = f.read()
|
||
|
|
self._private_key = serialization.load_pem_private_key(
|
||
|
|
private_key_data, password=None, backend=default_backend()
|
||
|
|
)
|
||
|
|
|
||
|
|
# Read public key
|
||
|
|
with open(public_path, 'rb') as f:
|
||
|
|
public_key_data = f.read()
|
||
|
|
self._public_key = serialization.load_pem_public_key(
|
||
|
|
public_key_data, backend=default_backend()
|
||
|
|
)
|
||
|
|
|
||
|
|
# Get key generation time (using file modification time)
|
||
|
|
self._key_generated_at = datetime.fromtimestamp(private_path.stat().st_mtime)
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to load keys from files: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def _load_keys_from_env(self) -> bool:
|
||
|
|
"""Load keys from environment variables"""
|
||
|
|
try:
|
||
|
|
private_key_env = os.getenv('JWT_PRIVATE_KEY')
|
||
|
|
public_key_env = os.getenv('JWT_PUBLIC_KEY')
|
||
|
|
|
||
|
|
if not (private_key_env and public_key_env):
|
||
|
|
return False
|
||
|
|
|
||
|
|
# Parse private key
|
||
|
|
self._private_key = serialization.load_pem_private_key(
|
||
|
|
private_key_env.encode(), password=None, backend=default_backend()
|
||
|
|
)
|
||
|
|
|
||
|
|
# Parse public key
|
||
|
|
self._public_key = serialization.load_pem_public_key(
|
||
|
|
public_key_env.encode(), backend=default_backend()
|
||
|
|
)
|
||
|
|
|
||
|
|
self._key_generated_at = datetime.utcnow()
|
||
|
|
return True
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to load keys from environment: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def generate_symmetric_key(self, length: int = 32) -> bytes:
|
||
|
|
"""Generate symmetric key
|
||
|
|
|
||
|
|
Args:
|
||
|
|
length: Key length (bytes), default 32 bytes (256 bits)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Generated key
|
||
|
|
"""
|
||
|
|
return secrets.token_bytes(length)
|
||
|
|
|
||
|
|
async def generate_key_pair(self) -> Tuple[bytes, bytes]:
|
||
|
|
"""Generate asymmetric key pair
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
(private key PEM, public key PEM) tuple
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
if self.algorithm == "RS256":
|
||
|
|
private_key = rsa.generate_private_key(
|
||
|
|
public_exponent=65537,
|
||
|
|
key_size=2048,
|
||
|
|
backend=default_backend()
|
||
|
|
)
|
||
|
|
elif self.algorithm == "ES256":
|
||
|
|
private_key = ec.generate_private_key(
|
||
|
|
ec.SECP256R1(), backend=default_backend()
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
raise ValueError(f"Unsupported algorithm for key generation: {self.algorithm}")
|
||
|
|
|
||
|
|
# Get public key
|
||
|
|
public_key = private_key.public_key()
|
||
|
|
|
||
|
|
# Serialize private key
|
||
|
|
private_pem = private_key.private_bytes(
|
||
|
|
encoding=serialization.Encoding.PEM,
|
||
|
|
format=serialization.PrivateFormat.PKCS8,
|
||
|
|
encryption_algorithm=serialization.NoEncryption()
|
||
|
|
)
|
||
|
|
|
||
|
|
# Serialize public key
|
||
|
|
public_pem = public_key.public_bytes(
|
||
|
|
encoding=serialization.Encoding.PEM,
|
||
|
|
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
||
|
|
)
|
||
|
|
|
||
|
|
# Store keys
|
||
|
|
self._private_key = private_key
|
||
|
|
self._public_key = public_key
|
||
|
|
self._key_generated_at = datetime.utcnow()
|
||
|
|
|
||
|
|
# If file paths are configured, save to files
|
||
|
|
if self.private_key_path and self.public_key_path:
|
||
|
|
await self._save_keys_to_files(private_pem, public_pem)
|
||
|
|
|
||
|
|
logger.info(f"Generated new {self.algorithm} key pair")
|
||
|
|
return private_pem, public_pem
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to generate key pair: {e}")
|
||
|
|
raise
|
||
|
|
|
||
|
|
async def _save_keys_to_files(self, private_pem: bytes, public_pem: bytes):
|
||
|
|
"""Save keys to files"""
|
||
|
|
try:
|
||
|
|
# Ensure directories exist
|
||
|
|
private_path = Path(self.private_key_path)
|
||
|
|
public_path = Path(self.public_key_path)
|
||
|
|
|
||
|
|
private_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
public_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
# Save private key (set secure permissions)
|
||
|
|
with open(private_path, 'wb') as f:
|
||
|
|
f.write(private_pem)
|
||
|
|
os.chmod(private_path, 0o600) # Only owner can read/write
|
||
|
|
|
||
|
|
# Save public key
|
||
|
|
with open(public_path, 'wb') as f:
|
||
|
|
f.write(public_pem)
|
||
|
|
os.chmod(public_path, 0o644) # Owner read/write, others read only
|
||
|
|
|
||
|
|
logger.info(f"Saved keys to files: {private_path}, {public_path}")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to save keys to files: {e}")
|
||
|
|
raise
|
||
|
|
|
||
|
|
def get_private_key(self):
|
||
|
|
"""Get private key for signing"""
|
||
|
|
if self.algorithm == "HS256":
|
||
|
|
return self._secret_key
|
||
|
|
else:
|
||
|
|
return self._private_key
|
||
|
|
|
||
|
|
def get_public_key(self):
|
||
|
|
"""Get public key for verification"""
|
||
|
|
if self.algorithm == "HS256":
|
||
|
|
return self._secret_key
|
||
|
|
else:
|
||
|
|
return self._public_key
|
||
|
|
|
||
|
|
def get_algorithm(self) -> str:
|
||
|
|
"""Get signing algorithm"""
|
||
|
|
return self.algorithm
|
||
|
|
|
||
|
|
async def is_key_expired(self) -> bool:
|
||
|
|
"""Check if key is expired"""
|
||
|
|
if not self._key_generated_at:
|
||
|
|
return True
|
||
|
|
|
||
|
|
expiry_time = self._key_generated_at + timedelta(seconds=self.key_rotation_interval)
|
||
|
|
return datetime.utcnow() > expiry_time
|
||
|
|
|
||
|
|
async def rotate_keys(self) -> bool:
|
||
|
|
"""Rotate keys"""
|
||
|
|
try:
|
||
|
|
logger.info("Starting key rotation")
|
||
|
|
|
||
|
|
if self.algorithm == "HS256":
|
||
|
|
# Generate new symmetric key
|
||
|
|
self._secret_key = await self.generate_symmetric_key()
|
||
|
|
self._key_generated_at = datetime.utcnow()
|
||
|
|
else:
|
||
|
|
# Generate new asymmetric key pair
|
||
|
|
await self.generate_key_pair()
|
||
|
|
|
||
|
|
logger.info("Key rotation completed successfully")
|
||
|
|
return True
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Key rotation failed: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def get_key_info(self) -> dict:
|
||
|
|
"""Get key information"""
|
||
|
|
return {
|
||
|
|
"algorithm": self.algorithm,
|
||
|
|
"key_generated_at": self._key_generated_at.isoformat() if self._key_generated_at else None,
|
||
|
|
"key_expires_at": (
|
||
|
|
self._key_generated_at + timedelta(seconds=self.key_rotation_interval)
|
||
|
|
).isoformat() if self._key_generated_at else None,
|
||
|
|
"is_expired": await self.is_key_expired(),
|
||
|
|
"has_private_key": self._private_key is not None or self._secret_key is not None,
|
||
|
|
"has_public_key": self._public_key is not None or self._secret_key is not None
|
||
|
|
}
|
||
|
|
|
||
|
|
async def export_public_key_pem(self) -> Optional[str]:
|
||
|
|
"""Export public key in PEM format"""
|
||
|
|
if self.algorithm == "HS256":
|
||
|
|
return None # Symmetric key not exported
|
||
|
|
|
||
|
|
if not self._public_key:
|
||
|
|
return None
|
||
|
|
|
||
|
|
try:
|
||
|
|
public_pem = self._public_key.public_bytes(
|
||
|
|
encoding=serialization.Encoding.PEM,
|
||
|
|
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
||
|
|
)
|
||
|
|
return public_pem.decode()
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to export public key: {e}")
|
||
|
|
return None
|