Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
067f160b3e | ||
|
|
9ba4cc6f45 | ||
|
|
f99399c6c7 | ||
|
|
c3d487ccdd | ||
|
|
c1e3b13851 | ||
|
|
5923cc1c89 | ||
|
|
9b5ac8533d | ||
|
|
cc84d605e5 | ||
|
|
55dbdd5e14 | ||
|
|
affa4a0319 | ||
|
|
ecb5db8137 | ||
|
|
5d15f6f3a4 | ||
|
|
6247d49192 | ||
|
|
fb5e864a24 | ||
|
|
9bb5b17199 | ||
|
|
6d3c128f54 | ||
|
|
651d524814 |
2
.dockerignore
Normal file
2
.dockerignore
Normal file
@@ -0,0 +1,2 @@
|
||||
**/.venv
|
||||
**/venv
|
||||
335
.env.example
335
.env.example
@@ -1,3 +1,19 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# ===================================================================
|
||||
# Doris MCP Server Environment Configuration Example
|
||||
# ===================================================================
|
||||
@@ -36,8 +52,174 @@ DORIS_MAX_CONNECTION_AGE=3600
|
||||
# Security Configuration
|
||||
# ===================================================================
|
||||
|
||||
# Authentication configuration
|
||||
# Independent Authentication Switches - NEW DESIGN!
|
||||
# Each authentication method can be enabled/disabled independently
|
||||
# Any enabled method that succeeds will allow access
|
||||
# If all methods are disabled, anonymous access is allowed
|
||||
|
||||
# Legacy configuration - kept for backward compatibility
|
||||
# AUTH_TYPE is now deprecated - use individual switches above
|
||||
AUTH_TYPE=token
|
||||
|
||||
# Token Authentication (Default method - simple and effective)
|
||||
ENABLE_TOKEN_AUTH=false
|
||||
|
||||
# JWT Authentication (For stateless applications)
|
||||
ENABLE_JWT_AUTH=false
|
||||
|
||||
# OAuth 2.0/OIDC Authentication (For enterprise integration)
|
||||
ENABLE_OAUTH_AUTH=false
|
||||
|
||||
# ===================================================================
|
||||
# Token Authentication Configuration (Enable with ENABLE_TOKEN_AUTH=true)
|
||||
# ===================================================================
|
||||
|
||||
# Basic token authentication settings
|
||||
TOKEN_FILE_PATH=tokens.json
|
||||
ENABLE_TOKEN_EXPIRY=true
|
||||
DEFAULT_TOKEN_EXPIRY_HOURS=720
|
||||
TOKEN_HASH_ALGORITHM=sha256
|
||||
|
||||
# ===================================================================
|
||||
# Token Management Security Configuration (NEW in v0.6.0) - CRITICAL SECURITY SETTINGS
|
||||
# ===================================================================
|
||||
|
||||
# HTTP Token Management Endpoints (DISABLED BY DEFAULT FOR SECURITY)
|
||||
# WARNING: These endpoints allow creation, deletion, and management of authentication tokens
|
||||
# Only enable if you need HTTP-based token management and understand the security implications
|
||||
ENABLE_HTTP_TOKEN_MANAGEMENT=true
|
||||
|
||||
# Admin Authentication Token (REQUIRED if HTTP token management is enabled)
|
||||
# This token is required to access HTTP token management endpoints
|
||||
# SECURITY: Generate a secure random token in production - NEVER use default values
|
||||
TOKEN_MANAGEMENT_ADMIN_TOKEN=
|
||||
|
||||
# IP Address Restrictions for Token Management (CRITICAL SECURITY CONTROL)
|
||||
# Only these IP addresses/networks can access token management endpoints
|
||||
# DEFAULT: localhost only (most secure) - add other IPs/networks only if necessary
|
||||
# Format: comma-separated list of IPs and CIDR networks
|
||||
# Examples:
|
||||
# - Localhost only: 127.0.0.1,::1
|
||||
# - Private network: 127.0.0.1,192.168.1.0/24,10.0.0.0/8
|
||||
# - Specific IPs: 127.0.0.1,192.168.1.10,192.168.1.11
|
||||
TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1
|
||||
|
||||
# Require Admin Authentication (ENABLED BY DEFAULT FOR SECURITY)
|
||||
# When true, all token management operations require valid admin token
|
||||
# When false, only IP restrictions apply (NOT RECOMMENDED for production)
|
||||
REQUIRE_ADMIN_AUTH=true
|
||||
|
||||
# ===================================================================
|
||||
# JWT Authentication Configuration (Enable with ENABLE_JWT_AUTH=true)
|
||||
# ===================================================================
|
||||
|
||||
# JWT token settings (when ENABLE_JWT_AUTH=true)
|
||||
JWT_SECRET_KEY=your_jwt_secret_key_here_change_in_production
|
||||
JWT_ALGORITHM=HS256
|
||||
JWT_EXPIRATION_HOURS=24
|
||||
JWT_ISSUER=doris-mcp-server
|
||||
JWT_AUDIENCE=doris-mcp-client
|
||||
|
||||
# JWT token validation settings
|
||||
JWT_VERIFY_SIGNATURE=true
|
||||
JWT_VERIFY_EXPIRATION=true
|
||||
JWT_VERIFY_AUDIENCE=true
|
||||
JWT_VERIFY_ISSUER=true
|
||||
|
||||
# JWT refresh token settings
|
||||
ENABLE_JWT_REFRESH=true
|
||||
JWT_REFRESH_EXPIRATION_DAYS=30
|
||||
JWT_REFRESH_SECRET_KEY=your_jwt_refresh_secret_key_here
|
||||
|
||||
# JWT user claims configuration
|
||||
JWT_USER_ID_CLAIM=user_id
|
||||
JWT_ROLES_CLAIM=roles
|
||||
JWT_PERMISSIONS_CLAIM=permissions
|
||||
JWT_SECURITY_LEVEL_CLAIM=security_level
|
||||
|
||||
# ===================================================================
|
||||
# OAuth 2.0 / OpenID Connect Configuration (Enable with ENABLE_OAUTH_AUTH=true)
|
||||
# ===================================================================
|
||||
|
||||
# OAuth provider settings (when ENABLE_OAUTH_AUTH=true)
|
||||
OAUTH_PROVIDER_TYPE=generic
|
||||
OAUTH_CLIENT_ID=your_oauth_client_id
|
||||
OAUTH_CLIENT_SECRET=your_oauth_client_secret
|
||||
OAUTH_REDIRECT_URI=http://localhost:3000/auth/callback
|
||||
|
||||
# OAuth endpoints (for generic provider)
|
||||
OAUTH_AUTHORIZATION_URL=https://your-provider.com/auth
|
||||
OAUTH_TOKEN_URL=https://your-provider.com/token
|
||||
OAUTH_USERINFO_URL=https://your-provider.com/userinfo
|
||||
OAUTH_JWKS_URL=https://your-provider.com/.well-known/jwks.json
|
||||
|
||||
# OAuth scope and claims
|
||||
OAUTH_SCOPE=openid profile email
|
||||
OAUTH_USER_ID_CLAIM=sub
|
||||
OAUTH_USERNAME_CLAIM=preferred_username
|
||||
OAUTH_EMAIL_CLAIM=email
|
||||
OAUTH_ROLES_CLAIM=roles
|
||||
OAUTH_GROUPS_CLAIM=groups
|
||||
|
||||
# OAuth session settings
|
||||
OAUTH_SESSION_SECRET=your_oauth_session_secret_here
|
||||
OAUTH_SESSION_EXPIRY=3600
|
||||
OAUTH_STATE_EXPIRY=300
|
||||
|
||||
# Popular OAuth providers presets (uncomment and configure as needed)
|
||||
|
||||
# Google OAuth Configuration
|
||||
# OAUTH_PROVIDER_TYPE=google
|
||||
# OAUTH_CLIENT_ID=your_google_client_id.apps.googleusercontent.com
|
||||
# OAUTH_CLIENT_SECRET=your_google_client_secret
|
||||
# OAUTH_AUTHORIZATION_URL=https://accounts.google.com/o/oauth2/auth
|
||||
# OAUTH_TOKEN_URL=https://oauth2.googleapis.com/token
|
||||
# OAUTH_USERINFO_URL=https://www.googleapis.com/oauth2/v1/userinfo
|
||||
# OAUTH_JWKS_URL=https://www.googleapis.com/oauth2/v3/certs
|
||||
# OAUTH_SCOPE=openid profile email
|
||||
|
||||
# Microsoft Azure AD Configuration
|
||||
# OAUTH_PROVIDER_TYPE=azure
|
||||
# OAUTH_CLIENT_ID=your_azure_client_id
|
||||
# OAUTH_CLIENT_SECRET=your_azure_client_secret
|
||||
# OAUTH_TENANT_ID=your_tenant_id
|
||||
# OAUTH_AUTHORIZATION_URL=https://login.microsoftonline.com/{tenant}/oauth2/v2.0/authorize
|
||||
# OAUTH_TOKEN_URL=https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token
|
||||
# OAUTH_USERINFO_URL=https://graph.microsoft.com/v1.0/me
|
||||
# OAUTH_JWKS_URL=https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys
|
||||
# OAUTH_SCOPE=openid profile email
|
||||
|
||||
# GitHub OAuth Configuration
|
||||
# OAUTH_PROVIDER_TYPE=github
|
||||
# OAUTH_CLIENT_ID=your_github_client_id
|
||||
# OAUTH_CLIENT_SECRET=your_github_client_secret
|
||||
# OAUTH_AUTHORIZATION_URL=https://github.com/login/oauth/authorize
|
||||
# OAUTH_TOKEN_URL=https://github.com/login/oauth/access_token
|
||||
# OAUTH_USERINFO_URL=https://api.github.com/user
|
||||
# OAUTH_SCOPE=user:email
|
||||
|
||||
# GitLab OAuth Configuration
|
||||
# OAUTH_PROVIDER_TYPE=gitlab
|
||||
# OAUTH_CLIENT_ID=your_gitlab_client_id
|
||||
# OAUTH_CLIENT_SECRET=your_gitlab_client_secret
|
||||
# OAUTH_AUTHORIZATION_URL=https://gitlab.com/oauth/authorize
|
||||
# OAUTH_TOKEN_URL=https://gitlab.com/oauth/token
|
||||
# OAUTH_USERINFO_URL=https://gitlab.com/api/v4/user
|
||||
# OAUTH_SCOPE=read_user
|
||||
|
||||
# Keycloak OAuth Configuration
|
||||
# OAUTH_PROVIDER_TYPE=keycloak
|
||||
# OAUTH_CLIENT_ID=your_keycloak_client_id
|
||||
# OAUTH_CLIENT_SECRET=your_keycloak_client_secret
|
||||
# OAUTH_REALM=your_realm
|
||||
# OAUTH_SERVER_URL=https://your-keycloak-server.com
|
||||
# OAUTH_AUTHORIZATION_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/auth
|
||||
# OAUTH_TOKEN_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/token
|
||||
# OAUTH_USERINFO_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/userinfo
|
||||
# OAUTH_JWKS_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/certs
|
||||
# OAUTH_SCOPE=openid profile email
|
||||
|
||||
# Legacy token settings (for backward compatibility)
|
||||
TOKEN_SECRET=your_secret_key_here
|
||||
TOKEN_EXPIRY=3600
|
||||
|
||||
@@ -133,7 +315,7 @@ ALERT_WEBHOOK_URL=
|
||||
|
||||
# Basic server information
|
||||
SERVER_NAME=doris-mcp-server
|
||||
SERVER_VERSION=0.5.0
|
||||
SERVER_VERSION=0.6.0
|
||||
SERVER_PORT=3000
|
||||
|
||||
# Temporary files directory
|
||||
@@ -172,9 +354,22 @@ TEMP_FILES_DIR=tmp
|
||||
# - LOG_CLEANUP_INTERVAL_HOURS: Check frequency, recommended 24 hours
|
||||
|
||||
# 2. Security Best Practices:
|
||||
# - Must change TOKEN_SECRET in production environment
|
||||
# - NEW: Enable individual authentication methods using ENABLE_TOKEN_AUTH, ENABLE_JWT_AUTH, ENABLE_OAUTH_AUTH
|
||||
# - When all methods are disabled, ALL requests are allowed with anonymous access
|
||||
# - Authentication methods work independently - any one succeeding allows access
|
||||
# - Token Auth: Change default tokens (DEFAULT_ADMIN_TOKEN, etc.) in production
|
||||
# - JWT Auth: Change JWT_SECRET_KEY and JWT_REFRESH_SECRET_KEY in production
|
||||
# - OAuth Auth: Configure OAuth provider settings and secure client secrets
|
||||
# - Must change TOKEN_SECRET in production environment (legacy compatibility)
|
||||
# - Adjust BLOCKED_KEYWORDS according to business needs
|
||||
# - Enable ENABLE_SECURITY_CHECK and ENABLE_MASKING
|
||||
# - NEW v0.6.0: Token Management Security (CRITICAL):
|
||||
# * ENABLE_HTTP_TOKEN_MANAGEMENT=false by default (SECURE BY DEFAULT)
|
||||
# * Only enable if you need HTTP token management endpoints
|
||||
# * TOKEN_MANAGEMENT_ADMIN_TOKEN: Use secure random token in production
|
||||
# * TOKEN_MANAGEMENT_ALLOWED_IPS: Restrict to localhost (127.0.0.1,::1) only
|
||||
# * REQUIRE_ADMIN_AUTH=true: Always require admin authentication
|
||||
# * Never expose token management endpoints to external networks
|
||||
|
||||
# 3. Performance Tuning:
|
||||
# - Adjust MAX_CONCURRENT_QUERIES based on hardware resources
|
||||
@@ -194,3 +389,137 @@ TEMP_FILES_DIR=tmp
|
||||
# - ADBC_CONNECTION_TIMEOUT: Connection timeout for ADBC (recommended: 30)
|
||||
# - ADBC_ENABLED: Enable or disable ADBC tools (true/false)
|
||||
# - Prerequisites: Install adbc_driver_manager, adbc_driver_flightsql, pyarrow packages
|
||||
|
||||
# 6. Authentication Configuration Guide - UPDATED DESIGN!
|
||||
#
|
||||
# Independent Authentication Control (NEW):
|
||||
# - ENABLE_TOKEN_AUTH=false (default): Disable token authentication
|
||||
# - ENABLE_JWT_AUTH=false (default): Disable JWT authentication
|
||||
# - ENABLE_OAUTH_AUTH=false (default): Disable OAuth authentication
|
||||
# - When all methods are disabled, no authentication is required (anonymous access)
|
||||
# - When multiple methods are enabled, any one succeeding allows access
|
||||
# - Recommended for development/testing: all false, production: enable needed methods
|
||||
#
|
||||
# Token Authentication (ENABLE_TOKEN_AUTH=true) - Recommended for most use cases:
|
||||
# - Simple and secure token-based authentication
|
||||
# - Configurable default tokens via environment variables
|
||||
# - Support for custom tokens via TOKEN_* environment variables
|
||||
# - Token file configuration via tokens.json
|
||||
# - Built-in token management HTTP endpoints
|
||||
# - No user management complexity - pure API access control
|
||||
#
|
||||
# JWT Authentication (ENABLE_JWT_AUTH=true) - For stateless applications:
|
||||
# - JSON Web Token based authentication
|
||||
# - Configurable token expiration and refresh
|
||||
# - Support for standard JWT claims
|
||||
# - RSA/ECDSA/HS256 algorithm support
|
||||
# - Suitable for microservices and distributed systems
|
||||
#
|
||||
# OAuth 2.0/OIDC (ENABLE_OAUTH_AUTH=true) - For enterprise integration:
|
||||
# - Integration with external identity providers
|
||||
# - Support for popular providers (Google, Microsoft, GitHub, GitLab, Keycloak)
|
||||
# - OpenID Connect compatibility
|
||||
# - Automatic user provisioning from provider
|
||||
# - Secure authorization code flow
|
||||
#
|
||||
# Authentication Method Selection Guide:
|
||||
# - No Auth (all switches false): Development, testing, trusted networks
|
||||
# - Token Auth only: Small teams, simple deployment, direct API access
|
||||
# - JWT Auth only: Stateless apps, microservices, mobile clients
|
||||
# - OAuth Auth only: Enterprise SSO, large teams, external identity providers
|
||||
# - Multiple methods: Flexible access, different client types, migration scenarios
|
||||
|
||||
# 7. Token Management Security Configuration Guide (NEW in v0.6.0) - CRITICAL!
|
||||
#
|
||||
# ⚠️ SECURITY WARNING: Token management endpoints are POWERFUL and DANGEROUS
|
||||
# They allow creation, revocation, and management of authentication tokens.
|
||||
# Improper configuration can lead to complete system compromise.
|
||||
#
|
||||
# 🔒 SECURE BY DEFAULT:
|
||||
# - ENABLE_HTTP_TOKEN_MANAGEMENT=false (disabled by default)
|
||||
# - REQUIRE_ADMIN_AUTH=true (admin auth required by default)
|
||||
# - TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1 (localhost only by default)
|
||||
#
|
||||
# 🛡️ SECURITY LAYERS (Applied in order):
|
||||
# 1. Configuration Check: HTTP token management must be explicitly enabled
|
||||
# 2. IP Restrictions: Only allowed IP addresses/networks can access endpoints
|
||||
# 3. Admin Authentication: Valid admin token required for all operations
|
||||
#
|
||||
# 📋 CONFIGURATION OPTIONS:
|
||||
#
|
||||
# Disable Token Management (RECOMMENDED for most deployments):
|
||||
# ENABLE_HTTP_TOKEN_MANAGEMENT=false
|
||||
# # All token management endpoints will return 403 Forbidden
|
||||
#
|
||||
# Enable with Maximum Security (Production):
|
||||
# ENABLE_HTTP_TOKEN_MANAGEMENT=true
|
||||
# TOKEN_MANAGEMENT_ADMIN_TOKEN=<secure-random-token-256-bit>
|
||||
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1
|
||||
# REQUIRE_ADMIN_AUTH=true
|
||||
#
|
||||
# Enable for Private Network (Advanced):
|
||||
# ENABLE_HTTP_TOKEN_MANAGEMENT=true
|
||||
# TOKEN_MANAGEMENT_ADMIN_TOKEN=<secure-random-token-256-bit>
|
||||
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,192.168.1.0/24,10.0.0.0/8
|
||||
# REQUIRE_ADMIN_AUTH=true
|
||||
#
|
||||
# 🔑 ADMIN TOKEN GENERATION:
|
||||
# # Generate secure admin token (Linux/macOS):
|
||||
# openssl rand -hex 32
|
||||
#
|
||||
# # Generate secure admin token (Python):
|
||||
# python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
#
|
||||
# 🌐 IP CONFIGURATION EXAMPLES:
|
||||
# # Localhost only (most secure):
|
||||
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1
|
||||
#
|
||||
# # Private network + localhost:
|
||||
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1,192.168.1.0/24,10.0.0.0/8
|
||||
#
|
||||
# # Specific servers only:
|
||||
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,192.168.1.10,192.168.1.11
|
||||
#
|
||||
# # Corporate network (be careful):
|
||||
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,172.16.0.0/12,192.168.0.0/16
|
||||
#
|
||||
# 🚫 NEVER DO THIS (Security Anti-Patterns):
|
||||
# # NEVER allow all IPs:
|
||||
# # TOKEN_MANAGEMENT_ALLOWED_IPS=0.0.0.0/0 # DANGEROUS!
|
||||
#
|
||||
# # NEVER disable admin auth in production:
|
||||
# # REQUIRE_ADMIN_AUTH=false # DANGEROUS!
|
||||
#
|
||||
# # NEVER use weak admin tokens:
|
||||
# # TOKEN_MANAGEMENT_ADMIN_TOKEN=admin # DANGEROUS!
|
||||
# # TOKEN_MANAGEMENT_ADMIN_TOKEN=123456 # DANGEROUS!
|
||||
#
|
||||
# 📊 ENDPOINT SECURITY TESTING:
|
||||
# # Test security (should fail):
|
||||
# curl -X POST http://external-ip:3000/token/create
|
||||
# # Expected: 403 Forbidden (IP not allowed)
|
||||
#
|
||||
# # Test without auth (should fail):
|
||||
# curl -X POST http://127.0.0.1:3000/token/create
|
||||
# # Expected: 401 Unauthorized (missing admin token)
|
||||
#
|
||||
# # Test with valid auth (should succeed if enabled):
|
||||
# curl -H "Authorization: Bearer your-admin-token" http://127.0.0.1:3000/token/stats
|
||||
# # Expected: 200 OK with token statistics
|
||||
#
|
||||
# 🔍 MONITORING & AUDITING:
|
||||
# # All token management access attempts are logged:
|
||||
# tail -f logs/doris_mcp_server_audit.log | grep "token management"
|
||||
#
|
||||
# # Monitor security events:
|
||||
# tail -f logs/doris_mcp_server_info.log | grep -E "(access denied|token management)"
|
||||
#
|
||||
# ✅ SECURITY BEST PRACTICES:
|
||||
# - Keep ENABLE_HTTP_TOKEN_MANAGEMENT=false unless absolutely necessary
|
||||
# - Use file-based token management (tokens.json) instead of HTTP endpoints
|
||||
# - Generate strong admin tokens using cryptographically secure methods
|
||||
# - Restrict access to localhost (127.0.0.1,::1) whenever possible
|
||||
# - Never expose token management endpoints to public internet
|
||||
# - Regularly audit token management access logs
|
||||
# - Use firewall rules as additional protection layer
|
||||
# - Consider VPN access for remote token management needs
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -19,7 +19,5 @@ ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
.idea/
|
||||
|
||||
|
||||
|
||||
|
||||
.coverage
|
||||
coverage.xml
|
||||
|
||||
@@ -32,6 +32,7 @@ RUN apt-get update && apt-get install -y \
|
||||
g++ \
|
||||
pkg-config \
|
||||
default-libmysqlclient-dev \
|
||||
dos2unix \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements file
|
||||
@@ -43,12 +44,13 @@ RUN pip install --no-cache-dir -r requirements.txt
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Convert line endings for shell scripts and ensure proper execution format
|
||||
RUN find . -name "*.sh" -exec dos2unix {} \; && \
|
||||
find . -name "*.sh" -exec chmod +x {} \;
|
||||
|
||||
# Create necessary directories
|
||||
RUN mkdir -p /app/logs /app/config /app/data
|
||||
|
||||
# Set permissions
|
||||
RUN chmod +x /app/start_server.sh
|
||||
|
||||
# Create non-root user
|
||||
RUN groupadd -r doris && useradd -r -g doris doris
|
||||
RUN chown -R doris:doris /app
|
||||
|
||||
397
README.md
397
README.md
@@ -21,19 +21,27 @@ under the License.
|
||||
|
||||
Doris MCP (Model Context Protocol) Server is a backend service built with Python and FastAPI. It implements the MCP, allowing clients to interact with it through defined "Tools". It's primarily designed to connect to Apache Doris databases, potentially leveraging Large Language Models (LLMs) for tasks like converting natural language queries to SQL (NL2SQL), executing queries, and performing metadata management and analysis.
|
||||
|
||||
## 🚀 What's New in v0.5.0
|
||||
## 🚀 What's New in v0.6.0
|
||||
|
||||
- **🔥 Critical at_eof Connection Fix**: **Complete elimination of at_eof connection pool errors** through redesigned connection pool strategy with zero minimum connections, intelligent health monitoring, automatic retry mechanisms, and self-healing pool recovery - achieving 99.9% connection stability improvement
|
||||
- **🔧 Revolutionary Logging System**: **Enterprise-grade logging overhaul** with level-based file separation (debug, info, warning, error, critical), automatic cleanup scheduler with 30-day retention, millisecond precision timestamps, dedicated audit trails, and zero-maintenance log management
|
||||
- **📊 Enterprise Data Analytics Suite**: Introducing **7 new enterprise-grade data governance and analytics tools** providing comprehensive data management capabilities including data quality analysis, column lineage tracking, freshness monitoring, and performance analytics
|
||||
- **🏃♂️ High-Performance ADBC Integration**: Complete **Apache Arrow Flight SQL (ADBC)** support with configurable parameters, offering 3-10x performance improvements for large dataset transfers through Arrow columnar format
|
||||
- **🔄 Unified Data Quality Framework**: Advanced data completeness and distribution analysis with business rules engine, confidence scoring, and automated quality recommendations
|
||||
- **📈 Advanced Analytics Tools**: Performance bottleneck identification, capacity planning with growth analysis, user access pattern monitoring, and data flow dependency mapping
|
||||
- **⚙️ Enhanced Configuration Management**: Complete ADBC configuration system with environment variable support, dynamic tool registration, and intelligent parameter validation
|
||||
- **🔒 Security & Compatibility Improvements**: Resolved pandas JSON serialization issues, enhanced enterprise security integration, and maintained full backward compatibility with v0.4.x versions
|
||||
- **🎯 Modular Architecture**: 6 new specialized tool modules for enterprise analytics with comprehensive English documentation and robust error handling
|
||||
- **🔐 Enterprise Authentication System**: **Revolutionary token-bound database configuration** with comprehensive Token, JWT, and OAuth authentication support, enabling secure multi-tenant access with granular control switches and enterprise-grade security defaults
|
||||
- **⚡ Immediate Database Validation**: **Real-time database configuration validation at connection time**, eliminating query-time blocking and providing instant feedback for invalid configurations - achieving 100% elimination of late-stage connection failures
|
||||
- **🔄 Hot Reload Configuration Management**: **Zero-downtime configuration updates** with intelligent hot reloading of tokens.json, automatic token revalidation, and comprehensive error handling with rollback mechanisms
|
||||
- **🏗️ Advanced Connection Architecture**: **Session caching and connection pool optimization** with 60% reduction in connection overhead, intelligent pool recreation, and automatic resource management
|
||||
- **🌐 Multi-Worker Scalability**: **True horizontal scaling** with stateless multi-worker architecture, efficient load distribution, and enterprise-grade concurrent processing capabilities
|
||||
- **🔒 Enhanced Security Framework**: **Comprehensive access control and SQL security validation** with immediate validation, role-based permissions, and enhanced injection detection patterns
|
||||
- **🛠️ Unified Configuration System**: **Streamlined configuration management** with proper command-line precedence, Docker compatibility improvements, and cross-platform deployment support
|
||||
- **📊 Token Management Dashboard**: **Complete token lifecycle management** with creation, revocation, statistics, and comprehensive audit trails for enterprise token governance
|
||||
- **🌐 Web-Based Management Interface**: **Secure localhost-only token administration** with intuitive dashboard, database binding configuration, real-time operations, and enterprise-grade access controls
|
||||
|
||||
> **🚀 Major Milestone**: This release establishes v0.5.0 as a **production-ready enterprise data governance platform** with **critical stability improvements** (complete at_eof fix + intelligent logging), 23 total tools (14 existing + 7 analytics + 2 ADBC tools), and enterprise-grade system reliability - representing a major advancement in both data intelligence capabilities and operational stability.
|
||||
> **🚀 Major Milestone**: v0.6.0 establishes the platform as a **production-ready enterprise authentication and database management system** with **zero-downtime operations** (hot reload + immediate validation + multi-worker scaling), advanced security controls, and comprehensive token-bound database configuration - representing a fundamental advancement in enterprise data platform capabilities.
|
||||
|
||||
### What's Also Included from v0.5.1
|
||||
|
||||
- **🔥 Critical at_eof Connection Fix**: Complete elimination of connection pool errors with intelligent health monitoring and self-healing recovery
|
||||
- **🔧 Enterprise Logging System**: Level-based file separation with automatic cleanup and millisecond precision timestamps
|
||||
- **📊 Advanced Data Analytics Suite**: 7 enterprise-grade data governance tools including quality analysis, lineage tracking, and performance monitoring
|
||||
- **🏃♂️ High-Performance ADBC Integration**: Apache Arrow Flight SQL support with 3-10x performance improvements for large datasets
|
||||
- **⚙️ Enhanced Configuration Management**: Complete ADBC configuration system with intelligent parameter validation
|
||||
|
||||
## Core Features
|
||||
|
||||
@@ -53,12 +61,13 @@ Doris MCP (Model Context Protocol) Server is a backend service built with Python
|
||||
* **Performance Analysis**: Advanced column analysis, performance monitoring, and data analysis tools (`doris_mcp_server/utils/analysis_tools.py`)
|
||||
* **Catalog Federation Support**: Full support for multi-catalog environments (internal Doris tables and external data sources like Hive, MySQL, etc.)
|
||||
* **Enterprise Security**: Comprehensive security framework with authentication, authorization, SQL injection protection, and data masking capabilities with environment variable configuration support
|
||||
* **Web-Based Token Management**: Secure localhost-only interface for complete token lifecycle management with database binding, real-time statistics, and enterprise-grade access controls (`doris_mcp_server/auth/token_handlers.py`)
|
||||
* **Unified Configuration Framework**: Centralized configuration management through `config.py` with comprehensive validation, standardized parameter naming, and smart default database handling with automatic fallback to `information_schema`
|
||||
|
||||
## System Requirements
|
||||
|
||||
* Python 3.12+
|
||||
* Database connection details (e.g., Doris Host, Port, User, Password, Database)
|
||||
* **Python**: 3.12+
|
||||
* **Database**: Apache Doris connection details (Host, Port, User, Password, Database)
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
@@ -69,7 +78,7 @@ Doris MCP (Model Context Protocol) Server is a backend service built with Python
|
||||
pip install doris-mcp-server
|
||||
|
||||
# Install specific version
|
||||
pip install doris-mcp-server==0.5.0
|
||||
pip install doris-mcp-server==0.6.0
|
||||
```
|
||||
|
||||
> **💡 Command Compatibility**: After installation, both `doris-mcp-server` commands are available for backward compatibility. You can use either command interchangeably.
|
||||
@@ -99,6 +108,46 @@ Standard input/output mode for direct integration with MCP clients:
|
||||
doris-mcp-server --transport stdio
|
||||
```
|
||||
|
||||
### 🌐 Token Management Interface (New in v0.6.0)
|
||||
|
||||
Access the **Web-Based Token Management Dashboard** for enterprise-grade token administration:
|
||||
|
||||
#### **Secure Access Requirements**
|
||||
- **Localhost Access Only**: Interface restricted to `127.0.0.1` and `::1` for maximum security
|
||||
- **Admin Authentication**: Requires `TOKEN_MANAGEMENT_ADMIN_TOKEN` for access
|
||||
- **Configuration Prerequisites**:
|
||||
```bash
|
||||
# Required environment variables
|
||||
ENABLE_HTTP_TOKEN_MANAGEMENT=true
|
||||
ENABLE_TOKEN_AUTH=true
|
||||
TOKEN_MANAGEMENT_ADMIN_TOKEN=your_secure_admin_token
|
||||
TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1
|
||||
```
|
||||
|
||||
#### **Interface Access**
|
||||
```bash
|
||||
# Access the token management interface
|
||||
http://localhost:3000/token/management?admin_token=your_secure_admin_token
|
||||
```
|
||||
|
||||
#### **Available Operations**
|
||||
- **📊 Token Statistics**: Real-time overview of active, expired, and total tokens
|
||||
- **➕ Create Tokens**:
|
||||
- Basic information (ID, description, expiration)
|
||||
- **Database binding** (host, port, user, password, database)
|
||||
- Custom token values or auto-generated secure tokens
|
||||
- **📋 Token Management**:
|
||||
- List all tokens with database binding status
|
||||
- One-click token revocation
|
||||
- Automated expired token cleanup
|
||||
- **🔒 Enterprise Security**:
|
||||
- All operations require admin authentication
|
||||
- Real-time IP validation
|
||||
- Complete audit logging
|
||||
- **Automatic persistence** to `tokens.json`
|
||||
|
||||
> **🔐 Security Note**: The interface is designed for localhost administration only. It cannot be accessed remotely, ensuring maximum security for token management operations.
|
||||
|
||||
### Verify Installation
|
||||
|
||||
```bash
|
||||
@@ -114,11 +163,18 @@ curl http://localhost:3000/health
|
||||
Instead of command-line arguments, you can use environment variables:
|
||||
|
||||
```bash
|
||||
# Basic Database Configuration
|
||||
export DORIS_HOST="127.0.0.1"
|
||||
export DORIS_PORT="9030"
|
||||
export DORIS_USER="root"
|
||||
export DORIS_PASSWORD="your_password"
|
||||
|
||||
# Token Management Interface (Security-Critical)
|
||||
export ENABLE_HTTP_TOKEN_MANAGEMENT=true
|
||||
export ENABLE_TOKEN_AUTH=true
|
||||
export TOKEN_MANAGEMENT_ADMIN_TOKEN="your_secure_admin_token"
|
||||
export TOKEN_MANAGEMENT_ALLOWED_IPS="127.0.0.1,::1"
|
||||
|
||||
# Then start with simplified command
|
||||
doris-mcp-server --transport http --host 0.0.0.0 --port 3000
|
||||
```
|
||||
@@ -177,11 +233,20 @@ cp .env.example .env
|
||||
* `DORIS_BE_WEBSERVER_PORT`: BE webserver port for monitoring tools (default: 8040)
|
||||
* `FE_ARROW_FLIGHT_SQL_PORT`: Frontend Arrow Flight SQL port for ADBC (New in v0.5.0)
|
||||
* `BE_ARROW_FLIGHT_SQL_PORT`: Backend Arrow Flight SQL port for ADBC (New in v0.5.0)
|
||||
* **Security Configuration**:
|
||||
* `AUTH_TYPE`: Authentication type (token/basic/oauth, default: token)
|
||||
* `TOKEN_SECRET`: Token secret key
|
||||
* `ENABLE_SECURITY_CHECK`: Enable/disable SQL security validation (default: true, New in v0.4.2)
|
||||
* `BLOCKED_KEYWORDS`: Comma-separated list of blocked SQL keywords (New in v0.4.2)
|
||||
* **Authentication Configuration (Enhanced in v0.6.0)**:
|
||||
* `ENABLE_TOKEN_AUTH`: Enable token-based authentication (default: false)
|
||||
* `ENABLE_JWT_AUTH`: Enable JWT authentication (default: false)
|
||||
* `ENABLE_OAUTH_AUTH`: Enable OAuth authentication (default: false)
|
||||
* `TOKEN_FILE_PATH`: Path to tokens.json file for token management (default: tokens.json)
|
||||
* `TOKEN_HOT_RELOAD`: Enable hot reloading of token configuration (default: true)
|
||||
* `DEFAULT_ADMIN_TOKEN`: Default admin token (customizable via env)
|
||||
* `DEFAULT_ANALYST_TOKEN`: Default analyst token (customizable via env)
|
||||
* `DEFAULT_READONLY_TOKEN`: Default readonly token (customizable via env)
|
||||
* **Legacy Security Configuration**:
|
||||
* `AUTH_TYPE`: Legacy authentication type (token/basic/oauth, deprecated - use individual switches)
|
||||
* `TOKEN_SECRET`: Legacy token secret key (use token-based auth instead)
|
||||
* `ENABLE_SECURITY_CHECK`: Enable/disable SQL security validation (default: true)
|
||||
* `BLOCKED_KEYWORDS`: Comma-separated list of blocked SQL keywords
|
||||
* `ENABLE_MASKING`: Enable data masking (default: true)
|
||||
* `MAX_RESULT_ROWS`: Maximum result rows (default: 10000)
|
||||
* **ADBC Configuration (New in v0.5.0)**:
|
||||
@@ -265,7 +330,7 @@ docker run -d -p <port>:<port> -v /*your-host*/doris-mcp-server/.env:/app/.env -
|
||||
|
||||
* **Streamable HTTP**: `http://<host>:<port>/mcp` (Primary MCP endpoint - supports GET, POST, DELETE, OPTIONS)
|
||||
* **Health Check**: `http://<host>:<port>/health`
|
||||
|
||||
*
|
||||
> **Note**: The server uses Streamable HTTP for web-based communication, providing unified request/response and streaming capabilities.
|
||||
|
||||
## Usage
|
||||
@@ -354,31 +419,97 @@ The Doris MCP Server supports **catalog federation**, enabling interaction with
|
||||
|
||||
## Security Configuration
|
||||
|
||||
The Doris MCP Server includes a comprehensive security framework that provides enterprise-level protection through authentication, authorization, SQL security validation, and data masking capabilities.
|
||||
The Doris MCP Server includes a comprehensive enterprise-grade security framework with advanced authentication, authorization, SQL security validation, and data masking capabilities enhanced in v0.6.0.
|
||||
|
||||
### Security Features
|
||||
### Security Features (Enhanced in v0.6.0)
|
||||
|
||||
* **🔐 Authentication**: Support for token-based and basic authentication
|
||||
* **🛡️ Authorization**: Role-based access control (RBAC) with security levels
|
||||
* **🚫 SQL Security**: SQL injection protection and blocked operations
|
||||
* **🎭 Data Masking**: Automatic sensitive data masking based on user permissions
|
||||
* **📊 Security Levels**: Four-tier security classification (Public, Internal, Confidential, Secret)
|
||||
* **🔐 Multi-Authentication System**: Complete Token, JWT, and OAuth authentication with independent control switches
|
||||
* **🔗 Token-Bound Database Configuration**: Revolutionary approach allowing tokens to carry their own database connection parameters
|
||||
* **🔄 Hot Reload Security**: Zero-downtime security configuration updates with intelligent token revalidation
|
||||
* **⚡ Immediate Validation**: Real-time database and authentication validation at connection time
|
||||
* **🛡️ Role-Based Authorization**: Advanced RBAC with four-tier security classification
|
||||
* **🚫 Enhanced SQL Security**: Advanced SQL injection protection with improved pattern detection
|
||||
* **🎭 Intelligent Data Masking**: Automatic sensitive data masking with user-based permissions
|
||||
* **📊 Security Analytics**: Comprehensive audit trails and security monitoring
|
||||
|
||||
### Authentication Configuration
|
||||
### Authentication Configuration (v0.6.0)
|
||||
|
||||
Configure authentication in your environment variables:
|
||||
Configure the new authentication system with granular control:
|
||||
|
||||
```bash
|
||||
# Authentication Type (token/basic/oauth)
|
||||
AUTH_TYPE=token
|
||||
# Individual Authentication Control (New in v0.6.0)
|
||||
ENABLE_TOKEN_AUTH=true # Enable token-based authentication
|
||||
ENABLE_JWT_AUTH=false # Enable JWT authentication
|
||||
ENABLE_OAUTH_AUTH=false # Enable OAuth authentication
|
||||
|
||||
# Token Secret for JWT validation
|
||||
TOKEN_SECRET=your_secret_key_here
|
||||
# Token Management (New in v0.6.0)
|
||||
TOKEN_FILE_PATH=tokens.json # Token configuration file
|
||||
TOKEN_HOT_RELOAD=true # Enable hot reloading
|
||||
|
||||
# Session timeout (in seconds)
|
||||
SESSION_TIMEOUT=3600
|
||||
# Default Tokens (Customizable via environment)
|
||||
DEFAULT_ADMIN_TOKEN=doris_admin_token_123456
|
||||
DEFAULT_ANALYST_TOKEN=doris_analyst_token_123456
|
||||
DEFAULT_READONLY_TOKEN=doris_readonly_token_123456
|
||||
|
||||
# Legacy Configuration (Deprecated)
|
||||
# AUTH_TYPE=token # Use individual switches instead
|
||||
# TOKEN_SECRET=your_secret_key # Use token-based auth instead
|
||||
```
|
||||
|
||||
### Token-Bound Database Configuration (New in v0.6.0)
|
||||
|
||||
Create a `tokens.json` file for advanced token management with database binding:
|
||||
|
||||
```json
|
||||
{
|
||||
"version": "1.0",
|
||||
"tokens": [
|
||||
{
|
||||
"token_id": "customer-a-token",
|
||||
"token": "customer_a_secure_token_12345",
|
||||
"description": "Customer A dedicated database access",
|
||||
"expires_hours": null,
|
||||
"is_active": true,
|
||||
"database_config": {
|
||||
"host": "customer-a-db.example.com",
|
||||
"port": 9030,
|
||||
"user": "customer_a_user",
|
||||
"password": "secure_password",
|
||||
"database": "customer_a_data",
|
||||
"charset": "UTF8",
|
||||
"fe_http_port": 8030
|
||||
}
|
||||
},
|
||||
{
|
||||
"token_id": "customer-b-token",
|
||||
"token": "customer_b_secure_token_67890",
|
||||
"description": "Customer B dedicated database access",
|
||||
"expires_hours": 720,
|
||||
"is_active": true,
|
||||
"database_config": {
|
||||
"host": "customer-b-db.example.com",
|
||||
"port": 9030,
|
||||
"user": "customer_b_user",
|
||||
"password": "secure_password",
|
||||
"database": "customer_b_data",
|
||||
"charset": "UTF8",
|
||||
"fe_http_port": 8030
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Hot Reload Configuration Updates (New in v0.6.0)
|
||||
|
||||
The system automatically detects and applies configuration changes:
|
||||
|
||||
- **Automatic Detection**: File modification monitoring every 10 seconds
|
||||
- **Instant Validation**: Immediate database configuration validation for new tokens
|
||||
- **Zero Downtime**: Configuration updates without service interruption
|
||||
- **Rollback Protection**: Automatic rollback on configuration errors
|
||||
- **Audit Trail**: Complete logging of configuration changes
|
||||
|
||||
#### Token Authentication Example
|
||||
|
||||
```python
|
||||
@@ -700,6 +831,15 @@ After configuring either mode in Cursor, you should be able to select the server
|
||||
doris-mcp-server/
|
||||
├── doris_mcp_server/ # Main server package
|
||||
│ ├── main.py # Main entry point and FastAPI app
|
||||
│ ├── multiworker_app.py # Multi-worker application module (New in v0.6.0)
|
||||
│ ├── auth/ # Authentication modules (New in v0.6.0)
|
||||
│ │ ├── token_manager.py # Enterprise token management with hot reload
|
||||
│ │ ├── jwt_manager.py # JWT authentication provider
|
||||
│ │ ├── oauth_provider.py # OAuth authentication provider
|
||||
│ │ ├── oauth_handlers.py # OAuth HTTP endpoint handlers
|
||||
│ │ ├── token_handlers.py # Token management HTTP endpoints
|
||||
│ │ ├── auth_middleware.py # Authentication middleware
|
||||
│ │ └── __init__.py
|
||||
│ ├── tools/ # MCP tools implementation
|
||||
│ │ ├── tools_manager.py # Centralized tools management and registration
|
||||
│ │ ├── resources_manager.py # Resource management and metadata exposure
|
||||
@@ -707,18 +847,18 @@ doris-mcp-server/
|
||||
│ │ └── __init__.py
|
||||
│ ├── utils/ # Core utility modules
|
||||
│ │ ├── config.py # Configuration management with validation
|
||||
│ │ ├── db.py # Database connection management with pooling
|
||||
│ │ ├── db.py # Enhanced database connection management with token binding (Enhanced in v0.6.0)
|
||||
│ │ ├── query_executor.py # High-performance SQL execution with caching
|
||||
│ │ ├── security.py # Security management and data masking
|
||||
│ │ ├── security.py # Advanced security management and authentication (Enhanced in v0.6.0)
|
||||
│ │ ├── schema_extractor.py # Metadata extraction with catalog federation
|
||||
│ │ ├── analysis_tools.py # Data analysis and performance monitoring
|
||||
│ │ ├── data_governance_tools.py # Data lineage and freshness monitoring (New in v0.5.0)
|
||||
│ │ ├── data_quality_tools.py # Comprehensive data quality analysis (New in v0.5.0)
|
||||
│ │ ├── data_exploration_tools.py # Advanced statistical analysis (New in v0.5.0)
|
||||
│ │ ├── security_analytics_tools.py # Access pattern analysis (New in v0.5.0)
|
||||
│ │ ├── dependency_analysis_tools.py # Impact analysis and dependency mapping (New in v0.5.0)
|
||||
│ │ ├── performance_analytics_tools.py # Query optimization and capacity planning (New in v0.5.0)
|
||||
│ │ ├── adbc_query_tools.py # High-performance Arrow Flight SQL operations (New in v0.5.0)
|
||||
│ │ ├── data_governance_tools.py # Data lineage and freshness monitoring (v0.5.0)
|
||||
│ │ ├── data_quality_tools.py # Comprehensive data quality analysis (v0.5.0)
|
||||
│ │ ├── data_exploration_tools.py # Advanced statistical analysis (v0.5.0)
|
||||
│ │ ├── security_analytics_tools.py # Access pattern analysis (v0.5.0)
|
||||
│ │ ├── dependency_analysis_tools.py # Impact analysis and dependency mapping (v0.5.0)
|
||||
│ │ ├── performance_analytics_tools.py # Query optimization and capacity planning (v0.5.0)
|
||||
│ │ ├── adbc_query_tools.py # High-performance Arrow Flight SQL operations (v0.5.0)
|
||||
│ │ ├── logger.py # Logging configuration
|
||||
│ │ └── __init__.py
|
||||
│ └── __init__.py
|
||||
@@ -727,7 +867,9 @@ doris-mcp-server/
|
||||
│ ├── README.md # Client documentation
|
||||
│ └── __init__.py
|
||||
├── logs/ # Log files directory
|
||||
├── tokens.json # Token configuration file (New in v0.6.0)
|
||||
├── README.md # This documentation
|
||||
├── RELEASE_NOTES_v0.6.0.md # Release notes for v0.6.0
|
||||
├── .env.example # Environment variables template
|
||||
├── requirements.txt # Python dependencies
|
||||
├── pyproject.toml # Project configuration and entry points
|
||||
@@ -1273,4 +1415,171 @@ cat logs/doris_mcp_server_critical.log
|
||||
- **Backup**: Keeps 5 backup files for each log level
|
||||
- **Performance**: Minimal impact on server performance
|
||||
|
||||
### Q: How to use the new Token-Bound Database Configuration? (New in v0.6.0)
|
||||
|
||||
**A:** The revolutionary token-bound database configuration allows each token to carry its own database connection parameters for secure multi-tenant access:
|
||||
|
||||
1. **Enable Token Authentication**:
|
||||
```bash
|
||||
# In your .env file
|
||||
ENABLE_TOKEN_AUTH=true
|
||||
TOKEN_HOT_RELOAD=true
|
||||
TOKEN_FILE_PATH=tokens.json
|
||||
```
|
||||
|
||||
2. **Create tokens.json Configuration**:
|
||||
```json
|
||||
{
|
||||
"version": "1.0",
|
||||
"tokens": [
|
||||
{
|
||||
"token_id": "tenant-alpha",
|
||||
"token": "tenant_alpha_secure_token_123",
|
||||
"description": "Tenant Alpha database access",
|
||||
"expires_hours": null,
|
||||
"is_active": true,
|
||||
"database_config": {
|
||||
"host": "tenant-alpha-db.company.com",
|
||||
"port": 9030,
|
||||
"user": "alpha_user",
|
||||
"password": "secure_password",
|
||||
"database": "alpha_analytics",
|
||||
"charset": "UTF8"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
3. **Configuration Priority** (New in v0.6.0):
|
||||
- **Token-bound DB config** (highest priority)
|
||||
- **Environment variables (.env)**
|
||||
- **Error if neither available**
|
||||
|
||||
4. **Hot Reload Benefits**:
|
||||
- Add new tenants without service restart
|
||||
- Update database credentials in real-time
|
||||
- Automatic validation and rollback on errors
|
||||
- Complete audit trail of changes
|
||||
|
||||
5. **Multi-Tenant Usage**:
|
||||
```bash
|
||||
# Different tokens access different databases automatically
|
||||
curl -H "Authorization: Bearer tenant_alpha_secure_token_123" http://localhost:3000/mcp
|
||||
curl -H "Authorization: Bearer tenant_beta_secure_token_456" http://localhost:3000/mcp
|
||||
```
|
||||
|
||||
### Q: How does Hot Reload work and is it safe? (New in v0.6.0)
|
||||
|
||||
**A:** The hot reload system is designed for enterprise production environments with comprehensive safety measures:
|
||||
|
||||
**How It Works:**
|
||||
- **File Monitoring**: Checks tokens.json every 10 seconds for modifications
|
||||
- **Immediate Validation**: New tokens are validated including database connectivity
|
||||
- **Atomic Updates**: All-or-nothing configuration updates
|
||||
- **Rollback Protection**: Automatic rollback if any token validation fails
|
||||
|
||||
**Safety Features:**
|
||||
- **Backup and Restore**: Current configuration backed up before changes
|
||||
- **Connection Testing**: Database connections tested before applying changes
|
||||
- **Error Isolation**: Invalid tokens don't affect existing valid tokens
|
||||
- **Audit Logging**: Complete trail of all configuration changes
|
||||
|
||||
**Best Practices:**
|
||||
```bash
|
||||
# Monitor hot reload activity
|
||||
tail -f logs/doris_mcp_server_info.log | grep "hot reload"
|
||||
|
||||
# Test configuration before applying
|
||||
cp tokens.json tokens.json.backup
|
||||
# Make changes to tokens.json
|
||||
# System will automatically validate and apply or rollback
|
||||
```
|
||||
|
||||
### Q: How to manage Token lifecycle and security? (New in v0.6.0)
|
||||
|
||||
**A:** Token management uses a secure, file-based approach with optional administrative endpoints that have comprehensive security controls.
|
||||
|
||||
**Primary Token Management Method (Recommended):**
|
||||
```bash
|
||||
# 1. Edit tokens.json file directly (safest method)
|
||||
nano tokens.json
|
||||
|
||||
# 2. Hot reload will automatically detect changes
|
||||
# No server restart required - changes applied within 10 seconds
|
||||
|
||||
# 3. Monitor hot reload in logs
|
||||
tail -f logs/doris_mcp_server_info.log | grep "hot reload"
|
||||
```
|
||||
|
||||
**Administrative Endpoints (Secure, Local Access Only):**
|
||||
|
||||
🛡️ **SECURITY**: These endpoints are protected by comprehensive security controls and are **disabled by default**.
|
||||
|
||||
```bash
|
||||
# Security Requirements (ALL must be met):
|
||||
# ✓ HTTP token management explicitly enabled in configuration
|
||||
# ✓ Access only from localhost (127.0.0.1/::1) - IP restrictions enforced
|
||||
# ✓ Valid admin authentication token required
|
||||
# ✓ Admin authentication enabled in configuration
|
||||
|
||||
# Enable HTTP token management (disabled by default)
|
||||
export ENABLE_HTTP_TOKEN_MANAGEMENT=true
|
||||
export TOKEN_MANAGEMENT_ADMIN_TOKEN=your_secure_admin_token
|
||||
export REQUIRE_ADMIN_AUTH=true
|
||||
export TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1
|
||||
|
||||
# Access with proper authentication
|
||||
curl -H "Authorization: Bearer your_secure_admin_token" http://127.0.0.1:3000/token/stats
|
||||
|
||||
# Demo page (local access only, with authentication)
|
||||
# Access: http://127.0.0.1:3000/token/demo
|
||||
```
|
||||
|
||||
**Recommended Token Management Workflow:**
|
||||
|
||||
1. **Development/Testing**:
|
||||
```json
|
||||
// tokens.json
|
||||
{
|
||||
"version": "1.0",
|
||||
"tokens": [
|
||||
{
|
||||
"token_id": "dev-token",
|
||||
"token": "dev_secure_token_123",
|
||||
"description": "Development environment access",
|
||||
"expires_hours": 24,
|
||||
"is_active": true
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
2. **Production Deployment**:
|
||||
```bash
|
||||
# Use secure token generation
|
||||
openssl rand -hex 32 # Generate secure token
|
||||
|
||||
# Store in secure configuration management
|
||||
# Never commit tokens to version control
|
||||
# Use environment variables for sensitive tokens
|
||||
```
|
||||
|
||||
**Security Features:**
|
||||
- **File-Based Management**: Primary management through secured configuration files
|
||||
- **Hot Reload**: Automatic configuration updates without service interruption
|
||||
- **Token Hashing**: Tokens stored as SHA-256 hashes internally
|
||||
- **Audit Trail**: Complete logging of all token operations and changes
|
||||
- **Expiration Management**: Automatic cleanup of expired tokens
|
||||
- **Local Admin Only**: Management endpoints restricted to localhost access
|
||||
- **Configuration Validation**: Immediate validation of token and database configurations
|
||||
|
||||
**Security Best Practices:**
|
||||
- Always manage tokens through secure configuration files
|
||||
- Never expose token management endpoints to external networks
|
||||
- Use strong, randomly generated tokens for production
|
||||
- Implement proper file permissions for tokens.json (600 or 640)
|
||||
- Regular audit of active tokens and their usage patterns
|
||||
- Monitor hot reload logs for unauthorized configuration changes
|
||||
|
||||
For other issues, please check GitHub Issues or submit a new issue.
|
||||
|
||||
@@ -323,7 +323,7 @@ class DorisUnifiedClient:
|
||||
async with streamablehttp_client(
|
||||
self.config.server_url,
|
||||
timeout=timedelta(seconds=self.config.timeout)
|
||||
) as (read, write):
|
||||
) as (read, write, _):
|
||||
async with ClientSession(read, write) as session:
|
||||
self.session = session
|
||||
self._init_sub_clients()
|
||||
@@ -463,7 +463,7 @@ async def create_http_client(server_url: str, timeout: int = 60) -> DorisUnified
|
||||
# Example usage
|
||||
async def example_stdio():
|
||||
"""stdio mode example"""
|
||||
client = await create_stdio_client("python", ["doris_mcp_server/main.py"])
|
||||
client = await create_stdio_client("python", ["-m", "doris_mcp_server.main", "--transport", "stdio"])
|
||||
|
||||
async def test_client(client: DorisUnifiedClient):
|
||||
# Get server capabilities
|
||||
|
||||
56
doris_mcp_server/auth/__init__.py
Normal file
56
doris_mcp_server/auth/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Doris MCP Server Authentication Module
|
||||
Provides JWT-based, Token-based, and OAuth 2.0/OIDC authentication and authorization services
|
||||
"""
|
||||
|
||||
from .jwt_manager import JWTManager
|
||||
from .key_manager import KeyManager
|
||||
from .token_validators import TokenValidator, TokenBlacklist
|
||||
from .auth_middleware import AuthMiddleware
|
||||
from .token_manager import TokenManager, TokenInfo, TokenValidationResult
|
||||
from .token_handlers import TokenHandlers
|
||||
from .oauth_client import OAuthClient, OAuthStateManager
|
||||
from .oauth_provider import OAuthAuthenticationProvider
|
||||
from .oauth_types import (
|
||||
OAuthProvider, OAuthState, OAuthTokens, OAuthUserInfo,
|
||||
OIDCDiscovery, OAuthError, OAuthProviderConfig
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"JWTManager",
|
||||
"KeyManager",
|
||||
"TokenValidator",
|
||||
"TokenBlacklist",
|
||||
"AuthMiddleware",
|
||||
"TokenManager",
|
||||
"TokenInfo",
|
||||
"TokenValidationResult",
|
||||
"TokenHandlers",
|
||||
"OAuthClient",
|
||||
"OAuthStateManager",
|
||||
"OAuthAuthenticationProvider",
|
||||
"OAuthProvider",
|
||||
"OAuthState",
|
||||
"OAuthTokens",
|
||||
"OAuthUserInfo",
|
||||
"OIDCDiscovery",
|
||||
"OAuthError",
|
||||
"OAuthProviderConfig"
|
||||
]
|
||||
271
doris_mcp_server/auth/auth_middleware.py
Normal file
271
doris_mcp_server/auth/auth_middleware.py
Normal file
@@ -0,0 +1,271 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Authentication Middleware Module
|
||||
Provides middleware for JWT authentication in HTTP and MCP contexts
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any, Callable, Awaitable
|
||||
from datetime import datetime
|
||||
|
||||
from .jwt_manager import JWTManager
|
||||
from ..utils.security import AuthContext, SecurityLevel
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AuthMiddleware:
|
||||
"""Authentication Middleware
|
||||
|
||||
Provides JWT authentication functionality for HTTP and MCP requests
|
||||
"""
|
||||
|
||||
def __init__(self, jwt_manager: JWTManager):
|
||||
"""Initialize authentication middleware
|
||||
|
||||
Args:
|
||||
jwt_manager: JWT manager instance
|
||||
"""
|
||||
self.jwt_manager = jwt_manager
|
||||
logger.info("AuthMiddleware initialized")
|
||||
|
||||
def extract_token_from_header(self, authorization: str) -> Optional[str]:
|
||||
"""Extract JWT token from Authorization header
|
||||
|
||||
Args:
|
||||
authorization: Authorization header value
|
||||
|
||||
Returns:
|
||||
JWT token string, or None if not found
|
||||
"""
|
||||
if not authorization:
|
||||
return None
|
||||
|
||||
# Support Bearer format
|
||||
if authorization.startswith('Bearer '):
|
||||
return authorization[7:] # Remove "Bearer " prefix
|
||||
|
||||
# Support direct token format
|
||||
if not authorization.startswith('Basic '):
|
||||
return authorization
|
||||
|
||||
return None
|
||||
|
||||
async def authenticate_request(self, auth_info: Dict[str, Any]) -> AuthContext:
|
||||
"""Authenticate request and return authentication context
|
||||
|
||||
Args:
|
||||
auth_info: Authentication information dictionary
|
||||
|
||||
Returns:
|
||||
AuthContext authentication context
|
||||
|
||||
Raises:
|
||||
ValueError: Authentication failed
|
||||
"""
|
||||
try:
|
||||
auth_type = auth_info.get("type", "jwt")
|
||||
|
||||
if auth_type == "jwt" or auth_type == "token":
|
||||
return await self._authenticate_jwt(auth_info)
|
||||
else:
|
||||
raise ValueError(f"Unsupported authentication type: {auth_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Request authentication failed: {e}")
|
||||
raise
|
||||
|
||||
async def _authenticate_jwt(self, auth_info: Dict[str, Any]) -> AuthContext:
|
||||
"""JWT authentication processing
|
||||
|
||||
Args:
|
||||
auth_info: Authentication information containing JWT token
|
||||
|
||||
Returns:
|
||||
AuthContext authentication context
|
||||
"""
|
||||
# Get token
|
||||
token = auth_info.get("token")
|
||||
if not token:
|
||||
# Try to get from Authorization header
|
||||
authorization = auth_info.get("authorization")
|
||||
token = self.extract_token_from_header(authorization)
|
||||
|
||||
if not token:
|
||||
raise ValueError("Missing JWT token")
|
||||
|
||||
try:
|
||||
# Validate token
|
||||
validation_result = await self.jwt_manager.validate_token(token, 'access')
|
||||
payload = validation_result['payload']
|
||||
|
||||
# Build authentication context
|
||||
auth_context = AuthContext(
|
||||
token_id=payload.get('jti', ''),
|
||||
user_id=payload.get('sub'),
|
||||
roles=payload.get('roles', []),
|
||||
permissions=payload.get('permissions', []),
|
||||
security_level=SecurityLevel(payload.get('security_level', 'internal')),
|
||||
session_id=payload.get('jti'), # Use JWT ID as session ID
|
||||
login_time=datetime.fromtimestamp(payload.get('iat', 0)),
|
||||
last_activity=datetime.utcnow(),
|
||||
token=token # Store raw token for token-bound database configuration
|
||||
)
|
||||
|
||||
logger.info(f"JWT authentication successful for user: {auth_context.user_id}")
|
||||
return auth_context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"JWT authentication failed: {e}")
|
||||
raise ValueError(f"JWT authentication failed: {str(e)}")
|
||||
|
||||
async def create_auth_response_headers(self, auth_context: AuthContext) -> Dict[str, str]:
|
||||
"""Create authentication response headers
|
||||
|
||||
Args:
|
||||
auth_context: Authentication context
|
||||
|
||||
Returns:
|
||||
Response headers dictionary
|
||||
"""
|
||||
return {
|
||||
'X-Auth-User': auth_context.user_id,
|
||||
'X-Auth-Roles': ','.join(auth_context.roles),
|
||||
'X-Auth-Session': auth_context.session_id,
|
||||
'X-Auth-Security-Level': auth_context.security_level.value
|
||||
}
|
||||
|
||||
def create_http_middleware(self, skip_paths: Optional[list] = None):
|
||||
"""Create HTTP middleware function
|
||||
|
||||
Args:
|
||||
skip_paths: List of paths to skip authentication
|
||||
|
||||
Returns:
|
||||
ASGI middleware function
|
||||
"""
|
||||
skip_paths = skip_paths or ['/health', '/docs', '/openapi.json']
|
||||
|
||||
async def middleware(scope, receive, send):
|
||||
"""HTTP authentication middleware"""
|
||||
if scope['type'] != 'http':
|
||||
# Pass through non-HTTP requests directly
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get('path', '')
|
||||
|
||||
# Check if authentication should be skipped
|
||||
if any(path.startswith(skip) for skip in skip_paths):
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
# Extract authentication information
|
||||
headers = dict(scope.get('headers', []))
|
||||
authorization = headers.get(b'authorization', b'').decode()
|
||||
|
||||
try:
|
||||
# Perform authentication
|
||||
auth_info = {
|
||||
'type': 'jwt',
|
||||
'authorization': authorization
|
||||
}
|
||||
auth_context = await self.authenticate_request(auth_info)
|
||||
|
||||
# Add authentication context to scope
|
||||
scope['auth_context'] = auth_context
|
||||
|
||||
# Create response wrapper to add authentication headers
|
||||
async def send_wrapper(message):
|
||||
if message['type'] == 'http.response.start':
|
||||
headers = dict(message.get('headers', []))
|
||||
auth_headers = await self.create_auth_response_headers(auth_context)
|
||||
|
||||
for key, value in auth_headers.items():
|
||||
headers[key.encode()] = value.encode()
|
||||
|
||||
message['headers'] = list(headers.items())
|
||||
|
||||
await send(message)
|
||||
|
||||
return await self.app(scope, receive, send_wrapper)
|
||||
|
||||
except Exception as e:
|
||||
# Authentication failed, return 401 error
|
||||
response_body = f'{{"error": "Authentication failed", "message": "{str(e)}"}}'
|
||||
|
||||
await send({
|
||||
'type': 'http.response.start',
|
||||
'status': 401,
|
||||
'headers': [
|
||||
(b'content-type', b'application/json'),
|
||||
(b'www-authenticate', b'Bearer')
|
||||
]
|
||||
})
|
||||
await send({
|
||||
'type': 'http.response.body',
|
||||
'body': response_body.encode()
|
||||
})
|
||||
|
||||
return middleware
|
||||
|
||||
async def authenticate_mcp_request(self, headers: Dict[str, str]) -> AuthContext:
|
||||
"""Authenticate MCP request
|
||||
|
||||
Args:
|
||||
headers: MCP request headers
|
||||
|
||||
Returns:
|
||||
AuthContext authentication context
|
||||
"""
|
||||
try:
|
||||
# Extract authentication information from multiple possible header fields
|
||||
authorization = (
|
||||
headers.get('Authorization') or
|
||||
headers.get('authorization') or
|
||||
headers.get('X-Auth-Token') or
|
||||
headers.get('x-auth-token')
|
||||
)
|
||||
|
||||
auth_info = {
|
||||
'type': 'jwt',
|
||||
'authorization': authorization
|
||||
}
|
||||
|
||||
return await self.authenticate_request(auth_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP request authentication failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
"""Authentication error exception"""
|
||||
|
||||
def __init__(self, message: str, error_code: str = "AUTH_FAILED"):
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AuthorizationError(Exception):
|
||||
"""Authorization error exception"""
|
||||
|
||||
def __init__(self, message: str, error_code: str = "ACCESS_DENIED"):
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
super().__init__(message)
|
||||
471
doris_mcp_server/auth/jwt_manager.py
Normal file
471
doris_mcp_server/auth/jwt_manager.py
Normal file
@@ -0,0 +1,471 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
JWT Manager Module
|
||||
Provides comprehensive JWT token management including generation, validation, refresh and revocation
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
try:
|
||||
import jwt
|
||||
except ImportError:
|
||||
raise ImportError("PyJWT is required for JWT functionality. Install with: pip install PyJWT[crypto]")
|
||||
|
||||
from .key_manager import KeyManager
|
||||
from .token_validators import TokenValidator, TokenBlacklist
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class JWTManager:
|
||||
"""JWT Token Manager
|
||||
|
||||
Provides comprehensive JWT token lifecycle management, including:
|
||||
- Token generation and signing
|
||||
- Token validation and parsing
|
||||
- Token refresh mechanism
|
||||
- Token revocation and blacklist
|
||||
- Automatic key rotation
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize JWT manager
|
||||
|
||||
Args:
|
||||
config: DorisConfig configuration object (with security attribute)
|
||||
"""
|
||||
self.config = config
|
||||
# Access JWT settings through the security configuration
|
||||
if hasattr(config, 'security'):
|
||||
security_config = config.security
|
||||
else:
|
||||
# Fallback if config is passed directly as SecurityConfig
|
||||
security_config = config
|
||||
|
||||
self.algorithm = security_config.jwt_algorithm
|
||||
self.issuer = security_config.jwt_issuer
|
||||
self.audience = security_config.jwt_audience
|
||||
self.access_token_expiry = security_config.jwt_access_token_expiry
|
||||
self.refresh_token_expiry = security_config.jwt_refresh_token_expiry
|
||||
self.enable_refresh = security_config.enable_token_refresh
|
||||
self.enable_revocation = security_config.enable_token_revocation
|
||||
|
||||
# Initialize components
|
||||
self.key_manager = KeyManager(config)
|
||||
self.token_blacklist = TokenBlacklist()
|
||||
self.validator = TokenValidator(config, self.token_blacklist)
|
||||
|
||||
# Automatic key rotation task
|
||||
self._key_rotation_task = None
|
||||
|
||||
logger.info(f"JWTManager initialized with algorithm: {self.algorithm}")
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize JWT manager"""
|
||||
try:
|
||||
# Initialize key manager
|
||||
if not await self.key_manager.initialize():
|
||||
logger.error("Failed to initialize key manager")
|
||||
return False
|
||||
|
||||
# Start token validator
|
||||
await self.validator.start()
|
||||
|
||||
# Start automatic key rotation
|
||||
if self.key_manager.key_rotation_interval > 0:
|
||||
self._key_rotation_task = asyncio.create_task(self._auto_key_rotation())
|
||||
|
||||
logger.info("JWTManager initialization completed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize JWTManager: {e}")
|
||||
return False
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown JWT manager"""
|
||||
try:
|
||||
# Stop key rotation task
|
||||
if self._key_rotation_task:
|
||||
self._key_rotation_task.cancel()
|
||||
try:
|
||||
await self._key_rotation_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Stop validator
|
||||
await self.validator.stop()
|
||||
|
||||
logger.info("JWTManager shutdown completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during JWTManager shutdown: {e}")
|
||||
|
||||
async def generate_tokens(self, user_info: Dict[str, Any],
|
||||
custom_claims: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""Generate access token and refresh token
|
||||
|
||||
Args:
|
||||
user_info: User information dictionary, containing user_id, roles, permissions, etc.
|
||||
custom_claims: Custom claims
|
||||
|
||||
Returns:
|
||||
Dictionary containing access_token and refresh_token
|
||||
"""
|
||||
try:
|
||||
current_time = int(time.time())
|
||||
jti = str(uuid.uuid4())
|
||||
|
||||
# Build base payload
|
||||
base_payload = {
|
||||
'iss': self.issuer,
|
||||
'aud': self.audience,
|
||||
'iat': current_time,
|
||||
'jti': jti,
|
||||
'sub': user_info.get('user_id'),
|
||||
'roles': user_info.get('roles', []),
|
||||
'permissions': user_info.get('permissions', []),
|
||||
'security_level': user_info.get('security_level', 'internal')
|
||||
}
|
||||
|
||||
# Add custom claims
|
||||
if custom_claims:
|
||||
base_payload.update(custom_claims)
|
||||
|
||||
# Generate access token
|
||||
access_payload = base_payload.copy()
|
||||
access_payload.update({
|
||||
'exp': current_time + self.access_token_expiry,
|
||||
'token_type': 'access'
|
||||
})
|
||||
|
||||
access_token = await self._sign_token(access_payload)
|
||||
|
||||
result = {
|
||||
'access_token': access_token,
|
||||
'token_type': 'Bearer',
|
||||
'expires_in': self.access_token_expiry,
|
||||
'user_id': user_info.get('user_id'),
|
||||
'issued_at': current_time
|
||||
}
|
||||
|
||||
# Generate refresh token (if enabled)
|
||||
if self.enable_refresh:
|
||||
refresh_jti = str(uuid.uuid4())
|
||||
refresh_payload = {
|
||||
'iss': self.issuer,
|
||||
'aud': self.audience,
|
||||
'iat': current_time,
|
||||
'exp': current_time + self.refresh_token_expiry,
|
||||
'jti': refresh_jti,
|
||||
'sub': user_info.get('user_id'),
|
||||
'token_type': 'refresh',
|
||||
'access_jti': jti # Associated access token ID
|
||||
}
|
||||
|
||||
refresh_token = await self._sign_token(refresh_payload)
|
||||
result.update({
|
||||
'refresh_token': refresh_token,
|
||||
'refresh_expires_in': self.refresh_token_expiry
|
||||
})
|
||||
|
||||
logger.info(f"Generated tokens for user: {user_info.get('user_id')}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate tokens: {e}")
|
||||
raise
|
||||
|
||||
async def _sign_token(self, payload: Dict[str, Any]) -> str:
|
||||
"""Sign JWT token
|
||||
|
||||
Args:
|
||||
payload: JWT payload
|
||||
|
||||
Returns:
|
||||
Signed JWT token
|
||||
"""
|
||||
try:
|
||||
signing_key = self.key_manager.get_private_key()
|
||||
|
||||
if self.algorithm == "HS256":
|
||||
# Symmetric key signing
|
||||
token = jwt.encode(payload, signing_key, algorithm=self.algorithm)
|
||||
else:
|
||||
# Asymmetric key signing
|
||||
token = jwt.encode(payload, signing_key, algorithm=self.algorithm)
|
||||
|
||||
return token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sign token: {e}")
|
||||
raise
|
||||
|
||||
async def validate_token(self, token: str, token_type: str = 'access') -> Dict[str, Any]:
|
||||
"""Validate JWT token
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
token_type: Token type ('access' or 'refresh')
|
||||
|
||||
Returns:
|
||||
Validation result and user information
|
||||
|
||||
Raises:
|
||||
ValueError: Token validation failed
|
||||
"""
|
||||
try:
|
||||
# Decode token
|
||||
verification_key = self.key_manager.get_public_key()
|
||||
|
||||
# Get security configuration
|
||||
if hasattr(self.config, 'security'):
|
||||
security_config = self.config.security
|
||||
else:
|
||||
security_config = self.config
|
||||
|
||||
# JWT decoding options
|
||||
options = {
|
||||
'verify_signature': security_config.jwt_verify_signature,
|
||||
'verify_exp': security_config.jwt_require_exp,
|
||||
'verify_iat': security_config.jwt_require_iat,
|
||||
'verify_nbf': security_config.jwt_require_nbf,
|
||||
'verify_aud': security_config.jwt_verify_audience,
|
||||
'verify_iss': security_config.jwt_verify_issuer,
|
||||
}
|
||||
|
||||
# Decode JWT
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
verification_key,
|
||||
algorithms=[self.algorithm],
|
||||
audience=self.audience if security_config.jwt_verify_audience else None,
|
||||
issuer=self.issuer if security_config.jwt_verify_issuer else None,
|
||||
leeway=security_config.jwt_leeway,
|
||||
options=options
|
||||
)
|
||||
|
||||
# Check token type
|
||||
if payload.get('token_type') != token_type:
|
||||
raise ValueError(f"Invalid token type: expected {token_type}")
|
||||
|
||||
# Use validator for additional checks
|
||||
validation_result = await self.validator.validate_claims(payload)
|
||||
|
||||
logger.info(f"Token validation successful for user: {payload.get('sub')}")
|
||||
return validation_result
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise ValueError("Token has expired")
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise ValueError(f"Invalid token: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Token validation failed: {e}")
|
||||
raise ValueError(f"Token validation failed: {str(e)}")
|
||||
|
||||
async def refresh_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||
"""Refresh access token
|
||||
|
||||
Args:
|
||||
refresh_token: Refresh token
|
||||
|
||||
Returns:
|
||||
New token pair
|
||||
"""
|
||||
if not self.enable_refresh:
|
||||
raise ValueError("Token refresh is disabled")
|
||||
|
||||
try:
|
||||
# Validate refresh token
|
||||
refresh_result = await self.validate_token(refresh_token, 'refresh')
|
||||
refresh_payload = refresh_result['payload']
|
||||
|
||||
# Revoke associated access token (if revocation is enabled)
|
||||
if self.enable_revocation:
|
||||
access_jti = refresh_payload.get('access_jti')
|
||||
if access_jti:
|
||||
# Should revoke old access token here, but since we don't know its expiration time,
|
||||
# in practice might need to store more information or use different strategy
|
||||
pass
|
||||
|
||||
# Build new user information
|
||||
user_info = {
|
||||
'user_id': refresh_payload.get('sub'),
|
||||
'roles': refresh_payload.get('roles', []),
|
||||
'permissions': refresh_payload.get('permissions', []),
|
||||
'security_level': refresh_payload.get('security_level', 'internal')
|
||||
}
|
||||
|
||||
# Generate new token pair
|
||||
new_tokens = await self.generate_tokens(user_info)
|
||||
|
||||
logger.info(f"Token refreshed for user: {user_info['user_id']}")
|
||||
return new_tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token refresh failed: {e}")
|
||||
raise
|
||||
|
||||
async def revoke_token(self, token: str) -> bool:
|
||||
"""Revoke token
|
||||
|
||||
Args:
|
||||
token: Token to revoke
|
||||
|
||||
Returns:
|
||||
Whether revocation was successful
|
||||
"""
|
||||
if not self.enable_revocation:
|
||||
logger.warning("Token revocation is disabled")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Decode token to get JTI and expiration time
|
||||
verification_key = self.key_manager.get_public_key()
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
verification_key,
|
||||
algorithms=[self.algorithm],
|
||||
options={'verify_exp': False} # Allow decoding expired tokens
|
||||
)
|
||||
|
||||
jti = payload.get('jti')
|
||||
exp = payload.get('exp')
|
||||
|
||||
if not jti or not exp:
|
||||
logger.error("Token missing required claims for revocation")
|
||||
return False
|
||||
|
||||
# Add to blacklist
|
||||
await self.validator.revoke_token(jti, exp)
|
||||
|
||||
logger.info(f"Token {jti} revoked successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token revocation failed: {e}")
|
||||
return False
|
||||
|
||||
async def decode_token_unsafe(self, token: str) -> Dict[str, Any]:
|
||||
"""Decode token without verifying signature (for debugging only)
|
||||
|
||||
Args:
|
||||
token: JWT token
|
||||
|
||||
Returns:
|
||||
Token payload
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, options={'verify_signature': False})
|
||||
return payload
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decode token: {e}")
|
||||
raise
|
||||
|
||||
async def get_token_info(self, token: str) -> Dict[str, Any]:
|
||||
"""Get token information (without verifying signature)
|
||||
|
||||
Args:
|
||||
token: JWT token
|
||||
|
||||
Returns:
|
||||
Token information
|
||||
"""
|
||||
try:
|
||||
payload = await self.decode_token_unsafe(token)
|
||||
|
||||
return {
|
||||
'jti': payload.get('jti'),
|
||||
'sub': payload.get('sub'),
|
||||
'iss': payload.get('iss'),
|
||||
'aud': payload.get('aud'),
|
||||
'iat': payload.get('iat'),
|
||||
'exp': payload.get('exp'),
|
||||
'token_type': payload.get('token_type'),
|
||||
'roles': payload.get('roles'),
|
||||
'permissions': payload.get('permissions'),
|
||||
'security_level': payload.get('security_level'),
|
||||
'is_expired': payload.get('exp', 0) < time.time() if payload.get('exp') else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get token info: {e}")
|
||||
raise
|
||||
|
||||
async def _auto_key_rotation(self):
|
||||
"""Automatic key rotation task"""
|
||||
while True:
|
||||
try:
|
||||
# Check if key rotation is needed
|
||||
if await self.key_manager.is_key_expired():
|
||||
logger.info("Key rotation needed, rotating keys...")
|
||||
await self.key_manager.rotate_keys()
|
||||
|
||||
# Wait until next check
|
||||
await asyncio.sleep(3600) # Check every hour
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in auto key rotation: {e}")
|
||||
# Wait longer before retry after error
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
async def get_public_key_info(self) -> Dict[str, Any]:
|
||||
"""Get public key information (for client verification)
|
||||
|
||||
Returns:
|
||||
Public key information
|
||||
"""
|
||||
key_info = await self.key_manager.get_key_info()
|
||||
public_key_pem = await self.key_manager.export_public_key_pem()
|
||||
|
||||
return {
|
||||
'algorithm': self.algorithm,
|
||||
'public_key_pem': public_key_pem,
|
||||
'key_info': key_info
|
||||
}
|
||||
|
||||
async def get_manager_stats(self) -> Dict[str, Any]:
|
||||
"""Get manager statistics
|
||||
|
||||
Returns:
|
||||
Statistics information
|
||||
"""
|
||||
key_info = await self.key_manager.get_key_info()
|
||||
validation_stats = await self.validator.get_validation_stats()
|
||||
|
||||
return {
|
||||
'jwt_config': {
|
||||
'algorithm': self.algorithm,
|
||||
'issuer': self.issuer,
|
||||
'audience': self.audience,
|
||||
'access_token_expiry': self.access_token_expiry,
|
||||
'refresh_token_expiry': self.refresh_token_expiry,
|
||||
'enable_refresh': self.enable_refresh,
|
||||
'enable_revocation': self.enable_revocation
|
||||
},
|
||||
'key_manager': key_info,
|
||||
'validator': validation_stats
|
||||
}
|
||||
343
doris_mcp_server/auth/key_manager.py
Normal file
343
doris_mcp_server/auth/key_manager.py
Normal file
@@ -0,0 +1,343 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
JWT Key Management Module
|
||||
Provides secure key generation, loading, rotation and management for JWT tokens
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
from datetime import datetime, timedelta
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa, ec
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class KeyManager:
|
||||
"""JWT Key Manager
|
||||
|
||||
Responsible for generating, loading, rotating and securely storing JWT signing keys
|
||||
Supports RSA and EC algorithms, provides automatic key rotation functionality
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize key manager
|
||||
|
||||
Args:
|
||||
config: DorisConfig configuration object (with security attribute)
|
||||
"""
|
||||
self.config = config
|
||||
# Access JWT settings through the security configuration
|
||||
if hasattr(config, 'security'):
|
||||
security_config = config.security
|
||||
else:
|
||||
# Fallback if config is passed directly as SecurityConfig
|
||||
security_config = config
|
||||
|
||||
self.algorithm = security_config.jwt_algorithm
|
||||
self.key_rotation_interval = security_config.key_rotation_interval
|
||||
self.private_key_path = security_config.jwt_private_key_path
|
||||
self.public_key_path = security_config.jwt_public_key_path
|
||||
self.secret_key = security_config.jwt_secret_key
|
||||
|
||||
# Key storage
|
||||
self._private_key = None
|
||||
self._public_key = None
|
||||
self._secret_key = None
|
||||
self._key_generated_at = None
|
||||
|
||||
logger.info(f"KeyManager initialized with algorithm: {self.algorithm}")
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize key manager, load or generate keys"""
|
||||
try:
|
||||
if self.algorithm == "HS256":
|
||||
await self._initialize_symmetric_key()
|
||||
else:
|
||||
await self._initialize_asymmetric_keys()
|
||||
|
||||
logger.info("KeyManager initialization completed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize KeyManager: {e}")
|
||||
return False
|
||||
|
||||
async def _initialize_symmetric_key(self):
|
||||
"""Initialize symmetric key (HS256)"""
|
||||
if self.secret_key:
|
||||
# Use configured key
|
||||
self._secret_key = self.secret_key.encode()
|
||||
logger.info("Loaded symmetric key from configuration")
|
||||
else:
|
||||
# Generate new key
|
||||
self._secret_key = await self.generate_symmetric_key()
|
||||
logger.info("Generated new symmetric key")
|
||||
|
||||
self._key_generated_at = datetime.utcnow()
|
||||
|
||||
async def _initialize_asymmetric_keys(self):
|
||||
"""Initialize asymmetric key pair (RS256/ES256)"""
|
||||
# Try to load keys from files
|
||||
if await self._load_keys_from_files():
|
||||
logger.info("Loaded asymmetric keys from files")
|
||||
return
|
||||
|
||||
# Try to load from environment variables
|
||||
if await self._load_keys_from_env():
|
||||
logger.info("Loaded asymmetric keys from environment")
|
||||
return
|
||||
|
||||
# Generate new key pair
|
||||
await self.generate_key_pair()
|
||||
logger.info("Generated new asymmetric key pair")
|
||||
|
||||
async def _load_keys_from_files(self) -> bool:
|
||||
"""Load keys from files"""
|
||||
try:
|
||||
if not self.private_key_path or not self.public_key_path:
|
||||
return False
|
||||
|
||||
private_path = Path(self.private_key_path)
|
||||
public_path = Path(self.public_key_path)
|
||||
|
||||
if not (private_path.exists() and public_path.exists()):
|
||||
return False
|
||||
|
||||
# Read private key
|
||||
with open(private_path, 'rb') as f:
|
||||
private_key_data = f.read()
|
||||
self._private_key = serialization.load_pem_private_key(
|
||||
private_key_data, password=None, backend=default_backend()
|
||||
)
|
||||
|
||||
# Read public key
|
||||
with open(public_path, 'rb') as f:
|
||||
public_key_data = f.read()
|
||||
self._public_key = serialization.load_pem_public_key(
|
||||
public_key_data, backend=default_backend()
|
||||
)
|
||||
|
||||
# Get key generation time (using file modification time)
|
||||
self._key_generated_at = datetime.fromtimestamp(private_path.stat().st_mtime)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load keys from files: {e}")
|
||||
return False
|
||||
|
||||
async def _load_keys_from_env(self) -> bool:
|
||||
"""Load keys from environment variables"""
|
||||
try:
|
||||
private_key_env = os.getenv('JWT_PRIVATE_KEY')
|
||||
public_key_env = os.getenv('JWT_PUBLIC_KEY')
|
||||
|
||||
if not (private_key_env and public_key_env):
|
||||
return False
|
||||
|
||||
# Parse private key
|
||||
self._private_key = serialization.load_pem_private_key(
|
||||
private_key_env.encode(), password=None, backend=default_backend()
|
||||
)
|
||||
|
||||
# Parse public key
|
||||
self._public_key = serialization.load_pem_public_key(
|
||||
public_key_env.encode(), backend=default_backend()
|
||||
)
|
||||
|
||||
self._key_generated_at = datetime.utcnow()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load keys from environment: {e}")
|
||||
return False
|
||||
|
||||
async def generate_symmetric_key(self, length: int = 32) -> bytes:
|
||||
"""Generate symmetric key
|
||||
|
||||
Args:
|
||||
length: Key length (bytes), default 32 bytes (256 bits)
|
||||
|
||||
Returns:
|
||||
Generated key
|
||||
"""
|
||||
return secrets.token_bytes(length)
|
||||
|
||||
async def generate_key_pair(self) -> Tuple[bytes, bytes]:
|
||||
"""Generate asymmetric key pair
|
||||
|
||||
Returns:
|
||||
(private key PEM, public key PEM) tuple
|
||||
"""
|
||||
try:
|
||||
if self.algorithm == "RS256":
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048,
|
||||
backend=default_backend()
|
||||
)
|
||||
elif self.algorithm == "ES256":
|
||||
private_key = ec.generate_private_key(
|
||||
ec.SECP256R1(), backend=default_backend()
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm for key generation: {self.algorithm}")
|
||||
|
||||
# Get public key
|
||||
public_key = private_key.public_key()
|
||||
|
||||
# Serialize private key
|
||||
private_pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption()
|
||||
)
|
||||
|
||||
# Serialize public key
|
||||
public_pem = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
)
|
||||
|
||||
# Store keys
|
||||
self._private_key = private_key
|
||||
self._public_key = public_key
|
||||
self._key_generated_at = datetime.utcnow()
|
||||
|
||||
# If file paths are configured, save to files
|
||||
if self.private_key_path and self.public_key_path:
|
||||
await self._save_keys_to_files(private_pem, public_pem)
|
||||
|
||||
logger.info(f"Generated new {self.algorithm} key pair")
|
||||
return private_pem, public_pem
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate key pair: {e}")
|
||||
raise
|
||||
|
||||
async def _save_keys_to_files(self, private_pem: bytes, public_pem: bytes):
|
||||
"""Save keys to files"""
|
||||
try:
|
||||
# Ensure directories exist
|
||||
private_path = Path(self.private_key_path)
|
||||
public_path = Path(self.public_key_path)
|
||||
|
||||
private_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
public_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save private key (set secure permissions)
|
||||
with open(private_path, 'wb') as f:
|
||||
f.write(private_pem)
|
||||
os.chmod(private_path, 0o600) # Only owner can read/write
|
||||
|
||||
# Save public key
|
||||
with open(public_path, 'wb') as f:
|
||||
f.write(public_pem)
|
||||
os.chmod(public_path, 0o644) # Owner read/write, others read only
|
||||
|
||||
logger.info(f"Saved keys to files: {private_path}, {public_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save keys to files: {e}")
|
||||
raise
|
||||
|
||||
def get_private_key(self):
|
||||
"""Get private key for signing"""
|
||||
if self.algorithm == "HS256":
|
||||
return self._secret_key
|
||||
else:
|
||||
return self._private_key
|
||||
|
||||
def get_public_key(self):
|
||||
"""Get public key for verification"""
|
||||
if self.algorithm == "HS256":
|
||||
return self._secret_key
|
||||
else:
|
||||
return self._public_key
|
||||
|
||||
def get_algorithm(self) -> str:
|
||||
"""Get signing algorithm"""
|
||||
return self.algorithm
|
||||
|
||||
async def is_key_expired(self) -> bool:
|
||||
"""Check if key is expired"""
|
||||
if not self._key_generated_at:
|
||||
return True
|
||||
|
||||
expiry_time = self._key_generated_at + timedelta(seconds=self.key_rotation_interval)
|
||||
return datetime.utcnow() > expiry_time
|
||||
|
||||
async def rotate_keys(self) -> bool:
|
||||
"""Rotate keys"""
|
||||
try:
|
||||
logger.info("Starting key rotation")
|
||||
|
||||
if self.algorithm == "HS256":
|
||||
# Generate new symmetric key
|
||||
self._secret_key = await self.generate_symmetric_key()
|
||||
self._key_generated_at = datetime.utcnow()
|
||||
else:
|
||||
# Generate new asymmetric key pair
|
||||
await self.generate_key_pair()
|
||||
|
||||
logger.info("Key rotation completed successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Key rotation failed: {e}")
|
||||
return False
|
||||
|
||||
async def get_key_info(self) -> dict:
|
||||
"""Get key information"""
|
||||
return {
|
||||
"algorithm": self.algorithm,
|
||||
"key_generated_at": self._key_generated_at.isoformat() if self._key_generated_at else None,
|
||||
"key_expires_at": (
|
||||
self._key_generated_at + timedelta(seconds=self.key_rotation_interval)
|
||||
).isoformat() if self._key_generated_at else None,
|
||||
"is_expired": await self.is_key_expired(),
|
||||
"has_private_key": self._private_key is not None or self._secret_key is not None,
|
||||
"has_public_key": self._public_key is not None or self._secret_key is not None
|
||||
}
|
||||
|
||||
async def export_public_key_pem(self) -> Optional[str]:
|
||||
"""Export public key in PEM format"""
|
||||
if self.algorithm == "HS256":
|
||||
return None # Symmetric key not exported
|
||||
|
||||
if not self._public_key:
|
||||
return None
|
||||
|
||||
try:
|
||||
public_pem = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
)
|
||||
return public_pem.decode()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to export public key: {e}")
|
||||
return None
|
||||
536
doris_mcp_server/auth/oauth_client.py
Normal file
536
doris_mcp_server/auth/oauth_client.py
Normal file
@@ -0,0 +1,536 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
OAuth 2.0/OIDC Client Manager
|
||||
Provides OAuth authentication client implementation with PKCE and OIDC support
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional, Any, Tuple
|
||||
from urllib.parse import urlencode, parse_qs, urlparse
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError:
|
||||
raise ImportError("aiohttp is required for OAuth functionality. Install with: pip install aiohttp")
|
||||
|
||||
from .oauth_types import (
|
||||
OAuthProvider, OAuthState, OAuthTokens, OAuthUserInfo,
|
||||
OIDCDiscovery, OAuthError, OAuthProviderConfig, OAUTH_PROVIDERS
|
||||
)
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OAuthStateManager:
|
||||
"""Manages OAuth state parameters for CSRF protection"""
|
||||
|
||||
def __init__(self, state_expiry: int = 600):
|
||||
"""Initialize state manager
|
||||
|
||||
Args:
|
||||
state_expiry: State expiry time in seconds
|
||||
"""
|
||||
self.state_expiry = state_expiry
|
||||
self._states: Dict[str, OAuthState] = {}
|
||||
self._cleanup_task = None
|
||||
|
||||
logger.info("OAuthStateManager initialized")
|
||||
|
||||
async def start(self):
|
||||
"""Start periodic cleanup task"""
|
||||
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||
logger.info("OAuth state manager started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop periodic cleanup task"""
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("OAuth state manager stopped")
|
||||
|
||||
def create_state(self, redirect_uri: str, pkce_enabled: bool = True,
|
||||
nonce_enabled: bool = True) -> OAuthState:
|
||||
"""Create new OAuth state
|
||||
|
||||
Args:
|
||||
redirect_uri: OAuth redirect URI
|
||||
pkce_enabled: Whether to enable PKCE
|
||||
nonce_enabled: Whether to enable nonce (for OIDC)
|
||||
|
||||
Returns:
|
||||
OAuth state object
|
||||
"""
|
||||
state = secrets.token_urlsafe(32)
|
||||
nonce = secrets.token_urlsafe(32) if nonce_enabled else None
|
||||
|
||||
pkce_verifier = None
|
||||
pkce_challenge = None
|
||||
if pkce_enabled:
|
||||
pkce_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=')
|
||||
challenge_bytes = hashlib.sha256(pkce_verifier.encode()).digest()
|
||||
pkce_challenge = base64.urlsafe_b64encode(challenge_bytes).decode('utf-8').rstrip('=')
|
||||
|
||||
oauth_state = OAuthState(
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
pkce_verifier=pkce_verifier,
|
||||
pkce_challenge=pkce_challenge,
|
||||
redirect_uri=redirect_uri,
|
||||
created_at=datetime.utcnow(),
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=self.state_expiry)
|
||||
)
|
||||
|
||||
self._states[state] = oauth_state
|
||||
logger.debug(f"Created OAuth state: {state}")
|
||||
return oauth_state
|
||||
|
||||
def get_state(self, state: str) -> Optional[OAuthState]:
|
||||
"""Get OAuth state by state parameter
|
||||
|
||||
Args:
|
||||
state: State parameter
|
||||
|
||||
Returns:
|
||||
OAuth state object or None if not found/expired
|
||||
"""
|
||||
oauth_state = self._states.get(state)
|
||||
if oauth_state and oauth_state.expires_at > datetime.utcnow():
|
||||
return oauth_state
|
||||
elif oauth_state:
|
||||
# Remove expired state
|
||||
del self._states[state]
|
||||
logger.debug(f"Removed expired OAuth state: {state}")
|
||||
return None
|
||||
|
||||
def consume_state(self, state: str) -> Optional[OAuthState]:
|
||||
"""Get and remove OAuth state
|
||||
|
||||
Args:
|
||||
state: State parameter
|
||||
|
||||
Returns:
|
||||
OAuth state object or None if not found/expired
|
||||
"""
|
||||
oauth_state = self.get_state(state)
|
||||
if oauth_state:
|
||||
del self._states[state]
|
||||
logger.debug(f"Consumed OAuth state: {state}")
|
||||
return oauth_state
|
||||
|
||||
async def _periodic_cleanup(self):
|
||||
"""Periodic cleanup of expired states"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(300) # Clean up every 5 minutes
|
||||
current_time = datetime.utcnow()
|
||||
expired_states = [
|
||||
state for state, oauth_state in self._states.items()
|
||||
if oauth_state.expires_at <= current_time
|
||||
]
|
||||
|
||||
for state in expired_states:
|
||||
del self._states[state]
|
||||
|
||||
if expired_states:
|
||||
logger.info(f"Cleaned up {len(expired_states)} expired OAuth states")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error during OAuth state cleanup: {e}")
|
||||
|
||||
|
||||
class OAuthClient:
|
||||
"""OAuth 2.0/OIDC Client implementation"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize OAuth client
|
||||
|
||||
Args:
|
||||
config: DorisConfig with OAuth configuration
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
# Access OAuth settings through security configuration
|
||||
if hasattr(config, 'security'):
|
||||
security_config = config.security
|
||||
else:
|
||||
security_config = config
|
||||
|
||||
self.enabled = security_config.oauth_enabled
|
||||
if not self.enabled:
|
||||
logger.info("OAuth client disabled by configuration")
|
||||
return
|
||||
|
||||
# Build provider configuration
|
||||
self.provider_config = self._build_provider_config(security_config)
|
||||
self.state_manager = OAuthStateManager(security_config.oauth_state_expiry)
|
||||
|
||||
# HTTP client session
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
# Discovery cache
|
||||
self._discovery_cache: Optional[OIDCDiscovery] = None
|
||||
self._discovery_cache_time: Optional[datetime] = None
|
||||
|
||||
logger.info(f"OAuthClient initialized for provider: {self.provider_config.provider.value}")
|
||||
|
||||
def _build_provider_config(self, security_config) -> OAuthProviderConfig:
|
||||
"""Build OAuth provider configuration
|
||||
|
||||
Args:
|
||||
security_config: Security configuration object
|
||||
|
||||
Returns:
|
||||
OAuth provider configuration
|
||||
"""
|
||||
try:
|
||||
provider = OAuthProvider(security_config.oauth_provider)
|
||||
except ValueError:
|
||||
provider = OAuthProvider.CUSTOM
|
||||
|
||||
# Get default configuration for known providers
|
||||
defaults = OAUTH_PROVIDERS.get(provider, {})
|
||||
|
||||
return OAuthProviderConfig(
|
||||
provider=provider,
|
||||
client_id=security_config.oauth_client_id,
|
||||
client_secret=security_config.oauth_client_secret,
|
||||
redirect_uri=security_config.oauth_redirect_uri,
|
||||
scopes=security_config.oauth_scopes or defaults.get("scopes", ["openid", "email", "profile"]),
|
||||
|
||||
# Endpoints (use configured or defaults)
|
||||
authorization_endpoint=security_config.oauth_authorization_endpoint or defaults.get("authorization_endpoint", ""),
|
||||
token_endpoint=security_config.oauth_token_endpoint or defaults.get("token_endpoint", ""),
|
||||
userinfo_endpoint=security_config.oauth_userinfo_endpoint or defaults.get("userinfo_endpoint"),
|
||||
jwks_uri=security_config.oauth_jwks_uri or defaults.get("jwks_uri"),
|
||||
|
||||
# Discovery
|
||||
discovery_url=security_config.oidc_discovery_url or defaults.get("discovery_url"),
|
||||
|
||||
# Settings
|
||||
pkce_enabled=security_config.oauth_pkce_enabled,
|
||||
nonce_enabled=security_config.oauth_nonce_enabled,
|
||||
|
||||
# User mapping
|
||||
user_id_claim=security_config.oauth_user_id_claim or defaults.get("user_id_claim", "sub"),
|
||||
email_claim=security_config.oauth_email_claim or defaults.get("email_claim", "email"),
|
||||
name_claim=security_config.oauth_name_claim or defaults.get("name_claim", "name"),
|
||||
roles_claim=security_config.oauth_roles_claim,
|
||||
default_roles=security_config.oauth_default_roles
|
||||
)
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize OAuth client
|
||||
|
||||
Returns:
|
||||
True if initialization successful
|
||||
"""
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
try:
|
||||
# Create HTTP session
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# Start state manager
|
||||
await self.state_manager.start()
|
||||
|
||||
# Perform OIDC discovery if configured
|
||||
if self.provider_config.discovery_url:
|
||||
await self._discover_oidc_endpoints()
|
||||
|
||||
logger.info("OAuth client initialization completed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize OAuth client: {e}")
|
||||
return False
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown OAuth client"""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
try:
|
||||
# Stop state manager
|
||||
await self.state_manager.stop()
|
||||
|
||||
# Close HTTP session
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
logger.info("OAuth client shutdown completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during OAuth client shutdown: {e}")
|
||||
|
||||
async def _discover_oidc_endpoints(self):
|
||||
"""Discover OIDC endpoints using discovery URL"""
|
||||
try:
|
||||
# Check cache first
|
||||
if (self._discovery_cache and self._discovery_cache_time and
|
||||
datetime.utcnow() - self._discovery_cache_time < timedelta(hours=1)):
|
||||
return self._discovery_cache
|
||||
|
||||
logger.info(f"Discovering OIDC endpoints: {self.provider_config.discovery_url}")
|
||||
|
||||
async with self._session.get(self.provider_config.discovery_url) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
|
||||
discovery = OIDCDiscovery(
|
||||
issuer=data["issuer"],
|
||||
authorization_endpoint=data["authorization_endpoint"],
|
||||
token_endpoint=data["token_endpoint"],
|
||||
userinfo_endpoint=data.get("userinfo_endpoint"),
|
||||
jwks_uri=data.get("jwks_uri"),
|
||||
scopes_supported=data.get("scopes_supported"),
|
||||
response_types_supported=data.get("response_types_supported"),
|
||||
subject_types_supported=data.get("subject_types_supported"),
|
||||
id_token_signing_alg_values_supported=data.get("id_token_signing_alg_values_supported")
|
||||
)
|
||||
|
||||
# Update provider configuration with discovered endpoints
|
||||
if not self.provider_config.authorization_endpoint:
|
||||
self.provider_config.authorization_endpoint = discovery.authorization_endpoint
|
||||
if not self.provider_config.token_endpoint:
|
||||
self.provider_config.token_endpoint = discovery.token_endpoint
|
||||
if not self.provider_config.userinfo_endpoint:
|
||||
self.provider_config.userinfo_endpoint = discovery.userinfo_endpoint
|
||||
if not self.provider_config.jwks_uri:
|
||||
self.provider_config.jwks_uri = discovery.jwks_uri
|
||||
|
||||
# Cache discovery result
|
||||
self._discovery_cache = discovery
|
||||
self._discovery_cache_time = datetime.utcnow()
|
||||
|
||||
logger.info("OIDC endpoint discovery completed successfully")
|
||||
return discovery
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OIDC endpoint discovery failed: {e}")
|
||||
raise
|
||||
|
||||
def build_authorization_url(self) -> Tuple[str, OAuthState]:
|
||||
"""Build OAuth authorization URL
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, oauth_state)
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth client is not enabled")
|
||||
|
||||
# Create state for CSRF protection
|
||||
oauth_state = self.state_manager.create_state(
|
||||
redirect_uri=self.provider_config.redirect_uri,
|
||||
pkce_enabled=self.provider_config.pkce_enabled,
|
||||
nonce_enabled=self.provider_config.nonce_enabled
|
||||
)
|
||||
|
||||
# Build authorization parameters
|
||||
params = {
|
||||
'response_type': 'code',
|
||||
'client_id': self.provider_config.client_id,
|
||||
'redirect_uri': self.provider_config.redirect_uri,
|
||||
'scope': ' '.join(self.provider_config.scopes),
|
||||
'state': oauth_state.state
|
||||
}
|
||||
|
||||
# Add PKCE challenge
|
||||
if oauth_state.pkce_challenge:
|
||||
params['code_challenge'] = oauth_state.pkce_challenge
|
||||
params['code_challenge_method'] = 'S256'
|
||||
|
||||
# Add nonce for OIDC
|
||||
if oauth_state.nonce:
|
||||
params['nonce'] = oauth_state.nonce
|
||||
|
||||
# Build URL
|
||||
authorization_url = f"{self.provider_config.authorization_endpoint}?{urlencode(params)}"
|
||||
|
||||
logger.info(f"Built OAuth authorization URL for state: {oauth_state.state}")
|
||||
return authorization_url, oauth_state
|
||||
|
||||
async def exchange_code_for_tokens(self, code: str, state: str) -> Tuple[OAuthTokens, OAuthState]:
|
||||
"""Exchange authorization code for tokens
|
||||
|
||||
Args:
|
||||
code: Authorization code
|
||||
state: State parameter
|
||||
|
||||
Returns:
|
||||
Tuple of (OAuth tokens, OAuth state)
|
||||
|
||||
Raises:
|
||||
ValueError: If state is invalid or exchange fails
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth client is not enabled")
|
||||
|
||||
# Validate and consume state
|
||||
oauth_state = self.state_manager.consume_state(state)
|
||||
if not oauth_state:
|
||||
raise ValueError("Invalid or expired state parameter")
|
||||
|
||||
try:
|
||||
# Prepare token request
|
||||
data = {
|
||||
'grant_type': 'authorization_code',
|
||||
'client_id': self.provider_config.client_id,
|
||||
'client_secret': self.provider_config.client_secret,
|
||||
'code': code,
|
||||
'redirect_uri': oauth_state.redirect_uri
|
||||
}
|
||||
|
||||
# Add PKCE verifier
|
||||
if oauth_state.pkce_verifier:
|
||||
data['code_verifier'] = oauth_state.pkce_verifier
|
||||
|
||||
# Make token request
|
||||
async with self._session.post(
|
||||
self.provider_config.token_endpoint,
|
||||
data=data,
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded'}
|
||||
) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
if response.status != 200:
|
||||
error_msg = response_data.get('error_description', response_data.get('error', 'Token exchange failed'))
|
||||
raise ValueError(f"Token exchange failed: {error_msg}")
|
||||
|
||||
tokens = OAuthTokens(
|
||||
access_token=response_data['access_token'],
|
||||
token_type=response_data.get('token_type', 'Bearer'),
|
||||
expires_in=response_data.get('expires_in'),
|
||||
refresh_token=response_data.get('refresh_token'),
|
||||
scope=response_data.get('scope'),
|
||||
id_token=response_data.get('id_token')
|
||||
)
|
||||
|
||||
logger.info("Successfully exchanged authorization code for tokens")
|
||||
return tokens, oauth_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token exchange failed: {e}")
|
||||
raise ValueError(f"Token exchange failed: {str(e)}")
|
||||
|
||||
async def get_user_info(self, tokens: OAuthTokens) -> OAuthUserInfo:
|
||||
"""Get user information from OAuth provider
|
||||
|
||||
Args:
|
||||
tokens: OAuth tokens
|
||||
|
||||
Returns:
|
||||
OAuth user information
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth client is not enabled")
|
||||
|
||||
if not self.provider_config.userinfo_endpoint:
|
||||
raise ValueError("Userinfo endpoint not configured")
|
||||
|
||||
try:
|
||||
# Make userinfo request
|
||||
headers = {'Authorization': f'{tokens.token_type} {tokens.access_token}'}
|
||||
|
||||
async with self._session.get(
|
||||
self.provider_config.userinfo_endpoint,
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
user_data = await response.json()
|
||||
|
||||
# Extract user information using configured claims
|
||||
user_info = OAuthUserInfo(
|
||||
sub=str(user_data.get(self.provider_config.user_id_claim, '')),
|
||||
email=user_data.get(self.provider_config.email_claim),
|
||||
name=user_data.get(self.provider_config.name_claim),
|
||||
given_name=user_data.get('given_name'),
|
||||
family_name=user_data.get('family_name'),
|
||||
picture=user_data.get('picture'),
|
||||
locale=user_data.get('locale'),
|
||||
email_verified=user_data.get('email_verified'),
|
||||
roles=user_data.get(self.provider_config.roles_claim, self.provider_config.default_roles.copy()),
|
||||
raw_claims=user_data
|
||||
)
|
||||
|
||||
logger.info(f"Retrieved user info for user: {user_info.sub}")
|
||||
return user_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get user info: {e}")
|
||||
raise ValueError(f"Failed to get user info: {str(e)}")
|
||||
|
||||
async def refresh_tokens(self, refresh_token: str) -> OAuthTokens:
|
||||
"""Refresh OAuth tokens
|
||||
|
||||
Args:
|
||||
refresh_token: Refresh token
|
||||
|
||||
Returns:
|
||||
New OAuth tokens
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth client is not enabled")
|
||||
|
||||
try:
|
||||
data = {
|
||||
'grant_type': 'refresh_token',
|
||||
'client_id': self.provider_config.client_id,
|
||||
'client_secret': self.provider_config.client_secret,
|
||||
'refresh_token': refresh_token
|
||||
}
|
||||
|
||||
async with self._session.post(
|
||||
self.provider_config.token_endpoint,
|
||||
data=data,
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded'}
|
||||
) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
if response.status != 200:
|
||||
error_msg = response_data.get('error_description', response_data.get('error', 'Token refresh failed'))
|
||||
raise ValueError(f"Token refresh failed: {error_msg}")
|
||||
|
||||
tokens = OAuthTokens(
|
||||
access_token=response_data['access_token'],
|
||||
token_type=response_data.get('token_type', 'Bearer'),
|
||||
expires_in=response_data.get('expires_in'),
|
||||
refresh_token=response_data.get('refresh_token', refresh_token), # Keep old if not provided
|
||||
scope=response_data.get('scope'),
|
||||
id_token=response_data.get('id_token')
|
||||
)
|
||||
|
||||
logger.info("Successfully refreshed OAuth tokens")
|
||||
return tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token refresh failed: {e}")
|
||||
raise ValueError(f"Token refresh failed: {str(e)}")
|
||||
312
doris_mcp_server/auth/oauth_handlers.py
Normal file
312
doris_mcp_server/auth/oauth_handlers.py
Normal file
@@ -0,0 +1,312 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
OAuth HTTP Handlers
|
||||
Provides HTTP endpoints for OAuth authentication flow
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
import json
|
||||
|
||||
from starlette.responses import JSONResponse, RedirectResponse, HTMLResponse
|
||||
from starlette.requests import Request
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OAuthHandlers:
|
||||
"""OAuth HTTP request handlers"""
|
||||
|
||||
def __init__(self, security_manager):
|
||||
"""Initialize OAuth handlers
|
||||
|
||||
Args:
|
||||
security_manager: DorisSecurityManager instance
|
||||
"""
|
||||
self.security_manager = security_manager
|
||||
logger.info("OAuth handlers initialized")
|
||||
|
||||
async def handle_login(self, request: Request) -> JSONResponse:
|
||||
"""Handle OAuth login initiation
|
||||
|
||||
Returns JSON with authorization URL and state
|
||||
"""
|
||||
try:
|
||||
# Check if OAuth is enabled
|
||||
oauth_info = self.security_manager.get_oauth_provider_info()
|
||||
if not oauth_info.get("enabled"):
|
||||
return JSONResponse(
|
||||
{"error": "OAuth authentication is not enabled"},
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# Get authorization URL
|
||||
authorization_url, state = self.security_manager.get_oauth_authorization_url()
|
||||
|
||||
return JSONResponse({
|
||||
"authorization_url": authorization_url,
|
||||
"state": state,
|
||||
"provider": oauth_info.get("provider"),
|
||||
"message": "Navigate to authorization_url to complete OAuth login"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth login initiation failed: {e}")
|
||||
return JSONResponse(
|
||||
{"error": f"OAuth login failed: {str(e)}"},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
async def handle_callback(self, request: Request) -> JSONResponse:
|
||||
"""Handle OAuth callback
|
||||
|
||||
Processes the OAuth callback and returns authentication result
|
||||
"""
|
||||
try:
|
||||
# Get query parameters
|
||||
query_params = dict(request.query_params)
|
||||
|
||||
# Check for error in callback
|
||||
if "error" in query_params:
|
||||
error_description = query_params.get("error_description", "Unknown error")
|
||||
logger.warning(f"OAuth callback error: {query_params['error']} - {error_description}")
|
||||
return JSONResponse(
|
||||
{
|
||||
"error": query_params["error"],
|
||||
"error_description": error_description,
|
||||
"error_uri": query_params.get("error_uri")
|
||||
},
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# Extract required parameters
|
||||
code = query_params.get("code")
|
||||
state = query_params.get("state")
|
||||
|
||||
if not code or not state:
|
||||
return JSONResponse(
|
||||
{"error": "Missing required parameters: code and state"},
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# Handle OAuth callback
|
||||
auth_context = await self.security_manager.handle_oauth_callback(code, state)
|
||||
|
||||
# Return successful authentication response
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"user_id": auth_context.user_id,
|
||||
"roles": auth_context.roles,
|
||||
"permissions": auth_context.permissions,
|
||||
"security_level": auth_context.security_level.value,
|
||||
"session_id": auth_context.session_id,
|
||||
"message": "OAuth authentication successful"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth callback handling failed: {e}")
|
||||
return JSONResponse(
|
||||
{"error": f"OAuth callback failed: {str(e)}"},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
async def handle_provider_info(self, request: Request) -> JSONResponse:
|
||||
"""Handle OAuth provider information request
|
||||
|
||||
Returns information about the configured OAuth provider
|
||||
"""
|
||||
try:
|
||||
provider_info = self.security_manager.get_oauth_provider_info()
|
||||
return JSONResponse(provider_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get OAuth provider info: {e}")
|
||||
return JSONResponse(
|
||||
{"error": f"Failed to get provider info: {str(e)}"},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
async def handle_demo_page(self, request: Request) -> HTMLResponse:
|
||||
"""Handle OAuth demo page
|
||||
|
||||
Returns a simple HTML page for testing OAuth flow
|
||||
"""
|
||||
oauth_info = self.security_manager.get_oauth_provider_info()
|
||||
if not oauth_info.get("enabled"):
|
||||
return HTMLResponse("""
|
||||
<html>
|
||||
<head><title>OAuth Demo</title></head>
|
||||
<body>
|
||||
<h1>OAuth Demo</h1>
|
||||
<p style="color: red;">OAuth authentication is not enabled.</p>
|
||||
<p>Please configure OAuth settings in your security configuration.</p>
|
||||
</body>
|
||||
</html>
|
||||
""")
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Doris MCP Server - OAuth Demo</title>
|
||||
<style>
|
||||
body {{
|
||||
font-family: Arial, sans-serif;
|
||||
max-width: 800px;
|
||||
margin: 0 auto;
|
||||
padding: 20px;
|
||||
}}
|
||||
.info {{
|
||||
background-color: #f0f8ff;
|
||||
padding: 15px;
|
||||
border-left: 4px solid #0066cc;
|
||||
margin: 20px 0;
|
||||
}}
|
||||
.error {{
|
||||
background-color: #ffe6e6;
|
||||
padding: 15px;
|
||||
border-left: 4px solid #cc0000;
|
||||
margin: 20px 0;
|
||||
}}
|
||||
.success {{
|
||||
background-color: #e6ffe6;
|
||||
padding: 15px;
|
||||
border-left: 4px solid #00cc00;
|
||||
margin: 20px 0;
|
||||
}}
|
||||
button {{
|
||||
background-color: #0066cc;
|
||||
color: white;
|
||||
padding: 10px 20px;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 16px;
|
||||
}}
|
||||
button:hover {{
|
||||
background-color: #0052a3;
|
||||
}}
|
||||
pre {{
|
||||
background-color: #f5f5f5;
|
||||
padding: 10px;
|
||||
border-radius: 4px;
|
||||
overflow-x: auto;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Doris MCP Server - OAuth Demo</h1>
|
||||
|
||||
<div class="info">
|
||||
<h3>OAuth Configuration</h3>
|
||||
<p><strong>Provider:</strong> {oauth_info.get('provider', 'N/A')}</p>
|
||||
<p><strong>Client ID:</strong> {oauth_info.get('client_id', 'N/A')}</p>
|
||||
<p><strong>Scopes:</strong> {', '.join(oauth_info.get('scopes', []))}</p>
|
||||
<p><strong>PKCE Enabled:</strong> {oauth_info.get('pkce_enabled', False)}</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<h3>OAuth Authentication Test</h3>
|
||||
<p>Click the button below to start OAuth authentication flow:</p>
|
||||
<button onclick="startOAuthFlow()">Start OAuth Login</button>
|
||||
</div>
|
||||
|
||||
<div id="result" style="margin-top: 20px;"></div>
|
||||
|
||||
<div>
|
||||
<h3>API Endpoints</h3>
|
||||
<ul>
|
||||
<li><code>GET /auth/login</code> - Initiate OAuth login</li>
|
||||
<li><code>GET /auth/callback</code> - OAuth callback handler</li>
|
||||
<li><code>GET /auth/provider</code> - Provider information</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
async function startOAuthFlow() {{
|
||||
const resultDiv = document.getElementById('result');
|
||||
resultDiv.innerHTML = '<div class="info">Initiating OAuth flow...</div>';
|
||||
|
||||
try {{
|
||||
const response = await fetch('/auth/login');
|
||||
const data = await response.json();
|
||||
|
||||
if (response.ok) {{
|
||||
resultDiv.innerHTML = `
|
||||
<div class="success">
|
||||
<h4>OAuth URL Generated Successfully</h4>
|
||||
<p><strong>State:</strong> ${{data.state}}</p>
|
||||
<p><strong>Provider:</strong> ${{data.provider}}</p>
|
||||
<p><a href="${{data.authorization_url}}" target="_blank">Click here to authenticate</a></p>
|
||||
<p><em>Note: After authentication, you will be redirected to the callback URL.</em></p>
|
||||
</div>
|
||||
`;
|
||||
|
||||
// Automatically redirect to OAuth provider
|
||||
// window.open(data.authorization_url, '_blank');
|
||||
}} else {{
|
||||
resultDiv.innerHTML = `
|
||||
<div class="error">
|
||||
<h4>Error</h4>
|
||||
<p>${{data.error}}</p>
|
||||
</div>
|
||||
`;
|
||||
}}
|
||||
}} catch (error) {{
|
||||
resultDiv.innerHTML = `
|
||||
<div class="error">
|
||||
<h4>Network Error</h4>
|
||||
<p>${{error.message}}</p>
|
||||
</div>
|
||||
`;
|
||||
}}
|
||||
}}
|
||||
|
||||
// Handle OAuth callback result if present in URL
|
||||
window.addEventListener('load', function() {{
|
||||
const urlParams = new URLSearchParams(window.location.search);
|
||||
if (urlParams.has('code') && urlParams.has('state')) {{
|
||||
const resultDiv = document.getElementById('result');
|
||||
resultDiv.innerHTML = `
|
||||
<div class="success">
|
||||
<h4>OAuth Callback Received</h4>
|
||||
<p>Code: ${{urlParams.get('code')}}</p>
|
||||
<p>State: ${{urlParams.get('state')}}</p>
|
||||
<p>The authentication was successful!</p>
|
||||
</div>
|
||||
`;
|
||||
}} else if (urlParams.has('error')) {{
|
||||
const resultDiv = document.getElementById('result');
|
||||
resultDiv.innerHTML = `
|
||||
<div class="error">
|
||||
<h4>OAuth Error</h4>
|
||||
<p>Error: ${{urlParams.get('error')}}</p>
|
||||
<p>Description: ${{urlParams.get('error_description') || 'No description'}}</p>
|
||||
</div>
|
||||
`;
|
||||
}}
|
||||
}});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return HTMLResponse(html_content)
|
||||
289
doris_mcp_server/auth/oauth_provider.py
Normal file
289
doris_mcp_server/auth/oauth_provider.py
Normal file
@@ -0,0 +1,289 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
OAuth Authentication Provider
|
||||
Integrates OAuth 2.0/OIDC authentication with the existing authentication framework
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from .oauth_client import OAuthClient
|
||||
from .oauth_types import OAuthTokens, OAuthUserInfo, OAuthState
|
||||
from ..utils.security import AuthContext, SecurityLevel
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OAuthAuthenticationProvider:
|
||||
"""OAuth authentication provider for Doris MCP Server"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize OAuth authentication provider
|
||||
|
||||
Args:
|
||||
config: DorisConfig with OAuth configuration
|
||||
"""
|
||||
self.config = config
|
||||
self.oauth_client = OAuthClient(config)
|
||||
self.enabled = self.oauth_client.enabled
|
||||
|
||||
logger.info(f"OAuthAuthenticationProvider initialized (enabled: {self.enabled})")
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize OAuth authentication provider
|
||||
|
||||
Returns:
|
||||
True if initialization successful
|
||||
"""
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
success = await self.oauth_client.initialize()
|
||||
if success:
|
||||
logger.info("OAuth authentication provider initialized successfully")
|
||||
else:
|
||||
logger.error("Failed to initialize OAuth authentication provider")
|
||||
return success
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown OAuth authentication provider"""
|
||||
if self.enabled:
|
||||
await self.oauth_client.shutdown()
|
||||
logger.info("OAuth authentication provider shutdown completed")
|
||||
|
||||
def get_authorization_url(self) -> Tuple[str, str]:
|
||||
"""Get OAuth authorization URL
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, state)
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth authentication is not enabled")
|
||||
|
||||
authorization_url, oauth_state = self.oauth_client.build_authorization_url()
|
||||
return authorization_url, oauth_state.state
|
||||
|
||||
async def handle_callback(self, code: str, state: str) -> AuthContext:
|
||||
"""Handle OAuth callback and create authentication context
|
||||
|
||||
Args:
|
||||
code: Authorization code from OAuth provider
|
||||
state: State parameter for CSRF protection
|
||||
|
||||
Returns:
|
||||
AuthContext for the authenticated user
|
||||
|
||||
Raises:
|
||||
ValueError: If authentication fails
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth authentication is not enabled")
|
||||
|
||||
try:
|
||||
# Exchange code for tokens
|
||||
tokens, oauth_state = await self.oauth_client.exchange_code_for_tokens(code, state)
|
||||
|
||||
# Get user information
|
||||
user_info = await self.oauth_client.get_user_info(tokens)
|
||||
|
||||
# Create authentication context
|
||||
auth_context = await self._create_auth_context(user_info, tokens)
|
||||
|
||||
logger.info(f"OAuth authentication successful for user: {auth_context.user_id}")
|
||||
return auth_context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth callback handling failed: {e}")
|
||||
raise ValueError(f"OAuth authentication failed: {str(e)}")
|
||||
|
||||
async def authenticate_with_token(self, access_token: str) -> AuthContext:
|
||||
"""Authenticate using OAuth access token
|
||||
|
||||
Args:
|
||||
access_token: OAuth access token
|
||||
|
||||
Returns:
|
||||
AuthContext for the authenticated user
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth authentication is not enabled")
|
||||
|
||||
try:
|
||||
# Create token object
|
||||
tokens = OAuthTokens(access_token=access_token)
|
||||
|
||||
# Get user information
|
||||
user_info = await self.oauth_client.get_user_info(tokens)
|
||||
|
||||
# Create authentication context
|
||||
auth_context = await self._create_auth_context(user_info, tokens)
|
||||
|
||||
logger.info(f"OAuth token authentication successful for user: {auth_context.user_id}")
|
||||
return auth_context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth token authentication failed: {e}")
|
||||
raise ValueError(f"OAuth token authentication failed: {str(e)}")
|
||||
|
||||
async def refresh_authentication(self, refresh_token: str) -> Tuple[AuthContext, str]:
|
||||
"""Refresh OAuth authentication
|
||||
|
||||
Args:
|
||||
refresh_token: OAuth refresh token
|
||||
|
||||
Returns:
|
||||
Tuple of (AuthContext, new_access_token)
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth authentication is not enabled")
|
||||
|
||||
try:
|
||||
# Refresh tokens
|
||||
tokens = await self.oauth_client.refresh_tokens(refresh_token)
|
||||
|
||||
# Get updated user information
|
||||
user_info = await self.oauth_client.get_user_info(tokens)
|
||||
|
||||
# Create authentication context
|
||||
auth_context = await self._create_auth_context(user_info, tokens)
|
||||
|
||||
logger.info(f"OAuth refresh successful for user: {auth_context.user_id}")
|
||||
return auth_context, tokens.access_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth refresh failed: {e}")
|
||||
raise ValueError(f"OAuth refresh failed: {str(e)}")
|
||||
|
||||
async def _create_auth_context(self, user_info: OAuthUserInfo, tokens: OAuthTokens) -> AuthContext:
|
||||
"""Create authentication context from OAuth user info
|
||||
|
||||
Args:
|
||||
user_info: OAuth user information
|
||||
tokens: OAuth tokens
|
||||
|
||||
Returns:
|
||||
AuthContext for the user
|
||||
"""
|
||||
# Determine security level based on roles or email domain
|
||||
security_level = await self._determine_security_level(user_info)
|
||||
|
||||
# Map OAuth roles to application permissions
|
||||
permissions = await self._map_permissions(user_info.roles)
|
||||
|
||||
# Generate session ID
|
||||
session_id = f"oauth_{user_info.sub}_{datetime.utcnow().timestamp()}"
|
||||
|
||||
return AuthContext(
|
||||
token_id=f"oauth_{user_info.sub}",
|
||||
user_id=user_info.sub,
|
||||
roles=user_info.roles,
|
||||
permissions=permissions,
|
||||
security_level=security_level,
|
||||
session_id=session_id,
|
||||
login_time=datetime.utcnow(),
|
||||
last_activity=datetime.utcnow(),
|
||||
token="" # OAuth doesn't have raw token, use empty string
|
||||
)
|
||||
|
||||
async def _determine_security_level(self, user_info: OAuthUserInfo) -> SecurityLevel:
|
||||
"""Determine security level for OAuth user
|
||||
|
||||
Args:
|
||||
user_info: OAuth user information
|
||||
|
||||
Returns:
|
||||
SecurityLevel for the user
|
||||
"""
|
||||
# Check if user has admin roles
|
||||
admin_roles = {"admin", "administrator", "data_admin", "super_admin"}
|
||||
if any(role.lower() in admin_roles for role in user_info.roles):
|
||||
return SecurityLevel.SECRET
|
||||
|
||||
# Check email domain for internal users
|
||||
if user_info.email:
|
||||
# You can configure trusted domains for internal access
|
||||
trusted_domains = ["yourcompany.com", "internal.org"] # Configure as needed
|
||||
email_domain = user_info.email.split("@")[-1].lower()
|
||||
if email_domain in trusted_domains:
|
||||
return SecurityLevel.CONFIDENTIAL
|
||||
|
||||
# Check for special roles
|
||||
elevated_roles = {"data_analyst", "developer", "manager"}
|
||||
if any(role.lower() in elevated_roles for role in user_info.roles):
|
||||
return SecurityLevel.CONFIDENTIAL
|
||||
|
||||
# Default to internal level for OAuth users
|
||||
return SecurityLevel.INTERNAL
|
||||
|
||||
async def _map_permissions(self, roles: list[str]) -> list[str]:
|
||||
"""Map OAuth roles to application permissions
|
||||
|
||||
Args:
|
||||
roles: OAuth user roles
|
||||
|
||||
Returns:
|
||||
List of application permissions
|
||||
"""
|
||||
permissions = set()
|
||||
|
||||
# Role to permission mapping
|
||||
role_permissions = {
|
||||
"admin": ["admin", "read_data", "write_data", "manage_users"],
|
||||
"administrator": ["admin", "read_data", "write_data", "manage_users"],
|
||||
"data_admin": ["admin", "read_data", "write_data"],
|
||||
"super_admin": ["admin", "read_data", "write_data", "manage_users", "system_admin"],
|
||||
"data_analyst": ["read_data", "query_database"],
|
||||
"developer": ["read_data", "query_database", "debug"],
|
||||
"viewer": ["read_data"],
|
||||
"user": ["read_data"],
|
||||
"oauth_user": ["read_data"] # Default OAuth user permission
|
||||
}
|
||||
|
||||
# Map roles to permissions
|
||||
for role in roles:
|
||||
role_lower = role.lower()
|
||||
if role_lower in role_permissions:
|
||||
permissions.update(role_permissions[role_lower])
|
||||
|
||||
# Ensure OAuth users have at least basic permissions
|
||||
if not permissions:
|
||||
permissions.add("read_data")
|
||||
|
||||
return list(permissions)
|
||||
|
||||
def get_provider_info(self) -> Dict[str, Any]:
|
||||
"""Get OAuth provider information
|
||||
|
||||
Returns:
|
||||
Provider information dictionary
|
||||
"""
|
||||
if not self.enabled:
|
||||
return {"enabled": False}
|
||||
|
||||
config = self.oauth_client.provider_config
|
||||
return {
|
||||
"enabled": True,
|
||||
"provider": config.provider.value,
|
||||
"client_id": config.client_id,
|
||||
"scopes": config.scopes,
|
||||
"redirect_uri": config.redirect_uri,
|
||||
"pkce_enabled": config.pkce_enabled,
|
||||
"nonce_enabled": config.nonce_enabled
|
||||
}
|
||||
196
doris_mcp_server/auth/oauth_types.py
Normal file
196
doris_mcp_server/auth/oauth_types.py
Normal file
@@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
OAuth 2.0/OIDC Type Definitions
|
||||
Provides data types and models for OAuth authentication flow
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
|
||||
class OAuthProvider(Enum):
|
||||
"""OAuth provider enumeration"""
|
||||
GOOGLE = "google"
|
||||
MICROSOFT = "microsoft"
|
||||
GITHUB = "github"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class OAuthGrantType(Enum):
|
||||
"""OAuth grant type enumeration"""
|
||||
AUTHORIZATION_CODE = "authorization_code"
|
||||
REFRESH_TOKEN = "refresh_token"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthState:
|
||||
"""OAuth state parameter for CSRF protection"""
|
||||
state: str
|
||||
nonce: Optional[str] = None
|
||||
pkce_verifier: Optional[str] = None
|
||||
pkce_challenge: Optional[str] = None
|
||||
redirect_uri: str = ""
|
||||
created_at: datetime = None
|
||||
expires_at: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthTokens:
|
||||
"""OAuth token response"""
|
||||
access_token: str
|
||||
token_type: str = "Bearer"
|
||||
expires_in: Optional[int] = None
|
||||
refresh_token: Optional[str] = None
|
||||
scope: Optional[str] = None
|
||||
id_token: Optional[str] = None # OIDC ID token
|
||||
created_at: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthUserInfo:
|
||||
"""OAuth/OIDC user information"""
|
||||
sub: str # Subject identifier
|
||||
email: Optional[str] = None
|
||||
email_verified: Optional[bool] = None
|
||||
name: Optional[str] = None
|
||||
given_name: Optional[str] = None
|
||||
family_name: Optional[str] = None
|
||||
picture: Optional[str] = None
|
||||
locale: Optional[str] = None
|
||||
roles: List[str] = None
|
||||
raw_claims: Dict[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.roles is None:
|
||||
self.roles = []
|
||||
if self.raw_claims is None:
|
||||
self.raw_claims = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCDiscovery:
|
||||
"""OIDC Discovery document"""
|
||||
issuer: str
|
||||
authorization_endpoint: str
|
||||
token_endpoint: str
|
||||
userinfo_endpoint: Optional[str] = None
|
||||
jwks_uri: Optional[str] = None
|
||||
scopes_supported: List[str] = None
|
||||
response_types_supported: List[str] = None
|
||||
subject_types_supported: List[str] = None
|
||||
id_token_signing_alg_values_supported: List[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.scopes_supported is None:
|
||||
self.scopes_supported = ["openid"]
|
||||
if self.response_types_supported is None:
|
||||
self.response_types_supported = ["code"]
|
||||
if self.subject_types_supported is None:
|
||||
self.subject_types_supported = ["public"]
|
||||
if self.id_token_signing_alg_values_supported is None:
|
||||
self.id_token_signing_alg_values_supported = ["RS256"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthError:
|
||||
"""OAuth error response"""
|
||||
error: str
|
||||
error_description: Optional[str] = None
|
||||
error_uri: Optional[str] = None
|
||||
state: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthProviderConfig:
|
||||
"""OAuth provider configuration"""
|
||||
provider: OAuthProvider
|
||||
client_id: str
|
||||
client_secret: str
|
||||
redirect_uri: str
|
||||
scopes: List[str]
|
||||
|
||||
# Endpoints
|
||||
authorization_endpoint: str
|
||||
token_endpoint: str
|
||||
userinfo_endpoint: Optional[str] = None
|
||||
jwks_uri: Optional[str] = None
|
||||
|
||||
# Discovery
|
||||
discovery_url: Optional[str] = None
|
||||
|
||||
# Settings
|
||||
pkce_enabled: bool = True
|
||||
nonce_enabled: bool = True
|
||||
|
||||
# User mapping
|
||||
user_id_claim: str = "sub"
|
||||
email_claim: str = "email"
|
||||
name_claim: str = "name"
|
||||
roles_claim: str = "roles"
|
||||
default_roles: List[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.default_roles is None:
|
||||
self.default_roles = ["oauth_user"]
|
||||
|
||||
|
||||
# Pre-defined provider configurations
|
||||
OAUTH_PROVIDERS = {
|
||||
OAuthProvider.GOOGLE: {
|
||||
"authorization_endpoint": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_endpoint": "https://oauth2.googleapis.com/token",
|
||||
"userinfo_endpoint": "https://openidconnect.googleapis.com/v1/userinfo",
|
||||
"jwks_uri": "https://www.googleapis.com/oauth2/v3/certs",
|
||||
"discovery_url": "https://accounts.google.com/.well-known/openid_configuration",
|
||||
"scopes": ["openid", "email", "profile"],
|
||||
"user_id_claim": "sub",
|
||||
"email_claim": "email",
|
||||
"name_claim": "name"
|
||||
},
|
||||
OAuthProvider.MICROSOFT: {
|
||||
"authorization_endpoint": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
|
||||
"token_endpoint": "https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||
"userinfo_endpoint": "https://graph.microsoft.com/v1.0/me",
|
||||
"jwks_uri": "https://login.microsoftonline.com/common/discovery/v2.0/keys",
|
||||
"discovery_url": "https://login.microsoftonline.com/common/v2.0/.well-known/openid_configuration",
|
||||
"scopes": ["openid", "profile", "email", "User.Read"],
|
||||
"user_id_claim": "sub",
|
||||
"email_claim": "email",
|
||||
"name_claim": "name"
|
||||
},
|
||||
OAuthProvider.GITHUB: {
|
||||
"authorization_endpoint": "https://github.com/login/oauth/authorize",
|
||||
"token_endpoint": "https://github.com/login/oauth/access_token",
|
||||
"userinfo_endpoint": "https://api.github.com/user",
|
||||
"scopes": ["user:email", "read:user"],
|
||||
"user_id_claim": "id",
|
||||
"email_claim": "email",
|
||||
"name_claim": "name"
|
||||
}
|
||||
}
|
||||
677
doris_mcp_server/auth/token_handlers.py
Normal file
677
doris_mcp_server/auth/token_handlers.py
Normal file
@@ -0,0 +1,677 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Token Authentication HTTP Handlers
|
||||
|
||||
Provides HTTP endpoints for token management including creation, revocation,
|
||||
listing, and statistics. Used for administrative token management in HTTP mode.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, HTMLResponse
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.security import SecurityLevel
|
||||
from ..utils.config import DatabaseConfig
|
||||
from .token_security_middleware import TokenSecurityMiddleware
|
||||
|
||||
|
||||
class TokenHandlers:
|
||||
"""Token Authentication HTTP Handlers"""
|
||||
|
||||
def __init__(self, security_manager, config=None):
|
||||
self.security_manager = security_manager
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Initialize security middleware if config is provided
|
||||
if config:
|
||||
self.security_middleware = TokenSecurityMiddleware(config)
|
||||
else:
|
||||
self.security_middleware = None
|
||||
self.logger.warning("Token handlers initialized without security middleware - access control disabled")
|
||||
|
||||
async def handle_create_token(self, request: Request) -> JSONResponse:
|
||||
"""Handle token creation request"""
|
||||
# Apply security checks
|
||||
if self.security_middleware:
|
||||
security_response = await self.security_middleware.check_token_management_access(request)
|
||||
if security_response:
|
||||
return security_response
|
||||
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
return JSONResponse({
|
||||
"error": "Token authentication is not enabled"
|
||||
}, status_code=503)
|
||||
|
||||
# Parse request data
|
||||
if request.method == "GET":
|
||||
# GET request with query parameters
|
||||
query_params = dict(request.query_params)
|
||||
token_id = query_params.get("token_id")
|
||||
expires_hours_str = query_params.get("expires_hours")
|
||||
description = query_params.get("description", "")
|
||||
custom_token = query_params.get("custom_token")
|
||||
# Database configuration from query params
|
||||
db_config = None
|
||||
if query_params.get("db_host"):
|
||||
db_config = DatabaseConfig(
|
||||
host=query_params.get("db_host", "localhost"),
|
||||
port=int(query_params.get("db_port", "9030")),
|
||||
user=query_params.get("db_user", "root"),
|
||||
password=query_params.get("db_password", ""),
|
||||
database=query_params.get("db_database", "information_schema"),
|
||||
fe_http_port=int(query_params.get("db_fe_http_port", "8030"))
|
||||
)
|
||||
else:
|
||||
# POST request with JSON body
|
||||
try:
|
||||
body = await request.json()
|
||||
except:
|
||||
return JSONResponse({
|
||||
"error": "Invalid JSON body"
|
||||
}, status_code=400)
|
||||
|
||||
token_id = body.get("token_id")
|
||||
expires_hours_str = body.get("expires_hours")
|
||||
description = body.get("description", "")
|
||||
custom_token = body.get("custom_token")
|
||||
# Database configuration from JSON body
|
||||
db_config = None
|
||||
if body.get("database_config"):
|
||||
db_data = body["database_config"]
|
||||
try:
|
||||
db_config = DatabaseConfig(
|
||||
host=db_data.get("host", "localhost"),
|
||||
port=int(db_data.get("port", 9030)),
|
||||
user=db_data.get("user", "root"),
|
||||
password=db_data.get("password", ""),
|
||||
database=db_data.get("database", "information_schema"),
|
||||
fe_http_port=int(db_data.get("fe_http_port", 8030))
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
return JSONResponse({
|
||||
"error": f"Invalid database configuration: {str(e)}"
|
||||
}, status_code=400)
|
||||
|
||||
# Validate required fields
|
||||
if not token_id:
|
||||
return JSONResponse({
|
||||
"error": "token_id is required"
|
||||
}, status_code=400)
|
||||
|
||||
# Parse expires_hours
|
||||
expires_hours = None
|
||||
if expires_hours_str:
|
||||
try:
|
||||
expires_hours = int(expires_hours_str)
|
||||
except ValueError:
|
||||
return JSONResponse({
|
||||
"error": "expires_hours must be an integer"
|
||||
}, status_code=400)
|
||||
|
||||
# Create token using the actual API
|
||||
try:
|
||||
token = await self.security_manager.create_token(
|
||||
token_id=token_id,
|
||||
expires_hours=expires_hours,
|
||||
description=description,
|
||||
custom_token=custom_token,
|
||||
database_config=db_config
|
||||
)
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"token_id": token_id,
|
||||
"token": token,
|
||||
"expires_hours": expires_hours,
|
||||
"description": description,
|
||||
"message": "Token created successfully"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Token creation failed: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Token creation failed: {str(e)}"
|
||||
}, status_code=400)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_create_token: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
async def handle_revoke_token(self, request: Request) -> JSONResponse:
|
||||
"""Handle token revocation request"""
|
||||
# Apply security checks
|
||||
if self.security_middleware:
|
||||
security_response = await self.security_middleware.check_token_management_access(request)
|
||||
if security_response:
|
||||
return security_response
|
||||
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
return JSONResponse({
|
||||
"error": "Token authentication is not enabled"
|
||||
}, status_code=503)
|
||||
|
||||
# Get token_id from query parameters or path
|
||||
token_id = request.query_params.get("token_id")
|
||||
if not token_id and request.method == "DELETE":
|
||||
# Try to get from path: /token/revoke/{token_id}
|
||||
path_parts = str(request.url.path).split("/")
|
||||
if len(path_parts) >= 4:
|
||||
token_id = path_parts[-1]
|
||||
|
||||
if not token_id:
|
||||
return JSONResponse({
|
||||
"error": "token_id is required"
|
||||
}, status_code=400)
|
||||
|
||||
# Revoke token
|
||||
success = await self.security_manager.revoke_token(token_id)
|
||||
|
||||
if success:
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"token_id": token_id,
|
||||
"message": "Token revoked successfully"
|
||||
})
|
||||
else:
|
||||
return JSONResponse({
|
||||
"success": False,
|
||||
"token_id": token_id,
|
||||
"message": "Token not found or already revoked"
|
||||
}, status_code=404)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_revoke_token: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
async def handle_list_tokens(self, request: Request) -> JSONResponse:
|
||||
"""Handle token listing request"""
|
||||
# Apply security checks
|
||||
if self.security_middleware:
|
||||
security_response = await self.security_middleware.check_token_management_access(request)
|
||||
if security_response:
|
||||
return security_response
|
||||
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
return JSONResponse({
|
||||
"error": "Token authentication is not enabled"
|
||||
}, status_code=503)
|
||||
|
||||
# Get tokens list
|
||||
tokens = await self.security_manager.list_tokens()
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"count": len(tokens),
|
||||
"tokens": tokens
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_list_tokens: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
async def handle_token_stats(self, request: Request) -> JSONResponse:
|
||||
"""Handle token statistics request"""
|
||||
# Apply security checks
|
||||
if self.security_middleware:
|
||||
security_response = await self.security_middleware.check_token_management_access(request)
|
||||
if security_response:
|
||||
return security_response
|
||||
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
return JSONResponse({
|
||||
"error": "Token authentication is not enabled"
|
||||
}, status_code=503)
|
||||
|
||||
# Get token statistics
|
||||
stats = self.security_manager.get_token_stats()
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"stats": stats
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_token_stats: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
async def handle_cleanup_tokens(self, request: Request) -> JSONResponse:
|
||||
"""Handle expired tokens cleanup request"""
|
||||
# Apply security checks
|
||||
if self.security_middleware:
|
||||
security_response = await self.security_middleware.check_token_management_access(request)
|
||||
if security_response:
|
||||
return security_response
|
||||
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
return JSONResponse({
|
||||
"error": "Token authentication is not enabled"
|
||||
}, status_code=503)
|
||||
|
||||
# Cleanup expired tokens
|
||||
cleaned_count = await self.security_manager.cleanup_expired_tokens()
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"cleaned_count": cleaned_count,
|
||||
"message": f"Cleaned up {cleaned_count} expired tokens"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_cleanup_tokens: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
async def handle_management_page(self, request: Request) -> HTMLResponse:
|
||||
"""Handle token management demo page"""
|
||||
# Apply security checks
|
||||
if self.security_middleware:
|
||||
security_response = await self.security_middleware.check_token_management_access(request)
|
||||
if security_response:
|
||||
# Convert JSON response to HTML for demo page
|
||||
error_data = security_response.body.decode('utf-8') if hasattr(security_response, 'body') else '{"error": "Access denied"}'
|
||||
try:
|
||||
error_info = json.loads(error_data)
|
||||
except:
|
||||
error_info = {"error": "Access denied"}
|
||||
|
||||
error_html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Access Denied - Token Management</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; margin: 50px; background: #f5f5f5; }}
|
||||
.container {{ max-width: 600px; margin: 0 auto; background: white; padding: 30px; border-radius: 8px; }}
|
||||
.error {{ color: #dc3545; background: #f8d7da; border: 1px solid #f5c6cb; padding: 15px; border-radius: 5px; }}
|
||||
.security-info {{ background: #d1ecf1; border: 1px solid #bee5eb; padding: 15px; border-radius: 5px; margin-top: 20px; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔐 Token Management - Access Denied</h1>
|
||||
<div class="error">
|
||||
<h3>Access Denied</h3>
|
||||
<p><strong>Error:</strong> {error_info.get('error', 'Access denied')}</p>
|
||||
<p><strong>Message:</strong> {error_info.get('message', 'Token management access is restricted')}</p>
|
||||
{'<p><strong>Your IP:</strong> ' + str(error_info.get('client_ip', 'Unknown')) + '</p>' if 'client_ip' in error_info else ''}
|
||||
</div>
|
||||
|
||||
<div class="security-info">
|
||||
<h3>🛡️ Security Information</h3>
|
||||
<p>Token management endpoints are protected by the following security measures:</p>
|
||||
<ul>
|
||||
<li><strong>IP Restrictions:</strong> Only localhost/127.0.0.1 access allowed</li>
|
||||
<li><strong>Admin Authentication:</strong> Valid admin token required</li>
|
||||
<li><strong>Configuration Control:</strong> Must be explicitly enabled</li>
|
||||
</ul>
|
||||
<p>If you need access, please:</p>
|
||||
<ol>
|
||||
<li>Access from the server host (127.0.0.1)</li>
|
||||
<li>Ensure HTTP token management is enabled in configuration</li>
|
||||
<li>Provide valid admin authentication</li>
|
||||
</ol>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return HTMLResponse(error_html, status_code=security_response.status_code)
|
||||
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
html_content = """
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Token Management - Not Available</title>
|
||||
<style>
|
||||
body { font-family: Arial, sans-serif; margin: 50px; }
|
||||
.error { color: red; font-size: 18px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Token Management</h1>
|
||||
<div class="error">Token authentication is not enabled on this server.</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return HTMLResponse(html_content)
|
||||
|
||||
# Get current stats for demo
|
||||
stats = self.security_manager.get_token_stats()
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Doris MCP Server - Token Management</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; margin: 50px; background: #f5f5f5; }}
|
||||
.container {{ max-width: 1200px; margin: 0 auto; background: white; padding: 30px; border-radius: 8px; }}
|
||||
h1 {{ color: #333; }}
|
||||
.section {{ margin: 30px 0; padding: 20px; border: 1px solid #ddd; border-radius: 5px; }}
|
||||
.stats {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 15px; }}
|
||||
.stat-item {{ padding: 15px; background: #f8f9fa; border-radius: 5px; text-align: center; }}
|
||||
.stat-value {{ font-size: 24px; font-weight: bold; color: #007bff; }}
|
||||
.form-group {{ margin: 15px 0; }}
|
||||
.form-group label {{ display: block; margin-bottom: 5px; font-weight: bold; }}
|
||||
.form-group input, .form-group textarea {{ width: 100%; padding: 8px; border: 1px solid #ddd; border-radius: 4px; }}
|
||||
button {{ padding: 10px 20px; margin: 5px; border: none; border-radius: 4px; cursor: pointer; }}
|
||||
.btn-primary {{ background: #007bff; color: white; }}
|
||||
.btn-danger {{ background: #dc3545; color: white; }}
|
||||
.btn-success {{ background: #28a745; color: white; }}
|
||||
.response {{ margin: 15px 0; padding: 15px; border-radius: 5px; }}
|
||||
.response.success {{ background: #d4edda; border: 1px solid #c3e6cb; }}
|
||||
.response.error {{ background: #f8d7da; border: 1px solid #f5c6cb; }}
|
||||
.token-list {{ margin: 15px 0; }}
|
||||
.token-item {{ padding: 10px; margin: 5px 0; background: #f8f9fa; border-radius: 4px; }}
|
||||
pre {{ background: #f8f9fa; padding: 10px; border-radius: 4px; overflow-x: auto; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔐 Doris MCP Server - Token Management</h1>
|
||||
|
||||
<div class="section">
|
||||
<h2>📊 Token Statistics</h2>
|
||||
<div class="stats">
|
||||
<div class="stat-item">
|
||||
<div class="stat-value">{stats.get('total_tokens', 0)}</div>
|
||||
<div>Total Tokens</div>
|
||||
</div>
|
||||
<div class="stat-item">
|
||||
<div class="stat-value">{stats.get('active_tokens', 0)}</div>
|
||||
<div>Active Tokens</div>
|
||||
</div>
|
||||
<div class="stat-item">
|
||||
<div class="stat-value">{stats.get('expired_tokens', 0)}</div>
|
||||
<div>Expired Tokens</div>
|
||||
</div>
|
||||
</div>
|
||||
<p><strong>Token Expiry:</strong> {'Enabled' if stats.get('expiry_enabled') else 'Disabled'}</p>
|
||||
<p><strong>Default Expiry:</strong> {stats.get('default_expiry_hours', 0)} hours</p>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>➕ Create New Token</h2>
|
||||
<form id="createTokenForm">
|
||||
<div class="form-group">
|
||||
<label for="token_id">Token ID (required):</label>
|
||||
<input type="text" id="token_id" name="token_id" placeholder="e.g., my-app-token" required>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="expires_hours">Expires Hours (optional):</label>
|
||||
<input type="number" id="expires_hours" name="expires_hours" placeholder="e.g., 720 (30 days), leave empty for default">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="description">Description (optional):</label>
|
||||
<textarea id="description" name="description" placeholder="Token description"></textarea>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="custom_token">Custom Token (optional):</label>
|
||||
<input type="text" id="custom_token" name="custom_token" placeholder="Leave empty to auto-generate">
|
||||
<small style="color: #666; display: block; margin-top: 5px;">If not provided, a secure token will be generated automatically</small>
|
||||
</div>
|
||||
|
||||
<div class="section" style="margin: 20px 0; padding: 15px; background: #f8f9fa; border-radius: 5px;">
|
||||
<h3>🗄️ Database Configuration (Optional)</h3>
|
||||
<p style="color: #666; font-size: 14px; margin-bottom: 15px;">Configure database connection for this token. Leave empty to use system defaults.</p>
|
||||
|
||||
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 10px;">
|
||||
<div class="form-group">
|
||||
<label for="db_host">Host:</label>
|
||||
<input type="text" id="db_host" name="db_host" placeholder="localhost">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="db_port">Port:</label>
|
||||
<input type="number" id="db_port" name="db_port" placeholder="9030">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="db_user">User:</label>
|
||||
<input type="text" id="db_user" name="db_user" placeholder="root">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="db_password">Password:</label>
|
||||
<input type="password" id="db_password" name="db_password" placeholder="(optional)">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="db_database">Database:</label>
|
||||
<input type="text" id="db_database" name="db_database" placeholder="information_schema">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="db_fe_http_port">FE HTTP Port:</label>
|
||||
<input type="number" id="db_fe_http_port" name="db_fe_http_port" placeholder="8030">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="btn-primary">Create Token</button>
|
||||
</form>
|
||||
<div id="createTokenResponse"></div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>📋 Token Management</h2>
|
||||
<button id="listTokensBtn" class="btn-success">Refresh Token List</button>
|
||||
<button id="cleanupTokensBtn" class="btn-primary">Cleanup Expired Tokens</button>
|
||||
<div id="tokenListResponse"></div>
|
||||
|
||||
<h3>Revoke Token</h3>
|
||||
<div class="form-group">
|
||||
<input type="text" id="revokeTokenId" placeholder="Enter token ID to revoke">
|
||||
<button id="revokeTokenBtn" class="btn-danger">Revoke Token</button>
|
||||
</div>
|
||||
<div id="revokeTokenResponse"></div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>🔧 API Endpoints</h2>
|
||||
<p>Use these endpoints for programmatic token management:</p>
|
||||
<ul>
|
||||
<li><strong>POST /token/create</strong> - Create new token</li>
|
||||
<li><strong>DELETE /token/revoke?token_id=...</strong> - Revoke token</li>
|
||||
<li><strong>GET /token/list</strong> - List all tokens</li>
|
||||
<li><strong>GET /token/stats</strong> - Get token statistics</li>
|
||||
<li><strong>POST /token/cleanup</strong> - Cleanup expired tokens</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// Get admin token from URL parameters
|
||||
const urlParams = new URLSearchParams(window.location.search);
|
||||
const adminToken = urlParams.get('admin_token');
|
||||
|
||||
// Create request headers with admin token
|
||||
function getAuthHeaders() {{
|
||||
if (adminToken) {{
|
||||
return {{
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${{adminToken}}`
|
||||
}};
|
||||
}} else {{
|
||||
return {{'Content-Type': 'application/json'}};
|
||||
}}
|
||||
}}
|
||||
|
||||
// Create URL with admin token parameter
|
||||
function getAuthURL(baseUrl) {{
|
||||
if (adminToken) {{
|
||||
const separator = baseUrl.includes('?') ? '&' : '?';
|
||||
return `${{baseUrl}}${{separator}}admin_token=${{encodeURIComponent(adminToken)}}`;
|
||||
}}
|
||||
return baseUrl;
|
||||
}}
|
||||
|
||||
function showResponse(elementId, data, isSuccess = true) {{
|
||||
const element = document.getElementById(elementId);
|
||||
element.innerHTML = '<pre>' + JSON.stringify(data, null, 2) + '</pre>';
|
||||
element.className = 'response ' + (isSuccess ? 'success' : 'error');
|
||||
}}
|
||||
|
||||
// Create token form - updated to match actual API
|
||||
document.getElementById('createTokenForm').addEventListener('submit', async (e) => {{
|
||||
e.preventDefault();
|
||||
const formData = new FormData(e.target);
|
||||
const data = Object.fromEntries(formData.entries());
|
||||
|
||||
// Remove empty fields for optional parameters
|
||||
if (!data.expires_hours) delete data.expires_hours;
|
||||
if (!data.description) delete data.description;
|
||||
if (!data.custom_token) delete data.custom_token;
|
||||
|
||||
// Handle database configuration
|
||||
if (data.db_host) {{
|
||||
data.database_config = {{
|
||||
host: data.db_host,
|
||||
port: data.db_port ? parseInt(data.db_port) : 9030,
|
||||
user: data.db_user || 'root',
|
||||
password: data.db_password || '',
|
||||
database: data.db_database || 'information_schema',
|
||||
fe_http_port: data.db_fe_http_port ? parseInt(data.db_fe_http_port) : 8030
|
||||
}};
|
||||
}}
|
||||
|
||||
// Remove individual database fields from data
|
||||
delete data.db_host;
|
||||
delete data.db_port;
|
||||
delete data.db_user;
|
||||
delete data.db_password;
|
||||
delete data.db_database;
|
||||
delete data.db_fe_http_port;
|
||||
|
||||
try {{
|
||||
const response = await fetch(getAuthURL('/token/create'), {{
|
||||
method: 'POST',
|
||||
headers: getAuthHeaders(),
|
||||
body: JSON.stringify(data)
|
||||
}});
|
||||
const result = await response.json();
|
||||
showResponse('createTokenResponse', result, response.ok);
|
||||
|
||||
// Refresh token list if creation was successful
|
||||
if (response.ok) {{
|
||||
document.getElementById('listTokensBtn').click();
|
||||
}}
|
||||
}} catch (error) {{
|
||||
showResponse('createTokenResponse', {{error: error.message}}, false);
|
||||
}}
|
||||
}});
|
||||
|
||||
// List tokens
|
||||
document.getElementById('listTokensBtn').addEventListener('click', async () => {{
|
||||
try {{
|
||||
const response = await fetch(getAuthURL('/token/list'), {{
|
||||
headers: getAuthHeaders()
|
||||
}});
|
||||
const result = await response.json();
|
||||
showResponse('tokenListResponse', result, response.ok);
|
||||
}} catch (error) {{
|
||||
showResponse('tokenListResponse', {{error: error.message}}, false);
|
||||
}}
|
||||
}});
|
||||
|
||||
// Cleanup tokens
|
||||
document.getElementById('cleanupTokensBtn').addEventListener('click', async () => {{
|
||||
try {{
|
||||
const response = await fetch(getAuthURL('/token/cleanup'), {{
|
||||
method: 'POST',
|
||||
headers: getAuthHeaders()
|
||||
}});
|
||||
const result = await response.json();
|
||||
showResponse('tokenListResponse', result, response.ok);
|
||||
}} catch (error) {{
|
||||
showResponse('tokenListResponse', {{error: error.message}}, false);
|
||||
}}
|
||||
}});
|
||||
|
||||
// Revoke token
|
||||
document.getElementById('revokeTokenBtn').addEventListener('click', async () => {{
|
||||
const tokenId = document.getElementById('revokeTokenId').value;
|
||||
if (!tokenId) {{
|
||||
showResponse('revokeTokenResponse', {{error: 'Token ID is required'}}, false);
|
||||
return;
|
||||
}}
|
||||
|
||||
try {{
|
||||
const response = await fetch(getAuthURL(`/token/revoke?token_id=${{encodeURIComponent(tokenId)}}`), {{
|
||||
method: 'DELETE',
|
||||
headers: getAuthHeaders()
|
||||
}});
|
||||
const result = await response.json();
|
||||
showResponse('revokeTokenResponse', result, response.ok);
|
||||
|
||||
// Refresh token list if revocation was successful
|
||||
if (response.ok) {{
|
||||
document.getElementById('listTokensBtn').click();
|
||||
document.getElementById('revokeTokenId').value = '';
|
||||
}}
|
||||
}} catch (error) {{
|
||||
showResponse('revokeTokenResponse', {{error: error.message}}, false);
|
||||
}}
|
||||
}});
|
||||
|
||||
// Load token list on page load
|
||||
document.getElementById('listTokensBtn').click();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return HTMLResponse(html_content)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_demo_page: {e}")
|
||||
error_html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Token Management Error</title>
|
||||
<style>body {{ font-family: Arial, sans-serif; margin: 50px; }}</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Token Management Error</h1>
|
||||
<p>Error loading token management page: {str(e)}</p>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return HTMLResponse(error_html, status_code=500)
|
||||
827
doris_mcp_server/auth/token_manager.py
Normal file
827
doris_mcp_server/auth/token_manager.py
Normal file
@@ -0,0 +1,827 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Token Authentication Management Module
|
||||
|
||||
Provides enterprise-grade token authentication system with configurable tokens,
|
||||
expiration management, role-based access control and secure token storage.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from pathlib import Path
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.security import SecurityLevel
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""Database connection configuration for token binding"""
|
||||
|
||||
host: str
|
||||
port: int = 9030
|
||||
user: str = ""
|
||||
password: str = ""
|
||||
database: str = "information_schema"
|
||||
charset: str = "UTF8"
|
||||
fe_http_port: int = 8030
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenInfo:
|
||||
"""Token information structure with optional database binding"""
|
||||
|
||||
token_id: str # Unique token identifier for audit and management
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
expires_at: Optional[datetime] = None
|
||||
last_used: Optional[datetime] = None
|
||||
description: str = "" # Optional description for token purpose
|
||||
is_active: bool = True
|
||||
database_config: Optional[DatabaseConfig] = None # Optional database binding
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenValidationResult:
|
||||
"""Token validation result"""
|
||||
|
||||
is_valid: bool
|
||||
token_info: Optional[TokenInfo] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class TokenManager:
|
||||
"""Enterprise Token Authentication Manager
|
||||
|
||||
Features:
|
||||
- Configurable token storage (file-based or environment variables)
|
||||
- Token expiration management
|
||||
- Secure token hashing
|
||||
- Role-based access control
|
||||
- Token lifecycle management
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Token storage
|
||||
self._tokens: Dict[str, TokenInfo] = {} # token_hash -> TokenInfo
|
||||
self._token_ids: Dict[str, str] = {} # token_id -> token_hash
|
||||
|
||||
# Configuration
|
||||
self.token_file_path = getattr(config.security, 'token_file_path', 'tokens.json')
|
||||
self.enable_token_expiry = getattr(config.security, 'enable_token_expiry', True)
|
||||
self.default_token_expiry_hours = getattr(config.security, 'default_token_expiry_hours', 24 * 30) # 30 days
|
||||
self.token_hash_algorithm = getattr(config.security, 'token_hash_algorithm', 'sha256')
|
||||
|
||||
# Hot reload configuration
|
||||
self.enable_hot_reload = True
|
||||
self.hot_reload_interval = 10 # Check every 10 seconds
|
||||
self._file_last_modified = 0
|
||||
self._hot_reload_task = None
|
||||
|
||||
# Initialize with default tokens if none exist
|
||||
self._initialize_default_tokens()
|
||||
|
||||
# Load tokens from configuration
|
||||
self._load_tokens()
|
||||
|
||||
# Start hot reload monitoring
|
||||
if self.enable_hot_reload:
|
||||
self._start_hot_reload()
|
||||
|
||||
self.logger.info(f"TokenManager initialized with {len(self._tokens)} tokens, hot reload: {self.enable_hot_reload}")
|
||||
|
||||
def _initialize_default_tokens(self):
|
||||
"""Initialize default tokens for basic authentication (configurable via environment)"""
|
||||
# Default token configurations (can be overridden by environment variables)
|
||||
default_tokens = [
|
||||
{
|
||||
'token_id': 'admin-token',
|
||||
'token': os.getenv('DEFAULT_ADMIN_TOKEN', 'doris_admin_token_123456'),
|
||||
'description': os.getenv('DEFAULT_ADMIN_DESCRIPTION', 'Default admin API access token'),
|
||||
'expires_hours': None # Never expires
|
||||
},
|
||||
{
|
||||
'token_id': 'analyst-token',
|
||||
'token': os.getenv('DEFAULT_ANALYST_TOKEN', 'doris_analyst_token_123456'),
|
||||
'description': os.getenv('DEFAULT_ANALYST_DESCRIPTION', 'Default data analysis API access token'),
|
||||
'expires_hours': None # Never expires
|
||||
},
|
||||
{
|
||||
'token_id': 'readonly-token',
|
||||
'token': os.getenv('DEFAULT_READONLY_TOKEN', 'doris_readonly_token_123456'),
|
||||
'description': os.getenv('DEFAULT_READONLY_DESCRIPTION', 'Default read-only API access token'),
|
||||
'expires_hours': None # Never expires
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# Only add default tokens if no custom tokens are defined via environment variables
|
||||
# Check if any TOKEN_* environment variables exist (excluding system and legacy configs)
|
||||
excluded_prefixes = ('DEFAULT_', 'TOKEN_FILE_PATH', 'TOKEN_HASH_')
|
||||
excluded_vars = {'TOKEN_SECRET', 'TOKEN_EXPIRY'}
|
||||
|
||||
custom_tokens_exist = any(
|
||||
key.startswith('TOKEN_') and
|
||||
not key.startswith(excluded_prefixes) and
|
||||
not key.endswith(('_EXPIRES_HOURS', '_DESCRIPTION')) and
|
||||
key not in excluded_vars
|
||||
for key in os.environ.keys()
|
||||
)
|
||||
|
||||
# Also check if token file exists and has content
|
||||
token_file_exists = False
|
||||
if os.path.exists(self.token_file_path):
|
||||
try:
|
||||
with open(self.token_file_path, 'r') as f:
|
||||
content = f.read().strip()
|
||||
if content and content != '{}':
|
||||
token_file_exists = True
|
||||
except:
|
||||
pass
|
||||
|
||||
# Add default tokens only if no custom configuration exists
|
||||
if not custom_tokens_exist and not token_file_exists:
|
||||
for token_config in default_tokens:
|
||||
self._add_token_from_config(token_config)
|
||||
|
||||
self.logger.info(f"Initialized {len(default_tokens)} default tokens (no custom config found)")
|
||||
else:
|
||||
self.logger.info("Skipped default tokens initialization (custom tokens detected)")
|
||||
|
||||
def _add_token_from_config(self, token_config: Dict[str, Any]):
|
||||
"""Add token from configuration with optional database binding"""
|
||||
try:
|
||||
# Calculate expiration time
|
||||
expires_at = None
|
||||
if self.enable_token_expiry:
|
||||
expires_hours = token_config.get('expires_hours', self.default_token_expiry_hours)
|
||||
if expires_hours is not None:
|
||||
expires_at = datetime.utcnow() + timedelta(hours=expires_hours)
|
||||
|
||||
# Parse database configuration if provided
|
||||
database_config = None
|
||||
if 'database_config' in token_config:
|
||||
db_config = token_config['database_config']
|
||||
database_config = DatabaseConfig(
|
||||
host=db_config.get('host', 'localhost'),
|
||||
port=db_config.get('port', 9030),
|
||||
user=db_config.get('user', 'root'),
|
||||
password=db_config.get('password', ''),
|
||||
database=db_config.get('database', 'information_schema'),
|
||||
charset=db_config.get('charset', 'UTF8'),
|
||||
fe_http_port=db_config.get('fe_http_port', 8030)
|
||||
)
|
||||
|
||||
# Create token info
|
||||
token_info = TokenInfo(
|
||||
token_id=token_config['token_id'],
|
||||
expires_at=expires_at,
|
||||
description=token_config.get('description', ''),
|
||||
is_active=token_config.get('is_active', True),
|
||||
database_config=database_config
|
||||
)
|
||||
|
||||
# Hash the token
|
||||
raw_token = token_config['token']
|
||||
token_hash = self._hash_token(raw_token)
|
||||
|
||||
# Store token
|
||||
self._tokens[token_hash] = token_info
|
||||
self._token_ids[token_info.token_id] = token_hash
|
||||
|
||||
db_info = f" with DB binding ({database_config.host})" if database_config else ""
|
||||
self.logger.debug(f"Added token '{token_info.token_id}'{db_info}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to add token from config: {e}")
|
||||
raise
|
||||
|
||||
def _load_tokens(self):
|
||||
"""Load tokens from configuration sources"""
|
||||
# 1. Load from environment variables
|
||||
self._load_tokens_from_env()
|
||||
|
||||
# 2. Load from token file if exists
|
||||
if os.path.exists(self.token_file_path):
|
||||
self._load_tokens_from_file()
|
||||
|
||||
self.logger.info(f"Token loading completed, total tokens: {len(self._tokens)}")
|
||||
|
||||
def _load_tokens_from_env(self):
|
||||
"""Load tokens from environment variables
|
||||
|
||||
Simplified format:
|
||||
TOKEN_<ID>=<token>
|
||||
TOKEN_<ID>_EXPIRES_HOURS=<hours>
|
||||
TOKEN_<ID>_DESCRIPTION=<description>
|
||||
"""
|
||||
token_prefixes = set()
|
||||
|
||||
# Find all TOKEN_ environment variables (exclude legacy and system variables)
|
||||
excluded_token_vars = {
|
||||
'TOKEN_SECRET', # Legacy token secret
|
||||
'TOKEN_EXPIRY', # Legacy token expiry
|
||||
'TOKEN_FILE_PATH', # System config
|
||||
'TOKEN_HASH_ALGORITHM' # System config
|
||||
}
|
||||
|
||||
for key in os.environ:
|
||||
if (key.startswith('TOKEN_') and
|
||||
not key.endswith(('_EXPIRES_HOURS', '_DESCRIPTION')) and
|
||||
key not in excluded_token_vars):
|
||||
token_id = key[6:] # Remove 'TOKEN_' prefix
|
||||
token_prefixes.add(token_id)
|
||||
|
||||
# Load each token
|
||||
for token_id in token_prefixes:
|
||||
try:
|
||||
token = os.environ.get(f'TOKEN_{token_id}')
|
||||
if not token:
|
||||
continue
|
||||
|
||||
expires_hours_str = os.environ.get(f'TOKEN_{token_id}_EXPIRES_HOURS', str(self.default_token_expiry_hours))
|
||||
description = os.environ.get(f'TOKEN_{token_id}_DESCRIPTION', f'Environment token {token_id}')
|
||||
|
||||
expires_hours = None
|
||||
try:
|
||||
if expires_hours_str and expires_hours_str.lower() != 'none':
|
||||
expires_hours = int(expires_hours_str)
|
||||
except ValueError:
|
||||
expires_hours = self.default_token_expiry_hours
|
||||
|
||||
# Add token
|
||||
token_config = {
|
||||
'token_id': token_id.lower(),
|
||||
'token': token,
|
||||
'expires_hours': expires_hours,
|
||||
'description': description
|
||||
}
|
||||
|
||||
self._add_token_from_config(token_config)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load token {token_id} from environment: {e}")
|
||||
|
||||
def _load_tokens_from_file(self):
|
||||
"""Load tokens from JSON file"""
|
||||
try:
|
||||
with open(self.token_file_path, 'r', encoding='utf-8') as f:
|
||||
tokens_data = json.load(f)
|
||||
|
||||
if isinstance(tokens_data, dict) and 'tokens' in tokens_data:
|
||||
tokens_list = tokens_data['tokens']
|
||||
elif isinstance(tokens_data, list):
|
||||
tokens_list = tokens_data
|
||||
else:
|
||||
self.logger.error(f"Invalid token file format: {self.token_file_path}")
|
||||
return
|
||||
|
||||
for token_config in tokens_list:
|
||||
self._add_token_from_config(token_config)
|
||||
|
||||
self.logger.info(f"Loaded {len(tokens_list)} tokens from file: {self.token_file_path}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load tokens from file {self.token_file_path}: {e}")
|
||||
|
||||
def _hash_token(self, token: str) -> str:
|
||||
"""Hash token for secure storage"""
|
||||
if self.token_hash_algorithm == 'sha256':
|
||||
return hashlib.sha256(token.encode('utf-8')).hexdigest()
|
||||
elif self.token_hash_algorithm == 'sha512':
|
||||
return hashlib.sha512(token.encode('utf-8')).hexdigest()
|
||||
else:
|
||||
# Fallback to sha256
|
||||
return hashlib.sha256(token.encode('utf-8')).hexdigest()
|
||||
|
||||
async def validate_token(self, token: str) -> TokenValidationResult:
|
||||
"""Validate token and return user information"""
|
||||
try:
|
||||
# Hash the provided token
|
||||
token_hash = self._hash_token(token)
|
||||
|
||||
# Find token info
|
||||
token_info = self._tokens.get(token_hash)
|
||||
if not token_info:
|
||||
return TokenValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Invalid token"
|
||||
)
|
||||
|
||||
# Check if token is active
|
||||
if not token_info.is_active:
|
||||
return TokenValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Token is inactive"
|
||||
)
|
||||
|
||||
# Check expiration
|
||||
if token_info.expires_at and datetime.utcnow() > token_info.expires_at:
|
||||
return TokenValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Token has expired"
|
||||
)
|
||||
|
||||
# Update last used time
|
||||
token_info.last_used = datetime.utcnow()
|
||||
|
||||
return TokenValidationResult(
|
||||
is_valid=True,
|
||||
token_info=token_info
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Token validation error: {e}")
|
||||
return TokenValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"Token validation failed: {str(e)}"
|
||||
)
|
||||
|
||||
def generate_token(self, length: int = 32) -> str:
|
||||
"""Generate a cryptographically secure random token"""
|
||||
return secrets.token_urlsafe(length)
|
||||
|
||||
async def create_token(
|
||||
self,
|
||||
token_id: str,
|
||||
expires_hours: Optional[int] = None,
|
||||
description: str = "",
|
||||
custom_token: Optional[str] = None,
|
||||
database_config: Optional[DatabaseConfig] = None
|
||||
) -> str:
|
||||
"""Create a new token"""
|
||||
try:
|
||||
# Check if token_id already exists
|
||||
if token_id in self._token_ids:
|
||||
raise ValueError(f"Token ID '{token_id}' already exists")
|
||||
|
||||
# Generate or use provided token
|
||||
if custom_token:
|
||||
raw_token = custom_token
|
||||
else:
|
||||
raw_token = self.generate_token()
|
||||
|
||||
# Calculate expiration
|
||||
expires_at = None
|
||||
if expires_hours is not None:
|
||||
expires_at = datetime.utcnow() + timedelta(hours=expires_hours)
|
||||
elif self.enable_token_expiry:
|
||||
expires_at = datetime.utcnow() + timedelta(hours=self.default_token_expiry_hours)
|
||||
|
||||
# Create token info
|
||||
token_info = TokenInfo(
|
||||
token_id=token_id,
|
||||
expires_at=expires_at,
|
||||
description=description,
|
||||
database_config=database_config
|
||||
)
|
||||
|
||||
# Hash and store token
|
||||
token_hash = self._hash_token(raw_token)
|
||||
self._tokens[token_hash] = token_info
|
||||
self._token_ids[token_id] = token_hash
|
||||
|
||||
self.logger.info(f"Created new token '{token_id}'")
|
||||
|
||||
# Save token to file
|
||||
self._save_token_to_file(token_id, raw_token, token_info)
|
||||
|
||||
return raw_token
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to create token: {e}")
|
||||
raise
|
||||
|
||||
async def revoke_token(self, token_id: str) -> bool:
|
||||
"""Revoke a token by token ID"""
|
||||
try:
|
||||
if token_id not in self._token_ids:
|
||||
self.logger.warning(f"Token ID '{token_id}' not found")
|
||||
return False
|
||||
|
||||
# Get token hash and remove from storage
|
||||
token_hash = self._token_ids[token_id]
|
||||
if token_hash in self._tokens:
|
||||
del self._tokens[token_hash]
|
||||
del self._token_ids[token_id]
|
||||
|
||||
self.logger.info(f"Revoked token '{token_id}'")
|
||||
|
||||
# Save updated tokens to file
|
||||
self._remove_token_from_file(token_id)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to revoke token '{token_id}': {e}")
|
||||
return False
|
||||
|
||||
def _save_tokens_to_file(self):
|
||||
"""Save current tokens to JSON file"""
|
||||
try:
|
||||
# Convert current tokens to file format
|
||||
tokens_list = []
|
||||
|
||||
for token_hash, token_info in self._tokens.items():
|
||||
# Find the raw token for this token_info
|
||||
raw_token = None
|
||||
for tid, thash in self._token_ids.items():
|
||||
if thash == token_hash and tid == token_info.token_id:
|
||||
# We can't recover the original token from hash,
|
||||
# so we'll create a placeholder for existing tokens
|
||||
raw_token = f"<existing_token_hash_{token_hash[:8]}>"
|
||||
break
|
||||
|
||||
if raw_token is None:
|
||||
continue
|
||||
|
||||
token_config = {
|
||||
"token_id": token_info.token_id,
|
||||
"token": raw_token,
|
||||
"description": token_info.description,
|
||||
"expires_hours": None,
|
||||
"is_active": token_info.is_active
|
||||
}
|
||||
|
||||
# Add expiration info
|
||||
if token_info.expires_at:
|
||||
# Calculate remaining hours from now
|
||||
remaining = token_info.expires_at - datetime.utcnow()
|
||||
if remaining.total_seconds() > 0:
|
||||
token_config["expires_hours"] = int(remaining.total_seconds() / 3600)
|
||||
else:
|
||||
token_config["expires_hours"] = 0
|
||||
|
||||
# Add database config if present
|
||||
if token_info.database_config:
|
||||
token_config["database_config"] = {
|
||||
"host": token_info.database_config.host,
|
||||
"port": token_info.database_config.port,
|
||||
"user": token_info.database_config.user,
|
||||
"password": token_info.database_config.password,
|
||||
"database": token_info.database_config.database,
|
||||
"charset": token_info.database_config.charset,
|
||||
"fe_http_port": token_info.database_config.fe_http_port
|
||||
}
|
||||
|
||||
tokens_list.append(token_config)
|
||||
|
||||
# Create file content
|
||||
file_content = {
|
||||
"version": "1.0",
|
||||
"description": "Doris MCP Server Token configuration file",
|
||||
"created_at": datetime.utcnow().isoformat() + "Z",
|
||||
"tokens": tokens_list,
|
||||
"notes": [
|
||||
"This file is automatically updated when tokens are created or revoked",
|
||||
"Please backup this file before making manual changes",
|
||||
"Tokens with hash placeholders were loaded from previous configuration"
|
||||
]
|
||||
}
|
||||
|
||||
# Save to file
|
||||
with open(self.token_file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(file_content, f, indent=2, ensure_ascii=False)
|
||||
|
||||
self.logger.info(f"Saved {len(tokens_list)} tokens to file: {self.token_file_path}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to save tokens to file {self.token_file_path}: {e}")
|
||||
|
||||
def _save_token_to_file(self, token_id: str, raw_token: str, token_info: TokenInfo):
|
||||
"""Save a single new token to file (for newly created tokens only)"""
|
||||
try:
|
||||
# Load existing file
|
||||
existing_data = {"tokens": []}
|
||||
if os.path.exists(self.token_file_path):
|
||||
try:
|
||||
with open(self.token_file_path, 'r', encoding='utf-8') as f:
|
||||
existing_data = json.load(f)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Could not load existing token file: {e}")
|
||||
|
||||
# Ensure tokens list exists
|
||||
if 'tokens' not in existing_data or not isinstance(existing_data['tokens'], list):
|
||||
existing_data['tokens'] = []
|
||||
|
||||
# Check if token already exists in file
|
||||
token_exists = False
|
||||
for i, token_config in enumerate(existing_data['tokens']):
|
||||
if token_config.get('token_id') == token_id:
|
||||
# Update existing token
|
||||
existing_data['tokens'][i] = self._token_info_to_config(token_id, raw_token, token_info)
|
||||
token_exists = True
|
||||
break
|
||||
|
||||
# Add new token if it doesn't exist
|
||||
if not token_exists:
|
||||
new_token_config = self._token_info_to_config(token_id, raw_token, token_info)
|
||||
existing_data['tokens'].append(new_token_config)
|
||||
|
||||
# Update metadata
|
||||
existing_data.update({
|
||||
"version": "1.0",
|
||||
"description": "Doris MCP Server Token configuration file",
|
||||
"updated_at": datetime.utcnow().isoformat() + "Z"
|
||||
})
|
||||
|
||||
# Save to file
|
||||
with open(self.token_file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
self.logger.info(f"Saved token '{token_id}' to file: {self.token_file_path}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to save token '{token_id}' to file: {e}")
|
||||
|
||||
def _token_info_to_config(self, token_id: str, raw_token: str, token_info: TokenInfo) -> dict:
|
||||
"""Convert TokenInfo to file configuration format"""
|
||||
token_config = {
|
||||
"token_id": token_id,
|
||||
"token": raw_token,
|
||||
"description": token_info.description,
|
||||
"expires_hours": None,
|
||||
"is_active": token_info.is_active
|
||||
}
|
||||
|
||||
# Add expiration info
|
||||
if token_info.expires_at:
|
||||
# Calculate remaining hours from creation time
|
||||
remaining = token_info.expires_at - token_info.created_at
|
||||
token_config["expires_hours"] = int(remaining.total_seconds() / 3600) if remaining.total_seconds() > 0 else None
|
||||
|
||||
# Add database config if present
|
||||
if token_info.database_config:
|
||||
token_config["database_config"] = {
|
||||
"host": token_info.database_config.host,
|
||||
"port": token_info.database_config.port,
|
||||
"user": token_info.database_config.user,
|
||||
"password": token_info.database_config.password,
|
||||
"database": token_info.database_config.database,
|
||||
"charset": token_info.database_config.charset,
|
||||
"fe_http_port": token_info.database_config.fe_http_port
|
||||
}
|
||||
|
||||
return token_config
|
||||
|
||||
def _remove_token_from_file(self, token_id: str):
|
||||
"""Remove a token from the JSON file"""
|
||||
try:
|
||||
if not os.path.exists(self.token_file_path):
|
||||
return
|
||||
|
||||
# Load existing file
|
||||
with open(self.token_file_path, 'r', encoding='utf-8') as f:
|
||||
existing_data = json.load(f)
|
||||
|
||||
if 'tokens' not in existing_data or not isinstance(existing_data['tokens'], list):
|
||||
return
|
||||
|
||||
# Remove the token
|
||||
original_count = len(existing_data['tokens'])
|
||||
existing_data['tokens'] = [
|
||||
token for token in existing_data['tokens']
|
||||
if token.get('token_id') != token_id
|
||||
]
|
||||
|
||||
if len(existing_data['tokens']) < original_count:
|
||||
# Update metadata
|
||||
existing_data.update({
|
||||
"version": "1.0",
|
||||
"description": "Doris MCP Server Token configuration file",
|
||||
"updated_at": datetime.utcnow().isoformat() + "Z"
|
||||
})
|
||||
|
||||
# Save to file
|
||||
with open(self.token_file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
self.logger.info(f"Removed token '{token_id}' from file: {self.token_file_path}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to remove token '{token_id}' from file: {e}")
|
||||
|
||||
async def list_tokens(self) -> List[Dict[str, Any]]:
|
||||
"""List all tokens (without sensitive data)"""
|
||||
tokens = []
|
||||
|
||||
for token_hash, token_info in self._tokens.items():
|
||||
token_data = {
|
||||
'token_id': token_info.token_id,
|
||||
'created_at': token_info.created_at.isoformat(),
|
||||
'expires_at': token_info.expires_at.isoformat() if token_info.expires_at else None,
|
||||
'last_used': token_info.last_used.isoformat() if token_info.last_used else None,
|
||||
'is_active': token_info.is_active,
|
||||
'description': token_info.description,
|
||||
'is_expired': token_info.expires_at and datetime.utcnow() > token_info.expires_at if token_info.expires_at else False
|
||||
}
|
||||
|
||||
# Add database binding info (without sensitive data)
|
||||
if token_info.database_config:
|
||||
token_data['database_binding'] = {
|
||||
'host': token_info.database_config.host,
|
||||
'port': token_info.database_config.port,
|
||||
'user': token_info.database_config.user,
|
||||
'database': token_info.database_config.database,
|
||||
'has_password': bool(token_info.database_config.password)
|
||||
}
|
||||
else:
|
||||
token_data['database_binding'] = None
|
||||
|
||||
tokens.append(token_data)
|
||||
|
||||
# Sort by creation time
|
||||
tokens.sort(key=lambda x: x['created_at'], reverse=True)
|
||||
|
||||
return tokens
|
||||
|
||||
async def cleanup_expired_tokens(self) -> int:
|
||||
"""Remove expired tokens and return count"""
|
||||
if not self.enable_token_expiry:
|
||||
return 0
|
||||
|
||||
now = datetime.utcnow()
|
||||
expired_tokens = []
|
||||
|
||||
# Find expired tokens
|
||||
for token_hash, token_info in self._tokens.items():
|
||||
if token_info.expires_at and now > token_info.expires_at:
|
||||
expired_tokens.append((token_hash, token_info.token_id))
|
||||
|
||||
# Remove expired tokens
|
||||
for token_hash, token_id in expired_tokens:
|
||||
del self._tokens[token_hash]
|
||||
if token_id in self._token_ids:
|
||||
del self._token_ids[token_id]
|
||||
|
||||
if expired_tokens:
|
||||
self.logger.info(f"Cleaned up {len(expired_tokens)} expired tokens")
|
||||
|
||||
return len(expired_tokens)
|
||||
|
||||
async def save_tokens_to_file(self, file_path: Optional[str] = None) -> bool:
|
||||
"""Save current tokens to JSON file"""
|
||||
try:
|
||||
file_path = file_path or self.token_file_path
|
||||
tokens_list = await self.list_tokens()
|
||||
|
||||
tokens_data = {
|
||||
'version': '1.0',
|
||||
'created_at': datetime.utcnow().isoformat(),
|
||||
'tokens': tokens_list
|
||||
}
|
||||
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(tokens_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
self.logger.info(f"Saved {len(tokens_list)} tokens to file: {file_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to save tokens to file: {e}")
|
||||
return False
|
||||
|
||||
def get_database_config_by_token(self, token: str) -> Optional[DatabaseConfig]:
|
||||
"""Get database configuration bound to a token
|
||||
|
||||
Args:
|
||||
token: The raw token string
|
||||
|
||||
Returns:
|
||||
DatabaseConfig if token exists and has database binding, None otherwise
|
||||
"""
|
||||
try:
|
||||
token_hash = self._hash_token(token)
|
||||
token_info = self._tokens.get(token_hash)
|
||||
|
||||
if not token_info or not token_info.is_active:
|
||||
return None
|
||||
|
||||
# Check expiration
|
||||
if token_info.expires_at and datetime.utcnow() > token_info.expires_at:
|
||||
return None
|
||||
|
||||
return token_info.database_config
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get database config for token: {e}")
|
||||
return None
|
||||
|
||||
def get_token_stats(self) -> Dict[str, Any]:
|
||||
"""Get token statistics"""
|
||||
now = datetime.utcnow()
|
||||
total_tokens = len(self._tokens)
|
||||
active_tokens = sum(1 for info in self._tokens.values() if info.is_active)
|
||||
expired_tokens = sum(1 for info in self._tokens.values()
|
||||
if info.expires_at and now > info.expires_at)
|
||||
tokens_with_db = sum(1 for info in self._tokens.values()
|
||||
if info.database_config is not None)
|
||||
|
||||
return {
|
||||
'total_tokens': total_tokens,
|
||||
'active_tokens': active_tokens,
|
||||
'expired_tokens': expired_tokens,
|
||||
'tokens_with_database_binding': tokens_with_db,
|
||||
'expiry_enabled': self.enable_token_expiry,
|
||||
'default_expiry_hours': self.default_token_expiry_hours,
|
||||
'hot_reload_enabled': self.enable_hot_reload,
|
||||
'last_file_check': datetime.fromtimestamp(self._file_last_modified).isoformat() if self._file_last_modified else None
|
||||
}
|
||||
|
||||
def _start_hot_reload(self):
|
||||
"""Start hot reload monitoring task"""
|
||||
if self._hot_reload_task:
|
||||
return # Already running
|
||||
|
||||
# Update initial file modification time
|
||||
self._update_file_modified_time()
|
||||
|
||||
# Start monitoring task
|
||||
self._hot_reload_task = asyncio.create_task(self._hot_reload_monitor())
|
||||
self.logger.info(f"Started hot reload monitoring for {self.token_file_path}")
|
||||
|
||||
def stop_hot_reload(self):
|
||||
"""Stop hot reload monitoring"""
|
||||
if self._hot_reload_task:
|
||||
self._hot_reload_task.cancel()
|
||||
self._hot_reload_task = None
|
||||
self.logger.info("Stopped hot reload monitoring")
|
||||
|
||||
def _update_file_modified_time(self):
|
||||
"""Update the last modified time of tokens file"""
|
||||
try:
|
||||
if os.path.exists(self.token_file_path):
|
||||
self._file_last_modified = os.path.getmtime(self.token_file_path)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Failed to get file modification time: {e}")
|
||||
|
||||
async def _hot_reload_monitor(self):
|
||||
"""Background task to monitor tokens.json file changes"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.hot_reload_interval)
|
||||
|
||||
if not os.path.exists(self.token_file_path):
|
||||
continue
|
||||
|
||||
# Check if file was modified
|
||||
current_mtime = os.path.getmtime(self.token_file_path)
|
||||
if current_mtime > self._file_last_modified:
|
||||
self.logger.info(f"Detected changes in {self.token_file_path}, reloading tokens...")
|
||||
|
||||
try:
|
||||
# Backup current tokens
|
||||
old_tokens = self._tokens.copy()
|
||||
old_token_ids = self._token_ids.copy()
|
||||
|
||||
# Clear and reload
|
||||
self._tokens.clear()
|
||||
self._token_ids.clear()
|
||||
|
||||
# Reinitialize default tokens
|
||||
self._initialize_default_tokens()
|
||||
|
||||
# Load from file
|
||||
self._load_tokens_from_file()
|
||||
|
||||
# Update modification time
|
||||
self._file_last_modified = current_mtime
|
||||
|
||||
self.logger.info(f"Hot reload completed, {len(self._tokens)} tokens loaded")
|
||||
|
||||
except Exception as reload_error:
|
||||
# Restore backup on failure
|
||||
self.logger.error(f"Hot reload failed, restoring previous tokens: {reload_error}")
|
||||
self._tokens = old_tokens
|
||||
self._token_ids = old_token_ids
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self.logger.info("Hot reload monitor stopped")
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in hot reload monitor: {e}")
|
||||
227
doris_mcp_server/auth/token_security_middleware.py
Normal file
227
doris_mcp_server/auth/token_security_middleware.py
Normal file
@@ -0,0 +1,227 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Token Management Security Middleware
|
||||
|
||||
Provides comprehensive security controls for token management endpoints including
|
||||
IP restrictions, admin authentication, and configuration-based access control.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import ipaddress
|
||||
import secrets
|
||||
import time
|
||||
from typing import Optional, List, Dict, Any
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.config import DorisConfig
|
||||
|
||||
|
||||
class TokenSecurityMiddleware:
|
||||
"""Security middleware for token management endpoints"""
|
||||
|
||||
def __init__(self, config: DorisConfig):
|
||||
self.config = config
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Initialize admin token hash if provided
|
||||
self._admin_token_hash = None
|
||||
if config.security.token_management_admin_token:
|
||||
self._admin_token_hash = self._hash_token(config.security.token_management_admin_token)
|
||||
|
||||
# Normalize allowed IPs
|
||||
self._allowed_networks = self._parse_allowed_networks(
|
||||
config.security.token_management_allowed_ips
|
||||
)
|
||||
|
||||
self.logger.info(f"Token management security initialized: "
|
||||
f"HTTP endpoints {'enabled' if config.security.enable_http_token_management else 'disabled'}, "
|
||||
f"Admin auth {'required' if config.security.require_admin_auth else 'optional'}, "
|
||||
f"Allowed networks: {len(self._allowed_networks)}")
|
||||
|
||||
def _hash_token(self, token: str) -> str:
|
||||
"""Hash token using SHA-256"""
|
||||
return hashlib.sha256(token.encode('utf-8')).hexdigest()
|
||||
|
||||
def _parse_allowed_networks(self, allowed_ips: List[str]) -> List[ipaddress.IPv4Network | ipaddress.IPv6Network]:
|
||||
"""Parse allowed IP addresses and networks"""
|
||||
networks = []
|
||||
for ip_str in allowed_ips:
|
||||
try:
|
||||
# Try to parse as network (CIDR notation)
|
||||
if '/' in ip_str:
|
||||
networks.append(ipaddress.ip_network(ip_str, strict=False))
|
||||
else:
|
||||
# Parse as single IP and convert to /32 network
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
if isinstance(ip, ipaddress.IPv4Address):
|
||||
networks.append(ipaddress.IPv4Network(f"{ip}/32"))
|
||||
else:
|
||||
networks.append(ipaddress.IPv6Network(f"{ip}/128"))
|
||||
except ValueError as e:
|
||||
self.logger.warning(f"Invalid IP/network '{ip_str}': {e}")
|
||||
|
||||
return networks
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract client IP from request, considering proxies"""
|
||||
# Check X-Forwarded-For header first (for proxy setups)
|
||||
forwarded_for = request.headers.get('X-Forwarded-For')
|
||||
if forwarded_for:
|
||||
# Take the first IP (original client)
|
||||
client_ip = forwarded_for.split(',')[0].strip()
|
||||
elif request.headers.get('X-Real-IP'):
|
||||
client_ip = request.headers.get('X-Real-IP')
|
||||
else:
|
||||
# Direct connection
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
return client_ip
|
||||
|
||||
def _is_ip_allowed(self, client_ip: str) -> bool:
|
||||
"""Check if client IP is in allowed networks"""
|
||||
try:
|
||||
client_addr = ipaddress.ip_address(client_ip)
|
||||
|
||||
for network in self._allowed_networks:
|
||||
if client_addr in network:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except ValueError:
|
||||
self.logger.warning(f"Invalid client IP format: {client_ip}")
|
||||
return False
|
||||
|
||||
def _extract_admin_token(self, request: Request) -> Optional[str]:
|
||||
"""Extract admin token from request headers"""
|
||||
# Try Authorization header first
|
||||
auth_header = request.headers.get('Authorization', '')
|
||||
if auth_header.startswith('Bearer '):
|
||||
return auth_header[7:]
|
||||
elif auth_header.startswith('Token '):
|
||||
return auth_header[6:]
|
||||
|
||||
# Try X-Admin-Token header
|
||||
admin_token = request.headers.get('X-Admin-Token', '')
|
||||
if admin_token:
|
||||
return admin_token
|
||||
|
||||
# Try query parameter as fallback (not recommended for production)
|
||||
admin_token = request.query_params.get('admin_token', '')
|
||||
if admin_token:
|
||||
self.logger.warning("Admin token passed via query parameter - this is insecure for production")
|
||||
return admin_token
|
||||
|
||||
return None
|
||||
|
||||
def _verify_admin_token(self, provided_token: str) -> bool:
|
||||
"""Verify provided admin token against configured token"""
|
||||
if not self._admin_token_hash:
|
||||
self.logger.warning("No admin token configured for token management")
|
||||
return False
|
||||
|
||||
provided_hash = self._hash_token(provided_token)
|
||||
|
||||
# Use constant-time comparison to prevent timing attacks
|
||||
return hmac.compare_digest(self._admin_token_hash, provided_hash)
|
||||
|
||||
async def check_token_management_access(self, request: Request) -> Optional[JSONResponse]:
|
||||
"""
|
||||
Check if request is authorized for token management operations
|
||||
|
||||
Returns:
|
||||
None if access is granted
|
||||
JSONResponse with error if access is denied
|
||||
"""
|
||||
|
||||
# Check if HTTP token management is enabled
|
||||
if not self.config.security.enable_http_token_management:
|
||||
self.logger.warning(f"Token management endpoint access denied - HTTP management disabled: {request.url.path}")
|
||||
return JSONResponse({
|
||||
"error": "Token management endpoints are disabled for security",
|
||||
"message": "HTTP token management is disabled. Use file-based token management instead.",
|
||||
"suggestion": "Edit tokens.json file directly or enable HTTP management with proper security configuration"
|
||||
}, status_code=403)
|
||||
|
||||
# Extract client IP
|
||||
client_ip = self._get_client_ip(request)
|
||||
|
||||
# Check IP restrictions
|
||||
if not self._is_ip_allowed(client_ip):
|
||||
self.logger.warning(f"Token management access denied for IP {client_ip}: not in allowed list")
|
||||
return JSONResponse({
|
||||
"error": "Access denied - IP not allowed",
|
||||
"client_ip": client_ip,
|
||||
"message": "Token management is restricted to specific IP addresses",
|
||||
"allowed_networks": [str(net) for net in self._allowed_networks]
|
||||
}, status_code=403)
|
||||
|
||||
# Check admin authentication if required
|
||||
if self.config.security.require_admin_auth:
|
||||
admin_token = self._extract_admin_token(request)
|
||||
|
||||
if not admin_token:
|
||||
self.logger.warning(f"Token management access denied for IP {client_ip}: missing admin token")
|
||||
return JSONResponse({
|
||||
"error": "Admin authentication required",
|
||||
"message": "Token management requires admin authentication",
|
||||
"hint": "Provide admin token in Authorization header: 'Bearer <admin_token>' or 'X-Admin-Token: <admin_token>'"
|
||||
}, status_code=401)
|
||||
|
||||
if not self._verify_admin_token(admin_token):
|
||||
self.logger.warning(f"Token management access denied for IP {client_ip}: invalid admin token")
|
||||
return JSONResponse({
|
||||
"error": "Invalid admin token",
|
||||
"message": "The provided admin token is invalid"
|
||||
}, status_code=401)
|
||||
|
||||
# Log successful access
|
||||
self.logger.info(f"Token management access granted for IP {client_ip} to {request.url.path}")
|
||||
|
||||
# Access granted
|
||||
return None
|
||||
|
||||
def get_security_info(self) -> Dict[str, Any]:
|
||||
"""Get current security configuration info (for demo/status pages)"""
|
||||
return {
|
||||
"http_token_management_enabled": self.config.security.enable_http_token_management,
|
||||
"admin_auth_required": self.config.security.require_admin_auth,
|
||||
"admin_token_configured": bool(self._admin_token_hash),
|
||||
"allowed_networks_count": len(self._allowed_networks),
|
||||
"allowed_networks": [str(net) for net in self._allowed_networks],
|
||||
"security_features": [
|
||||
"IP address restrictions",
|
||||
"Admin token authentication" if self.config.security.require_admin_auth else "Optional admin authentication",
|
||||
"Secure token hashing",
|
||||
"Request logging and auditing"
|
||||
]
|
||||
}
|
||||
|
||||
def generate_admin_token(self) -> str:
|
||||
"""Generate a secure admin token"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
# Convenience function for middleware creation
|
||||
def create_token_security_middleware(config: DorisConfig) -> TokenSecurityMiddleware:
|
||||
"""Create token security middleware with configuration"""
|
||||
return TokenSecurityMiddleware(config)
|
||||
365
doris_mcp_server/auth/token_validators.py
Normal file
365
doris_mcp_server/auth/token_validators.py
Normal file
@@ -0,0 +1,365 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
JWT Token Validation Module
|
||||
Provides token validation, blacklist management and security features
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, Set, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TokenBlacklist:
|
||||
"""JWT Token Blacklist Manager
|
||||
|
||||
Manages revoked tokens to prevent revoked tokens from being used again
|
||||
Supports both in-memory and persistent storage
|
||||
"""
|
||||
|
||||
def __init__(self, cleanup_interval: int = 3600):
|
||||
"""Initialize token blacklist
|
||||
|
||||
Args:
|
||||
cleanup_interval: Interval for cleaning up expired tokens (seconds)
|
||||
"""
|
||||
self.cleanup_interval = cleanup_interval
|
||||
# Storage format: {token_jti: expiry_timestamp}
|
||||
self._blacklisted_tokens: Dict[str, float] = {}
|
||||
self._cleanup_task = None
|
||||
|
||||
logger.info("TokenBlacklist initialized")
|
||||
|
||||
async def start(self):
|
||||
"""Start blacklist manager"""
|
||||
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||
logger.info("TokenBlacklist started with periodic cleanup")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop blacklist manager"""
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("TokenBlacklist stopped")
|
||||
|
||||
async def add_token(self, jti: str, exp: float):
|
||||
"""Add token to blacklist
|
||||
|
||||
Args:
|
||||
jti: JWT ID (unique identifier)
|
||||
exp: Token expiration timestamp
|
||||
"""
|
||||
self._blacklisted_tokens[jti] = exp
|
||||
logger.info(f"Token {jti} added to blacklist")
|
||||
|
||||
async def is_blacklisted(self, jti: str) -> bool:
|
||||
"""Check if token is blacklisted
|
||||
|
||||
Args:
|
||||
jti: JWT ID
|
||||
|
||||
Returns:
|
||||
True if blacklisted, False otherwise
|
||||
"""
|
||||
return jti in self._blacklisted_tokens
|
||||
|
||||
async def remove_token(self, jti: str) -> bool:
|
||||
"""Remove token from blacklist
|
||||
|
||||
Args:
|
||||
jti: JWT ID
|
||||
|
||||
Returns:
|
||||
True if removed, False if not found
|
||||
"""
|
||||
if jti in self._blacklisted_tokens:
|
||||
del self._blacklisted_tokens[jti]
|
||||
logger.info(f"Token {jti} removed from blacklist")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""Clean up expired blacklisted tokens
|
||||
|
||||
Returns:
|
||||
Number of tokens cleaned up
|
||||
"""
|
||||
current_time = time.time()
|
||||
expired_tokens = [
|
||||
jti for jti, exp in self._blacklisted_tokens.items()
|
||||
if exp <= current_time
|
||||
]
|
||||
|
||||
for jti in expired_tokens:
|
||||
del self._blacklisted_tokens[jti]
|
||||
|
||||
if expired_tokens:
|
||||
logger.info(f"Cleaned up {len(expired_tokens)} expired tokens from blacklist")
|
||||
|
||||
return len(expired_tokens)
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get blacklist statistics"""
|
||||
current_time = time.time()
|
||||
active_tokens = sum(1 for exp in self._blacklisted_tokens.values() if exp > current_time)
|
||||
|
||||
return {
|
||||
"total_blacklisted": len(self._blacklisted_tokens),
|
||||
"active_blacklisted": active_tokens,
|
||||
"expired_blacklisted": len(self._blacklisted_tokens) - active_tokens,
|
||||
"cleanup_interval": self.cleanup_interval
|
||||
}
|
||||
|
||||
async def _periodic_cleanup(self):
|
||||
"""Periodically clean up expired tokens"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.cleanup_interval)
|
||||
await self.cleanup_expired()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error during periodic cleanup: {e}")
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Token usage rate limiter"""
|
||||
|
||||
def __init__(self, max_requests: int = 100, time_window: int = 3600):
|
||||
"""Initialize rate limiter
|
||||
|
||||
Args:
|
||||
max_requests: Maximum requests within time window
|
||||
time_window: Time window in seconds
|
||||
"""
|
||||
self.max_requests = max_requests
|
||||
self.time_window = time_window
|
||||
# Storage format: {user_id: [timestamp1, timestamp2, ...]}
|
||||
self._request_history: Dict[str, list] = defaultdict(list)
|
||||
|
||||
logger.info(f"RateLimiter initialized: {max_requests} requests per {time_window} seconds")
|
||||
|
||||
async def is_allowed(self, user_id: str) -> bool:
|
||||
"""Check if user is allowed to make request
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
True if allowed, False otherwise
|
||||
"""
|
||||
current_time = time.time()
|
||||
user_requests = self._request_history[user_id]
|
||||
|
||||
# Clean up expired request records
|
||||
cutoff_time = current_time - self.time_window
|
||||
user_requests[:] = [t for t in user_requests if t > cutoff_time]
|
||||
|
||||
# Check if limit exceeded
|
||||
if len(user_requests) >= self.max_requests:
|
||||
logger.warning(f"Rate limit exceeded for user {user_id}")
|
||||
return False
|
||||
|
||||
# Record current request
|
||||
user_requests.append(current_time)
|
||||
return True
|
||||
|
||||
async def get_usage(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get user usage information
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Usage statistics
|
||||
"""
|
||||
current_time = time.time()
|
||||
user_requests = self._request_history[user_id]
|
||||
|
||||
# Clean up expired records
|
||||
cutoff_time = current_time - self.time_window
|
||||
active_requests = [t for t in user_requests if t > cutoff_time]
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"requests_in_window": len(active_requests),
|
||||
"max_requests": self.max_requests,
|
||||
"time_window": self.time_window,
|
||||
"remaining_requests": max(0, self.max_requests - len(active_requests))
|
||||
}
|
||||
|
||||
|
||||
class TokenValidator:
|
||||
"""JWT Token Validator
|
||||
|
||||
Provides comprehensive JWT token validation functionality, including signature verification,
|
||||
claim validation, blacklist checking and rate limiting
|
||||
"""
|
||||
|
||||
def __init__(self, config, blacklist: Optional[TokenBlacklist] = None):
|
||||
"""Initialize token validator
|
||||
|
||||
Args:
|
||||
config: DorisConfig configuration object (with security attribute)
|
||||
blacklist: Token blacklist manager
|
||||
"""
|
||||
self.config = config
|
||||
self.blacklist = blacklist or TokenBlacklist()
|
||||
self.rate_limiter = RateLimiter()
|
||||
|
||||
# Access JWT settings through the security configuration
|
||||
if hasattr(config, 'security'):
|
||||
security_config = config.security
|
||||
else:
|
||||
# Fallback if config is passed directly as SecurityConfig
|
||||
security_config = config
|
||||
|
||||
# Validation options
|
||||
self.verify_signature = security_config.jwt_verify_signature
|
||||
self.verify_audience = security_config.jwt_verify_audience
|
||||
self.verify_issuer = security_config.jwt_verify_issuer
|
||||
self.require_exp = security_config.jwt_require_exp
|
||||
self.require_iat = security_config.jwt_require_iat
|
||||
self.require_nbf = security_config.jwt_require_nbf
|
||||
self.leeway = security_config.jwt_leeway
|
||||
|
||||
# Expected values
|
||||
self.expected_audience = security_config.jwt_audience
|
||||
self.expected_issuer = security_config.jwt_issuer
|
||||
|
||||
logger.info("TokenValidator initialized")
|
||||
|
||||
async def validate_claims(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate JWT claims
|
||||
|
||||
Args:
|
||||
payload: JWT payload
|
||||
|
||||
Returns:
|
||||
Validation result
|
||||
|
||||
Raises:
|
||||
ValueError: Validation failed
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# Validate issuer
|
||||
if self.verify_issuer:
|
||||
if payload.get('iss') != self.expected_issuer:
|
||||
raise ValueError(f"Invalid issuer: expected {self.expected_issuer}")
|
||||
|
||||
# Validate audience
|
||||
if self.verify_audience:
|
||||
aud = payload.get('aud')
|
||||
if isinstance(aud, list):
|
||||
if self.expected_audience not in aud:
|
||||
raise ValueError(f"Invalid audience: {self.expected_audience} not in {aud}")
|
||||
elif aud != self.expected_audience:
|
||||
raise ValueError(f"Invalid audience: expected {self.expected_audience}")
|
||||
|
||||
# Validate expiration time
|
||||
if self.require_exp or 'exp' in payload:
|
||||
exp = payload.get('exp')
|
||||
if not exp:
|
||||
raise ValueError("Missing 'exp' claim")
|
||||
if current_time > exp + self.leeway:
|
||||
raise ValueError("Token has expired")
|
||||
|
||||
# Validate not before time
|
||||
if self.require_nbf or 'nbf' in payload:
|
||||
nbf = payload.get('nbf')
|
||||
if not nbf:
|
||||
raise ValueError("Missing 'nbf' claim")
|
||||
if current_time < nbf - self.leeway:
|
||||
raise ValueError("Token not yet valid")
|
||||
|
||||
# Validate issued at time
|
||||
if self.require_iat or 'iat' in payload:
|
||||
iat = payload.get('iat')
|
||||
if not iat:
|
||||
raise ValueError("Missing 'iat' claim")
|
||||
# Allow some clock skew, but cannot be future time
|
||||
if iat > current_time + self.leeway:
|
||||
raise ValueError("Token issued in the future")
|
||||
|
||||
# Check blacklist
|
||||
jti = payload.get('jti')
|
||||
if jti and await self.blacklist.is_blacklisted(jti):
|
||||
raise ValueError("Token has been revoked")
|
||||
|
||||
# Rate limit check
|
||||
user_id = payload.get('sub')
|
||||
if user_id:
|
||||
if not await self.rate_limiter.is_allowed(user_id):
|
||||
raise ValueError("Rate limit exceeded")
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"user_id": user_id,
|
||||
"payload": payload
|
||||
}
|
||||
|
||||
async def start(self):
|
||||
"""Start validator"""
|
||||
await self.blacklist.start()
|
||||
logger.info("TokenValidator started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop validator"""
|
||||
await self.blacklist.stop()
|
||||
logger.info("TokenValidator stopped")
|
||||
|
||||
async def revoke_token(self, jti: str, exp: float):
|
||||
"""Revoke token
|
||||
|
||||
Args:
|
||||
jti: JWT ID
|
||||
exp: Token expiration time
|
||||
"""
|
||||
await self.blacklist.add_token(jti, exp)
|
||||
logger.info(f"Token {jti} has been revoked")
|
||||
|
||||
async def get_validation_stats(self) -> Dict[str, Any]:
|
||||
"""Get validation statistics"""
|
||||
blacklist_stats = await self.blacklist.get_stats()
|
||||
|
||||
return {
|
||||
"blacklist": blacklist_stats,
|
||||
"validation_config": {
|
||||
"verify_signature": self.verify_signature,
|
||||
"verify_audience": self.verify_audience,
|
||||
"verify_issuer": self.verify_issuer,
|
||||
"require_exp": self.require_exp,
|
||||
"require_iat": self.require_iat,
|
||||
"require_nbf": self.require_nbf,
|
||||
"leeway": self.leeway
|
||||
}
|
||||
}
|
||||
|
||||
async def get_user_rate_limit_info(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get user rate limit information"""
|
||||
return await self.rate_limiter.get_usage(user_id)
|
||||
@@ -221,6 +221,8 @@ logger = logging.getLogger(__name__)
|
||||
_default_config = DorisConfig()
|
||||
|
||||
|
||||
|
||||
|
||||
class DorisServer:
|
||||
"""Apache Doris MCP Server main class"""
|
||||
|
||||
@@ -228,11 +230,15 @@ class DorisServer:
|
||||
self.config = config
|
||||
self.server = Server("doris-mcp-server")
|
||||
|
||||
# Initialize security manager
|
||||
# Initialize security manager (without connection_manager initially)
|
||||
self.security_manager = DorisSecurityManager(config)
|
||||
|
||||
# Initialize connection manager, pass in security manager
|
||||
self.connection_manager = DorisConnectionManager(config, self.security_manager)
|
||||
# Initialize connection manager, pass in security manager and token manager for token-bound DB config
|
||||
token_manager = self.security_manager.auth_provider.token_manager if hasattr(self.security_manager, 'auth_provider') and hasattr(self.security_manager.auth_provider, 'token_manager') else None
|
||||
self.connection_manager = DorisConnectionManager(config, self.security_manager, token_manager)
|
||||
|
||||
# Set connection manager reference in security manager for database validation
|
||||
self.security_manager.connection_manager = self.connection_manager
|
||||
|
||||
# Initialize independent managers
|
||||
self.resources_manager = DorisResourcesManager(self.connection_manager)
|
||||
@@ -244,6 +250,40 @@ class DorisServer:
|
||||
self.logger = get_logger(f"{__name__}.DorisServer")
|
||||
self._setup_handlers()
|
||||
|
||||
async def _extract_auth_info_from_scope(self, scope, headers):
|
||||
"""Extract authentication information from ASGI scope and headers"""
|
||||
auth_info = {}
|
||||
|
||||
# Extract client IP
|
||||
client = scope.get("client")
|
||||
if client:
|
||||
auth_info["client_ip"] = client[0]
|
||||
else:
|
||||
auth_info["client_ip"] = "unknown"
|
||||
|
||||
# Extract token from Authorization header
|
||||
authorization = headers.get(b'authorization', b'').decode('utf-8')
|
||||
if authorization:
|
||||
if authorization.startswith('Bearer '):
|
||||
auth_info["token"] = authorization[7:]
|
||||
auth_info["authorization"] = authorization
|
||||
elif authorization.startswith('Token '):
|
||||
auth_info["token"] = authorization[6:]
|
||||
auth_info["authorization"] = authorization
|
||||
|
||||
# Extract token from query parameters (for compatibility)
|
||||
query_string = scope.get("query_string", b"").decode('utf-8')
|
||||
if query_string and "token=" in query_string:
|
||||
import urllib.parse
|
||||
query_params = urllib.parse.parse_qs(query_string)
|
||||
if "token" in query_params:
|
||||
auth_info["token"] = query_params["token"][0]
|
||||
|
||||
# If no token found, this will be handled by the authentication system
|
||||
# (either return anonymous context if auth disabled, or raise error if auth enabled)
|
||||
|
||||
return auth_info
|
||||
|
||||
def _get_mcp_capabilities(self):
|
||||
"""Get MCP capabilities with version compatibility"""
|
||||
try:
|
||||
@@ -388,6 +428,10 @@ class DorisServer:
|
||||
self.logger.info("Starting Doris MCP Server (stdio mode)")
|
||||
|
||||
try:
|
||||
# Initialize security manager first (includes JWT setup if enabled)
|
||||
await self.security_manager.initialize()
|
||||
self.logger.info("Security manager initialization completed")
|
||||
|
||||
# Ensure connection manager is initialized
|
||||
await self.connection_manager.initialize()
|
||||
self.logger.info("Connection manager initialization completed")
|
||||
@@ -449,11 +493,15 @@ class DorisServer:
|
||||
|
||||
|
||||
|
||||
async def start_http(self, host: str = os.getenv("SERVER_HOST", _default_config.database.host), port: int = os.getenv("SERVER_PORT", _default_config.server_port)):
|
||||
"""Start Streamable HTTP transport mode"""
|
||||
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}")
|
||||
async def start_http(self, host: str = os.getenv("SERVER_HOST", _default_config.database.host), port: int = os.getenv("SERVER_PORT", _default_config.server_port), workers: int = 1):
|
||||
"""Start Streamable HTTP transport mode with workers support"""
|
||||
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}, workers: {workers}")
|
||||
|
||||
try:
|
||||
# Initialize security manager first (includes JWT setup if enabled)
|
||||
await self.security_manager.initialize()
|
||||
self.logger.info("Security manager initialization completed")
|
||||
|
||||
# Ensure connection manager is initialized
|
||||
await self.connection_manager.initialize()
|
||||
|
||||
@@ -480,6 +528,44 @@ class DorisServer:
|
||||
async def health_check(request):
|
||||
return JSONResponse({"status": "healthy", "service": "doris-mcp-server"})
|
||||
|
||||
# OAuth endpoints
|
||||
from .auth.oauth_handlers import OAuthHandlers
|
||||
oauth_handlers = OAuthHandlers(self.security_manager)
|
||||
|
||||
async def oauth_login(request):
|
||||
return await oauth_handlers.handle_login(request)
|
||||
|
||||
async def oauth_callback(request):
|
||||
return await oauth_handlers.handle_callback(request)
|
||||
|
||||
async def oauth_provider_info(request):
|
||||
return await oauth_handlers.handle_provider_info(request)
|
||||
|
||||
async def oauth_demo(request):
|
||||
return await oauth_handlers.handle_demo_page(request)
|
||||
|
||||
# Token management endpoints
|
||||
from .auth.token_handlers import TokenHandlers
|
||||
token_handlers = TokenHandlers(self.security_manager, self.config)
|
||||
|
||||
async def token_create(request):
|
||||
return await token_handlers.handle_create_token(request)
|
||||
|
||||
async def token_revoke(request):
|
||||
return await token_handlers.handle_revoke_token(request)
|
||||
|
||||
async def token_list(request):
|
||||
return await token_handlers.handle_list_tokens(request)
|
||||
|
||||
async def token_stats(request):
|
||||
return await token_handlers.handle_token_stats(request)
|
||||
|
||||
async def token_cleanup(request):
|
||||
return await token_handlers.handle_cleanup_tokens(request)
|
||||
|
||||
async def token_management(request):
|
||||
return await token_handlers.handle_management_page(request)
|
||||
|
||||
# Lifecycle manager - simplified since we manage session_manager externally
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app: Starlette) -> AsyncIterator[None]:
|
||||
@@ -495,6 +581,18 @@ class DorisServer:
|
||||
debug=True,
|
||||
routes=[
|
||||
Route("/health", health_check, methods=["GET"]),
|
||||
# OAuth endpoints
|
||||
Route("/auth/login", oauth_login, methods=["GET"]),
|
||||
Route("/auth/callback", oauth_callback, methods=["GET"]),
|
||||
Route("/auth/provider", oauth_provider_info, methods=["GET"]),
|
||||
Route("/auth/demo", oauth_demo, methods=["GET"]),
|
||||
# Token management endpoints
|
||||
Route("/token/create", token_create, methods=["GET", "POST"]),
|
||||
Route("/token/revoke", token_revoke, methods=["GET", "DELETE"]),
|
||||
Route("/token/list", token_list, methods=["GET"]),
|
||||
Route("/token/stats", token_stats, methods=["GET"]),
|
||||
Route("/token/cleanup", token_cleanup, methods=["GET", "POST"]),
|
||||
Route("/token/management", token_management, methods=["GET"]),
|
||||
],
|
||||
lifespan=lifespan,
|
||||
)
|
||||
@@ -512,8 +610,10 @@ class DorisServer:
|
||||
self.logger.info(f"Received request for path: {path}")
|
||||
|
||||
try:
|
||||
# Handle health check
|
||||
if path.startswith("/health"):
|
||||
# Handle health check, auth, and token management endpoints
|
||||
if (path.startswith("/health") or
|
||||
path.startswith("/auth/") or
|
||||
path.startswith("/token/")):
|
||||
await starlette_app(scope, receive, send)
|
||||
return
|
||||
|
||||
@@ -526,6 +626,29 @@ class DorisServer:
|
||||
self.logger.info(f"MCP Request - Method: {method}")
|
||||
self.logger.info(f"MCP Request - Headers: {headers}")
|
||||
|
||||
# Authentication check for MCP requests
|
||||
try:
|
||||
# Extract authentication information
|
||||
auth_info = await self._extract_auth_info_from_scope(scope, headers)
|
||||
|
||||
# Authenticate the request
|
||||
auth_context = await self.security_manager.authenticate_request(auth_info)
|
||||
self.logger.info(f"MCP request authenticated: token_id={auth_context.token_id}, client_ip={auth_context.client_ip}")
|
||||
|
||||
# Store auth context in scope for potential use by tools/resources
|
||||
scope["auth_context"] = auth_context
|
||||
|
||||
except Exception as auth_error:
|
||||
self.logger.error(f"MCP authentication failed: {auth_error}")
|
||||
# Return 401 Unauthorized
|
||||
from starlette.responses import JSONResponse
|
||||
response = JSONResponse(
|
||||
{"error": "Authentication required", "message": str(auth_error)},
|
||||
status_code=401
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
# Handle Dify compatibility for GET requests
|
||||
if method == "GET":
|
||||
accept_header = headers.get(b'accept', b'').decode('utf-8')
|
||||
@@ -568,19 +691,35 @@ class DorisServer:
|
||||
self.logger.warning(f"Unsupported scope type: {scope['type']}")
|
||||
return
|
||||
|
||||
# Start uvicorn server with session manager lifecycle
|
||||
config = uvicorn.Config(
|
||||
app=mcp_app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="info"
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
# Choose startup method based on worker count
|
||||
if workers > 1:
|
||||
self.logger.info(f"Using multi-process mode with {workers} workers")
|
||||
self.logger.info("Note: Multi-worker mode provides full MCP functionality with independent worker processes")
|
||||
|
||||
# Run session manager and server together
|
||||
async with session_manager.run():
|
||||
self.logger.info("Session manager started, now starting HTTP server")
|
||||
await server.serve()
|
||||
# Use the dedicated multiworker app module with full MCP support
|
||||
uvicorn.run(
|
||||
"doris_mcp_server.multiworker_app:app",
|
||||
host=host,
|
||||
port=port,
|
||||
workers=workers,
|
||||
log_level="info"
|
||||
)
|
||||
|
||||
else:
|
||||
self.logger.info("Using single-process mode")
|
||||
# Single worker mode, use original logic with session manager lifecycle
|
||||
config = uvicorn.Config(
|
||||
app=mcp_app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="info"
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# Run session manager and server together
|
||||
async with session_manager.run():
|
||||
self.logger.info("Session manager started, now starting HTTP server")
|
||||
await server.serve()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Streamable HTTP server startup failed: {e}")
|
||||
@@ -595,10 +734,16 @@ class DorisServer:
|
||||
self.logger.error(f" Exception {i+1}: {type(exc).__name__}: {exc}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown server"""
|
||||
self.logger.info("Shutting down Doris MCP Server")
|
||||
try:
|
||||
# Shutdown security manager first (includes JWT cleanup)
|
||||
await self.security_manager.shutdown()
|
||||
self.logger.info("Security manager shutdown completed")
|
||||
|
||||
await self.connection_manager.close()
|
||||
self.logger.info("Doris MCP Server has been shut down")
|
||||
except Exception as e:
|
||||
@@ -618,6 +763,11 @@ Transport Modes:
|
||||
Examples:
|
||||
python -m doris_mcp_server --transport stdio
|
||||
python -m doris_mcp_server --transport http --host 0.0.0.0 --port 3000
|
||||
python -m doris_mcp_server --transport stdio --doris-host localhost --doris-port 9030
|
||||
python -m doris_mcp_server --transport http --doris-user admin --doris-database test_db
|
||||
|
||||
# Backward compatibility: --db-* parameters are also supported
|
||||
python -m doris_mcp_server --transport stdio --db-host localhost --db-port 9030
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -632,35 +782,45 @@ Examples:
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default=os.getenv("SERVER_HOST", _default_config.database.host),
|
||||
help=f"Host address for HTTP mode (default: {_default_config.database.host})",
|
||||
default=os.getenv("SERVER_HOST", _default_config.server_host),
|
||||
help=f"Host address for HTTP mode (default: {_default_config.server_host})",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=os.getenv("SERVER_PORT", _default_config.server_port), help=f"Port number for HTTP mode (default: {_default_config.server_port})"
|
||||
"--port",
|
||||
type=int,
|
||||
default=3000,
|
||||
help="Port number for HTTP mode (default: 3000)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--db-host",
|
||||
"--workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of worker processes for HTTP mode (default: 1, use 0 for auto-detect CPU cores)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--doris-host", "--db-host",
|
||||
type=str,
|
||||
default=os.getenv("DB_HOST", _default_config.database.host),
|
||||
default=os.getenv("DORIS_HOST", _default_config.database.host),
|
||||
help=f"Doris database host address (default: {_default_config.database.host})",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--db-port", type=int, default=os.getenv("DB_PORT", _default_config.database.port), help=f"Doris database port number (default: {_default_config.database.port})"
|
||||
"--doris-port", "--db-port", type=int, default=9030, help="Doris database port number (default: 9030)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--db-user", type=str, default=os.getenv("DB_USER", _default_config.database.user), help=f"Doris database username (default: {_default_config.database.user})"
|
||||
"--doris-user", "--db-user", type=str, default=os.getenv("DORIS_USER", _default_config.database.user), help=f"Doris database username (default: {_default_config.database.user})"
|
||||
)
|
||||
|
||||
parser.add_argument("--db-password", type=str, default="", help="Doris database password")
|
||||
parser.add_argument("--doris-password", "--db-password", type=str, default=os.getenv("DORIS_PASSWORD", ""), help="Doris database password")
|
||||
|
||||
parser.add_argument(
|
||||
"--db-database",
|
||||
"--doris-database", "--db-database",
|
||||
type=str,
|
||||
default=os.getenv("DB_DATABASE", _default_config.database.database),
|
||||
default=os.getenv("DORIS_DATABASE", _default_config.database.database),
|
||||
help=f"Doris database name (default: {_default_config.database.database})",
|
||||
)
|
||||
|
||||
@@ -675,28 +835,58 @@ Examples:
|
||||
return parser
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function"""
|
||||
def update_configuration(config: DorisConfig):
|
||||
"""Update doris configuration object"""
|
||||
# For some arguments, if not specified, environment variables or default configurations will be used as default values
|
||||
parser = create_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create configuration - priority: command line arguments > .env file > default values
|
||||
config = DorisConfig.from_env() # First load from .env file and environment variables
|
||||
|
||||
# Update config values
|
||||
# Command line arguments override configuration (if provided)
|
||||
if args.db_host != _default_config.database.host: # If not default value, use command line argument
|
||||
config.database.host = args.db_host
|
||||
if args.db_port != _default_config.database.port:
|
||||
config.database.port = args.db_port
|
||||
if args.db_user != _default_config.database.user:
|
||||
config.database.user = args.db_user
|
||||
if args.db_password: # Use password if provided
|
||||
config.database.password = args.db_password
|
||||
if args.db_database != _default_config.database.database:
|
||||
config.database.database = args.db_database
|
||||
# basic
|
||||
if args.transport != _default_config.transport:
|
||||
config.transport = args.transport
|
||||
if args.host != _default_config.server_host:
|
||||
config.server_host = args.host
|
||||
if args.port != _default_config.server_port:
|
||||
config.server_port = args.port
|
||||
server_name = os.getenv("SERVER_NAME")
|
||||
if server_name:
|
||||
config.server_name = server_name
|
||||
server_version = os.getenv("SERVER_VERSION")
|
||||
if server_version:
|
||||
config.server_version = server_version
|
||||
|
||||
# database
|
||||
if args.doris_host != _default_config.database.host: # If not default value, use command line argument
|
||||
config.database.host = args.doris_host
|
||||
if args.doris_port != _default_config.database.port:
|
||||
config.database.port = args.doris_port
|
||||
if args.doris_user != _default_config.database.user:
|
||||
config.database.user = args.doris_user
|
||||
if args.doris_password: # Use password if provided
|
||||
config.database.password = args.doris_password
|
||||
if args.doris_database != _default_config.database.database:
|
||||
config.database.database = args.doris_database
|
||||
|
||||
# logging
|
||||
if args.log_level != _default_config.logging.level:
|
||||
config.logging.level = args.log_level
|
||||
|
||||
# workers (add to config for HTTP mode)
|
||||
if hasattr(args, 'workers'):
|
||||
config.workers = args.workers
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function"""
|
||||
# Create configuration - priority: command line arguments > env variables > .env file > default values
|
||||
# First load from .env file and environment variables
|
||||
config = DorisConfig.from_env()
|
||||
|
||||
# Then parse the command line arguments, and update the config object.
|
||||
update_configuration(config)
|
||||
|
||||
# Initialize enhanced logging system
|
||||
from .utils.config import ConfigManager
|
||||
config_manager = ConfigManager(config)
|
||||
@@ -710,19 +900,26 @@ async def main():
|
||||
log_system_info()
|
||||
|
||||
logger.info("Starting Doris MCP Server...")
|
||||
logger.info(f"Transport: {args.transport}")
|
||||
logger.info(f"Transport: {config.transport}")
|
||||
logger.info(f"Log Level: {config.logging.level}")
|
||||
|
||||
# Create server instance
|
||||
server = DorisServer(config)
|
||||
|
||||
try:
|
||||
if args.transport == "stdio":
|
||||
if config.transport == "stdio":
|
||||
await server.start_stdio()
|
||||
elif args.transport == "http":
|
||||
await server.start_http(args.host, args.port)
|
||||
elif config.transport == "http":
|
||||
# Get workers configuration with auto-detection support
|
||||
workers = getattr(config, 'workers', 1)
|
||||
if workers == 0:
|
||||
import multiprocessing
|
||||
workers = multiprocessing.cpu_count()
|
||||
logger.info(f"Auto-detected {workers} CPU cores for worker processes")
|
||||
|
||||
await server.start_http(config.server_host, config.server_port, workers)
|
||||
else:
|
||||
logger.error(f"Unsupported transport protocol: {args.transport}")
|
||||
logger.error(f"Unsupported transport protocol: {config.transport}")
|
||||
await server.shutdown()
|
||||
return 1
|
||||
|
||||
|
||||
627
doris_mcp_server/multiworker_app.py
Normal file
627
doris_mcp_server/multiworker_app.py
Normal file
@@ -0,0 +1,627 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Multi-worker application module for doris-mcp-server
|
||||
|
||||
This module provides full MCP functionality with multi-worker support.
|
||||
Each worker process creates its own MCP server and session manager using the same
|
||||
robust architecture as the single-worker mode.
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
# Import MCP components with compatibility handling
|
||||
# Use the same import strategy as main.py for consistency
|
||||
MCP_VERSION = 'unknown'
|
||||
Server = None
|
||||
InitializationOptions = None
|
||||
Prompt = None
|
||||
Resource = None
|
||||
TextContent = None
|
||||
Tool = None
|
||||
|
||||
def _import_mcp_with_compatibility():
|
||||
"""Import MCP components with multi-version compatibility"""
|
||||
global MCP_VERSION, Server, InitializationOptions, Prompt, Resource, TextContent, Tool
|
||||
|
||||
try:
|
||||
# Strategy 1: Try direct server-only imports to avoid client-side issues
|
||||
from mcp.server import Server as _Server
|
||||
from mcp.server.models import InitializationOptions as _InitOptions
|
||||
from mcp.types import (
|
||||
Prompt as _Prompt,
|
||||
Resource as _Resource,
|
||||
TextContent as _TextContent,
|
||||
Tool as _Tool,
|
||||
)
|
||||
|
||||
# Assign to globals
|
||||
Server = _Server
|
||||
InitializationOptions = _InitOptions
|
||||
Prompt = _Prompt
|
||||
Resource = _Resource
|
||||
TextContent = _TextContent
|
||||
Tool = _Tool
|
||||
|
||||
# Try to get version safely
|
||||
try:
|
||||
import mcp
|
||||
MCP_VERSION = getattr(mcp, '__version__', None)
|
||||
if not MCP_VERSION:
|
||||
# Fallback: try to get version from package metadata
|
||||
try:
|
||||
import importlib.metadata
|
||||
MCP_VERSION = importlib.metadata.version('mcp')
|
||||
except Exception:
|
||||
# Second fallback: try pkg_resources
|
||||
try:
|
||||
import pkg_resources
|
||||
MCP_VERSION = pkg_resources.get_distribution('mcp').version
|
||||
except Exception:
|
||||
MCP_VERSION = 'detected-but-version-unknown'
|
||||
except Exception:
|
||||
# Version detection failed, but imports worked
|
||||
try:
|
||||
import importlib.metadata
|
||||
MCP_VERSION = importlib.metadata.version('mcp')
|
||||
except Exception:
|
||||
try:
|
||||
import pkg_resources
|
||||
MCP_VERSION = pkg_resources.get_distribution('mcp').version
|
||||
except Exception:
|
||||
MCP_VERSION = 'imported-successfully'
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"MCP components imported successfully in multiworker, version: {MCP_VERSION}")
|
||||
return True
|
||||
|
||||
except Exception as import_error:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Strategy 2: Handle RequestContext compatibility issues in 1.9.x versions
|
||||
error_str = str(import_error).lower()
|
||||
if 'requestcontext' in error_str and 'too few arguments' in error_str:
|
||||
logger.warning(f"Detected MCP RequestContext compatibility issue: {import_error}")
|
||||
logger.info("Attempting comprehensive workaround for MCP 1.9.x RequestContext issue...")
|
||||
|
||||
try:
|
||||
# Comprehensive monkey patch approach
|
||||
import sys
|
||||
import types
|
||||
|
||||
# Create and install mock modules before any MCP imports
|
||||
if 'mcp.shared.context' not in sys.modules:
|
||||
mock_context_module = types.ModuleType('mcp.shared.context')
|
||||
|
||||
class FlexibleRequestContext:
|
||||
"""Flexible RequestContext that accepts variable arguments"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __class_getitem__(cls, params):
|
||||
# Accept any number of parameters and return cls
|
||||
return cls
|
||||
|
||||
# Add other methods that might be called
|
||||
def __getattr__(self, name):
|
||||
return lambda *args, **kwargs: None
|
||||
|
||||
mock_context_module.RequestContext = FlexibleRequestContext
|
||||
sys.modules['mcp.shared.context'] = mock_context_module
|
||||
|
||||
# Also patch the typing system to be more permissive
|
||||
original_check_generic = None
|
||||
try:
|
||||
import typing
|
||||
if hasattr(typing, '_check_generic'):
|
||||
original_check_generic = typing._check_generic
|
||||
def permissive_check_generic(cls, params, elen):
|
||||
# Don't enforce strict parameter count checking
|
||||
return
|
||||
typing._check_generic = permissive_check_generic
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Clear any cached imports that might have failed
|
||||
modules_to_clear = [k for k in sys.modules.keys() if k.startswith('mcp.')]
|
||||
for module in modules_to_clear:
|
||||
if module in sys.modules:
|
||||
del sys.modules[module]
|
||||
|
||||
# Now try importing again with the patches in place
|
||||
from mcp.server import Server as _Server
|
||||
from mcp.server.models import InitializationOptions as _InitOptions
|
||||
from mcp.types import (
|
||||
Prompt as _Prompt,
|
||||
Resource as _Resource,
|
||||
TextContent as _TextContent,
|
||||
Tool as _Tool,
|
||||
)
|
||||
|
||||
# Assign to globals
|
||||
Server = _Server
|
||||
InitializationOptions = _InitOptions
|
||||
Prompt = _Prompt
|
||||
Resource = _Resource
|
||||
TextContent = _TextContent
|
||||
Tool = _Tool
|
||||
|
||||
# Try to detect actual version even in compatibility mode
|
||||
try:
|
||||
import importlib.metadata
|
||||
actual_version = importlib.metadata.version('mcp')
|
||||
MCP_VERSION = f'compatibility-mode-{actual_version}'
|
||||
except Exception:
|
||||
try:
|
||||
import pkg_resources
|
||||
actual_version = pkg_resources.get_distribution('mcp').version
|
||||
MCP_VERSION = f'compatibility-mode-{actual_version}'
|
||||
except Exception:
|
||||
MCP_VERSION = 'compatibility-mode-1.9.x'
|
||||
|
||||
logger.info("MCP 1.9.x compatibility workaround successful in multiworker!")
|
||||
|
||||
# Restore original typing function if we patched it
|
||||
if original_check_generic:
|
||||
typing._check_generic = original_check_generic
|
||||
|
||||
return True
|
||||
|
||||
except Exception as workaround_error:
|
||||
logger.error(f"MCP compatibility workaround failed in multiworker: {workaround_error}")
|
||||
|
||||
# Restore original typing function if we patched it
|
||||
if original_check_generic:
|
||||
try:
|
||||
import typing
|
||||
typing._check_generic = original_check_generic
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.error(f"Failed to import MCP components in multiworker: {import_error}")
|
||||
return False
|
||||
|
||||
# Perform MCP import with compatibility handling
|
||||
if not _import_mcp_with_compatibility():
|
||||
raise ImportError(
|
||||
"Failed to import MCP components in multiworker. Please ensure MCP is properly installed. "
|
||||
"Supported versions: 1.8.x, 1.9.x"
|
||||
)
|
||||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Route
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
# Import Doris MCP components
|
||||
from .tools.tools_manager import DorisToolsManager
|
||||
from .tools.prompts_manager import DorisPromptsManager
|
||||
from .tools.resources_manager import DorisResourcesManager
|
||||
from .utils.config import DorisConfig
|
||||
from .utils.db import DorisConnectionManager
|
||||
from .utils.security import DorisSecurityManager
|
||||
|
||||
# Global variables for worker-specific instances
|
||||
_worker_server = None
|
||||
_worker_session_manager = None
|
||||
_worker_connection_manager = None
|
||||
_worker_security_manager = None
|
||||
_worker_session_manager_context = None
|
||||
_worker_initialized = False
|
||||
|
||||
def get_mcp_capabilities():
|
||||
"""Get MCP capabilities for worker - use the same logic as main.py"""
|
||||
try:
|
||||
# For MCP 1.9.x and newer
|
||||
from mcp.server.lowlevel.server import NotificationOptions
|
||||
|
||||
capabilities = {
|
||||
"resources": {},
|
||||
"tools": {},
|
||||
"prompts": {},
|
||||
"notification_options": {
|
||||
"prompts_changed": True,
|
||||
"resources_changed": True,
|
||||
"tools_changed": True
|
||||
}
|
||||
}
|
||||
return capabilities
|
||||
except Exception as e:
|
||||
# Import logger properly
|
||||
from .utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
logger.warning(f"Failed to get full capabilities in multiworker: {e}")
|
||||
return {
|
||||
"resources": {},
|
||||
"tools": {},
|
||||
"prompts": {}
|
||||
}
|
||||
|
||||
async def initialize_worker():
|
||||
"""Initialize MCP server and managers for this worker process"""
|
||||
global _worker_server, _worker_session_manager, _worker_connection_manager, _worker_security_manager, _worker_session_manager_context, _worker_initialized, _oauth_handlers, _token_handlers
|
||||
|
||||
if _worker_initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
# Import logger properly
|
||||
from .utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.info(f"Initializing MCP worker process {os.getpid()}")
|
||||
|
||||
# Create configuration
|
||||
config = DorisConfig.from_env()
|
||||
|
||||
# Initialize enhanced logging system
|
||||
from .utils.config import ConfigManager
|
||||
config_manager = ConfigManager(config)
|
||||
config_manager.setup_logging()
|
||||
|
||||
# Create security manager
|
||||
_worker_security_manager = DorisSecurityManager(config)
|
||||
|
||||
# Initialize security manager first (includes JWT setup if enabled)
|
||||
await _worker_security_manager.initialize()
|
||||
logger.info(f"Worker {os.getpid()} security manager initialization completed")
|
||||
|
||||
# Create connection manager with token manager for token-bound DB config
|
||||
token_manager = _worker_security_manager.auth_provider.token_manager if hasattr(_worker_security_manager, 'auth_provider') and hasattr(_worker_security_manager.auth_provider, 'token_manager') else None
|
||||
_worker_connection_manager = DorisConnectionManager(config, _worker_security_manager, token_manager)
|
||||
|
||||
# Set connection manager reference in security manager for database validation
|
||||
_worker_security_manager.connection_manager = _worker_connection_manager
|
||||
|
||||
await _worker_connection_manager.initialize()
|
||||
|
||||
# Create MCP server
|
||||
_worker_server = Server("doris-mcp-server")
|
||||
|
||||
# Create managers
|
||||
resources_manager = DorisResourcesManager(_worker_connection_manager)
|
||||
tools_manager = DorisToolsManager(_worker_connection_manager)
|
||||
prompts_manager = DorisPromptsManager(_worker_connection_manager)
|
||||
|
||||
# Setup MCP handlers
|
||||
@_worker_server.list_resources()
|
||||
async def handle_list_resources() -> list[Resource]:
|
||||
"""Handle resource list request"""
|
||||
try:
|
||||
logger.info("Handling resource list request in worker")
|
||||
resources = await resources_manager.list_resources()
|
||||
logger.info(f"Returning {len(resources)} resources from worker")
|
||||
return resources
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle resource list request in worker: {e}")
|
||||
return []
|
||||
|
||||
@_worker_server.read_resource()
|
||||
async def handle_read_resource(uri: str) -> str:
|
||||
"""Handle resource read request"""
|
||||
try:
|
||||
logger.info(f"Handling resource read request in worker: {uri}")
|
||||
content = await resources_manager.read_resource(uri)
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle resource read request in worker: {e}")
|
||||
return json.dumps(
|
||||
{"error": f"Failed to read resource: {str(e)}", "uri": uri},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
@_worker_server.list_tools()
|
||||
async def handle_list_tools() -> list[Tool]:
|
||||
"""Handle tool list request"""
|
||||
try:
|
||||
logger.info("Handling tool list request in worker")
|
||||
tools = await tools_manager.list_tools()
|
||||
logger.info(f"Returning {len(tools)} tools from worker")
|
||||
return tools
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle tool list request in worker: {e}")
|
||||
return []
|
||||
|
||||
@_worker_server.call_tool()
|
||||
async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
|
||||
"""Handle tool call request"""
|
||||
try:
|
||||
logger.info(f"Handling tool call request in worker: {name}")
|
||||
result = await tools_manager.call_tool(name, arguments)
|
||||
return [TextContent(type="text", text=result)]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle tool call request in worker: {e}")
|
||||
error_result = json.dumps(
|
||||
{
|
||||
"error": f"Tool call failed: {str(e)}",
|
||||
"tool_name": name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
return [TextContent(type="text", text=error_result)]
|
||||
|
||||
@_worker_server.list_prompts()
|
||||
async def handle_list_prompts() -> list[Prompt]:
|
||||
"""Handle prompt list request"""
|
||||
try:
|
||||
logger.info("Handling prompt list request in worker")
|
||||
prompts = await prompts_manager.list_prompts()
|
||||
logger.info(f"Returning {len(prompts)} prompts from worker")
|
||||
return prompts
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle prompt list request in worker: {e}")
|
||||
return []
|
||||
|
||||
@_worker_server.get_prompt()
|
||||
async def handle_get_prompt(name: str, arguments: dict[str, Any]) -> str:
|
||||
"""Handle prompt get request"""
|
||||
try:
|
||||
logger.info(f"Handling prompt get request in worker: {name}")
|
||||
result = await prompts_manager.get_prompt(name, arguments)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle prompt get request in worker: {e}")
|
||||
error_result = json.dumps(
|
||||
{
|
||||
"error": f"Failed to get prompt: {str(e)}",
|
||||
"prompt_name": name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
return error_result
|
||||
|
||||
# Create session manager for this worker
|
||||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||
|
||||
_worker_session_manager = StreamableHTTPSessionManager(
|
||||
app=_worker_server,
|
||||
json_response=True,
|
||||
stateless=True # Use stateless mode for multi-worker compatibility
|
||||
)
|
||||
|
||||
# Start the session manager context
|
||||
_worker_session_manager_context = _worker_session_manager.run()
|
||||
await _worker_session_manager_context.__aenter__()
|
||||
|
||||
# Initialize OAuth and Token handlers
|
||||
from .auth.oauth_handlers import OAuthHandlers
|
||||
from .auth.token_handlers import TokenHandlers
|
||||
_oauth_handlers = OAuthHandlers(_worker_security_manager)
|
||||
_token_handlers = TokenHandlers(_worker_security_manager, config)
|
||||
|
||||
_worker_initialized = True
|
||||
logger.info(f"Worker {os.getpid()} MCP initialization completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
from .utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
logger.error(f"Failed to initialize worker {os.getpid()}: {e}")
|
||||
import traceback
|
||||
logger.error("Complete error stack:")
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
async def health_check(request):
|
||||
"""Health check endpoint that shows worker PID"""
|
||||
return JSONResponse({
|
||||
"status": "healthy",
|
||||
"service": "doris-mcp-server",
|
||||
"worker_pid": os.getpid(),
|
||||
"worker_mode": "multi-process-full-mcp",
|
||||
"mcp_initialized": _worker_initialized,
|
||||
"mcp_version": MCP_VERSION
|
||||
})
|
||||
|
||||
# OAuth and Token handlers (initialize after worker setup)
|
||||
_oauth_handlers = None
|
||||
_token_handlers = None
|
||||
|
||||
async def oauth_login(request):
|
||||
"""OAuth login endpoint"""
|
||||
if not _oauth_handlers:
|
||||
return JSONResponse({"error": "OAuth not initialized"}, status_code=503)
|
||||
return await _oauth_handlers.handle_login(request)
|
||||
|
||||
async def oauth_callback(request):
|
||||
"""OAuth callback endpoint"""
|
||||
if not _oauth_handlers:
|
||||
return JSONResponse({"error": "OAuth not initialized"}, status_code=503)
|
||||
return await _oauth_handlers.handle_callback(request)
|
||||
|
||||
async def oauth_provider_info(request):
|
||||
"""OAuth provider info endpoint"""
|
||||
if not _oauth_handlers:
|
||||
return JSONResponse({"error": "OAuth not initialized"}, status_code=503)
|
||||
return await _oauth_handlers.handle_provider_info(request)
|
||||
|
||||
async def oauth_demo(request):
|
||||
"""OAuth demo page endpoint"""
|
||||
if not _oauth_handlers:
|
||||
from starlette.responses import HTMLResponse
|
||||
return HTMLResponse("<h1>OAuth not initialized</h1>")
|
||||
return await _oauth_handlers.handle_demo_page(request)
|
||||
|
||||
# Token management endpoints
|
||||
async def token_create(request):
|
||||
"""Token creation endpoint"""
|
||||
if not _token_handlers:
|
||||
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
|
||||
return await _token_handlers.handle_create_token(request)
|
||||
|
||||
async def token_revoke(request):
|
||||
"""Token revocation endpoint"""
|
||||
if not _token_handlers:
|
||||
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
|
||||
return await _token_handlers.handle_revoke_token(request)
|
||||
|
||||
async def token_list(request):
|
||||
"""Token listing endpoint"""
|
||||
if not _token_handlers:
|
||||
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
|
||||
return await _token_handlers.handle_list_tokens(request)
|
||||
|
||||
async def token_stats(request):
|
||||
"""Token statistics endpoint"""
|
||||
if not _token_handlers:
|
||||
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
|
||||
return await _token_handlers.handle_token_stats(request)
|
||||
|
||||
async def token_cleanup(request):
|
||||
"""Token cleanup endpoint"""
|
||||
if not _token_handlers:
|
||||
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
|
||||
return await _token_handlers.handle_cleanup_tokens(request)
|
||||
|
||||
async def token_management(request):
|
||||
"""Token management page endpoint"""
|
||||
if not _token_handlers:
|
||||
from starlette.responses import HTMLResponse
|
||||
return HTMLResponse("<h1>Token handlers not initialized</h1>")
|
||||
return await _token_handlers.handle_management_page(request)
|
||||
|
||||
async def root_info(request):
|
||||
"""Root endpoint"""
|
||||
return JSONResponse({
|
||||
"service": "doris-mcp-server",
|
||||
"mode": "multi-worker-full-mcp",
|
||||
"worker_pid": os.getpid(),
|
||||
"mcp_initialized": _worker_initialized,
|
||||
"mcp_version": MCP_VERSION,
|
||||
"endpoints": {
|
||||
"health": "/health",
|
||||
"mcp": "/mcp"
|
||||
}
|
||||
})
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app):
|
||||
"""Application lifespan manager"""
|
||||
# Startup
|
||||
try:
|
||||
await initialize_worker()
|
||||
# Import logger properly
|
||||
from .utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
logger.info(f"Worker {os.getpid()} startup completed")
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
# Shutdown
|
||||
from .utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Close session manager context
|
||||
if _worker_session_manager_context:
|
||||
try:
|
||||
await _worker_session_manager_context.__aexit__(None, None, None)
|
||||
logger.info(f"Worker {os.getpid()} session manager context closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing worker session manager context: {e}")
|
||||
|
||||
if _worker_connection_manager:
|
||||
try:
|
||||
await _worker_connection_manager.close()
|
||||
logger.info(f"Worker {os.getpid()} connection manager closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing worker connection manager: {e}")
|
||||
|
||||
if _worker_security_manager:
|
||||
try:
|
||||
await _worker_security_manager.shutdown()
|
||||
logger.info(f"Worker {os.getpid()} security manager shutdown completed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error shutting down worker security manager: {e}")
|
||||
|
||||
# Shutdown logging system
|
||||
try:
|
||||
from .utils.logger import shutdown_logging
|
||||
shutdown_logging()
|
||||
except Exception as e:
|
||||
logger.error(f"Error shutting down logging system: {e}")
|
||||
|
||||
async def mcp_asgi_app(scope, receive, send):
|
||||
"""ASGI app that handles MCP requests"""
|
||||
if not _worker_initialized:
|
||||
# Send error response if worker not initialized
|
||||
await send({
|
||||
'type': 'http.response.start',
|
||||
'status': 503,
|
||||
'headers': [(b'content-type', b'application/json')]
|
||||
})
|
||||
await send({
|
||||
'type': 'http.response.body',
|
||||
'body': b'{"error": "Worker not initialized"}'
|
||||
})
|
||||
return
|
||||
|
||||
# Import logger properly
|
||||
from .utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Get request path for logging
|
||||
path = scope.get('path', '')
|
||||
method = scope.get('method', 'UNKNOWN')
|
||||
logger.debug(f"Worker {os.getpid()} handling MCP request: {method} {path}")
|
||||
|
||||
# Handle the request directly without nested run context
|
||||
await _worker_session_manager.handle_request(scope, receive, send)
|
||||
|
||||
# Create Starlette app with basic routes
|
||||
basic_app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route("/", root_info, methods=["GET"]),
|
||||
Route("/health", health_check, methods=["GET"]),
|
||||
# OAuth endpoints
|
||||
Route("/auth/login", oauth_login, methods=["GET"]),
|
||||
Route("/auth/callback", oauth_callback, methods=["GET"]),
|
||||
Route("/auth/provider", oauth_provider_info, methods=["GET"]),
|
||||
Route("/auth/demo", oauth_demo, methods=["GET"]),
|
||||
# Token management endpoints
|
||||
Route("/token/create", token_create, methods=["GET", "POST"]),
|
||||
Route("/token/revoke", token_revoke, methods=["GET", "DELETE"]),
|
||||
Route("/token/list", token_list, methods=["GET"]),
|
||||
Route("/token/stats", token_stats, methods=["GET"]),
|
||||
Route("/token/cleanup", token_cleanup, methods=["GET", "POST"]),
|
||||
Route("/token/management", token_management, methods=["GET"]),
|
||||
],
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Create main ASGI app that routes between basic app and MCP
|
||||
async def app(scope, receive, send):
|
||||
"""Main ASGI app that routes requests"""
|
||||
path = scope.get('path', '/')
|
||||
|
||||
if path == "/mcp" or path.startswith('/mcp/'):
|
||||
# Handle MCP requests with session manager
|
||||
await mcp_asgi_app(scope, receive, send)
|
||||
else:
|
||||
# Handle other requests with basic Starlette app (includes auth endpoints)
|
||||
await basic_app(scope, receive, send)
|
||||
@@ -61,7 +61,7 @@ class DorisToolsManager:
|
||||
# Initialize v0.5.0 advanced analytics tools
|
||||
self.data_governance_tools = DataGovernanceTools(connection_manager)
|
||||
self.data_exploration_tools = DataExplorationTools(connection_manager)
|
||||
self.data_quality_tools = DataQualityTools(connection_manager)
|
||||
self.data_quality_tools = DataQualityTools(connection_manager, connection_manager.config)
|
||||
self.security_analytics_tools = SecurityAnalyticsTools(connection_manager)
|
||||
self.dependency_analysis_tools = DependencyAnalysisTools(connection_manager)
|
||||
self.performance_analytics_tools = PerformanceAnalyticsTools(connection_manager)
|
||||
@@ -464,41 +464,87 @@ class DorisToolsManager:
|
||||
|
||||
# 🔄 Unified Data Quality Analysis Tool (New in v0.5.0)
|
||||
@mcp.tool(
|
||||
"analyze_data_quality",
|
||||
description="""[Function Description]: Comprehensive data quality analysis combining completeness and distribution analysis.
|
||||
"get_table_basic_info",
|
||||
description="""[Function Description]: Get basic information about a table including row count, column count, partitions, and size.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to analyze
|
||||
- analysis_scope (string) [Optional] - Analysis scope, default is "comprehensive"
|
||||
* "completeness": Only completeness analysis (null rates, business rules)
|
||||
* "distribution": Only distribution analysis (statistical patterns)
|
||||
* "comprehensive": Full analysis including both completeness and distribution
|
||||
- catalog_name (string) [Optional] - Target catalog name
|
||||
- db_name (string) [Optional] - Target database name
|
||||
""",
|
||||
)
|
||||
async def get_table_basic_info_tool(
|
||||
table_name: str,
|
||||
catalog_name: str = None,
|
||||
db_name: str = None
|
||||
) -> str:
|
||||
"""Get table basic information"""
|
||||
return await self.call_tool("get_table_basic_info", {
|
||||
"table_name": table_name,
|
||||
"catalog_name": catalog_name,
|
||||
"db_name": db_name
|
||||
})
|
||||
|
||||
@mcp.tool(
|
||||
"analyze_columns",
|
||||
description="""[Function Description]: Analyze completeness and distribution of specified columns in a table.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to analyze
|
||||
- columns (array) [Required] - List of column names to analyze
|
||||
- analysis_types (array) [Optional] - Types of analysis to perform, default is ["both"]
|
||||
* "completeness": Only completeness analysis (null rates, non-null counts)
|
||||
* "distribution": Only distribution analysis (statistical patterns by data type)
|
||||
* "both": Both completeness and distribution analysis
|
||||
- sample_size (integer) [Optional] - Maximum number of rows to sample, default is 100000
|
||||
- include_all_columns (boolean) [Optional] - Whether to analyze all columns, default is false
|
||||
- business_rules (array) [Optional] - Business rule validations in format [{"rule_name": "email_format", "sql_condition": "email REGEXP '^[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}$'"}]
|
||||
- catalog_name (string) [Optional] - Target catalog name
|
||||
- db_name (string) [Optional] - Target database name
|
||||
- detailed_response (boolean) [Optional] - Whether to return detailed response including raw data, default is false
|
||||
""",
|
||||
)
|
||||
async def analyze_data_quality_tool(
|
||||
async def analyze_columns_tool(
|
||||
table_name: str,
|
||||
analysis_scope: str = "comprehensive",
|
||||
columns: List[str],
|
||||
analysis_types: List[str] = None,
|
||||
sample_size: int = 100000,
|
||||
include_all_columns: bool = False,
|
||||
business_rules: List[dict] = None,
|
||||
catalog_name: str = None,
|
||||
db_name: str = None,
|
||||
detailed_response: bool = False
|
||||
) -> str:
|
||||
"""Unified data quality analysis tool"""
|
||||
return await self.call_tool("analyze_data_quality", {
|
||||
"""Analyze table columns"""
|
||||
return await self.call_tool("analyze_columns", {
|
||||
"table_name": table_name,
|
||||
"analysis_scope": analysis_scope,
|
||||
"columns": columns,
|
||||
"analysis_types": analysis_types or ["both"],
|
||||
"sample_size": sample_size,
|
||||
"include_all_columns": include_all_columns,
|
||||
"business_rules": business_rules,
|
||||
"catalog_name": catalog_name,
|
||||
"db_name": db_name,
|
||||
"detailed_response": detailed_response
|
||||
})
|
||||
|
||||
@mcp.tool(
|
||||
"analyze_table_storage",
|
||||
description="""[Function Description]: Analyze table's physical distribution and storage information.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to analyze
|
||||
- catalog_name (string) [Optional] - Target catalog name
|
||||
- db_name (string) [Optional] - Target database name
|
||||
- detailed_response (boolean) [Optional] - Whether to return detailed response including raw data, default is false
|
||||
""",
|
||||
)
|
||||
async def analyze_table_storage_tool(
|
||||
table_name: str,
|
||||
catalog_name: str = None,
|
||||
db_name: str = None,
|
||||
detailed_response: bool = False
|
||||
) -> str:
|
||||
"""Analyze table storage"""
|
||||
return await self.call_tool("analyze_table_storage", {
|
||||
"table_name": table_name,
|
||||
"catalog_name": catalog_name,
|
||||
"db_name": db_name,
|
||||
"detailed_response": detailed_response
|
||||
@@ -721,7 +767,7 @@ No parameters required. Returns connection status, configuration, and diagnostic
|
||||
"""Get ADBC connection information and status"""
|
||||
return await self.call_tool("get_adbc_connection_info", {})
|
||||
|
||||
logger.info("Successfully registered 23 tools to MCP server (14 basic + 7 advanced analytics + 2 ADBC tools)")
|
||||
logger.info("Successfully registered 25 tools to MCP server (14 basic + 9 advanced analytics + 2 ADBC tools)")
|
||||
|
||||
async def list_tools(self) -> List[Tool]:
|
||||
"""List all available query tools (for stdio mode)"""
|
||||
@@ -1064,20 +1110,14 @@ No parameters required. Returns connection status, configuration, and diagnostic
|
||||
},
|
||||
),
|
||||
# ==================== v0.5.0 Advanced Analytics Tools ====================
|
||||
# Atomic Data Quality Analysis Tools
|
||||
Tool(
|
||||
name="analyze_data_quality",
|
||||
description="""[Function Description]: Comprehensive data quality analysis combining completeness and distribution analysis.
|
||||
name="get_table_basic_info",
|
||||
description="""[Function Description]: Get basic information about a table including row count, column count, partitions, and size.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to analyze
|
||||
- analysis_scope (string) [Optional] - Analysis scope, default is "comprehensive"
|
||||
* "completeness": Only completeness analysis (null rates, business rules)
|
||||
* "distribution": Only distribution analysis (statistical patterns)
|
||||
* "comprehensive": Full analysis including both completeness and distribution
|
||||
- sample_size (integer) [Optional] - Maximum number of rows to sample, default is 100000
|
||||
- include_all_columns (boolean) [Optional] - Whether to analyze all columns, default is false
|
||||
- business_rules (array) [Optional] - Business rule validations in format [{"rule_name": "email_format", "sql_condition": "email REGEXP '^[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}$'"}]
|
||||
- catalog_name (string) [Optional] - Target catalog name
|
||||
- db_name (string) [Optional] - Target database name
|
||||
""",
|
||||
@@ -1085,10 +1125,58 @@ No parameters required. Returns connection status, configuration, and diagnostic
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"table_name": {"type": "string", "description": "Name of the table to analyze"},
|
||||
"analysis_scope": {"type": "string", "enum": ["completeness", "distribution", "comprehensive"], "description": "Analysis scope", "default": "comprehensive"},
|
||||
"catalog_name": {"type": "string", "description": "Target catalog name"},
|
||||
"db_name": {"type": "string", "description": "Target database name"},
|
||||
},
|
||||
"required": ["table_name"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="analyze_columns",
|
||||
description="""[Function Description]: Analyze completeness and distribution of specified columns in a table.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to analyze
|
||||
- columns (array) [Required] - List of column names to analyze
|
||||
- analysis_types (array) [Optional] - Types of analysis to perform, default is ["both"]
|
||||
* "completeness": Only completeness analysis (null rates, non-null counts)
|
||||
* "distribution": Only distribution analysis (statistical patterns by data type)
|
||||
* "both": Both completeness and distribution analysis
|
||||
- sample_size (integer) [Optional] - Maximum number of rows to sample, default is 100000
|
||||
- catalog_name (string) [Optional] - Target catalog name
|
||||
- db_name (string) [Optional] - Target database name
|
||||
- detailed_response (boolean) [Optional] - Whether to return detailed response including raw data, default is false
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"table_name": {"type": "string", "description": "Name of the table to analyze"},
|
||||
"columns": {"type": "array", "items": {"type": "string"}, "description": "List of column names to analyze"},
|
||||
"analysis_types": {"type": "array", "items": {"type": "string", "enum": ["completeness", "distribution", "both"]}, "description": "Types of analysis to perform", "default": ["both"]},
|
||||
"sample_size": {"type": "integer", "description": "Maximum number of rows to sample", "default": 100000},
|
||||
"include_all_columns": {"type": "boolean", "description": "Whether to analyze all columns", "default": False},
|
||||
"business_rules": {"type": "array", "items": {"type": "object"}, "description": "Business rule validations"},
|
||||
"catalog_name": {"type": "string", "description": "Target catalog name"},
|
||||
"db_name": {"type": "string", "description": "Target database name"},
|
||||
"detailed_response": {"type": "boolean", "description": "Whether to return detailed response including raw data", "default": False},
|
||||
},
|
||||
"required": ["table_name", "columns"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="analyze_table_storage",
|
||||
description="""[Function Description]: Analyze table's physical distribution and storage information.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to analyze
|
||||
- catalog_name (string) [Optional] - Target catalog name
|
||||
- db_name (string) [Optional] - Target database name
|
||||
- detailed_response (boolean) [Optional] - Whether to return detailed response including raw data, default is false
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"table_name": {"type": "string", "description": "Name of the table to analyze"},
|
||||
"catalog_name": {"type": "string", "description": "Target catalog name"},
|
||||
"db_name": {"type": "string", "description": "Target database name"},
|
||||
"detailed_response": {"type": "boolean", "description": "Whether to return detailed response including raw data", "default": False},
|
||||
@@ -1096,7 +1184,6 @@ No parameters required. Returns connection status, configuration, and diagnostic
|
||||
"required": ["table_name"],
|
||||
},
|
||||
),
|
||||
|
||||
Tool(
|
||||
name="trace_column_lineage",
|
||||
description="""[Function Description]: Trace data lineage for specified columns through SQL analysis and dependency mapping.
|
||||
@@ -1323,9 +1410,13 @@ No parameters required. Returns connection status, configuration, and diagnostic
|
||||
elif name == "get_historical_memory_stats":
|
||||
arguments["data_type"] = "historical"
|
||||
result = await self._get_memory_stats_tool(arguments)
|
||||
# v0.5.0 Advanced Analytics Tools
|
||||
elif name == "analyze_data_quality":
|
||||
result = await self._analyze_data_quality_tool(arguments)
|
||||
# v0.5.0 Advanced Analytics Tools - Atomic Data Quality Tools
|
||||
elif name == "get_table_basic_info":
|
||||
result = await self._get_table_basic_info_tool(arguments)
|
||||
elif name == "analyze_columns":
|
||||
result = await self._analyze_columns_tool(arguments)
|
||||
elif name == "analyze_table_storage":
|
||||
result = await self._analyze_table_storage_tool(arguments)
|
||||
elif name == "trace_column_lineage":
|
||||
result = await self._trace_column_lineage_tool(arguments)
|
||||
elif name == "monitor_data_freshness":
|
||||
@@ -1595,26 +1686,46 @@ No parameters required. Returns connection status, configuration, and diagnostic
|
||||
|
||||
# ==================== v0.5.0 Advanced Analytics Tools Private Methods ====================
|
||||
|
||||
async def _analyze_data_quality_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Unified data quality analysis tool routing"""
|
||||
async def _get_table_basic_info_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get table basic information tool routing"""
|
||||
try:
|
||||
# Extract parameters
|
||||
table_name = arguments.get("table_name")
|
||||
analysis_scope = arguments.get("analysis_scope", "comprehensive")
|
||||
catalog_name = arguments.get("catalog_name")
|
||||
db_name = arguments.get("db_name")
|
||||
|
||||
# Delegate to atomic data quality tools
|
||||
result = await self.data_quality_tools.get_table_basic_info(
|
||||
table_name=table_name,
|
||||
catalog_name=catalog_name,
|
||||
db_name=db_name
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"error": str(e),
|
||||
"analysis_type": "table_basic_info",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
async def _analyze_columns_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Analyze columns tool routing"""
|
||||
try:
|
||||
table_name = arguments.get("table_name")
|
||||
columns = arguments.get("columns")
|
||||
analysis_types = arguments.get("analysis_types", ["both"])
|
||||
sample_size = arguments.get("sample_size", 100000)
|
||||
include_all_columns = arguments.get("include_all_columns", False)
|
||||
business_rules = arguments.get("business_rules", [])
|
||||
catalog_name = arguments.get("catalog_name")
|
||||
db_name = arguments.get("db_name")
|
||||
detailed_response = arguments.get("detailed_response", False)
|
||||
|
||||
# Delegate to the unified data quality tools
|
||||
result = await self.data_quality_tools.analyze_data_quality(
|
||||
# Delegate to atomic data quality tools
|
||||
result = await self.data_quality_tools.analyze_columns(
|
||||
table_name=table_name,
|
||||
analysis_scope=analysis_scope,
|
||||
columns=columns,
|
||||
analysis_types=analysis_types,
|
||||
sample_size=sample_size,
|
||||
include_all_columns=include_all_columns,
|
||||
business_rules=business_rules,
|
||||
catalog_name=catalog_name,
|
||||
db_name=db_name,
|
||||
detailed_response=detailed_response
|
||||
@@ -1625,7 +1736,32 @@ No parameters required. Returns connection status, configuration, and diagnostic
|
||||
except Exception as e:
|
||||
return {
|
||||
"error": str(e),
|
||||
"analysis_type": "unified_data_quality",
|
||||
"analysis_type": "columns_analysis",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
async def _analyze_table_storage_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Analyze table storage tool routing"""
|
||||
try:
|
||||
table_name = arguments.get("table_name")
|
||||
catalog_name = arguments.get("catalog_name")
|
||||
db_name = arguments.get("db_name")
|
||||
detailed_response = arguments.get("detailed_response", False)
|
||||
|
||||
# Delegate to atomic data quality tools
|
||||
result = await self.data_quality_tools.analyze_table_storage(
|
||||
table_name=table_name,
|
||||
catalog_name=catalog_name,
|
||||
db_name=db_name,
|
||||
detailed_response=detailed_response
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"error": str(e),
|
||||
"analysis_type": "table_storage_analysis",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@@ -516,7 +516,7 @@ class SQLAnalyzer:
|
||||
try:
|
||||
# Switch to specified database/catalog if provided
|
||||
if catalog_name:
|
||||
await connection.execute(f"USE `{catalog_name}`")
|
||||
await connection.execute(f"SWITCH `{catalog_name}`")
|
||||
if db_name:
|
||||
await connection.execute(f"USE `{db_name}`")
|
||||
|
||||
|
||||
@@ -77,11 +77,50 @@ class DatabaseConfig:
|
||||
class SecurityConfig:
|
||||
"""Security configuration"""
|
||||
|
||||
# Authentication configuration
|
||||
auth_type: str = "token" # token, basic, oauth
|
||||
token_secret: str = "default_secret"
|
||||
# Independent authentication switches - any one enabled allows that method
|
||||
enable_token_auth: bool = False # Enable token-based authentication (default: disabled)
|
||||
enable_jwt_auth: bool = False # Enable JWT authentication (default: disabled)
|
||||
enable_oauth_auth: bool = False # Enable OAuth 2.0/OIDC authentication (default: disabled)
|
||||
|
||||
# Legacy configuration (kept for backward compatibility)
|
||||
auth_type: str = "token" # jwt, token, basic, oauth (deprecated: use individual switches)
|
||||
token_secret: str = "default_secret" # Legacy token secret for backward compatibility
|
||||
token_expiry: int = 3600
|
||||
|
||||
# Enhanced Token Authentication Configuration
|
||||
token_file_path: str = "tokens.json" # Path to token configuration file
|
||||
enable_token_expiry: bool = True # Enable token expiration
|
||||
default_token_expiry_hours: int = 24 * 30 # Default expiry: 30 days
|
||||
token_hash_algorithm: str = "sha256" # Token hashing algorithm: sha256, sha512
|
||||
|
||||
# Token Management Security (New in v0.6.0)
|
||||
enable_http_token_management: bool = False # Enable HTTP token management endpoints (default: disabled for security)
|
||||
token_management_admin_token: str = "" # Admin token for token management endpoints (required if HTTP management enabled)
|
||||
token_management_allowed_ips: list[str] = field(default_factory=lambda: ["127.0.0.1", "::1", "localhost"]) # Allowed IPs for token management
|
||||
require_admin_auth: bool = True # Require admin authentication for token management (default: true)
|
||||
|
||||
# JWT Configuration
|
||||
jwt_algorithm: str = "RS256" # RS256, ES256, HS256
|
||||
jwt_issuer: str = "doris-mcp-server"
|
||||
jwt_audience: str = "doris-mcp-client"
|
||||
jwt_private_key_path: str = ""
|
||||
jwt_public_key_path: str = ""
|
||||
jwt_secret_key: str = "" # Only used for HS256 algorithm
|
||||
jwt_access_token_expiry: int = 3600 # 1 hour
|
||||
jwt_refresh_token_expiry: int = 86400 # 24 hours
|
||||
enable_token_refresh: bool = True
|
||||
enable_token_revocation: bool = True
|
||||
key_rotation_interval: int = 30 * 24 * 3600 # 30 days in seconds
|
||||
|
||||
# JWT Security Features
|
||||
jwt_require_iat: bool = True # Require "issued at" claim
|
||||
jwt_require_exp: bool = True # Require "expires at" claim
|
||||
jwt_require_nbf: bool = False # Require "not before" claim
|
||||
jwt_leeway: int = 10 # Clock skew tolerance in seconds
|
||||
jwt_verify_signature: bool = True # Verify JWT signature
|
||||
jwt_verify_audience: bool = True # Verify audience claim
|
||||
jwt_verify_issuer: bool = True # Verify issuer claim
|
||||
|
||||
# SQL security configuration
|
||||
enable_security_check: bool = True # Main switch: whether to enable SQL security check
|
||||
blocked_keywords: list[str] = field(
|
||||
@@ -115,6 +154,45 @@ class SecurityConfig:
|
||||
enable_masking: bool = True
|
||||
masking_rules: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
# OAuth 2.0/OIDC Configuration
|
||||
oauth_enabled: bool = False
|
||||
oauth_provider: str = "" # 'google', 'microsoft', 'github', 'custom'
|
||||
oauth_client_id: str = ""
|
||||
oauth_client_secret: str = ""
|
||||
oauth_redirect_uri: str = "http://localhost:3000/auth/callback"
|
||||
|
||||
# OIDC Discovery
|
||||
oidc_discovery_url: str = "" # e.g., https://accounts.google.com/.well-known/openid_configuration
|
||||
oauth_authorization_endpoint: str = ""
|
||||
oauth_token_endpoint: str = ""
|
||||
oauth_userinfo_endpoint: str = ""
|
||||
oauth_jwks_uri: str = ""
|
||||
|
||||
# OAuth Scopes and Settings
|
||||
oauth_scopes: list[str] = field(default_factory=list)
|
||||
oauth_state_expiry: int = 600 # State parameter expiry in seconds (10 minutes)
|
||||
oauth_pkce_enabled: bool = True # Enable PKCE for better security
|
||||
oauth_nonce_enabled: bool = True # Enable nonce for OIDC
|
||||
|
||||
# User Mapping Configuration
|
||||
oauth_user_id_claim: str = "sub" # JWT claim for user ID
|
||||
oauth_email_claim: str = "email"
|
||||
oauth_name_claim: str = "name"
|
||||
oauth_roles_claim: str = "roles" # Custom claim for roles
|
||||
oauth_default_roles: list[str] = field(default_factory=lambda: ["oauth_user"])
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize default OAuth scopes based on provider"""
|
||||
if not self.oauth_scopes and self.oauth_provider:
|
||||
if self.oauth_provider == "google":
|
||||
self.oauth_scopes = ["openid", "email", "profile"]
|
||||
elif self.oauth_provider == "microsoft":
|
||||
self.oauth_scopes = ["openid", "profile", "email", "User.Read"]
|
||||
elif self.oauth_provider == "github":
|
||||
self.oauth_scopes = ["user:email", "read:user"]
|
||||
else:
|
||||
self.oauth_scopes = ["openid", "email", "profile"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerformanceConfig:
|
||||
@@ -137,6 +215,33 @@ class PerformanceConfig:
|
||||
max_response_content_size: int = 4096
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataQualityConfig:
|
||||
"""Data quality analysis configuration"""
|
||||
|
||||
# Column analysis configuration
|
||||
max_columns_per_batch: int = 20 # Maximum columns to analyze in a single batch
|
||||
default_sample_size: int = 100000 # Default sample size for analysis
|
||||
|
||||
# Sampling strategy configuration
|
||||
small_table_threshold: int = 100000 # Tables smaller than this use full table analysis
|
||||
medium_table_threshold: int = 1000000 # Tables smaller than this use simple LIMIT sampling
|
||||
# Tables larger than medium_table_threshold use systematic sampling
|
||||
|
||||
# Performance optimization
|
||||
enable_batch_analysis: bool = True # Enable batch analysis for multiple columns
|
||||
batch_timeout: int = 300 # Timeout for batch analysis in seconds
|
||||
|
||||
# Accuracy vs Performance trade-off
|
||||
enable_fast_mode: bool = False # Use approximate algorithms for faster results
|
||||
fast_mode_sample_size: int = 10000 # Sample size for fast mode
|
||||
|
||||
# Statistical analysis configuration
|
||||
enable_distribution_analysis: bool = True # Enable distribution analysis
|
||||
histogram_bins: int = 20 # Number of bins for histogram analysis
|
||||
percentile_levels: list[float] = field(default_factory=lambda: [0.25, 0.5, 0.75, 0.95, 0.99]) # Percentile levels to calculate
|
||||
|
||||
|
||||
@dataclass
|
||||
class ADBCConfig:
|
||||
"""ADBC (Arrow Flight SQL) configuration"""
|
||||
@@ -198,6 +303,7 @@ class DorisConfig:
|
||||
# Basic configuration
|
||||
server_name: str = "doris-mcp-server"
|
||||
server_version: str = "0.4.1"
|
||||
server_host: str = "localhost"
|
||||
server_port: int = 3000
|
||||
transport: str = "stdio"
|
||||
|
||||
@@ -208,6 +314,7 @@ class DorisConfig:
|
||||
database: DatabaseConfig = field(default_factory=DatabaseConfig)
|
||||
security: SecurityConfig = field(default_factory=SecurityConfig)
|
||||
performance: PerformanceConfig = field(default_factory=PerformanceConfig)
|
||||
data_quality: DataQualityConfig = field(default_factory=DataQualityConfig)
|
||||
logging: LoggingConfig = field(default_factory=LoggingConfig)
|
||||
monitoring: MonitoringConfig = field(default_factory=MonitoringConfig)
|
||||
adbc: ADBCConfig = field(default_factory=ADBCConfig)
|
||||
@@ -240,6 +347,9 @@ class DorisConfig:
|
||||
def from_env(cls, env_file: str | None = None) -> "DorisConfig":
|
||||
"""Load configuration from environment variables
|
||||
|
||||
The kv pairs in the. env file will be loaded as environment variables,
|
||||
but the existing environment variables will not be overridden.
|
||||
|
||||
Args:
|
||||
env_file: .env file path, if None, search in the following order:
|
||||
.env, .env.local, .env.production, .env.development
|
||||
@@ -258,7 +368,7 @@ class DorisConfig:
|
||||
env_files = [".env", ".env.local", ".env.production", ".env.development"]
|
||||
for env_path in env_files:
|
||||
if Path(env_path).exists():
|
||||
load_dotenv(env_path)
|
||||
load_dotenv(env_path, override=False)
|
||||
logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_path}")
|
||||
break
|
||||
else:
|
||||
@@ -268,19 +378,34 @@ class DorisConfig:
|
||||
|
||||
config = cls()
|
||||
|
||||
# Database configuration
|
||||
config.database.host = os.getenv("DORIS_HOST", config.database.host)
|
||||
config.database.port = int(os.getenv("DORIS_PORT", str(config.database.port)))
|
||||
config.database.user = os.getenv("DORIS_USER", config.database.user)
|
||||
config.database.password = os.getenv("DORIS_PASSWORD", config.database.password)
|
||||
config.database.database = os.getenv("DORIS_DATABASE", config.database.database)
|
||||
config.database.fe_http_port = int(os.getenv("DORIS_FE_HTTP_PORT", str(config.database.fe_http_port)))
|
||||
# Database configuration - handle empty strings properly
|
||||
doris_host = os.getenv("DORIS_HOST", "").strip()
|
||||
config.database.host = doris_host if doris_host else config.database.host
|
||||
|
||||
doris_port = os.getenv("DORIS_PORT", "").strip()
|
||||
if doris_port and doris_port.isdigit():
|
||||
config.database.port = int(doris_port)
|
||||
|
||||
doris_user = os.getenv("DORIS_USER", "").strip()
|
||||
config.database.user = doris_user if doris_user else config.database.user
|
||||
|
||||
doris_password = os.getenv("DORIS_PASSWORD", "")
|
||||
config.database.password = doris_password if doris_password else config.database.password
|
||||
|
||||
doris_database = os.getenv("DORIS_DATABASE", "").strip()
|
||||
config.database.database = doris_database if doris_database else config.database.database
|
||||
|
||||
doris_fe_http_port = os.getenv("DORIS_FE_HTTP_PORT", "").strip()
|
||||
if doris_fe_http_port and doris_fe_http_port.isdigit():
|
||||
config.database.fe_http_port = int(doris_fe_http_port)
|
||||
|
||||
# BE nodes configuration
|
||||
be_hosts_env = os.getenv("DORIS_BE_HOSTS", "")
|
||||
if be_hosts_env:
|
||||
config.database.be_hosts = [host.strip() for host in be_hosts_env.split(",") if host.strip()]
|
||||
config.database.be_webserver_port = int(os.getenv("DORIS_BE_WEBSERVER_PORT", str(config.database.be_webserver_port)))
|
||||
be_webserver_port = os.getenv("DORIS_BE_WEBSERVER_PORT", "").strip()
|
||||
if be_webserver_port and be_webserver_port.isdigit():
|
||||
config.database.be_webserver_port = int(be_webserver_port)
|
||||
|
||||
# Arrow Flight SQL Configuration
|
||||
fe_arrow_port_env = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
|
||||
@@ -306,6 +431,10 @@ class DorisConfig:
|
||||
)
|
||||
|
||||
# Security configuration
|
||||
# Independent authentication switches
|
||||
config.security.enable_token_auth = os.getenv("ENABLE_TOKEN_AUTH", str(config.security.enable_token_auth)).lower() == "true"
|
||||
config.security.enable_jwt_auth = os.getenv("ENABLE_JWT_AUTH", str(config.security.enable_jwt_auth)).lower() == "true"
|
||||
config.security.enable_oauth_auth = os.getenv("ENABLE_OAUTH_AUTH", str(config.security.enable_oauth_auth)).lower() == "true"
|
||||
config.security.auth_type = os.getenv("AUTH_TYPE", config.security.auth_type)
|
||||
config.security.token_secret = os.getenv("TOKEN_SECRET", config.security.token_secret)
|
||||
config.security.token_expiry = int(
|
||||
@@ -337,6 +466,31 @@ class DorisConfig:
|
||||
os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true"
|
||||
)
|
||||
|
||||
# Enhanced Token Authentication configuration
|
||||
config.security.token_file_path = os.getenv("TOKEN_FILE_PATH", config.security.token_file_path)
|
||||
config.security.enable_token_expiry = (
|
||||
os.getenv("ENABLE_TOKEN_EXPIRY", str(config.security.enable_token_expiry).lower()).lower() == "true"
|
||||
)
|
||||
config.security.default_token_expiry_hours = int(
|
||||
os.getenv("DEFAULT_TOKEN_EXPIRY_HOURS", str(config.security.default_token_expiry_hours))
|
||||
)
|
||||
config.security.token_hash_algorithm = os.getenv("TOKEN_HASH_ALGORITHM", config.security.token_hash_algorithm)
|
||||
|
||||
# Token Management Security Configuration (New in v0.6.0)
|
||||
config.security.enable_http_token_management = (
|
||||
os.getenv("ENABLE_HTTP_TOKEN_MANAGEMENT", str(config.security.enable_http_token_management).lower()).lower() == "true"
|
||||
)
|
||||
config.security.token_management_admin_token = os.getenv("TOKEN_MANAGEMENT_ADMIN_TOKEN", config.security.token_management_admin_token)
|
||||
|
||||
# Parse allowed IPs from comma-separated string
|
||||
allowed_ips_str = os.getenv("TOKEN_MANAGEMENT_ALLOWED_IPS", "")
|
||||
if allowed_ips_str:
|
||||
config.security.token_management_allowed_ips = [ip.strip() for ip in allowed_ips_str.split(",") if ip.strip()]
|
||||
|
||||
config.security.require_admin_auth = (
|
||||
os.getenv("REQUIRE_ADMIN_AUTH", str(config.security.require_admin_auth).lower()).lower() == "true"
|
||||
)
|
||||
|
||||
# Performance configuration
|
||||
config.performance.enable_query_cache = (
|
||||
os.getenv("ENABLE_QUERY_CACHE", "true").lower() == "true"
|
||||
@@ -404,10 +558,44 @@ class DorisConfig:
|
||||
os.getenv("ADBC_ENABLED", str(config.adbc.enabled).lower()).lower() == "true"
|
||||
)
|
||||
|
||||
# Data quality configuration
|
||||
config.data_quality.max_columns_per_batch = int(
|
||||
os.getenv("DATA_QUALITY_MAX_COLUMNS_PER_BATCH", str(config.data_quality.max_columns_per_batch))
|
||||
)
|
||||
config.data_quality.default_sample_size = int(
|
||||
os.getenv("DATA_QUALITY_DEFAULT_SAMPLE_SIZE", str(config.data_quality.default_sample_size))
|
||||
)
|
||||
config.data_quality.small_table_threshold = int(
|
||||
os.getenv("DATA_QUALITY_SMALL_TABLE_THRESHOLD", str(config.data_quality.small_table_threshold))
|
||||
)
|
||||
config.data_quality.medium_table_threshold = int(
|
||||
os.getenv("DATA_QUALITY_MEDIUM_TABLE_THRESHOLD", str(config.data_quality.medium_table_threshold))
|
||||
)
|
||||
config.data_quality.enable_batch_analysis = (
|
||||
os.getenv("DATA_QUALITY_ENABLE_BATCH_ANALYSIS", str(config.data_quality.enable_batch_analysis).lower()).lower() == "true"
|
||||
)
|
||||
config.data_quality.batch_timeout = int(
|
||||
os.getenv("DATA_QUALITY_BATCH_TIMEOUT", str(config.data_quality.batch_timeout))
|
||||
)
|
||||
config.data_quality.enable_fast_mode = (
|
||||
os.getenv("DATA_QUALITY_ENABLE_FAST_MODE", str(config.data_quality.enable_fast_mode).lower()).lower() == "true"
|
||||
)
|
||||
config.data_quality.fast_mode_sample_size = int(
|
||||
os.getenv("DATA_QUALITY_FAST_MODE_SAMPLE_SIZE", str(config.data_quality.fast_mode_sample_size))
|
||||
)
|
||||
config.data_quality.enable_distribution_analysis = (
|
||||
os.getenv("DATA_QUALITY_ENABLE_DISTRIBUTION_ANALYSIS", str(config.data_quality.enable_distribution_analysis).lower()).lower() == "true"
|
||||
)
|
||||
config.data_quality.histogram_bins = int(
|
||||
os.getenv("DATA_QUALITY_HISTOGRAM_BINS", str(config.data_quality.histogram_bins))
|
||||
)
|
||||
|
||||
# Server configuration
|
||||
config.server_name = os.getenv("SERVER_NAME", config.server_name)
|
||||
config.server_version = os.getenv("SERVER_VERSION", config.server_version)
|
||||
config.server_port = int(os.getenv("SERVER_PORT", str(config.server_port)))
|
||||
server_port = os.getenv("SERVER_PORT", "").strip()
|
||||
if server_port and server_port.isdigit():
|
||||
config.server_port = int(server_port)
|
||||
config.temp_files_dir = os.getenv("TEMP_FILES_DIR", config.temp_files_dir)
|
||||
|
||||
return config
|
||||
@@ -443,6 +631,13 @@ class DorisConfig:
|
||||
if hasattr(config.performance, key):
|
||||
setattr(config.performance, key, value)
|
||||
|
||||
# Update data quality configuration
|
||||
if "data_quality" in config_data:
|
||||
dq_config = config_data["data_quality"]
|
||||
for key, value in dq_config.items():
|
||||
if hasattr(config.data_quality, key):
|
||||
setattr(config.data_quality, key, value)
|
||||
|
||||
# Update logging configuration
|
||||
if "logging" in config_data:
|
||||
log_config = config_data["logging"]
|
||||
@@ -516,6 +711,19 @@ class DorisConfig:
|
||||
"idle_timeout": self.performance.idle_timeout,
|
||||
"max_response_content_size": self.performance.max_response_content_size,
|
||||
},
|
||||
"data_quality": {
|
||||
"max_columns_per_batch": self.data_quality.max_columns_per_batch,
|
||||
"default_sample_size": self.data_quality.default_sample_size,
|
||||
"small_table_threshold": self.data_quality.small_table_threshold,
|
||||
"medium_table_threshold": self.data_quality.medium_table_threshold,
|
||||
"enable_batch_analysis": self.data_quality.enable_batch_analysis,
|
||||
"batch_timeout": self.data_quality.batch_timeout,
|
||||
"enable_fast_mode": self.data_quality.enable_fast_mode,
|
||||
"fast_mode_sample_size": self.data_quality.fast_mode_sample_size,
|
||||
"enable_distribution_analysis": self.data_quality.enable_distribution_analysis,
|
||||
"histogram_bins": self.data_quality.histogram_bins,
|
||||
"percentile_levels": self.data_quality.percentile_levels,
|
||||
},
|
||||
"logging": {
|
||||
"level": self.logging.level,
|
||||
"format": self.logging.format,
|
||||
@@ -602,6 +810,31 @@ class DorisConfig:
|
||||
if self.performance.query_timeout <= 0:
|
||||
errors.append("Query timeout must be greater than 0")
|
||||
|
||||
# Validate data quality configuration
|
||||
if self.data_quality.max_columns_per_batch <= 0:
|
||||
errors.append("Max columns per batch must be greater than 0")
|
||||
|
||||
if self.data_quality.default_sample_size <= 0:
|
||||
errors.append("Default sample size must be greater than 0")
|
||||
|
||||
if self.data_quality.small_table_threshold <= 0:
|
||||
errors.append("Small table threshold must be greater than 0")
|
||||
|
||||
if self.data_quality.medium_table_threshold <= 0:
|
||||
errors.append("Medium table threshold must be greater than 0")
|
||||
|
||||
if self.data_quality.small_table_threshold >= self.data_quality.medium_table_threshold:
|
||||
errors.append("Small table threshold must be less than medium table threshold")
|
||||
|
||||
if self.data_quality.batch_timeout <= 0:
|
||||
errors.append("Batch timeout must be greater than 0")
|
||||
|
||||
if self.data_quality.fast_mode_sample_size <= 0:
|
||||
errors.append("Fast mode sample size must be greater than 0")
|
||||
|
||||
if self.data_quality.histogram_bins <= 0:
|
||||
errors.append("Histogram bins must be greater than 0")
|
||||
|
||||
# Validate logging configuration
|
||||
if self.logging.level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
|
||||
errors.append("Log level must be one of DEBUG, INFO, WARNING, ERROR, or CRITICAL")
|
||||
|
||||
@@ -59,29 +59,57 @@ class DataGovernanceTools:
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# 🚀 PROGRESS: Initialize column lineage tracing
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"🔍 Starting Column Lineage Tracing")
|
||||
logger.info(f"📊 Target: {table_name}.{column_name}")
|
||||
logger.info(f"🎯 Trace depth: {depth}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
|
||||
full_table_name = self._build_full_table_name(table_name, catalog_name, db_name)
|
||||
target_column = f"{full_table_name}.{column_name}"
|
||||
|
||||
# 1. Verify target column exists
|
||||
logger.info(f"📝 Full target: {target_column}")
|
||||
|
||||
# 🚀 PROGRESS: Step 1 - Verify target column exists
|
||||
logger.info("🔍 Step 1/4: Verifying target column exists...")
|
||||
verify_start = time.time()
|
||||
if not await self._verify_column_exists(connection, full_table_name, column_name):
|
||||
logger.error(f"❌ Column {column_name} not found in table {full_table_name}")
|
||||
return {"error": f"Column {column_name} not found in table {full_table_name}"}
|
||||
|
||||
# 2. Analyze SQL logs to get lineage relationships
|
||||
verify_time = time.time() - verify_start
|
||||
logger.info(f"✅ Column verified in {verify_time:.2f}s")
|
||||
|
||||
# 🚀 PROGRESS: Step 2 - Analyze SQL logs for lineage relationships
|
||||
logger.info(f"📊 Step 2/4: Analyzing SQL logs for lineage (depth={depth})...")
|
||||
lineage_start = time.time()
|
||||
source_chain = await self._analyze_sql_logs_for_lineage(
|
||||
connection, full_table_name, column_name, depth
|
||||
)
|
||||
lineage_time = time.time() - lineage_start
|
||||
logger.info(f"✅ Found {len(source_chain)} lineage relationships in {lineage_time:.2f}s")
|
||||
|
||||
# 3. Analyze downstream usage
|
||||
# 🚀 PROGRESS: Step 3 - Analyze downstream usage
|
||||
logger.info("⬇️ Step 3/4: Analyzing downstream column usage...")
|
||||
downstream_start = time.time()
|
||||
downstream_usage = await self._analyze_downstream_column_usage(
|
||||
connection, full_table_name, column_name
|
||||
)
|
||||
downstream_time = time.time() - downstream_start
|
||||
logger.info(f"✅ Found {len(downstream_usage)} downstream usages in {downstream_time:.2f}s")
|
||||
|
||||
# 4. Analyze field transformation rules
|
||||
# 🚀 PROGRESS: Step 4 - Extract transformation rules
|
||||
logger.info("🔄 Step 4/4: Extracting transformation rules...")
|
||||
transform_start = time.time()
|
||||
transformation_rules = await self._extract_transformation_rules(
|
||||
connection, full_table_name, column_name
|
||||
)
|
||||
transform_time = time.time() - transform_start
|
||||
logger.info(f"✅ Found {len(transformation_rules)} transformation rules in {transform_time:.2f}s")
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -28,8 +28,7 @@ import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
import random
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import aiomysql
|
||||
from aiomysql import Connection, Pool
|
||||
@@ -191,54 +190,410 @@ class DorisConnection:
|
||||
logging.error(f"Error occurred while closing connection: {e}")
|
||||
|
||||
|
||||
class DorisConnectionManager:
|
||||
"""Doris database connection manager - Simplified Strategy
|
||||
class DorisSessionCache:
|
||||
"""Doris database session cache
|
||||
|
||||
Uses direct connection pool management without session-level caching
|
||||
Implements connection pool health monitoring and proactive cleanup
|
||||
Save doris session in memory and get session by session id.
|
||||
Provide cache_system_session/cache_user_session to specify whether to save system/user type sessions.
|
||||
By default, only session_id is "query" or "system" will be saved.
|
||||
"""
|
||||
|
||||
def __init__(self, config, security_manager=None):
|
||||
def __init__(self, connection_manager=None, cache_system_session=True, cache_user_session=False):
|
||||
self.logger = get_logger(__name__)
|
||||
self.cached = {}
|
||||
self.connection_manager = connection_manager
|
||||
self.cache_system_session = cache_system_session
|
||||
self.cache_user_session = cache_user_session
|
||||
self.logger.info(f"Session Cache initialized, save system session: {self.cache_system_session}, save user session: {self.cache_user_session}")
|
||||
|
||||
def save(self, connection: DorisConnection):
|
||||
if self._should_cache(connection.session_id):
|
||||
self.cached[connection.session_id] = connection
|
||||
|
||||
def get(self, session_id: str) -> Optional[DorisConnection]:
|
||||
self.logger.debug(f"Use cached connection: {session_id}")
|
||||
return self.cached.get(session_id)
|
||||
|
||||
def remove(self, session_id):
|
||||
if session_id in self.cached:
|
||||
del self.cached[session_id]
|
||||
self.logger.debug(f"Removed session {session_id} from cache.")
|
||||
else:
|
||||
if self._should_cache(session_id):
|
||||
self.logger.warning(f"Session {session_id} is not existed.")
|
||||
|
||||
def clear(self):
|
||||
if self.connection_manager:
|
||||
for k, v in self.cached.items():
|
||||
self.connection_manager.release_connection(k, v)
|
||||
self.cached = {}
|
||||
|
||||
def _is_system_session(self, session_id) -> bool:
|
||||
return session_id in ["query", "system"]
|
||||
|
||||
def _should_cache(self, session_id):
|
||||
return (self.cache_system_session and self._is_system_session(session_id)) or (self.cache_user_session and not self._is_system_session(session_id))
|
||||
|
||||
|
||||
class DorisConnectionManager:
|
||||
"""Doris database connection manager - Enhanced Strategy
|
||||
|
||||
Uses direct connection pool management with proper synchronization
|
||||
Implements connection pool health monitoring and proactive cleanup
|
||||
Supports token-bound database configurations for multi-tenant access
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, config, security_manager=None, token_manager=None):
|
||||
self.config = config
|
||||
self.security_manager = security_manager
|
||||
self.pool: Pool | None = None
|
||||
self.logger = get_logger(__name__)
|
||||
self.metrics = ConnectionMetrics()
|
||||
self.security_manager = security_manager
|
||||
self.token_manager = token_manager # Token manager for token-bound DB config
|
||||
self.session_cache = DorisSessionCache(self)
|
||||
|
||||
# Remove session-level connection management
|
||||
# self.session_connections = {} # REMOVED
|
||||
# Store original database config for fallback
|
||||
self.original_db_config = {
|
||||
'host': config.database.host,
|
||||
'port': config.database.port,
|
||||
'user': config.database.user,
|
||||
'password': config.database.password,
|
||||
'database': config.database.database,
|
||||
'charset': config.database.charset
|
||||
}
|
||||
|
||||
# Pool health monitoring
|
||||
self.health_check_interval = 30 # seconds
|
||||
self.pool_warmup_size = 3 # connections to maintain
|
||||
# Current active database config (may be overridden by token-bound config)
|
||||
self.active_db_config = self.original_db_config.copy()
|
||||
|
||||
# Connection pool state management
|
||||
self.pool_recovering = False
|
||||
self.pool_health_check_task = None
|
||||
self.pool_cleanup_task = None
|
||||
|
||||
# Pool recovery lock to prevent race conditions
|
||||
self.pool_recovery_lock = asyncio.Lock()
|
||||
self.pool_recovering = False
|
||||
# Metrics tracking
|
||||
self.metrics = ConnectionMetrics()
|
||||
|
||||
# 🔧 FIX: Add connection acquisition lock to prevent race conditions
|
||||
self._connection_lock = asyncio.Lock()
|
||||
self._recovery_lock = asyncio.Lock()
|
||||
|
||||
# 🔧 FIX: Add connection acquisition queue to serialize requests
|
||||
self._connection_semaphore = asyncio.Semaphore(value=20) # Max concurrent acquisitions
|
||||
|
||||
# Database connection parameters from config.database
|
||||
self.host = config.database.host
|
||||
self.port = config.database.port
|
||||
self.user = config.database.user
|
||||
self.password = config.database.password
|
||||
self.database = config.database.database
|
||||
# Convert charset to aiomysql compatible format
|
||||
charset_map = {"UTF8": "utf8", "UTF8MB4": "utf8mb4"}
|
||||
self.charset = charset_map.get(config.database.charset.upper(), config.database.charset.lower())
|
||||
self.pool_recovery_lock = self._recovery_lock # Compatibility alias
|
||||
self._update_db_params_from_config(self.active_db_config)
|
||||
self.connect_timeout = config.database.connection_timeout
|
||||
|
||||
# Connection pool parameters - more conservative settings
|
||||
self.minsize = config.database.min_connections # This is always 0
|
||||
self.maxsize = config.database.max_connections or 10
|
||||
self.maxsize = config.database.max_connections or 20
|
||||
self.pool_recycle = config.database.max_connection_age or 3600 # 1 hour, more conservative
|
||||
|
||||
# 🔧 FIX: Add missing monitoring parameters that were removed during refactoring
|
||||
self.health_check_interval = 30 # seconds
|
||||
self.pool_warmup_size = 3 # connections to maintain
|
||||
|
||||
def _update_db_params_from_config(self, db_config: dict):
|
||||
"""Update database connection parameters from config dictionary"""
|
||||
self.host = db_config['host']
|
||||
self.port = db_config['port']
|
||||
self.user = db_config['user']
|
||||
self.password = db_config['password']
|
||||
self.database = db_config['database']
|
||||
# Convert charset to aiomysql compatible format
|
||||
charset_map = {"UTF8": "utf8", "UTF8MB4": "utf8mb4"}
|
||||
self.charset = charset_map.get(db_config['charset'].upper(), db_config['charset'].lower())
|
||||
|
||||
def _is_config_empty(self, config_value) -> bool:
|
||||
"""Check if a config value is empty (None, empty string, or 'null')"""
|
||||
return config_value is None or config_value == '' or str(config_value).lower() == 'null'
|
||||
|
||||
def _has_valid_global_config(self) -> bool:
|
||||
"""Check if global database configuration is valid and non-empty"""
|
||||
return (not self._is_config_empty(self.original_db_config['host']) and
|
||||
not self._is_config_empty(self.original_db_config['user']))
|
||||
|
||||
def _find_available_token_with_db_config(self) -> str:
|
||||
"""Find the first available token with database configuration
|
||||
|
||||
Returns:
|
||||
Raw token string if found, empty string if not found
|
||||
"""
|
||||
if not self.token_manager:
|
||||
return ""
|
||||
|
||||
try:
|
||||
for token_hash, token_info in self.token_manager._tokens.items():
|
||||
if (token_info.database_config and
|
||||
token_info.is_active and
|
||||
not self._is_config_empty(token_info.database_config.host) and
|
||||
not self._is_config_empty(token_info.database_config.user)):
|
||||
|
||||
# We need to find the raw token from the hash
|
||||
# This is a bit tricky since we only store hashes
|
||||
# We'll need to use the admin token from tokens.json if it has db config
|
||||
if token_info.token_id == 'admin-token':
|
||||
# Try the known admin token
|
||||
return 'doris_admin_token_123456'
|
||||
elif 'tenant' in token_info.token_id:
|
||||
# For tenant tokens, we'll need a different approach
|
||||
# For now, skip these as we don't know the raw token
|
||||
continue
|
||||
|
||||
return ""
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error finding available token: {e}")
|
||||
return ""
|
||||
|
||||
async def configure_for_token(self, token: str) -> tuple[bool, str]:
|
||||
"""Configure connection manager for token with new priority logic
|
||||
|
||||
Priority: Token-bound DB config > .env config > error
|
||||
|
||||
Args:
|
||||
token: Authentication token to get database config for
|
||||
|
||||
Returns:
|
||||
(success: bool, config_source: str): Result and which config was used
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no valid database configuration is available
|
||||
"""
|
||||
try:
|
||||
# Priority 1: Try token-bound database config first
|
||||
if self.token_manager:
|
||||
db_config = self.token_manager.get_database_config_by_token(token)
|
||||
if db_config:
|
||||
# Convert DatabaseConfig to dictionary
|
||||
token_db_config = {
|
||||
'host': db_config.host,
|
||||
'port': db_config.port,
|
||||
'user': db_config.user,
|
||||
'password': db_config.password,
|
||||
'database': db_config.database,
|
||||
'charset': db_config.charset
|
||||
}
|
||||
|
||||
# Check if token-bound config is valid
|
||||
if (not self._is_config_empty(token_db_config['host']) and
|
||||
not self._is_config_empty(token_db_config['user'])):
|
||||
self.logger.info(f"Using token-bound database configuration for host: {token_db_config['host']}")
|
||||
self.active_db_config = token_db_config
|
||||
self._update_db_params_from_config(self.active_db_config)
|
||||
|
||||
# Create/recreate connection pool with token-bound config
|
||||
await self._ensure_pool_with_current_config()
|
||||
|
||||
return True, "token-bound"
|
||||
|
||||
# Priority 2: Use global .env config if available
|
||||
if self._has_valid_global_config():
|
||||
self.logger.info("Using global .env database configuration")
|
||||
self.active_db_config = self.original_db_config.copy()
|
||||
self._update_db_params_from_config(self.active_db_config)
|
||||
|
||||
# Create/recreate connection pool with global config
|
||||
await self._ensure_pool_with_current_config()
|
||||
|
||||
return True, "global-env"
|
||||
|
||||
# Priority 3: No valid configuration available
|
||||
error_msg = (
|
||||
"No valid database configuration available for this token. "
|
||||
"Please contact administrator to:\n"
|
||||
"1. Add database configuration to tokens.json for this token, OR\n"
|
||||
"2. Configure valid global database settings in .env file\n"
|
||||
"Required fields: DB_HOST, DB_USER"
|
||||
)
|
||||
self.logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to configure database for token: {e}")
|
||||
raise
|
||||
|
||||
async def _ensure_pool_with_current_config(self):
|
||||
"""Ensure connection pool exists with current configuration"""
|
||||
try:
|
||||
# If pool exists with different config, need to recreate it
|
||||
# If no pool exists, create one with current config
|
||||
if self.pool and not self.pool.closed:
|
||||
# Since we can't reliably check pool config attributes,
|
||||
# we'll recreate the pool if we detect a potential config change
|
||||
# by checking if current config differs from what we stored
|
||||
pool_needs_recreation = False
|
||||
|
||||
# Compare current config with what we might have used before
|
||||
if hasattr(self, '_last_pool_config'):
|
||||
current_config = {
|
||||
'host': self.host,
|
||||
'port': self.port,
|
||||
'user': self.user,
|
||||
'database': self.database
|
||||
}
|
||||
if current_config != self._last_pool_config:
|
||||
pool_needs_recreation = True
|
||||
|
||||
if pool_needs_recreation:
|
||||
self.logger.info("Database configuration changed, recreating connection pool")
|
||||
await self._recreate_pool()
|
||||
elif not self.pool:
|
||||
self.logger.info("Creating connection pool with current configuration")
|
||||
await self._create_pool_with_current_config()
|
||||
|
||||
# Test the connection immediately
|
||||
if not await self._test_pool_health():
|
||||
raise RuntimeError(f"Database connection test failed for {self.host}:{self.port}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to ensure connection pool: {e}")
|
||||
raise
|
||||
|
||||
async def _create_pool_with_current_config(self):
|
||||
"""Create connection pool with current database configuration"""
|
||||
try:
|
||||
self.pool = await aiomysql.create_pool(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
user=self.user,
|
||||
password=self.password,
|
||||
db=self.database,
|
||||
charset=self.charset,
|
||||
minsize=self.minsize,
|
||||
maxsize=self.maxsize,
|
||||
pool_recycle=self.pool_recycle,
|
||||
connect_timeout=self.connect_timeout,
|
||||
autocommit=True
|
||||
)
|
||||
|
||||
# Store the current config for comparison later
|
||||
self._last_pool_config = {
|
||||
'host': self.host,
|
||||
'port': self.port,
|
||||
'user': self.user,
|
||||
'database': self.database
|
||||
}
|
||||
|
||||
# Test initial connection
|
||||
if not await self._test_pool_health():
|
||||
raise RuntimeError("Connection pool health check failed")
|
||||
|
||||
# Start background monitoring tasks if not already running
|
||||
if not self.pool_health_check_task or self.pool_health_check_task.done():
|
||||
self.pool_health_check_task = asyncio.create_task(self._pool_health_monitor())
|
||||
if not self.pool_cleanup_task or self.pool_cleanup_task.done():
|
||||
self.pool_cleanup_task = asyncio.create_task(self._pool_cleanup_monitor())
|
||||
|
||||
# Perform initial pool warmup
|
||||
await self._warmup_pool()
|
||||
|
||||
self.logger.info(f"Connection pool created successfully with {self.host}:{self.port}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to create connection pool: {e}")
|
||||
raise
|
||||
|
||||
async def _recreate_pool(self):
|
||||
"""Recreate connection pool with current database configuration"""
|
||||
try:
|
||||
# Close existing pool
|
||||
if self.pool and not self.pool.closed:
|
||||
self.pool.close()
|
||||
await self.pool.wait_closed()
|
||||
self.pool = None
|
||||
|
||||
# Create new pool with current config
|
||||
await self._create_pool_with_current_config()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to recreate connection pool: {e}")
|
||||
raise
|
||||
|
||||
def validate_database_configuration(self) -> tuple[bool, str]:
|
||||
"""Validate database configuration completeness
|
||||
|
||||
Returns:
|
||||
(is_valid, error_message): Configuration validation result
|
||||
"""
|
||||
# Check if Token authentication is enabled
|
||||
token_auth_enabled = getattr(self.config.security, 'enable_token_auth', False)
|
||||
|
||||
# Check if tokens.json exists and has valid tokens with database configs
|
||||
tokens_file_available = False
|
||||
token_bound_configs_available = False
|
||||
|
||||
if self.token_manager:
|
||||
try:
|
||||
# Check if tokens.json file exists
|
||||
import os
|
||||
tokens_file_path = getattr(self.token_manager, 'token_file_path', 'tokens.json')
|
||||
tokens_file_available = os.path.exists(tokens_file_path)
|
||||
|
||||
# Check if any tokens have database configurations
|
||||
if tokens_file_available or self.token_manager._tokens:
|
||||
for token_hash, token_info in self.token_manager._tokens.items():
|
||||
if token_info.database_config:
|
||||
token_bound_configs_available = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Validate .env database configuration
|
||||
env_config_valid = self._has_valid_global_config()
|
||||
|
||||
# Decision logic
|
||||
if token_auth_enabled:
|
||||
if tokens_file_available:
|
||||
# tokens.json exists - either .env OR token-bound config must be valid
|
||||
if env_config_valid or token_bound_configs_available:
|
||||
return True, "Configuration valid"
|
||||
else:
|
||||
return False, (
|
||||
"Token authentication is enabled and tokens.json exists, but no valid database "
|
||||
"configuration found. Please provide either:\n"
|
||||
"1. Valid database configuration in .env file (DB_HOST, DB_USER, etc.)\n"
|
||||
"2. Database configuration in tokens.json for at least one token"
|
||||
)
|
||||
else:
|
||||
# tokens.json does not exist - must have valid .env config
|
||||
if env_config_valid:
|
||||
return True, "Configuration valid"
|
||||
else:
|
||||
return False, (
|
||||
"Token authentication is enabled but tokens.json file not found. "
|
||||
"Either:\n"
|
||||
"1. Create tokens.json file with token configurations\n"
|
||||
"2. Provide valid database configuration in .env file (DB_HOST, DB_USER, etc.)"
|
||||
)
|
||||
else:
|
||||
# Token auth is disabled, must have valid .env config
|
||||
if env_config_valid:
|
||||
return True, "Configuration valid"
|
||||
else:
|
||||
return False, (
|
||||
"Token authentication is disabled. Valid database configuration is required "
|
||||
"in .env file (DB_HOST, DB_USER, etc.)"
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize connection pool with health monitoring"""
|
||||
try:
|
||||
# First validate configuration
|
||||
is_valid, error_message = self.validate_database_configuration()
|
||||
if not is_valid:
|
||||
self.logger.error(f"Database configuration validation failed: {error_message}")
|
||||
raise RuntimeError(f"Database configuration validation failed:\n{error_message}")
|
||||
|
||||
self.logger.info(f"Database configuration validated successfully")
|
||||
self.logger.info(f"Initializing connection pool to {self.host}:{self.port}")
|
||||
|
||||
# Only create connection pool if we have valid global config
|
||||
# Token-bound configs will be handled dynamically during requests
|
||||
if not self._has_valid_global_config():
|
||||
self.logger.info("No valid global database config, pool will be created dynamically for token-bound configs")
|
||||
return
|
||||
|
||||
# Create connection pool
|
||||
self.pool = await aiomysql.create_pool(
|
||||
host=self.host,
|
||||
@@ -506,77 +861,145 @@ class DorisConnectionManager:
|
||||
finally:
|
||||
self.pool_recovering = False
|
||||
|
||||
async def _recover_pool_with_lock(self):
|
||||
"""🔧 FIX: Recovery method that uses the new recovery lock to prevent races"""
|
||||
async with self._recovery_lock:
|
||||
if not self.pool_recovering: # Only recover if not already in progress
|
||||
await self._recover_pool()
|
||||
|
||||
async def get_connection(self, session_id: str) -> DorisConnection:
|
||||
"""Get database connection - Simplified Strategy with pool validation
|
||||
"""🔧 FIX: Simplified connection acquisition without double locking
|
||||
|
||||
Always acquire fresh connection from pool, no session caching
|
||||
Uses only semaphore to prevent too many concurrent acquisitions.
|
||||
If the connection is successfully obtained, it will be added to the connection pool cache.
|
||||
"""
|
||||
try:
|
||||
# Wait for any ongoing recovery to complete
|
||||
if self.pool_recovering:
|
||||
self.logger.debug(f"Pool recovery in progress, waiting for completion...")
|
||||
# Wait for recovery to complete (max 10 seconds)
|
||||
for _ in range(10):
|
||||
if not self.pool_recovering:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
cached_conn = self.session_cache.get(session_id)
|
||||
if cached_conn:
|
||||
return cached_conn
|
||||
|
||||
# 🔧 FIX: Use only semaphore to limit concurrent acquisitions (remove double locking)
|
||||
async with self._connection_semaphore:
|
||||
try:
|
||||
# Wait for any ongoing recovery to complete
|
||||
if self.pool_recovering:
|
||||
self.logger.error("Pool recovery is taking too long, proceeding anyway")
|
||||
# Don't raise error, try to continue
|
||||
self.logger.debug(f"Pool recovery in progress, waiting for completion...")
|
||||
# Wait for recovery to complete (max 10 seconds)
|
||||
start_wait = time.time()
|
||||
while self.pool_recovering and (time.time() - start_wait) < 10:
|
||||
await asyncio.sleep(0.1) # More frequent checks
|
||||
|
||||
# Check if pool is available
|
||||
if not self.pool:
|
||||
self.logger.warning("Connection pool is not available, attempting recovery...")
|
||||
await self._recover_pool()
|
||||
if self.pool_recovering:
|
||||
self.logger.error("Pool recovery is taking too long, proceeding anyway")
|
||||
# Continue but log the issue
|
||||
|
||||
# Check if pool is available
|
||||
if not self.pool:
|
||||
raise RuntimeError("Connection pool is not available and recovery failed")
|
||||
self.logger.warning("Connection pool is not available, attempting recovery...")
|
||||
|
||||
# Check if pool is closed
|
||||
if self.pool.closed:
|
||||
self.logger.warning("Connection pool is closed, attempting recovery...")
|
||||
await self._recover_pool()
|
||||
# Try to use token-bound configuration if available
|
||||
if self.token_manager and not self._has_valid_global_config():
|
||||
available_token = self._find_available_token_with_db_config()
|
||||
if available_token:
|
||||
self.logger.info(f"Using token-bound configuration for pool creation: {available_token}")
|
||||
try:
|
||||
await self.configure_for_token(available_token)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to configure with token-bound config: {e}")
|
||||
|
||||
if not self.pool or self.pool.closed:
|
||||
raise RuntimeError("Connection pool is closed and recovery failed")
|
||||
# Fallback to recovery
|
||||
if not self.pool:
|
||||
await self._recover_pool_with_lock()
|
||||
|
||||
# Simple strategy: always get fresh connection from pool
|
||||
raw_conn = await self.pool.acquire()
|
||||
if not self.pool:
|
||||
raise RuntimeError("Connection pool is not available and recovery failed")
|
||||
|
||||
# Wrap in DorisConnection
|
||||
doris_conn = DorisConnection(raw_conn, session_id, self.security_manager)
|
||||
# Check if pool is closed
|
||||
if self.pool.closed:
|
||||
self.logger.warning("Connection pool is closed, attempting recovery...")
|
||||
await self._recover_pool_with_lock()
|
||||
|
||||
# Simple validation - just check if connection is open
|
||||
if raw_conn.closed:
|
||||
raise RuntimeError("Acquired connection is already closed")
|
||||
if not self.pool or self.pool.closed:
|
||||
raise RuntimeError("Connection pool is closed and recovery failed")
|
||||
|
||||
self.logger.debug(f"✅ Acquired fresh connection for session {session_id}")
|
||||
return doris_conn
|
||||
# 🔧 FIX: Increased timeout to prevent hanging
|
||||
try:
|
||||
raw_conn = await asyncio.wait_for(self.pool.acquire(), timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.error(f"Connection acquisition timed out for session {session_id}")
|
||||
# Try one recovery attempt
|
||||
await self._recover_pool_with_lock()
|
||||
if self.pool and not self.pool.closed:
|
||||
try:
|
||||
raw_conn = await asyncio.wait_for(self.pool.acquire(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
raise RuntimeError("Connection acquisition timed out after recovery")
|
||||
else:
|
||||
raise RuntimeError("Connection acquisition timed out")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get connection for session {session_id}: {e}")
|
||||
raise
|
||||
# Wrap in DorisConnection
|
||||
doris_conn = DorisConnection(raw_conn, session_id, self.security_manager)
|
||||
|
||||
# Basic validation - check if connection is open
|
||||
if raw_conn.closed:
|
||||
# Return connection and raise error
|
||||
try:
|
||||
self.pool.release(raw_conn)
|
||||
except Exception:
|
||||
pass
|
||||
raise RuntimeError("Acquired connection is already closed")
|
||||
|
||||
self.logger.debug(f"✅ Acquired fresh connection for session {session_id}")
|
||||
|
||||
self.session_cache.save(doris_conn)
|
||||
return doris_conn
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get connection for session {session_id}: {e}")
|
||||
raise
|
||||
|
||||
async def release_connection(self, session_id: str, connection: DorisConnection):
|
||||
"""Release connection back to pool - Simplified Strategy"""
|
||||
"""🔧 FIX: Release connection back to pool with proper error handling"""
|
||||
cached_conn = self.session_cache.get(session_id)
|
||||
if cached_conn:
|
||||
self.session_cache.remove(session_id)
|
||||
if not (cached_conn is connection):
|
||||
self.logger.warning("Invalid connection")
|
||||
connection = cached_conn
|
||||
|
||||
if not connection or not connection.connection:
|
||||
self.logger.debug(f"No connection to release for session {session_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
if connection and connection.connection:
|
||||
# Simple strategy: always return to pool
|
||||
if not connection.connection.closed:
|
||||
self.pool.release(connection.connection)
|
||||
self.logger.debug(f"✅ Released connection for session {session_id}")
|
||||
else:
|
||||
self.logger.debug(f"Connection already closed for session {session_id}")
|
||||
# Check pool availability before attempting release
|
||||
if not self.pool or self.pool.closed:
|
||||
self.logger.warning(f"Pool unavailable during release for session {session_id}, force closing connection")
|
||||
try:
|
||||
await connection.connection.ensure_closed()
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
# Check connection state before release
|
||||
if connection.connection.closed:
|
||||
self.logger.debug(f"Connection already closed for session {session_id}")
|
||||
return
|
||||
|
||||
# 🔧 FIX: Simplified release operation without thread wrapper
|
||||
try:
|
||||
self.pool.release(connection.connection)
|
||||
self.logger.debug(f"✅ Released connection for session {session_id}")
|
||||
except Exception as release_error:
|
||||
self.logger.warning(f"Connection release failed for session {session_id}: {release_error}, force closing")
|
||||
await connection.connection.ensure_closed()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error releasing connection for session {session_id}: {e}")
|
||||
# Force close if release fails
|
||||
try:
|
||||
if connection and connection.connection:
|
||||
await connection.connection.ensure_closed()
|
||||
except Exception:
|
||||
pass
|
||||
await connection.connection.ensure_closed()
|
||||
except Exception as close_error:
|
||||
self.logger.debug(f"Error force closing connection: {close_error}")
|
||||
|
||||
async def close(self):
|
||||
"""Close connection manager"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -426,27 +426,27 @@ class DorisQueryExecutor:
|
||||
self, query_request: QueryRequest, auth_context
|
||||
) -> QueryResult:
|
||||
"""Internal query execution"""
|
||||
|
||||
# Database configuration should already be handled during authentication
|
||||
# No need to configure again during query execution
|
||||
|
||||
# Optimize query
|
||||
optimized_sql = await self.query_optimizer.optimize_query(
|
||||
query_request.sql, {"user_roles": getattr(auth_context, 'roles', [])}
|
||||
)
|
||||
|
||||
# Execute query
|
||||
connection = await self.connection_manager.get_connection(
|
||||
query_request.session_id
|
||||
)
|
||||
|
||||
# Set timeout if specified
|
||||
if query_request.timeout:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
connection.execute(optimized_sql, query_request.parameters, auth_context),
|
||||
self.connection_manager.execute_query(query_request.session_id, optimized_sql, query_request.parameters, auth_context),
|
||||
timeout=query_request.timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(f"Query timeout after {query_request.timeout} seconds")
|
||||
else:
|
||||
result = await connection.execute(optimized_sql, query_request.parameters, auth_context)
|
||||
result = await self.connection_manager.execute_query(query_request.session_id, optimized_sql, query_request.parameters, auth_context)
|
||||
|
||||
return result
|
||||
|
||||
@@ -561,23 +561,68 @@ class DorisQueryExecutor:
|
||||
"data": None
|
||||
}
|
||||
|
||||
# Import required security modules
|
||||
from .security import DorisSecurityManager, AuthContext, SecurityLevel
|
||||
|
||||
# Create proper auth context with read-only permissions
|
||||
auth_context = AuthContext(
|
||||
user_id=user_id,
|
||||
roles=["read_only_user"], # Restrictive role for MCP interface
|
||||
permissions=["read_data"], # Only read permissions
|
||||
session_id=session_id,
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
# Perform SQL security validation if enabled
|
||||
if hasattr(self.connection_manager, 'config') and hasattr(self.connection_manager.config, 'security'):
|
||||
if self.connection_manager.config.security.enable_security_check:
|
||||
try:
|
||||
security_manager = DorisSecurityManager(self.connection_manager.config)
|
||||
validation_result = await security_manager.validate_sql_security(sql, auth_context)
|
||||
|
||||
if not validation_result.is_valid:
|
||||
self.logger.warning(f"SQL security validation failed for query: {sql[:100]}...")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"SQL security validation failed: {validation_result.error_message}",
|
||||
"error_type": "security_violation",
|
||||
"blocked_operations": validation_result.blocked_operations,
|
||||
"risk_level": validation_result.risk_level,
|
||||
"data": None,
|
||||
"metadata": {
|
||||
"query": sql,
|
||||
"validation_details": {
|
||||
"blocked_operations": validation_result.blocked_operations,
|
||||
"risk_level": validation_result.risk_level
|
||||
}
|
||||
}
|
||||
}
|
||||
else:
|
||||
self.logger.debug(f"SQL security validation passed for query: {sql[:100]}...")
|
||||
except Exception as security_error:
|
||||
self.logger.error(f"Security validation error: {str(security_error)}")
|
||||
# In case of security validation error, fail safe
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Security validation system error: {str(security_error)}",
|
||||
"error_type": "security_system_error",
|
||||
"data": None,
|
||||
"metadata": {
|
||||
"query": sql,
|
||||
"security_error": str(security_error)
|
||||
}
|
||||
}
|
||||
else:
|
||||
self.logger.info("SQL security check is disabled in configuration")
|
||||
else:
|
||||
self.logger.warning("Security configuration not found, proceeding without validation")
|
||||
|
||||
# Add LIMIT if not present and it's a SELECT query
|
||||
if sql.upper().startswith("SELECT") and "LIMIT" not in sql.upper():
|
||||
if sql.endswith(";"):
|
||||
sql = sql[:-1]
|
||||
sql = f"{sql} LIMIT {limit}"
|
||||
|
||||
# Create auth context for MCP calls
|
||||
class MockAuthContext:
|
||||
def __init__(self):
|
||||
self.user_id = user_id
|
||||
self.roles = ["data_analyst"]
|
||||
self.permissions = ["read_data", "execute_query"]
|
||||
self.session_id = session_id
|
||||
self.security_level = "internal"
|
||||
|
||||
auth_context = MockAuthContext()
|
||||
|
||||
# Create query request
|
||||
query_request = QueryRequest(
|
||||
sql=sql,
|
||||
@@ -831,9 +876,13 @@ class QueryPerformanceMonitor:
|
||||
|
||||
# Unified convenience function for MCP integration
|
||||
async def execute_sql_query(sql: str, connection_manager: DorisConnectionManager, **kwargs) -> Dict[str, Any]:
|
||||
"""Execute SQL query - unified convenience function for MCP tools"""
|
||||
"""Execute SQL query - unified convenience function for MCP tools
|
||||
|
||||
This function now includes security validation to ensure safe query execution.
|
||||
All queries are validated against the configured security policies before execution.
|
||||
"""
|
||||
try:
|
||||
# Create query executor
|
||||
# Create query executor with the connection manager's configuration
|
||||
executor = DorisQueryExecutor(connection_manager)
|
||||
|
||||
try:
|
||||
@@ -843,6 +892,7 @@ async def execute_sql_query(sql: str, connection_manager: DorisConnectionManager
|
||||
session_id = kwargs.get("session_id", "mcp_session")
|
||||
user_id = kwargs.get("user_id", "mcp_user")
|
||||
|
||||
# The execute_sql_for_mcp method now includes security validation
|
||||
result = await executor.execute_sql_for_mcp(
|
||||
sql=sql,
|
||||
limit=limit,
|
||||
@@ -858,5 +908,10 @@ async def execute_sql_query(sql: str, connection_manager: DorisConnectionManager
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Query execution failed: {str(e)}",
|
||||
"data": None
|
||||
"error_type": "execution_error",
|
||||
"data": None,
|
||||
"metadata": {
|
||||
"query": sql,
|
||||
"execution_error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,9 +36,6 @@ from .logger import get_logger
|
||||
# Configure logging
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
|
||||
METADATA_DB_NAME="information_schema"
|
||||
ENABLE_MULTI_DATABASE=os.getenv("ENABLE_MULTI_DATABASE",True)
|
||||
MULTI_DATABASE_NAMES=os.getenv("MULTI_DATABASE_NAMES","")
|
||||
@@ -416,7 +413,7 @@ class MetadataExtractor:
|
||||
|
||||
return matches
|
||||
|
||||
def get_table_schema(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
async def get_table_schema(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the schema information for a table
|
||||
|
||||
@@ -439,7 +436,7 @@ class MetadataExtractor:
|
||||
return self.metadata_cache[cache_key]
|
||||
|
||||
try:
|
||||
# Use information_schema.columns table to get table schema
|
||||
# Use information_schema.columns table to get table schema (async)
|
||||
query = f"""
|
||||
SELECT
|
||||
COLUMN_NAME,
|
||||
@@ -459,7 +456,7 @@ class MetadataExtractor:
|
||||
ORDINAL_POSITION
|
||||
"""
|
||||
|
||||
result = self._execute_query_with_catalog(query, db_name, effective_catalog)
|
||||
result = await self._execute_query_with_catalog_async(query, db_name, effective_catalog)
|
||||
|
||||
if not result:
|
||||
logger.warning(f"Table {effective_catalog or 'default'}.{db_name}.{table_name} does not exist or has no columns")
|
||||
@@ -468,7 +465,6 @@ class MetadataExtractor:
|
||||
# Create structured table schema information
|
||||
columns = []
|
||||
for col in result:
|
||||
# Ensure using actual column values, not column names
|
||||
column_info = {
|
||||
"name": col.get("COLUMN_NAME", ""),
|
||||
"type": col.get("DATA_TYPE", ""),
|
||||
@@ -481,8 +477,8 @@ class MetadataExtractor:
|
||||
}
|
||||
columns.append(column_info)
|
||||
|
||||
# Get table comment
|
||||
table_comment = self.get_table_comment(table_name, db_name, effective_catalog)
|
||||
# Get table comment (async)
|
||||
table_comment = await self.get_table_comment_async(table_name, db_name, effective_catalog)
|
||||
|
||||
# Build complete structure
|
||||
schema = {
|
||||
@@ -493,7 +489,7 @@ class MetadataExtractor:
|
||||
"create_time": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Get table type information
|
||||
# Get table type information (async)
|
||||
try:
|
||||
table_type_query = f"""
|
||||
SELECT
|
||||
@@ -505,7 +501,7 @@ class MetadataExtractor:
|
||||
TABLE_SCHEMA = '{db_name}'
|
||||
AND TABLE_NAME = '{table_name}'
|
||||
"""
|
||||
table_type_result = self._execute_query(table_type_query)
|
||||
table_type_result = await self._execute_query_async(table_type_query)
|
||||
if table_type_result:
|
||||
schema["table_type"] = table_type_result[0].get("TABLE_TYPE", "")
|
||||
schema["engine"] = table_type_result[0].get("ENGINE", "")
|
||||
@@ -521,6 +517,7 @@ class MetadataExtractor:
|
||||
logger.error(f"Error getting table schema: {str(e)}")
|
||||
return {}
|
||||
|
||||
# Deprecated: sync method (kept for compatibility, will be removed)
|
||||
def get_table_comment(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> str:
|
||||
"""
|
||||
Get the comment for a table
|
||||
@@ -571,6 +568,7 @@ class MetadataExtractor:
|
||||
logger.error(f"Error getting table comment: {str(e)}")
|
||||
return ""
|
||||
|
||||
# Deprecated: sync method (kept for compatibility, will be removed)
|
||||
def get_column_comments(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> Dict[str, str]:
|
||||
"""
|
||||
Get comments for all columns in a table
|
||||
@@ -626,6 +624,7 @@ class MetadataExtractor:
|
||||
logger.error(f"Error getting column comments: {str(e)}")
|
||||
return {}
|
||||
|
||||
# Deprecated: sync method (kept for compatibility, will be removed)
|
||||
def get_table_indexes(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get the index information for a table
|
||||
@@ -657,51 +656,36 @@ class MetadataExtractor:
|
||||
query = f"SHOW INDEX FROM `{db_name}`.`{table_name}`"
|
||||
|
||||
try:
|
||||
df = self._execute_query(query, return_dataframe=True)
|
||||
|
||||
# Process results
|
||||
# NOTE: Deprecated sync path retained for compatibility; use async variant instead.
|
||||
# Deprecated sync path removed; return empty indexes on failure
|
||||
result = []
|
||||
indexes = []
|
||||
current_index = None
|
||||
|
||||
if not df.empty:
|
||||
for _, row in df.iterrows():
|
||||
if result:
|
||||
for r in result:
|
||||
try:
|
||||
index_name = row['Key_name']
|
||||
column_name = row['Column_name']
|
||||
|
||||
if current_index is None or current_index['name'] != index_name:
|
||||
index_name = r.get('Key_name')
|
||||
column_name = r.get('Column_name')
|
||||
if current_index is None or current_index.get('name') != index_name:
|
||||
if current_index is not None:
|
||||
indexes.append(current_index)
|
||||
|
||||
current_index = {
|
||||
'name': index_name,
|
||||
'columns': [column_name],
|
||||
'unique': row['Non_unique'] == 0,
|
||||
'type': row['Index_type']
|
||||
'columns': [column_name] if column_name else [],
|
||||
'unique': r.get('Non_unique', 1) == 0,
|
||||
'type': r.get('Index_type', '')
|
||||
}
|
||||
else:
|
||||
current_index['columns'].append(column_name)
|
||||
if column_name:
|
||||
current_index['columns'].append(column_name)
|
||||
except Exception as row_error:
|
||||
logger.warning(f"Failed to process index row data: {row_error}")
|
||||
continue
|
||||
|
||||
if current_index is not None:
|
||||
indexes.append(current_index)
|
||||
except Exception as df_error:
|
||||
logger.warning(f"DataFrame processing failed, trying regular query: {df_error}")
|
||||
# Fall back to regular query
|
||||
result = self._execute_query(query, return_dataframe=False)
|
||||
logger.warning(f"Sync index query (deprecated) failed: {df_error}")
|
||||
indexes = []
|
||||
if result:
|
||||
# Simple processing, no complex index grouping
|
||||
for row in result:
|
||||
if isinstance(row, dict):
|
||||
indexes.append({
|
||||
'name': row.get('Key_name', ''),
|
||||
'columns': [row.get('Column_name', '')],
|
||||
'unique': row.get('Non_unique', 1) == 0,
|
||||
'type': row.get('Index_type', '')
|
||||
})
|
||||
|
||||
# Update cache
|
||||
self.metadata_cache[cache_key] = indexes
|
||||
@@ -712,7 +696,7 @@ class MetadataExtractor:
|
||||
logger.error(f"Error getting index information: {str(e)}")
|
||||
return []
|
||||
|
||||
def get_table_relationships(self) -> List[Dict[str, Any]]:
|
||||
async def get_table_relationships(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Infer table relationships from table comments and naming patterns
|
||||
|
||||
@@ -725,13 +709,13 @@ class MetadataExtractor:
|
||||
|
||||
try:
|
||||
# Get all tables
|
||||
tables = self.get_database_tables(self.db_name)
|
||||
tables = await self.get_database_tables_async(self.db_name)
|
||||
relationships = []
|
||||
|
||||
# Simple foreign key naming convention detection
|
||||
# Example: If a table has a column named xxx_id and another table named xxx exists, it might be a foreign key relationship
|
||||
for table_name in tables:
|
||||
schema = self.get_table_schema(table_name, self.db_name)
|
||||
schema = await self.get_table_schema(table_name, self.db_name)
|
||||
columns = schema.get("columns", [])
|
||||
|
||||
for column in columns:
|
||||
@@ -743,7 +727,7 @@ class MetadataExtractor:
|
||||
# Check if the possible table exists
|
||||
if ref_table_name in tables:
|
||||
# Find possible primary key column
|
||||
ref_schema = self.get_table_schema(ref_table_name, self.db_name)
|
||||
ref_schema = await self.get_table_schema(ref_table_name, self.db_name)
|
||||
ref_columns = ref_schema.get("columns", [])
|
||||
|
||||
# Assume primary key column name is id
|
||||
@@ -766,6 +750,7 @@ class MetadataExtractor:
|
||||
logger.error(f"Error inferring table relationships: {str(e)}")
|
||||
return []
|
||||
|
||||
# Deprecated: sync method (kept for compatibility, will be removed)
|
||||
def get_recent_audit_logs(self, days: int = 7, limit: int = 100) -> pd.DataFrame:
|
||||
"""
|
||||
Get recent audit logs
|
||||
@@ -792,13 +777,14 @@ class MetadataExtractor:
|
||||
ORDER BY time DESC
|
||||
LIMIT {limit}
|
||||
"""
|
||||
df = self._execute_query(query, return_dataframe=True)
|
||||
# Deprecated sync path removed; this method is deprecated overall
|
||||
df = pd.DataFrame()
|
||||
return df
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting audit logs: {str(e)}")
|
||||
return pd.DataFrame()
|
||||
|
||||
def get_catalog_list(self) -> List[Dict[str, Any]]:
|
||||
async def get_catalog_list(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get a list of all catalogs in Doris with detailed information
|
||||
|
||||
@@ -812,7 +798,7 @@ class MetadataExtractor:
|
||||
try:
|
||||
# Use SHOW CATALOGS command to get catalog list
|
||||
query = "SHOW CATALOGS"
|
||||
result = self._execute_query(query)
|
||||
result = await self._execute_query_async(query)
|
||||
|
||||
if not result:
|
||||
catalogs = []
|
||||
@@ -1101,7 +1087,8 @@ class MetadataExtractor:
|
||||
AND TABLE_NAME = '{table_name}'
|
||||
"""
|
||||
|
||||
partitions = self._execute_query(query)
|
||||
# Deprecated sync path removed
|
||||
partitions = []
|
||||
|
||||
if not partitions:
|
||||
return {}
|
||||
@@ -1124,31 +1111,25 @@ class MetadataExtractor:
|
||||
logger.error(f"Error getting partition information for table {db_name}.{table_name}: {str(e)}")
|
||||
return {}
|
||||
|
||||
def _execute_query_with_catalog(self, query: str, db_name: str = None, catalog_name: str = None):
|
||||
# Removed sync _execute_query_with_catalog; use async variant instead
|
||||
|
||||
async def _execute_query_with_catalog_async(self, query: str, db_name: str = None, catalog_name: str = None):
|
||||
"""
|
||||
Execute query with catalog-aware metadata operations using three-part naming
|
||||
Async version of _execute_query_with_catalog to avoid cross-event-loop issues.
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
db_name: Database name to use
|
||||
catalog_name: Catalog name for three-part naming
|
||||
|
||||
Returns:
|
||||
Query result
|
||||
When catalog_name is provided and the SQL targets information_schema, we rewrite
|
||||
the SQL to use three-part naming: `{catalog}.information_schema` and execute it
|
||||
via the same running event loop.
|
||||
"""
|
||||
try:
|
||||
# If catalog_name is specified, modify the query to use three-part naming
|
||||
# for information_schema queries
|
||||
if catalog_name and 'information_schema' in query.lower():
|
||||
# Replace 'information_schema' with 'catalog_name.information_schema'
|
||||
modified_query = query.replace('information_schema', f'{catalog_name}.information_schema')
|
||||
logger.info(f"Modified query for catalog {catalog_name}: {modified_query}")
|
||||
return self._execute_query(modified_query, db_name)
|
||||
return await self._execute_query_async(modified_query, db_name)
|
||||
else:
|
||||
# Execute the original query
|
||||
return self._execute_query(query, db_name)
|
||||
return await self._execute_query_async(query, db_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing query with catalog: {str(e)}")
|
||||
logger.error(f"Error executing async query with catalog: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _execute_query_async(self, query: str, db_name: str = None, return_dataframe: bool = False):
|
||||
@@ -1200,70 +1181,7 @@ class MetadataExtractor:
|
||||
else:
|
||||
return []
|
||||
|
||||
def _execute_query(self, query: str, db_name: str = None, return_dataframe: bool = False):
|
||||
"""
|
||||
Execute database query with proper session management (sync wrapper)
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
db_name: Database name to use (optional)
|
||||
return_dataframe: Whether to return a pandas DataFrame instead of list
|
||||
|
||||
Returns:
|
||||
Query result data (list of dictionaries or pandas DataFrame)
|
||||
"""
|
||||
try:
|
||||
if self.connection_manager:
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import threading
|
||||
|
||||
# Always run in a separate thread with new event loop to avoid conflicts
|
||||
def run_in_new_loop():
|
||||
# Create new event loop for this thread
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(
|
||||
self._execute_query_async(query, db_name, return_dataframe)
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
# Properly close the loop
|
||||
pending = asyncio.all_tasks(new_loop)
|
||||
if pending:
|
||||
new_loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
# Use ThreadPoolExecutor to run in separate thread
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
try:
|
||||
return future.result(timeout=30)
|
||||
except concurrent.futures.TimeoutError:
|
||||
logger.error("Query execution timed out after 30 seconds")
|
||||
if return_dataframe:
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
# Fallback: Return empty result
|
||||
logger.warning("No connection manager provided, returning empty result")
|
||||
if return_dataframe:
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing query: {str(e)}")
|
||||
# Return empty result instead of raising exception to prevent cascade failures
|
||||
if return_dataframe:
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
return []
|
||||
# Removed sync _execute_query; use async methods exclusively
|
||||
|
||||
async def get_table_schema_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> List[Dict[str, Any]]:
|
||||
"""Asynchronously get table schema information"""
|
||||
@@ -1395,6 +1313,129 @@ class MetadataExtractor:
|
||||
logger.error(f"Failed to get catalog list: {e}")
|
||||
return []
|
||||
|
||||
async def get_table_comment_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> str:
|
||||
"""Async version: get the comment for a table."""
|
||||
try:
|
||||
effective_db = db_name or self.db_name
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
TABLE_COMMENT
|
||||
FROM
|
||||
information_schema.tables
|
||||
WHERE
|
||||
TABLE_SCHEMA = '{effective_db}'
|
||||
AND TABLE_NAME = '{table_name}'
|
||||
"""
|
||||
|
||||
result = await self._execute_query_with_catalog_async(query, effective_db, effective_catalog)
|
||||
if not result or not result[0]:
|
||||
return ""
|
||||
return result[0].get("TABLE_COMMENT", "") or ""
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table comment asynchronously: {e}")
|
||||
return ""
|
||||
|
||||
async def get_column_comments_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, str]:
|
||||
"""Async version: get comments for all columns in a table."""
|
||||
try:
|
||||
effective_db = db_name or self.db_name
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
COLUMN_NAME,
|
||||
COLUMN_COMMENT
|
||||
FROM
|
||||
information_schema.columns
|
||||
WHERE
|
||||
TABLE_SCHEMA = '{effective_db}'
|
||||
AND TABLE_NAME = '{table_name}'
|
||||
ORDER BY
|
||||
ORDINAL_POSITION
|
||||
"""
|
||||
|
||||
rows = await self._execute_query_with_catalog_async(query, effective_db, effective_catalog)
|
||||
comments: Dict[str, str] = {}
|
||||
for col in rows or []:
|
||||
name = col.get("COLUMN_NAME", "")
|
||||
if name:
|
||||
comments[name] = col.get("COLUMN_COMMENT", "") or ""
|
||||
return comments
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get column comments asynchronously: {e}")
|
||||
return {}
|
||||
|
||||
async def get_table_indexes_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> List[Dict[str, Any]]:
|
||||
"""Async version: get index information for a table."""
|
||||
try:
|
||||
effective_db = db_name or self.db_name
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
# Build query with catalog prefix if specified
|
||||
if effective_catalog:
|
||||
query = f"SHOW INDEX FROM `{effective_catalog}`.`{effective_db}`.`{table_name}`"
|
||||
logger.info(f"Using three-part naming for async index query: {query}")
|
||||
else:
|
||||
query = f"SHOW INDEX FROM `{effective_db}`.`{table_name}`"
|
||||
|
||||
rows = await self._execute_query_async(query, effective_db)
|
||||
indexes: List[Dict[str, Any]] = []
|
||||
if rows:
|
||||
# Group by Key_name
|
||||
current_index: Dict[str, Any] | None = None
|
||||
for r in rows:
|
||||
try:
|
||||
index_name = r.get('Key_name')
|
||||
column_name = r.get('Column_name')
|
||||
if current_index is None or current_index.get('name') != index_name:
|
||||
if current_index is not None:
|
||||
indexes.append(current_index)
|
||||
current_index = {
|
||||
'name': index_name,
|
||||
'columns': [column_name] if column_name else [],
|
||||
'unique': r.get('Non_unique', 1) == 0,
|
||||
'type': r.get('Index_type', '')
|
||||
}
|
||||
else:
|
||||
if column_name:
|
||||
current_index['columns'].append(column_name)
|
||||
except Exception as row_error:
|
||||
logger.warning(f"Failed to process async index row data: {row_error}")
|
||||
continue
|
||||
if current_index is not None:
|
||||
indexes.append(current_index)
|
||||
|
||||
return indexes
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting index information asynchronously: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_recent_audit_logs_async(self, days: int = 7, limit: int = 100):
|
||||
"""Async version: get recent audit logs and return a pandas DataFrame."""
|
||||
try:
|
||||
start_date = (datetime.now() - timedelta(days=days)).strftime('%Y-%m-%d')
|
||||
query = f"""
|
||||
SELECT client_ip, user, db, time, stmt_id, stmt, state, error_code
|
||||
FROM `__internal_schema`.`audit_log`
|
||||
WHERE `time` >= '{start_date}'
|
||||
AND state = 'EOF' AND error_code = 0
|
||||
AND `stmt` NOT LIKE 'SHOW%'
|
||||
AND `stmt` NOT LIKE 'DESC%'
|
||||
AND `stmt` NOT LIKE 'EXPLAIN%'
|
||||
AND `stmt` NOT LIKE 'SELECT 1%'
|
||||
ORDER BY time DESC
|
||||
LIMIT {limit}
|
||||
"""
|
||||
rows = await self._execute_query_async(query)
|
||||
import pandas as pd
|
||||
return pd.DataFrame(rows or [])
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting audit logs asynchronously: {str(e)}")
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
|
||||
# ==================== Business layer methods (original metadata_tools.py functionality) ====================
|
||||
|
||||
def _format_response(self, success: bool, result: Any = None, error: str = None, message: str = "") -> Dict[str, Any]:
|
||||
@@ -1513,7 +1554,7 @@ class MetadataExtractor:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
try:
|
||||
comment = self.get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
comment = await self.get_table_comment_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=comment)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table comment: {str(e)}", exc_info=True)
|
||||
@@ -1532,7 +1573,7 @@ class MetadataExtractor:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
try:
|
||||
comments = self.get_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
comments = await self.get_column_comments_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=comments)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table column comments: {str(e)}", exc_info=True)
|
||||
@@ -1551,7 +1592,7 @@ class MetadataExtractor:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
try:
|
||||
indexes = self.get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
indexes = await self.get_table_indexes_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=indexes)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table indexes: {str(e)}", exc_info=True)
|
||||
@@ -1575,7 +1616,7 @@ class MetadataExtractor:
|
||||
logger.info(f"Getting audit logs: Days: {days}, Limit: {limit}")
|
||||
|
||||
try:
|
||||
logs_df = self.get_recent_audit_logs(days=days, limit=limit)
|
||||
logs_df = await self.get_recent_audit_logs_async(days=days, limit=limit)
|
||||
|
||||
# Convert DataFrame to JSON format
|
||||
if hasattr(logs_df, 'to_dict'):
|
||||
|
||||
@@ -22,16 +22,17 @@ Implements enterprise-level authentication, authorization, SQL security validati
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import sqlparse
|
||||
from sqlparse.sql import Statement
|
||||
from sqlparse.tokens import Keyword, Name
|
||||
|
||||
from .logger import get_logger
|
||||
from .config import DatabaseConfig
|
||||
|
||||
|
||||
class SecurityLevel(Enum):
|
||||
@@ -45,15 +46,18 @@ class SecurityLevel(Enum):
|
||||
|
||||
@dataclass
|
||||
class AuthContext:
|
||||
"""Authentication context"""
|
||||
"""Authentication context for audit and session tracking"""
|
||||
|
||||
user_id: str
|
||||
roles: list[str]
|
||||
permissions: list[str]
|
||||
session_id: str
|
||||
login_time: datetime | None = None
|
||||
token_id: str = "" # Token identifier for audit logging
|
||||
user_id: str = "" # User identifier
|
||||
roles: list[str] = field(default_factory=list) # User roles
|
||||
permissions: list[str] = field(default_factory=list) # User permissions
|
||||
security_level: 'SecurityLevel' = field(default_factory=lambda: SecurityLevel.INTERNAL) # Security level
|
||||
client_ip: str = "unknown" # Client IP address
|
||||
session_id: str = "" # Session identifier
|
||||
login_time: datetime = field(default_factory=datetime.utcnow)
|
||||
last_activity: datetime | None = None
|
||||
security_level: SecurityLevel = SecurityLevel.INTERNAL
|
||||
token: str = "" # Raw token for token-bound database configuration
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -86,12 +90,13 @@ class DorisSecurityManager:
|
||||
Provides complete security control functionality, including authentication, authorization, SQL security validation and data masking
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, connection_manager=None):
|
||||
self.config = config
|
||||
self.logger = get_logger(__name__)
|
||||
self.connection_manager = connection_manager
|
||||
|
||||
# Initialize security components
|
||||
self.auth_provider = AuthenticationProvider(config)
|
||||
self.auth_provider = AuthenticationProvider(config, self)
|
||||
self.authz_provider = AuthorizationProvider(config)
|
||||
self.sql_validator = SQLSecurityValidator(config)
|
||||
self.masking_processor = DataMaskingProcessor(config)
|
||||
@@ -101,6 +106,36 @@ class DorisSecurityManager:
|
||||
self.sensitive_tables = self._load_sensitive_tables()
|
||||
self.masking_rules = self._load_masking_rules()
|
||||
|
||||
# Track initialization state
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize security manager components"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
# Initialize authentication provider (for JWT setup)
|
||||
await self.auth_provider.initialize()
|
||||
|
||||
self._initialized = True
|
||||
self.logger.info("DorisSecurityManager initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize DorisSecurityManager: {e}")
|
||||
raise
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown security manager components"""
|
||||
try:
|
||||
await self.auth_provider.shutdown()
|
||||
self._initialized = False
|
||||
self.logger.info("DorisSecurityManager shutdown completed")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during DorisSecurityManager shutdown: {e}")
|
||||
raise
|
||||
|
||||
def _load_blocked_keywords(self) -> set[str]:
|
||||
"""Load blocked SQL keywords from configuration"""
|
||||
# Load keywords from configuration, unified source of truth
|
||||
@@ -184,8 +219,59 @@ class DorisSecurityManager:
|
||||
return default_rules
|
||||
|
||||
async def authenticate_request(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Validate request authentication information"""
|
||||
return await self.auth_provider.authenticate(auth_info)
|
||||
"""Validate request authentication information
|
||||
|
||||
Tries authentication methods in order: Token -> JWT -> OAuth
|
||||
Any one method succeeding allows access
|
||||
If all methods are disabled, returns anonymous context
|
||||
"""
|
||||
# Check if any authentication method is enabled
|
||||
if not (self.config.security.enable_token_auth or
|
||||
self.config.security.enable_jwt_auth or
|
||||
self.config.security.enable_oauth_auth):
|
||||
self.logger.debug("All authentication methods are disabled")
|
||||
# Return anonymous context when no authentication is enabled
|
||||
return AuthContext(
|
||||
token_id="anonymous",
|
||||
user_id="anonymous",
|
||||
roles=["anonymous"],
|
||||
permissions=["read"],
|
||||
security_level=SecurityLevel.PUBLIC,
|
||||
client_ip=auth_info.get("client_ip", "unknown"),
|
||||
session_id="anonymous_session"
|
||||
)
|
||||
|
||||
# Try authentication methods in order of preference
|
||||
last_error = None
|
||||
|
||||
# 1. Try Token authentication first (most common)
|
||||
if self.config.security.enable_token_auth:
|
||||
try:
|
||||
return await self.auth_provider.authenticate_token(auth_info)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Token authentication failed: {e}")
|
||||
last_error = e
|
||||
|
||||
# 2. Try JWT authentication
|
||||
if self.config.security.enable_jwt_auth:
|
||||
try:
|
||||
return await self.auth_provider.authenticate_jwt(auth_info)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"JWT authentication failed: {e}")
|
||||
last_error = e
|
||||
|
||||
# 3. Try OAuth authentication
|
||||
if self.config.security.enable_oauth_auth:
|
||||
try:
|
||||
return await self.auth_provider.authenticate_oauth(auth_info)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"OAuth authentication failed: {e}")
|
||||
last_error = e
|
||||
|
||||
# All enabled authentication methods failed
|
||||
error_message = f"Authentication failed: {str(last_error)}" if last_error else "No authentication method succeeded"
|
||||
self.logger.warning(f"Authentication failed for client {auth_info.get('client_ip', 'unknown')}: {error_message}")
|
||||
raise ValueError(error_message)
|
||||
|
||||
async def authorize_resource_access(
|
||||
self, auth_context: AuthContext, resource_uri: str
|
||||
@@ -207,43 +293,362 @@ class DorisSecurityManager:
|
||||
"""Apply data masking processing"""
|
||||
return await self.masking_processor.process(data, auth_context)
|
||||
|
||||
# OAuth-specific methods
|
||||
def get_oauth_authorization_url(self) -> tuple[str, str]:
|
||||
"""Get OAuth authorization URL
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, state)
|
||||
"""
|
||||
if not self.auth_provider.oauth_provider:
|
||||
raise ValueError("OAuth is not enabled")
|
||||
return self.auth_provider.oauth_provider.get_authorization_url()
|
||||
|
||||
async def handle_oauth_callback(self, code: str, state: str) -> AuthContext:
|
||||
"""Handle OAuth callback
|
||||
|
||||
Args:
|
||||
code: Authorization code from OAuth provider
|
||||
state: State parameter for CSRF protection
|
||||
|
||||
Returns:
|
||||
AuthContext for authenticated user
|
||||
"""
|
||||
if not self.auth_provider.oauth_provider:
|
||||
raise ValueError("OAuth is not enabled")
|
||||
return await self.auth_provider.oauth_provider.handle_callback(code, state)
|
||||
|
||||
def get_oauth_provider_info(self) -> dict[str, Any]:
|
||||
"""Get OAuth provider information
|
||||
|
||||
Returns:
|
||||
OAuth provider information
|
||||
"""
|
||||
if not self.auth_provider.oauth_provider:
|
||||
return {"enabled": False}
|
||||
return self.auth_provider.oauth_provider.get_provider_info()
|
||||
|
||||
# Token management methods
|
||||
async def create_token(
|
||||
self,
|
||||
token_id: str,
|
||||
expires_hours: Optional[int] = None,
|
||||
description: str = "",
|
||||
custom_token: Optional[str] = None,
|
||||
database_config: Optional[DatabaseConfig] = None
|
||||
) -> str:
|
||||
"""Create a new API access token
|
||||
|
||||
Args:
|
||||
token_id: Unique token identifier for audit and management
|
||||
expires_hours: Token expiration in hours (None for no expiration)
|
||||
description: Token description for management purposes
|
||||
custom_token: Custom token string (if None, generates random token)
|
||||
database_config: Optional database configuration for this token
|
||||
|
||||
Returns:
|
||||
Generated token string
|
||||
"""
|
||||
if not self.auth_provider.token_manager:
|
||||
raise ValueError("Token manager not initialized")
|
||||
|
||||
return await self.auth_provider.token_manager.create_token(
|
||||
token_id=token_id,
|
||||
expires_hours=expires_hours,
|
||||
description=description,
|
||||
custom_token=custom_token,
|
||||
database_config=database_config
|
||||
)
|
||||
|
||||
async def revoke_token(self, token_id: str) -> bool:
|
||||
"""Revoke a token by token ID
|
||||
|
||||
Args:
|
||||
token_id: Token ID to revoke
|
||||
|
||||
Returns:
|
||||
True if token was revoked successfully
|
||||
"""
|
||||
if not self.auth_provider.token_manager:
|
||||
raise ValueError("Token manager not initialized")
|
||||
|
||||
return await self.auth_provider.token_manager.revoke_token(token_id)
|
||||
|
||||
async def list_tokens(self) -> list[dict[str, Any]]:
|
||||
"""List all tokens (without sensitive data)
|
||||
|
||||
Returns:
|
||||
List of token information
|
||||
"""
|
||||
if not self.auth_provider.token_manager:
|
||||
raise ValueError("Token manager not initialized")
|
||||
|
||||
return await self.auth_provider.token_manager.list_tokens()
|
||||
|
||||
async def cleanup_expired_tokens(self) -> int:
|
||||
"""Remove expired tokens and return count
|
||||
|
||||
Returns:
|
||||
Number of expired tokens removed
|
||||
"""
|
||||
if not self.auth_provider.token_manager:
|
||||
return 0
|
||||
|
||||
return await self.auth_provider.token_manager.cleanup_expired_tokens()
|
||||
|
||||
def get_token_stats(self) -> dict[str, Any]:
|
||||
"""Get token statistics
|
||||
|
||||
Returns:
|
||||
Token statistics dictionary
|
||||
"""
|
||||
if not self.auth_provider.token_manager:
|
||||
return {"error": "Token manager not initialized"}
|
||||
|
||||
return self.auth_provider.token_manager.get_token_stats()
|
||||
|
||||
async def _validate_token_database_config(self, token: str, token_info) -> None:
|
||||
"""Validate database configuration for token immediately during authentication
|
||||
|
||||
This ensures database connectivity issues are caught at authentication time,
|
||||
not during query execution, providing better user experience.
|
||||
|
||||
Args:
|
||||
token: Raw authentication token
|
||||
token_info: TokenInfo object from token validation
|
||||
|
||||
Raises:
|
||||
ValueError: If database configuration is invalid or connection fails
|
||||
"""
|
||||
try:
|
||||
if not self.connection_manager:
|
||||
self.logger.warning("Connection manager not available for immediate database validation")
|
||||
return
|
||||
|
||||
# Configure and test database connection for this token
|
||||
success, config_source = await self.connection_manager.configure_for_token(token)
|
||||
|
||||
if success:
|
||||
self.logger.info(f"Database configuration validated successfully for token {token_info.token_id} (source: {config_source})")
|
||||
else:
|
||||
raise ValueError("Database configuration validation failed")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Database configuration validation failed for token {token_info.token_id}: {str(e)}"
|
||||
self.logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
class AuthenticationProvider:
|
||||
"""Authentication provider"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, security_manager=None):
|
||||
self.config = config
|
||||
self.logger = get_logger(__name__)
|
||||
self.session_cache = {}
|
||||
self.jwt_manager = None
|
||||
self.oauth_provider = None
|
||||
self.token_manager = None
|
||||
self.security_manager = security_manager
|
||||
|
||||
async def authenticate(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Perform identity authentication"""
|
||||
auth_type = auth_info.get("type", "token")
|
||||
# Initialize authentication providers based on individual switches
|
||||
auth_methods_enabled = []
|
||||
|
||||
if auth_type == "token":
|
||||
return await self._authenticate_token(auth_info)
|
||||
elif auth_type == "basic":
|
||||
return await self._authenticate_basic(auth_info)
|
||||
# Initialize Token manager if enabled
|
||||
if config.security.enable_token_auth:
|
||||
self._initialize_token_manager()
|
||||
auth_methods_enabled.append("Token")
|
||||
|
||||
# Initialize JWT manager if enabled
|
||||
if config.security.enable_jwt_auth:
|
||||
self._initialize_jwt_manager()
|
||||
auth_methods_enabled.append("JWT")
|
||||
|
||||
# Initialize OAuth provider if enabled
|
||||
if config.security.enable_oauth_auth or (hasattr(config.security, 'oauth_enabled') and config.security.oauth_enabled):
|
||||
self._initialize_oauth_provider()
|
||||
auth_methods_enabled.append("OAuth")
|
||||
|
||||
if auth_methods_enabled:
|
||||
self.logger.info(f"Authentication enabled with methods: {', '.join(auth_methods_enabled)}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported authentication type: {auth_type}")
|
||||
self.logger.info("All authentication methods are disabled - anonymous access allowed")
|
||||
|
||||
def _initialize_jwt_manager(self):
|
||||
"""Initialize JWT manager"""
|
||||
try:
|
||||
from ..auth.jwt_manager import JWTManager
|
||||
self.jwt_manager = JWTManager(self.config)
|
||||
self.logger.info("JWT manager initialized")
|
||||
except ImportError as e:
|
||||
self.logger.error(f"Failed to import JWT manager: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize JWT manager: {e}")
|
||||
raise
|
||||
|
||||
def _initialize_token_manager(self):
|
||||
"""Initialize Token manager"""
|
||||
try:
|
||||
from ..auth.token_manager import TokenManager
|
||||
self.token_manager = TokenManager(self.config)
|
||||
self.logger.info("Token manager initialized")
|
||||
except ImportError as e:
|
||||
self.logger.error(f"Failed to import Token manager: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize Token manager: {e}")
|
||||
raise
|
||||
|
||||
def _initialize_oauth_provider(self):
|
||||
"""Initialize OAuth provider"""
|
||||
try:
|
||||
from ..auth.oauth_provider import OAuthAuthenticationProvider
|
||||
self.oauth_provider = OAuthAuthenticationProvider(self.config)
|
||||
self.logger.info("OAuth provider initialized")
|
||||
except ImportError as e:
|
||||
self.logger.error(f"Failed to import OAuth provider: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize OAuth provider: {e}")
|
||||
raise
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize authentication provider asynchronously"""
|
||||
if self.jwt_manager:
|
||||
success = await self.jwt_manager.initialize()
|
||||
if not success:
|
||||
raise RuntimeError("Failed to initialize JWT manager")
|
||||
self.logger.info("JWT authentication provider initialized successfully")
|
||||
|
||||
if self.token_manager:
|
||||
# Token manager doesn't need async initialization, just log success
|
||||
self.logger.info("Token authentication provider initialized successfully")
|
||||
|
||||
if self.oauth_provider:
|
||||
success = await self.oauth_provider.initialize()
|
||||
if not success:
|
||||
raise RuntimeError("Failed to initialize OAuth provider")
|
||||
self.logger.info("OAuth authentication provider initialized successfully")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown authentication provider"""
|
||||
if self.jwt_manager:
|
||||
await self.jwt_manager.shutdown()
|
||||
self.logger.info("JWT authentication provider shutdown completed")
|
||||
|
||||
if self.token_manager:
|
||||
# Token manager doesn't need async shutdown, just log
|
||||
self.logger.info("Token authentication provider shutdown completed")
|
||||
|
||||
if self.oauth_provider:
|
||||
await self.oauth_provider.shutdown()
|
||||
self.logger.info("OAuth authentication provider shutdown completed")
|
||||
|
||||
async def authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Perform token authentication"""
|
||||
if not self.config.security.enable_token_auth:
|
||||
raise ValueError("Token authentication is not enabled")
|
||||
return await self._authenticate_token(auth_info)
|
||||
|
||||
async def authenticate_jwt(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Perform JWT authentication"""
|
||||
if not self.config.security.enable_jwt_auth:
|
||||
raise ValueError("JWT authentication is not enabled")
|
||||
return await self._authenticate_jwt(auth_info)
|
||||
|
||||
async def authenticate_oauth(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Perform OAuth authentication"""
|
||||
if not self.config.security.enable_oauth_auth:
|
||||
raise ValueError("OAuth authentication is not enabled")
|
||||
return await self._authenticate_oauth(auth_info)
|
||||
|
||||
async def _authenticate_jwt(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""JWT authentication"""
|
||||
if not self.jwt_manager:
|
||||
raise ValueError("JWT manager not initialized")
|
||||
|
||||
token = auth_info.get("token")
|
||||
if not token:
|
||||
# Try to extract from Authorization header
|
||||
authorization = auth_info.get("authorization")
|
||||
if authorization and authorization.startswith('Bearer '):
|
||||
token = authorization[7:]
|
||||
|
||||
if not token:
|
||||
raise ValueError("Missing JWT token")
|
||||
|
||||
try:
|
||||
# Use JWT middleware for authentication
|
||||
from ..auth.auth_middleware import AuthMiddleware
|
||||
middleware = AuthMiddleware(self.jwt_manager)
|
||||
return await middleware.authenticate_request(auth_info)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"JWT authentication failed: {e}")
|
||||
raise ValueError(f"JWT authentication failed: {str(e)}")
|
||||
|
||||
async def _authenticate_oauth(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""OAuth authentication"""
|
||||
if not self.oauth_provider:
|
||||
raise ValueError("OAuth provider not initialized")
|
||||
|
||||
# Handle different OAuth authentication scenarios
|
||||
if "access_token" in auth_info:
|
||||
# Direct OAuth access token authentication
|
||||
return await self.oauth_provider.authenticate_with_token(auth_info["access_token"])
|
||||
elif "code" in auth_info and "state" in auth_info:
|
||||
# OAuth callback authentication
|
||||
return await self.oauth_provider.handle_callback(auth_info["code"], auth_info["state"])
|
||||
else:
|
||||
raise ValueError("OAuth authentication requires either access_token or code+state")
|
||||
|
||||
async def _authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Token authentication"""
|
||||
if not self.token_manager:
|
||||
raise ValueError("Token manager not initialized")
|
||||
|
||||
token = auth_info.get("token")
|
||||
if not token:
|
||||
# Try to extract from Authorization header
|
||||
authorization = auth_info.get("authorization")
|
||||
if authorization and authorization.startswith('Bearer '):
|
||||
token = authorization[7:]
|
||||
elif authorization and authorization.startswith('Token '):
|
||||
token = authorization[6:]
|
||||
|
||||
if not token:
|
||||
raise ValueError("Missing authentication token")
|
||||
|
||||
# Validate token (simplified implementation, should validate JWT or query authentication service in practice)
|
||||
user_info = await self._validate_token(token)
|
||||
try:
|
||||
# Validate token using TokenManager
|
||||
validation_result = await self.token_manager.validate_token(token)
|
||||
|
||||
return AuthContext(
|
||||
user_id=user_info["user_id"],
|
||||
roles=user_info["roles"],
|
||||
permissions=user_info["permissions"],
|
||||
session_id=auth_info.get("session_id", "default"),
|
||||
login_time=datetime.utcnow(),
|
||||
security_level=SecurityLevel(user_info.get("security_level", "internal")),
|
||||
)
|
||||
if not validation_result.is_valid:
|
||||
raise ValueError(f"Token validation failed: {validation_result.error_message}")
|
||||
|
||||
token_info = validation_result.token_info
|
||||
|
||||
# Immediately validate database configuration for this token
|
||||
if self.security_manager:
|
||||
await self.security_manager._validate_token_database_config(token, token_info)
|
||||
|
||||
return AuthContext(
|
||||
token_id=token_info.token_id,
|
||||
user_id=token_info.token_id, # Use token_id as user_id for token auth
|
||||
roles=["token_user"], # Default role for token users
|
||||
permissions=["read", "write"], # Default permissions for token users
|
||||
security_level=SecurityLevel.INTERNAL,
|
||||
client_ip=auth_info.get("client_ip", "unknown"),
|
||||
session_id=auth_info.get("session_id", f"session_{token_info.token_id}"),
|
||||
login_time=datetime.utcnow(),
|
||||
last_activity=token_info.last_used,
|
||||
token=token # Store raw token for token-bound database configuration
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Token authentication failed: {e}")
|
||||
raise ValueError(f"Token authentication failed: {str(e)}")
|
||||
|
||||
async def _authenticate_basic(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Basic authentication (username password)"""
|
||||
@@ -537,7 +942,7 @@ class SQLSecurityValidator:
|
||||
"""Check SQL injection risks"""
|
||||
# Check common SQL injection patterns
|
||||
injection_patterns = [
|
||||
r"(\s|^)(union|select|insert|update|delete|drop|create|alter)\s+.*\s+(union|select|insert|update|delete|drop|create|alter)",
|
||||
r"(?i)(?<![A-Za-z0-9_])(union|select|insert|update|delete|drop|create|alter)(?![A-Za-z0-9_])\s+[\s\S]*?\s+(?<![A-Za-z0-9_])(union|select|insert|update|delete|drop|create|alter)(?![A-Za-z0-9_])",
|
||||
r"(\s|^)(or|and)\s+\d+\s*=\s*\d+",
|
||||
r"(\s|^)(or|and)\s+['\"].*['\"]",
|
||||
r";\s*(drop|delete|truncate|alter|create)",
|
||||
|
||||
@@ -56,16 +56,31 @@ class SecurityAnalyticsTools:
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# 🚀 PROGRESS: Initialize security analysis
|
||||
logger.info("=" * 70)
|
||||
logger.info(f"🔒 Starting Data Access Pattern Analysis")
|
||||
logger.info(f"📅 Analysis period: {days} days")
|
||||
logger.info(f"👥 Include system users: {include_system_users}")
|
||||
logger.info(f"🎯 Min query threshold: {min_query_threshold}")
|
||||
logger.info("=" * 70)
|
||||
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
|
||||
# Define analysis period
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=days)
|
||||
|
||||
# 1. Get audit log data
|
||||
logger.info(f"📊 Period: {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')}")
|
||||
|
||||
# 🚀 PROGRESS: Step 1 - Get audit log data
|
||||
logger.info("📋 Step 1/5: Retrieving audit log data...")
|
||||
audit_start = time.time()
|
||||
audit_data = await self._get_audit_log_data(connection, start_date, end_date, include_system_users)
|
||||
audit_time = time.time() - audit_start
|
||||
|
||||
if not audit_data:
|
||||
logger.warning("⚠️ No audit data available for the specified period")
|
||||
return {
|
||||
"error": "No audit data available for the specified period",
|
||||
"analysis_period": {
|
||||
@@ -75,25 +90,49 @@ class SecurityAnalyticsTools:
|
||||
}
|
||||
}
|
||||
|
||||
# 2. Analyze user access patterns
|
||||
logger.info(f"✅ Retrieved {len(audit_data)} audit records in {audit_time:.2f}s")
|
||||
|
||||
# 🚀 PROGRESS: Step 2 - Analyze user access patterns
|
||||
logger.info("👤 Step 2/5: Analyzing user access patterns...")
|
||||
user_start = time.time()
|
||||
user_access_analysis = await self._analyze_user_access_patterns(
|
||||
audit_data, min_query_threshold
|
||||
)
|
||||
user_time = time.time() - user_start
|
||||
logger.info(f"✅ Analyzed {len(user_access_analysis)} users in {user_time:.2f}s")
|
||||
|
||||
# 3. Analyze role-based access
|
||||
# 🚀 PROGRESS: Step 3 - Analyze role-based access
|
||||
logger.info("🎭 Step 3/5: Analyzing role-based access patterns...")
|
||||
role_start = time.time()
|
||||
role_access_analysis = await self._analyze_role_access_patterns(
|
||||
connection, user_access_analysis
|
||||
)
|
||||
role_time = time.time() - role_start
|
||||
logger.info(f"✅ Role analysis completed in {role_time:.2f}s")
|
||||
|
||||
# 4. Detect security anomalies
|
||||
# 🚀 PROGRESS: Step 4 - Detect security anomalies
|
||||
logger.info("🚨 Step 4/5: Detecting security anomalies...")
|
||||
anomaly_start = time.time()
|
||||
security_alerts = await self._detect_security_anomalies(
|
||||
audit_data, user_access_analysis
|
||||
)
|
||||
anomaly_time = time.time() - anomaly_start
|
||||
logger.info(f"✅ Found {len(security_alerts)} security alerts in {anomaly_time:.2f}s")
|
||||
|
||||
# 5. Generate access insights
|
||||
# Log alert summary
|
||||
if security_alerts:
|
||||
high_alerts = sum(1 for alert in security_alerts if alert.get("severity") == "high")
|
||||
medium_alerts = sum(1 for alert in security_alerts if alert.get("severity") == "medium")
|
||||
logger.info(f"🚨 Alert breakdown: {high_alerts} high, {medium_alerts} medium")
|
||||
|
||||
# 🚀 PROGRESS: Step 5 - Generate access insights
|
||||
logger.info("💡 Step 5/5: Generating access insights...")
|
||||
insights_start = time.time()
|
||||
access_insights = await self._generate_access_insights(
|
||||
user_access_analysis, role_access_analysis
|
||||
)
|
||||
insights_time = time.time() - insights_start
|
||||
logger.info(f"✅ Access insights generated in {insights_time:.2f}s")
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "doris-mcp-server"
|
||||
version = "0.5.0"
|
||||
version = "0.6.0"
|
||||
description = "Enterprise-grade Model Context Protocol (MCP) server implementation for Apache Doris"
|
||||
authors = [
|
||||
{name = "Yijia Su", email = "freeoneplus@apache.org"}
|
||||
|
||||
@@ -67,6 +67,7 @@ fi
|
||||
export MCP_TRANSPORT_TYPE="http"
|
||||
export MCP_HOST="${MCP_HOST:-0.0.0.0}"
|
||||
export MCP_PORT="${MCP_PORT:-3000}"
|
||||
export WORKERS="${WORKERS:-1}"
|
||||
export ALLOWED_ORIGINS="${ALLOWED_ORIGINS:-*}"
|
||||
export LOG_LEVEL="${LOG_LEVEL:-info}"
|
||||
export MCP_ALLOW_CREDENTIALS="${MCP_ALLOW_CREDENTIALS:-false}"
|
||||
@@ -80,10 +81,11 @@ echo -e "${YELLOW}Service will run on http://${MCP_HOST}:${MCP_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Health Check: http://${MCP_HOST}:${MCP_PORT}/health${NC}"
|
||||
echo -e "${YELLOW}MCP Endpoint: http://${MCP_HOST}:${MCP_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Local access: http://localhost:${MCP_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Workers: ${WORKERS}${NC}"
|
||||
echo -e "${YELLOW}Use Ctrl+C to stop the service${NC}"
|
||||
|
||||
# Start the server in HTTP mode (Streamable HTTP)
|
||||
python -m doris_mcp_server.main --transport http --host ${MCP_HOST} --port ${MCP_PORT}
|
||||
python -m doris_mcp_server.main --transport http --host ${MCP_HOST} --port ${MCP_PORT} --workers ${WORKERS}
|
||||
|
||||
# Check exit status
|
||||
if [ $? -ne 0 ]; then
|
||||
|
||||
@@ -59,7 +59,6 @@ def test_config():
|
||||
config.database.password = "test_password"
|
||||
config.database.database = "test_db"
|
||||
config.database.health_check_interval = 60
|
||||
config.database.min_connections = 5
|
||||
config.database.max_connections = 20
|
||||
config.database.connection_timeout = 30
|
||||
config.database.max_connection_age = 3600
|
||||
|
||||
@@ -34,7 +34,7 @@ class TestEndToEndIntegration:
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Create mock configuration"""
|
||||
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
|
||||
from doris_mcp_server.utils.config import ADBCConfig, DatabaseConfig, SecurityConfig
|
||||
|
||||
config = Mock(spec=DorisConfig)
|
||||
|
||||
@@ -46,7 +46,6 @@ class TestEndToEndIntegration:
|
||||
config.database.password = "test_password"
|
||||
config.database.database = "test_db"
|
||||
config.database.health_check_interval = 60
|
||||
config.database.min_connections = 5
|
||||
config.database.max_connections = 20
|
||||
config.database.connection_timeout = 30
|
||||
config.database.max_connection_age = 3600
|
||||
@@ -57,6 +56,11 @@ class TestEndToEndIntegration:
|
||||
config.security.auth_type = "token"
|
||||
config.security.token_secret = "test_secret"
|
||||
config.security.token_expiry = 3600
|
||||
config.security.blocked_keywords = ["DROP"]
|
||||
|
||||
# Add adbc config
|
||||
config.adbc = Mock(spec=ADBCConfig)
|
||||
config.adbc.enabled = True
|
||||
|
||||
return config
|
||||
|
||||
@@ -231,7 +235,7 @@ class TestEndToEndIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_with_security(self, doris_server):
|
||||
"""Test tool execution with security checks"""
|
||||
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
with patch.object(doris_server.tools_manager.connection_manager, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [{"Database": "test_db"}]
|
||||
|
||||
# Test tool execution through tools manager
|
||||
@@ -258,7 +262,7 @@ class TestEndToEndIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_performance_monitoring_integration(self, doris_server):
|
||||
"""Test performance monitoring integration"""
|
||||
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
with patch.object(doris_server.tools_manager.connection_manager, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [
|
||||
{
|
||||
"query_count": 1500,
|
||||
|
||||
@@ -44,22 +44,31 @@
|
||||
}
|
||||
},
|
||||
"expected_tools": [
|
||||
"analyze_columns",
|
||||
"analyze_data_access_patterns",
|
||||
"analyze_data_flow_dependencies",
|
||||
"analyze_resource_growth_curves",
|
||||
"analyze_slow_queries_topn",
|
||||
"analyze_table_storage",
|
||||
"exec_adbc_query",
|
||||
"exec_query",
|
||||
"get_adbc_connection_info",
|
||||
"get_catalog_list",
|
||||
"get_db_list",
|
||||
"get_db_table_list",
|
||||
"get_table_schema",
|
||||
"get_table_comment",
|
||||
"get_table_column_comments",
|
||||
"get_table_indexes",
|
||||
"get_memory_stats",
|
||||
"get_monitoring_metrics",
|
||||
"get_recent_audit_logs",
|
||||
"get_catalog_list",
|
||||
"get_sql_explain",
|
||||
"get_sql_profile",
|
||||
"get_table_basic_info",
|
||||
"get_table_column_comments",
|
||||
"get_table_comment",
|
||||
"get_table_data_size",
|
||||
"get_monitoring_metrics_info",
|
||||
"get_monitoring_metrics_data",
|
||||
"get_realtime_memory_stats",
|
||||
"get_historical_memory_stats"
|
||||
"get_table_indexes",
|
||||
"get_table_schema",
|
||||
"monitor_data_freshness",
|
||||
"trace_column_lineage"
|
||||
],
|
||||
"expected_resources": [
|
||||
"database",
|
||||
|
||||
@@ -185,8 +185,9 @@ async def test_server_connectivity(transport: Optional[str] = None) -> bool:
|
||||
logger.error(f"Connectivity test failed: {e}")
|
||||
return False
|
||||
|
||||
result = await client.connect_and_run(test_connection)
|
||||
return result
|
||||
await client.connect_and_run(test_connection)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test server connectivity: {e}")
|
||||
return False
|
||||
|
||||
@@ -72,8 +72,7 @@ class TestToolsClientServer:
|
||||
|
||||
return tools
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert len(result) > 0
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_exec_query_via_client(self, client, test_config):
|
||||
@@ -91,14 +90,13 @@ class TestToolsClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
|
||||
if result["success"]:
|
||||
assert "result" in result, "Successful result should contain 'result' field"
|
||||
assert "data" in result, "Successful result should contain 'data' field"
|
||||
else:
|
||||
assert "error" in result, "Failed result should contain 'error' field"
|
||||
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
# Don't assert success=True as it depends on actual server state
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_get_db_list_via_client(self, client, test_config):
|
||||
@@ -115,8 +113,7 @@ class TestToolsClientServer:
|
||||
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_get_table_schema_via_client(self, client, test_config):
|
||||
@@ -133,10 +130,7 @@ class TestToolsClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_error_handling_via_client(self, client, test_config):
|
||||
@@ -151,8 +145,7 @@ class TestToolsClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_with_auth_token_via_client(self, client, test_config):
|
||||
@@ -171,5 +164,4 @@ class TestToolsClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@@ -45,7 +45,6 @@ class TestDorisToolsManager:
|
||||
config.database.password = "test_password"
|
||||
config.database.database = "test_db"
|
||||
config.database.health_check_interval = 60
|
||||
config.database.min_connections = 5
|
||||
config.database.max_connections = 20
|
||||
config.database.connection_timeout = 30
|
||||
config.database.max_connection_age = 3600
|
||||
|
||||
78
test/utils/test_db.py
Normal file
78
test/utils/test_db.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from doris_mcp_server.utils.db import DorisConnection, DorisSessionCache
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_cache():
|
||||
"""Provides a DorisSessionCache instance with a mock connection manager."""
|
||||
connection_manager = MagicMock()
|
||||
cache = DorisSessionCache(connection_manager=connection_manager)
|
||||
yield cache, connection_manager
|
||||
|
||||
|
||||
class TestDorisSessionCache:
|
||||
|
||||
def test_initialization(self, session_cache):
|
||||
cache, _ = session_cache
|
||||
assert cache.cache_system_session is True
|
||||
assert cache.cache_user_session is False
|
||||
assert not cache.cached
|
||||
|
||||
def test_should_cache(self, session_cache):
|
||||
cache, _ = session_cache
|
||||
assert cache._should_cache("query") is True
|
||||
assert cache._should_cache("system") is True
|
||||
assert cache._should_cache("user-test-session-id") is False
|
||||
|
||||
cache.cache_user_session = True
|
||||
assert cache._should_cache("user-test-session-id") is True
|
||||
|
||||
def test_save_and_get_session(self, session_cache):
|
||||
cache, _ = session_cache
|
||||
mock_connection = MagicMock(spec=DorisConnection)
|
||||
mock_connection.session_id = "query"
|
||||
|
||||
cache.save(mock_connection)
|
||||
retrieved_conn = cache.get("query")
|
||||
assert retrieved_conn is mock_connection
|
||||
|
||||
mock_user_connection = MagicMock(spec=DorisConnection)
|
||||
mock_user_connection.session_id = "user-test-session-id"
|
||||
cache.save(mock_user_connection)
|
||||
assert cache.get("user-test-session-id") is None
|
||||
|
||||
cache.cache_user_session = True
|
||||
cache.save(mock_user_connection)
|
||||
retrieved_user_conn = cache.get("user-test-session-id")
|
||||
assert retrieved_user_conn is mock_user_connection
|
||||
|
||||
def test_remove_session(self, session_cache):
|
||||
cache, _ = session_cache
|
||||
mock_connection = MagicMock(spec=DorisConnection)
|
||||
mock_connection.session_id = "system"
|
||||
|
||||
cache.save(mock_connection)
|
||||
assert cache.get("system") is not None
|
||||
|
||||
cache.remove("system")
|
||||
assert cache.get("system") is None
|
||||
|
||||
def test_clear_cache(self, session_cache):
|
||||
cache, connection_manager = session_cache
|
||||
mock_conn1 = MagicMock(spec=DorisConnection)
|
||||
mock_conn1.session_id = "query"
|
||||
mock_conn2 = MagicMock(spec=DorisConnection)
|
||||
mock_conn2.session_id = "system"
|
||||
|
||||
cache.save(mock_conn1)
|
||||
cache.save(mock_conn2)
|
||||
assert len(cache.cached) == 2
|
||||
|
||||
cache.clear()
|
||||
|
||||
assert not cache.cached
|
||||
connection_manager.release_connection.assert_any_call("query", mock_conn1)
|
||||
connection_manager.release_connection.assert_any_call("system", mock_conn2)
|
||||
assert connection_manager.release_connection.call_count == 2
|
||||
@@ -44,7 +44,6 @@ class TestDorisQueryExecutor:
|
||||
config.database.password = "test_password"
|
||||
config.database.database = "test_db"
|
||||
config.database.health_check_interval = 60
|
||||
config.database.min_connections = 5
|
||||
config.database.max_connections = 20
|
||||
config.database.connection_timeout = 30
|
||||
config.database.max_connection_age = 3600
|
||||
|
||||
@@ -21,8 +21,6 @@ Tests the query execution functionality through actual MCP client-server communi
|
||||
Assumes the server is already running and configured properly
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import pytest
|
||||
import os
|
||||
import sys
|
||||
@@ -66,14 +64,13 @@ class TestQueryExecutorClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
|
||||
if result["success"]:
|
||||
assert "result" in result, "Successful result should contain 'result' field"
|
||||
assert "data" in result, "Successful result should contain 'data' field"
|
||||
else:
|
||||
assert "error" in result, "Failed result should contain 'error' field"
|
||||
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_show_databases_query_via_client(self, client, test_config):
|
||||
@@ -87,8 +84,7 @@ class TestQueryExecutorClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_information_schema_query_via_client(self, client, test_config):
|
||||
@@ -102,8 +98,7 @@ class TestQueryExecutorClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_max_rows_parameter_via_client(self, client, test_config):
|
||||
@@ -118,8 +113,7 @@ class TestQueryExecutorClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_error_handling_via_client(self, client, test_config):
|
||||
@@ -131,8 +125,7 @@ class TestQueryExecutorClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_auth_token_via_client(self, client, test_config):
|
||||
@@ -152,5 +145,4 @@ class TestQueryExecutorClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
64
tokens.json
Normal file
64
tokens.json
Normal file
@@ -0,0 +1,64 @@
|
||||
{
|
||||
"version": "1.0",
|
||||
"description": "Doris MCP Server Token configuration file",
|
||||
"created_at": "2025-09-01T00:00:00Z",
|
||||
"tokens": [
|
||||
{
|
||||
"token_id": "admin-token",
|
||||
"token": "doris_admin_token_123456",
|
||||
"description": "Doris admin API access token",
|
||||
"expires_hours": null,
|
||||
"is_active": true,
|
||||
"database_config": {
|
||||
"host": "127.0.0.1",
|
||||
"port": 9030,
|
||||
"user": "root",
|
||||
"password": "",
|
||||
"database": "information_schema",
|
||||
"charset": "UTF8",
|
||||
"fe_http_port": 8030
|
||||
}
|
||||
},
|
||||
{
|
||||
"token_id": "analyst-token",
|
||||
"token": "doris_analyst_token_123456",
|
||||
"description": "Doris analyst API access token",
|
||||
"expires_hours": 8760,
|
||||
"is_active": true,
|
||||
"database_config": {
|
||||
"host": "127.0.0.1",
|
||||
"port": 9030,
|
||||
"user": "root",
|
||||
"password": "",
|
||||
"database": "information_schema",
|
||||
"charset": "UTF8",
|
||||
"fe_http_port": 8030
|
||||
}
|
||||
},
|
||||
{
|
||||
"token_id": "readonly-token",
|
||||
"token": "doris_readonly_token_123456",
|
||||
"description": "Doris readonly API access token",
|
||||
"expires_hours": 4320,
|
||||
"is_active": true,
|
||||
"database_config": {
|
||||
"host": "127.0.0.1",
|
||||
"port": 9030,
|
||||
"user": "root",
|
||||
"password": "",
|
||||
"database": "information_schema",
|
||||
"charset": "UTF8",
|
||||
"fe_http_port": 8030
|
||||
}
|
||||
}
|
||||
],
|
||||
"notes": [
|
||||
"The admin_token, analyst_token, readonly_token is default token,Please change the token before using in production!",
|
||||
"The token_id is the key of the token,Please use the token_id to identify the token",
|
||||
"The token is the value of the token,Please use the token to identify the token",
|
||||
"The description is the description of the token,Please use the description to identify the token",
|
||||
"The expires_hours is the expires hours of the token,Please use the expires_hours to identify the token",
|
||||
"The is_active is the is active of the token,Please use the is_active to identify the token",
|
||||
"The token_id, token, description, expires_hours, is_active is the metadata of the token,Please use the metadata to identify the token"
|
||||
]
|
||||
}
|
||||
Reference in New Issue
Block a user