Compare commits
58 Commits
0.3.0
...
infrastruc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a6f893628b | ||
|
|
81305ffbf9 | ||
|
|
43143f0b30 | ||
|
|
e58361e04b | ||
|
|
a125a2f5f8 | ||
|
|
2613912df3 | ||
|
|
067f160b3e | ||
|
|
9ba4cc6f45 | ||
|
|
f99399c6c7 | ||
|
|
c3d487ccdd | ||
|
|
c1e3b13851 | ||
|
|
5923cc1c89 | ||
|
|
9b5ac8533d | ||
|
|
cc84d605e5 | ||
|
|
55dbdd5e14 | ||
|
|
affa4a0319 | ||
|
|
ecb5db8137 | ||
|
|
5d15f6f3a4 | ||
|
|
6247d49192 | ||
|
|
fb5e864a24 | ||
|
|
9bb5b17199 | ||
|
|
6d3c128f54 | ||
|
|
651d524814 | ||
|
|
54572d0861 | ||
|
|
d12dfbd014 | ||
|
|
4052b7e938 | ||
|
|
693c48d5ee | ||
|
|
c1ce9a5cc7 | ||
|
|
282a1c0bd9 | ||
|
|
e3b9bf96ab | ||
|
|
667cecbbe0 | ||
|
|
c777905bd3 | ||
|
|
d4ea125e35 | ||
|
|
f135d9b949 | ||
|
|
124dd0da88 | ||
|
|
775b4cb630 | ||
|
|
26e8bc1149 | ||
|
|
8526cb75fe | ||
|
|
97006a756d | ||
|
|
72865654e2 | ||
|
|
050c09f902 | ||
|
|
159399bd38 | ||
|
|
e859fbb778 | ||
|
|
1b9cb29f5f | ||
|
|
c95c0fe03c | ||
|
|
1e2e79d90d | ||
|
|
609816bc4a | ||
|
|
5d46d153e1 | ||
|
|
0a81d5693b | ||
|
|
a4306867f6 | ||
|
|
a22ff3ae9b | ||
|
|
2c5f26889c | ||
|
|
e47534c296 | ||
|
|
0f52591259 | ||
|
|
3b429f37b3 | ||
|
|
f5a4c8abbe | ||
|
|
87563ef6e1 | ||
|
|
b6157c500b |
27
.asf.yaml
@@ -24,9 +24,28 @@ github:
|
||||
- olap
|
||||
- lakehouse
|
||||
- mcp
|
||||
- ai
|
||||
enabled_merge_buttons:
|
||||
squash: true
|
||||
merge: false
|
||||
rebase: false
|
||||
squash: true
|
||||
merge: false
|
||||
rebase: false
|
||||
features:
|
||||
issues: true
|
||||
projects: true
|
||||
rulesets:
|
||||
- name: "Default Branch Protection"
|
||||
type: branch
|
||||
branches:
|
||||
includes:
|
||||
- "~DEFAULT_BRANCH"
|
||||
- "release/*"
|
||||
- "rel/*"
|
||||
excludes: []
|
||||
bypass_teams:
|
||||
- root
|
||||
restrict_deletion: true
|
||||
restrict_force_push: true
|
||||
notifications:
|
||||
pullrequests_status: commits@doris.apache.org
|
||||
issues: commits@doris.apache.org
|
||||
commits: commits@doris.apache.org
|
||||
pullrequests: commits@doris.apache.org
|
||||
|
||||
2
.dockerignore
Normal file
@@ -0,0 +1,2 @@
|
||||
**/.venv
|
||||
**/venv
|
||||
500
.env.example
@@ -14,58 +14,512 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# ===================================================================
|
||||
# Doris MCP Server Environment Configuration Example
|
||||
# ===================================================================
|
||||
# Copy this file to .env and modify the configuration values as needed
|
||||
|
||||
# Doris MCP Server Environment Configuration
|
||||
# Copy this file to .env and modify the values as needed
|
||||
# ===================================================================
|
||||
# Database Connection Configuration
|
||||
# ===================================================================
|
||||
|
||||
# Database Configuration
|
||||
# Doris FE (Frontend) connection settings
|
||||
DORIS_HOST=localhost
|
||||
DORIS_PORT=9030
|
||||
DORIS_USER=root
|
||||
DORIS_PASSWORD=your_password_here
|
||||
DORIS_DATABASE=your_database_name
|
||||
DORIS_PASSWORD=
|
||||
DORIS_DATABASE=information_schema
|
||||
|
||||
# Connection Pool Settings
|
||||
DORIS_MIN_CONNECTIONS=5
|
||||
# Doris FE HTTP API port (for Profile and other HTTP APIs)
|
||||
DORIS_FE_HTTP_PORT=8030
|
||||
|
||||
# Doris BE (Backend) nodes configuration (optional, for external access)
|
||||
# Format: host1,host2,host3 (if empty, will use "show backends" to get BE nodes)
|
||||
DORIS_BE_HOSTS=
|
||||
DORIS_BE_WEBSERVER_PORT=8040
|
||||
|
||||
# Connection pool configuration
|
||||
DORIS_MAX_CONNECTIONS=20
|
||||
DORIS_CONNECTION_TIMEOUT=30
|
||||
DORIS_HEALTH_CHECK_INTERVAL=60
|
||||
DORIS_MAX_CONNECTION_AGE=3600
|
||||
|
||||
# Security Settings
|
||||
# Arrow Flight SQL Configuration (Required for ADBC tools)
|
||||
# FE_ARROW_FLIGHT_SQL_PORT=
|
||||
# BE_ARROW_FLIGHT_SQL_PORT=
|
||||
|
||||
# ===================================================================
|
||||
# Security Configuration
|
||||
# ===================================================================
|
||||
|
||||
# Independent Authentication Switches - NEW DESIGN!
|
||||
# Each authentication method can be enabled/disabled independently
|
||||
# Any enabled method that succeeds will allow access
|
||||
# If all methods are disabled, anonymous access is allowed
|
||||
|
||||
# Legacy configuration - kept for backward compatibility
|
||||
# AUTH_TYPE is now deprecated - use individual switches above
|
||||
AUTH_TYPE=token
|
||||
TOKEN_SECRET=your_256_bit_secret_key_here
|
||||
|
||||
# Token Authentication (Default method - simple and effective)
|
||||
ENABLE_TOKEN_AUTH=false
|
||||
|
||||
# JWT Authentication (For stateless applications)
|
||||
ENABLE_JWT_AUTH=false
|
||||
|
||||
# OAuth 2.0/OIDC Authentication (For enterprise integration)
|
||||
ENABLE_OAUTH_AUTH=false
|
||||
|
||||
# ===================================================================
|
||||
# Token Authentication Configuration (Enable with ENABLE_TOKEN_AUTH=true)
|
||||
# ===================================================================
|
||||
|
||||
# Basic token authentication settings
|
||||
TOKEN_FILE_PATH=tokens.json
|
||||
ENABLE_TOKEN_EXPIRY=true
|
||||
DEFAULT_TOKEN_EXPIRY_HOURS=720
|
||||
TOKEN_HASH_ALGORITHM=sha256
|
||||
|
||||
# ===================================================================
|
||||
# Token Management Security Configuration (NEW in v0.6.0) - CRITICAL SECURITY SETTINGS
|
||||
# ===================================================================
|
||||
|
||||
# HTTP Token Management Endpoints (DISABLED BY DEFAULT FOR SECURITY)
|
||||
# WARNING: These endpoints allow creation, deletion, and management of authentication tokens
|
||||
# Only enable if you need HTTP-based token management and understand the security implications
|
||||
ENABLE_HTTP_TOKEN_MANAGEMENT=true
|
||||
|
||||
# Admin Authentication Token (REQUIRED if HTTP token management is enabled)
|
||||
# This token is required to access HTTP token management endpoints
|
||||
# SECURITY: Generate a secure random token in production - NEVER use default values
|
||||
TOKEN_MANAGEMENT_ADMIN_TOKEN=
|
||||
|
||||
# IP Address Restrictions for Token Management (CRITICAL SECURITY CONTROL)
|
||||
# Only these IP addresses/networks can access token management endpoints
|
||||
# DEFAULT: localhost only (most secure) - add other IPs/networks only if necessary
|
||||
# Format: comma-separated list of IPs and CIDR networks
|
||||
# Examples:
|
||||
# - Localhost only: 127.0.0.1,::1
|
||||
# - Private network: 127.0.0.1,192.168.1.0/24,10.0.0.0/8
|
||||
# - Specific IPs: 127.0.0.1,192.168.1.10,192.168.1.11
|
||||
TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1
|
||||
|
||||
# Require Admin Authentication (ENABLED BY DEFAULT FOR SECURITY)
|
||||
# When true, all token management operations require valid admin token
|
||||
# When false, only IP restrictions apply (NOT RECOMMENDED for production)
|
||||
REQUIRE_ADMIN_AUTH=true
|
||||
|
||||
# ===================================================================
|
||||
# JWT Authentication Configuration (Enable with ENABLE_JWT_AUTH=true)
|
||||
# ===================================================================
|
||||
|
||||
# JWT token settings (when ENABLE_JWT_AUTH=true)
|
||||
JWT_SECRET_KEY=your_jwt_secret_key_here_change_in_production
|
||||
JWT_ALGORITHM=HS256
|
||||
JWT_EXPIRATION_HOURS=24
|
||||
JWT_ISSUER=doris-mcp-server
|
||||
JWT_AUDIENCE=doris-mcp-client
|
||||
|
||||
# JWT token validation settings
|
||||
JWT_VERIFY_SIGNATURE=true
|
||||
JWT_VERIFY_EXPIRATION=true
|
||||
JWT_VERIFY_AUDIENCE=true
|
||||
JWT_VERIFY_ISSUER=true
|
||||
|
||||
# JWT refresh token settings
|
||||
ENABLE_JWT_REFRESH=true
|
||||
JWT_REFRESH_EXPIRATION_DAYS=30
|
||||
JWT_REFRESH_SECRET_KEY=your_jwt_refresh_secret_key_here
|
||||
|
||||
# JWT user claims configuration
|
||||
JWT_USER_ID_CLAIM=user_id
|
||||
JWT_ROLES_CLAIM=roles
|
||||
JWT_PERMISSIONS_CLAIM=permissions
|
||||
JWT_SECURITY_LEVEL_CLAIM=security_level
|
||||
|
||||
# ===================================================================
|
||||
# OAuth 2.0 / OpenID Connect Configuration (Enable with ENABLE_OAUTH_AUTH=true)
|
||||
# ===================================================================
|
||||
|
||||
# OAuth provider settings (when ENABLE_OAUTH_AUTH=true)
|
||||
OAUTH_PROVIDER_TYPE=generic
|
||||
OAUTH_CLIENT_ID=your_oauth_client_id
|
||||
OAUTH_CLIENT_SECRET=your_oauth_client_secret
|
||||
OAUTH_REDIRECT_URI=http://localhost:3000/auth/callback
|
||||
|
||||
# OAuth endpoints (for generic provider)
|
||||
OAUTH_AUTHORIZATION_URL=https://your-provider.com/auth
|
||||
OAUTH_TOKEN_URL=https://your-provider.com/token
|
||||
OAUTH_USERINFO_URL=https://your-provider.com/userinfo
|
||||
OAUTH_JWKS_URL=https://your-provider.com/.well-known/jwks.json
|
||||
|
||||
# OAuth scope and claims
|
||||
OAUTH_SCOPE=openid profile email
|
||||
OAUTH_USER_ID_CLAIM=sub
|
||||
OAUTH_USERNAME_CLAIM=preferred_username
|
||||
OAUTH_EMAIL_CLAIM=email
|
||||
OAUTH_ROLES_CLAIM=roles
|
||||
OAUTH_GROUPS_CLAIM=groups
|
||||
|
||||
# OAuth session settings
|
||||
OAUTH_SESSION_SECRET=your_oauth_session_secret_here
|
||||
OAUTH_SESSION_EXPIRY=3600
|
||||
OAUTH_STATE_EXPIRY=300
|
||||
|
||||
# Popular OAuth providers presets (uncomment and configure as needed)
|
||||
|
||||
# Google OAuth Configuration
|
||||
# OAUTH_PROVIDER_TYPE=google
|
||||
# OAUTH_CLIENT_ID=your_google_client_id.apps.googleusercontent.com
|
||||
# OAUTH_CLIENT_SECRET=your_google_client_secret
|
||||
# OAUTH_AUTHORIZATION_URL=https://accounts.google.com/o/oauth2/auth
|
||||
# OAUTH_TOKEN_URL=https://oauth2.googleapis.com/token
|
||||
# OAUTH_USERINFO_URL=https://www.googleapis.com/oauth2/v1/userinfo
|
||||
# OAUTH_JWKS_URL=https://www.googleapis.com/oauth2/v3/certs
|
||||
# OAUTH_SCOPE=openid profile email
|
||||
|
||||
# Microsoft Azure AD Configuration
|
||||
# OAUTH_PROVIDER_TYPE=azure
|
||||
# OAUTH_CLIENT_ID=your_azure_client_id
|
||||
# OAUTH_CLIENT_SECRET=your_azure_client_secret
|
||||
# OAUTH_TENANT_ID=your_tenant_id
|
||||
# OAUTH_AUTHORIZATION_URL=https://login.microsoftonline.com/{tenant}/oauth2/v2.0/authorize
|
||||
# OAUTH_TOKEN_URL=https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token
|
||||
# OAUTH_USERINFO_URL=https://graph.microsoft.com/v1.0/me
|
||||
# OAUTH_JWKS_URL=https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys
|
||||
# OAUTH_SCOPE=openid profile email
|
||||
|
||||
# GitHub OAuth Configuration
|
||||
# OAUTH_PROVIDER_TYPE=github
|
||||
# OAUTH_CLIENT_ID=your_github_client_id
|
||||
# OAUTH_CLIENT_SECRET=your_github_client_secret
|
||||
# OAUTH_AUTHORIZATION_URL=https://github.com/login/oauth/authorize
|
||||
# OAUTH_TOKEN_URL=https://github.com/login/oauth/access_token
|
||||
# OAUTH_USERINFO_URL=https://api.github.com/user
|
||||
# OAUTH_SCOPE=user:email
|
||||
|
||||
# GitLab OAuth Configuration
|
||||
# OAUTH_PROVIDER_TYPE=gitlab
|
||||
# OAUTH_CLIENT_ID=your_gitlab_client_id
|
||||
# OAUTH_CLIENT_SECRET=your_gitlab_client_secret
|
||||
# OAUTH_AUTHORIZATION_URL=https://gitlab.com/oauth/authorize
|
||||
# OAUTH_TOKEN_URL=https://gitlab.com/oauth/token
|
||||
# OAUTH_USERINFO_URL=https://gitlab.com/api/v4/user
|
||||
# OAUTH_SCOPE=read_user
|
||||
|
||||
# Keycloak OAuth Configuration
|
||||
# OAUTH_PROVIDER_TYPE=keycloak
|
||||
# OAUTH_CLIENT_ID=your_keycloak_client_id
|
||||
# OAUTH_CLIENT_SECRET=your_keycloak_client_secret
|
||||
# OAUTH_REALM=your_realm
|
||||
# OAUTH_SERVER_URL=https://your-keycloak-server.com
|
||||
# OAUTH_AUTHORIZATION_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/auth
|
||||
# OAUTH_TOKEN_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/token
|
||||
# OAUTH_USERINFO_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/userinfo
|
||||
# OAUTH_JWKS_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/certs
|
||||
# OAUTH_SCOPE=openid profile email
|
||||
|
||||
# Legacy token settings (for backward compatibility)
|
||||
TOKEN_SECRET=your_secret_key_here
|
||||
TOKEN_EXPIRY=3600
|
||||
|
||||
# SQL security check
|
||||
ENABLE_SECURITY_CHECK=true
|
||||
|
||||
# Blocked keywords (comma separated)
|
||||
BLOCKED_KEYWORDS=DROP,CREATE,ALTER,TRUNCATE,DELETE,INSERT,UPDATE,GRANT,REVOKE,EXEC,EXECUTE,SHUTDOWN,KILL
|
||||
|
||||
# Query limits
|
||||
MAX_QUERY_COMPLEXITY=100
|
||||
MAX_RESULT_ROWS=10000
|
||||
|
||||
# Data masking
|
||||
ENABLE_MASKING=true
|
||||
|
||||
# Performance Settings
|
||||
# ===================================================================
|
||||
# Performance Configuration
|
||||
# ===================================================================
|
||||
|
||||
# Query cache
|
||||
ENABLE_QUERY_CACHE=true
|
||||
CACHE_TTL=300
|
||||
MAX_CACHE_SIZE=1000
|
||||
|
||||
# Concurrency control
|
||||
MAX_CONCURRENT_QUERIES=50
|
||||
QUERY_TIMEOUT=300
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=INFO
|
||||
LOG_FILE_PATH=./log/doris-mcp-server.log
|
||||
ENABLE_AUDIT=true
|
||||
AUDIT_FILE_PATH=./log/doris-mcp-audit.log
|
||||
# Response content size limit (characters)
|
||||
MAX_RESPONSE_CONTENT_SIZE=4096
|
||||
|
||||
# Monitoring Settings
|
||||
# ===================================================================
|
||||
# ADBC (Arrow Flight SQL) Configuration
|
||||
# ===================================================================
|
||||
# Enable/disable ADBC tools
|
||||
ADBC_ENABLED=true
|
||||
|
||||
# Default ADBC query parameters
|
||||
ADBC_DEFAULT_MAX_ROWS=100000
|
||||
ADBC_DEFAULT_TIMEOUT=60
|
||||
# Format: "arrow", "pandas", "dict"
|
||||
ADBC_DEFAULT_RETURN_FORMAT=arrow
|
||||
|
||||
# ADBC connection timeout
|
||||
ADBC_CONNECTION_TIMEOUT=300
|
||||
|
||||
# ===================================================================
|
||||
# Logging Configuration
|
||||
# ===================================================================
|
||||
|
||||
# Basic logging configuration
|
||||
LOG_LEVEL=INFO
|
||||
LOG_FILE_PATH=
|
||||
|
||||
# Audit logging
|
||||
ENABLE_AUDIT=true
|
||||
AUDIT_FILE_PATH=
|
||||
|
||||
# Log file rotation configuration
|
||||
LOG_MAX_FILE_SIZE=10485760
|
||||
LOG_BACKUP_COUNT=5
|
||||
|
||||
# ===================================================================
|
||||
# Log Cleanup Configuration - NEW!
|
||||
# ===================================================================
|
||||
|
||||
# Enable automatic log cleanup
|
||||
ENABLE_LOG_CLEANUP=true
|
||||
|
||||
# Maximum age of log files in days (files older than this will be deleted)
|
||||
LOG_MAX_AGE_DAYS=30
|
||||
|
||||
# Cleanup check interval in hours
|
||||
LOG_CLEANUP_INTERVAL_HOURS=24
|
||||
|
||||
# ===================================================================
|
||||
# Monitoring Configuration
|
||||
# ===================================================================
|
||||
|
||||
# Metrics collection
|
||||
ENABLE_METRICS=true
|
||||
METRICS_PORT=3001
|
||||
METRICS_PATH=/metrics
|
||||
HEALTH_CHECK_PORT=3002
|
||||
HEALTH_CHECK_PATH=/health
|
||||
|
||||
# Alert configuration
|
||||
ENABLE_ALERTS=false
|
||||
ALERT_WEBHOOK_URL=
|
||||
|
||||
# Server Settings
|
||||
# ===================================================================
|
||||
# Server Configuration
|
||||
# ===================================================================
|
||||
|
||||
# Basic server information
|
||||
SERVER_NAME=doris-mcp-server
|
||||
SERVER_VERSION=0.3.0
|
||||
SERVER_VERSION=0.6.0
|
||||
SERVER_PORT=3000
|
||||
|
||||
# Development Settings (for development environment only)
|
||||
DEBUG=false
|
||||
VERBOSE=false
|
||||
# Temporary files directory
|
||||
TEMP_FILES_DIR=tmp
|
||||
|
||||
# ===================================================================
|
||||
# Configuration Examples for Different Environments
|
||||
# ===================================================================
|
||||
|
||||
# Development Environment Example:
|
||||
# LOG_LEVEL=DEBUG
|
||||
# LOG_MAX_AGE_DAYS=7
|
||||
# LOG_CLEANUP_INTERVAL_HOURS=6
|
||||
# ENABLE_SECURITY_CHECK=false
|
||||
|
||||
# Production Environment Example:
|
||||
# LOG_LEVEL=INFO
|
||||
# LOG_MAX_AGE_DAYS=30
|
||||
# LOG_CLEANUP_INTERVAL_HOURS=24
|
||||
# ENABLE_SECURITY_CHECK=true
|
||||
# ENABLE_LOG_CLEANUP=true
|
||||
|
||||
# Testing Environment Example:
|
||||
# LOG_LEVEL=WARNING
|
||||
# LOG_MAX_AGE_DAYS=3
|
||||
# LOG_CLEANUP_INTERVAL_HOURS=1
|
||||
# MAX_RESULT_ROWS=1000
|
||||
|
||||
# ===================================================================
|
||||
# Advanced Configuration Notes
|
||||
# ===================================================================
|
||||
|
||||
# 1. Log Cleanup Feature:
|
||||
# - ENABLE_LOG_CLEANUP: Controls whether to enable automatic cleanup
|
||||
# - LOG_MAX_AGE_DAYS: File retention days, recommended 30 days for production, 7 days for development
|
||||
# - LOG_CLEANUP_INTERVAL_HOURS: Check frequency, recommended 24 hours
|
||||
|
||||
# 2. Security Best Practices:
|
||||
# - NEW: Enable individual authentication methods using ENABLE_TOKEN_AUTH, ENABLE_JWT_AUTH, ENABLE_OAUTH_AUTH
|
||||
# - When all methods are disabled, ALL requests are allowed with anonymous access
|
||||
# - Authentication methods work independently - any one succeeding allows access
|
||||
# - Token Auth: Change default tokens (DEFAULT_ADMIN_TOKEN, etc.) in production
|
||||
# - JWT Auth: Change JWT_SECRET_KEY and JWT_REFRESH_SECRET_KEY in production
|
||||
# - OAuth Auth: Configure OAuth provider settings and secure client secrets
|
||||
# - Must change TOKEN_SECRET in production environment (legacy compatibility)
|
||||
# - Adjust BLOCKED_KEYWORDS according to business needs
|
||||
# - Enable ENABLE_SECURITY_CHECK and ENABLE_MASKING
|
||||
# - NEW v0.6.0: Token Management Security (CRITICAL):
|
||||
# * ENABLE_HTTP_TOKEN_MANAGEMENT=false by default (SECURE BY DEFAULT)
|
||||
# * Only enable if you need HTTP token management endpoints
|
||||
# * TOKEN_MANAGEMENT_ADMIN_TOKEN: Use secure random token in production
|
||||
# * TOKEN_MANAGEMENT_ALLOWED_IPS: Restrict to localhost (127.0.0.1,::1) only
|
||||
# * REQUIRE_ADMIN_AUTH=true: Always require admin authentication
|
||||
# * Never expose token management endpoints to external networks
|
||||
|
||||
# 3. Performance Tuning:
|
||||
# - Adjust MAX_CONCURRENT_QUERIES based on hardware resources
|
||||
# - Adjust QUERY_TIMEOUT based on query complexity
|
||||
# - Adjust MAX_CACHE_SIZE based on memory size
|
||||
|
||||
# 4. Connection Pool Optimization:
|
||||
# - DORIS_MAX_CONNECTIONS recommended to be 2-4 times the number of CPU cores
|
||||
# - DORIS_CONNECTION_TIMEOUT adjust based on network latency
|
||||
# - DORIS_MAX_CONNECTION_AGE recommended 1 hour to avoid long connection issues
|
||||
|
||||
# 5. ADBC (Arrow Flight SQL) Configuration:
|
||||
# - FE_ARROW_FLIGHT_SQL_PORT and BE_ARROW_FLIGHT_SQL_PORT: Required for ADBC functionality
|
||||
# - ADBC_DEFAULT_MAX_ROWS: Default maximum rows for ADBC queries (recommended: 100000)
|
||||
# - ADBC_DEFAULT_TIMEOUT: Default timeout for ADBC queries in seconds (recommended: 60)
|
||||
# - ADBC_DEFAULT_RETURN_FORMAT: Default return format (arrow/pandas/dict, recommended: arrow)
|
||||
# - ADBC_CONNECTION_TIMEOUT: Connection timeout for ADBC (recommended: 30)
|
||||
# - ADBC_ENABLED: Enable or disable ADBC tools (true/false)
|
||||
# - Prerequisites: Install adbc_driver_manager, adbc_driver_flightsql, pyarrow packages
|
||||
|
||||
# 6. Authentication Configuration Guide - UPDATED DESIGN!
|
||||
#
|
||||
# Independent Authentication Control (NEW):
|
||||
# - ENABLE_TOKEN_AUTH=false (default): Disable token authentication
|
||||
# - ENABLE_JWT_AUTH=false (default): Disable JWT authentication
|
||||
# - ENABLE_OAUTH_AUTH=false (default): Disable OAuth authentication
|
||||
# - When all methods are disabled, no authentication is required (anonymous access)
|
||||
# - When multiple methods are enabled, any one succeeding allows access
|
||||
# - Recommended for development/testing: all false, production: enable needed methods
|
||||
#
|
||||
# Token Authentication (ENABLE_TOKEN_AUTH=true) - Recommended for most use cases:
|
||||
# - Simple and secure token-based authentication
|
||||
# - Configurable default tokens via environment variables
|
||||
# - Support for custom tokens via TOKEN_* environment variables
|
||||
# - Token file configuration via tokens.json
|
||||
# - Built-in token management HTTP endpoints
|
||||
# - No user management complexity - pure API access control
|
||||
#
|
||||
# JWT Authentication (ENABLE_JWT_AUTH=true) - For stateless applications:
|
||||
# - JSON Web Token based authentication
|
||||
# - Configurable token expiration and refresh
|
||||
# - Support for standard JWT claims
|
||||
# - RSA/ECDSA/HS256 algorithm support
|
||||
# - Suitable for microservices and distributed systems
|
||||
#
|
||||
# OAuth 2.0/OIDC (ENABLE_OAUTH_AUTH=true) - For enterprise integration:
|
||||
# - Integration with external identity providers
|
||||
# - Support for popular providers (Google, Microsoft, GitHub, GitLab, Keycloak)
|
||||
# - OpenID Connect compatibility
|
||||
# - Automatic user provisioning from provider
|
||||
# - Secure authorization code flow
|
||||
#
|
||||
# Authentication Method Selection Guide:
|
||||
# - No Auth (all switches false): Development, testing, trusted networks
|
||||
# - Token Auth only: Small teams, simple deployment, direct API access
|
||||
# - JWT Auth only: Stateless apps, microservices, mobile clients
|
||||
# - OAuth Auth only: Enterprise SSO, large teams, external identity providers
|
||||
# - Multiple methods: Flexible access, different client types, migration scenarios
|
||||
|
||||
# 7. Token Management Security Configuration Guide (NEW in v0.6.0) - CRITICAL!
|
||||
#
|
||||
# ⚠️ SECURITY WARNING: Token management endpoints are POWERFUL and DANGEROUS
|
||||
# They allow creation, revocation, and management of authentication tokens.
|
||||
# Improper configuration can lead to complete system compromise.
|
||||
#
|
||||
# 🔒 SECURE BY DEFAULT:
|
||||
# - ENABLE_HTTP_TOKEN_MANAGEMENT=false (disabled by default)
|
||||
# - REQUIRE_ADMIN_AUTH=true (admin auth required by default)
|
||||
# - TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1 (localhost only by default)
|
||||
#
|
||||
# 🛡️ SECURITY LAYERS (Applied in order):
|
||||
# 1. Configuration Check: HTTP token management must be explicitly enabled
|
||||
# 2. IP Restrictions: Only allowed IP addresses/networks can access endpoints
|
||||
# 3. Admin Authentication: Valid admin token required for all operations
|
||||
#
|
||||
# 📋 CONFIGURATION OPTIONS:
|
||||
#
|
||||
# Disable Token Management (RECOMMENDED for most deployments):
|
||||
# ENABLE_HTTP_TOKEN_MANAGEMENT=false
|
||||
# # All token management endpoints will return 403 Forbidden
|
||||
#
|
||||
# Enable with Maximum Security (Production):
|
||||
# ENABLE_HTTP_TOKEN_MANAGEMENT=true
|
||||
# TOKEN_MANAGEMENT_ADMIN_TOKEN=<secure-random-token-256-bit>
|
||||
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1
|
||||
# REQUIRE_ADMIN_AUTH=true
|
||||
#
|
||||
# Enable for Private Network (Advanced):
|
||||
# ENABLE_HTTP_TOKEN_MANAGEMENT=true
|
||||
# TOKEN_MANAGEMENT_ADMIN_TOKEN=<secure-random-token-256-bit>
|
||||
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,192.168.1.0/24,10.0.0.0/8
|
||||
# REQUIRE_ADMIN_AUTH=true
|
||||
#
|
||||
# 🔑 ADMIN TOKEN GENERATION:
|
||||
# # Generate secure admin token (Linux/macOS):
|
||||
# openssl rand -hex 32
|
||||
#
|
||||
# # Generate secure admin token (Python):
|
||||
# python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
#
|
||||
# 🌐 IP CONFIGURATION EXAMPLES:
|
||||
# # Localhost only (most secure):
|
||||
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1
|
||||
#
|
||||
# # Private network + localhost:
|
||||
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1,192.168.1.0/24,10.0.0.0/8
|
||||
#
|
||||
# # Specific servers only:
|
||||
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,192.168.1.10,192.168.1.11
|
||||
#
|
||||
# # Corporate network (be careful):
|
||||
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,172.16.0.0/12,192.168.0.0/16
|
||||
#
|
||||
# 🚫 NEVER DO THIS (Security Anti-Patterns):
|
||||
# # NEVER allow all IPs:
|
||||
# # TOKEN_MANAGEMENT_ALLOWED_IPS=0.0.0.0/0 # DANGEROUS!
|
||||
#
|
||||
# # NEVER disable admin auth in production:
|
||||
# # REQUIRE_ADMIN_AUTH=false # DANGEROUS!
|
||||
#
|
||||
# # NEVER use weak admin tokens:
|
||||
# # TOKEN_MANAGEMENT_ADMIN_TOKEN=admin # DANGEROUS!
|
||||
# # TOKEN_MANAGEMENT_ADMIN_TOKEN=123456 # DANGEROUS!
|
||||
#
|
||||
# 📊 ENDPOINT SECURITY TESTING:
|
||||
# # Test security (should fail):
|
||||
# curl -X POST http://external-ip:3000/token/create
|
||||
# # Expected: 403 Forbidden (IP not allowed)
|
||||
#
|
||||
# # Test without auth (should fail):
|
||||
# curl -X POST http://127.0.0.1:3000/token/create
|
||||
# # Expected: 401 Unauthorized (missing admin token)
|
||||
#
|
||||
# # Test with valid auth (should succeed if enabled):
|
||||
# curl -H "Authorization: Bearer your-admin-token" http://127.0.0.1:3000/token/stats
|
||||
# # Expected: 200 OK with token statistics
|
||||
#
|
||||
# 🔍 MONITORING & AUDITING:
|
||||
# # All token management access attempts are logged:
|
||||
# tail -f logs/doris_mcp_server_audit.log | grep "token management"
|
||||
#
|
||||
# # Monitor security events:
|
||||
# tail -f logs/doris_mcp_server_info.log | grep -E "(access denied|token management)"
|
||||
#
|
||||
# ✅ SECURITY BEST PRACTICES:
|
||||
# - Keep ENABLE_HTTP_TOKEN_MANAGEMENT=false unless absolutely necessary
|
||||
# - Use file-based token management (tokens.json) instead of HTTP endpoints
|
||||
# - Generate strong admin tokens using cryptographically secure methods
|
||||
# - Restrict access to localhost (127.0.0.1,::1) whenever possible
|
||||
# - Never expose token management endpoints to public internet
|
||||
# - Regularly audit token management access logs
|
||||
# - Use firewall rules as additional protection layer
|
||||
# - Consider VPN access for remote token management needs
|
||||
23
.gitignore
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
*.log
|
||||
*.log.*
|
||||
*.bak
|
||||
logs
|
||||
/configs/*.py
|
||||
.vscode/
|
||||
__pycache__/
|
||||
*.log
|
||||
|
||||
.python-version
|
||||
Pipfile.lock
|
||||
poetry.lock
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
.idea/
|
||||
.coverage
|
||||
coverage.xml
|
||||
10
Dockerfile
@@ -32,6 +32,7 @@ RUN apt-get update && apt-get install -y \
|
||||
g++ \
|
||||
pkg-config \
|
||||
default-libmysqlclient-dev \
|
||||
dos2unix \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements file
|
||||
@@ -43,12 +44,13 @@ RUN pip install --no-cache-dir -r requirements.txt
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Convert line endings for shell scripts and ensure proper execution format
|
||||
RUN find . -name "*.sh" -exec dos2unix {} \; && \
|
||||
find . -name "*.sh" -exec chmod +x {} \;
|
||||
|
||||
# Create necessary directories
|
||||
RUN mkdir -p /app/logs /app/config /app/data
|
||||
|
||||
# Set permissions
|
||||
RUN chmod +x /app/start.sh
|
||||
|
||||
# Create non-root user
|
||||
RUN groupadd -r doris && useradd -r -g doris doris
|
||||
RUN chown -R doris:doris /app
|
||||
@@ -62,4 +64,4 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
||||
EXPOSE 3000 3001 3002
|
||||
|
||||
# Start command
|
||||
CMD ["/app/start.sh"]
|
||||
CMD ["/app/start_server.sh"]
|
||||
|
||||
@@ -133,9 +133,6 @@ async def database_operations(client):
|
||||
|
||||
# Get table schema
|
||||
schema = await client.get_table_schema("table_name", "db_name")
|
||||
|
||||
# Column data analysis
|
||||
analysis = await client.analyze_column("table", "column", "basic")
|
||||
```
|
||||
|
||||
## 🧪 Testing
|
||||
@@ -177,7 +174,6 @@ python test_unified_client.py benchmark
|
||||
2. get_table_list: Get table list for specified database
|
||||
3. get_table_schema: Get table structure information
|
||||
4. exec_query: Execute SQL query
|
||||
5. column_analysis: Analyze column data distribution and statistics
|
||||
...
|
||||
|
||||
🧪 Testing basic functionality...
|
||||
@@ -189,8 +185,6 @@ python test_unified_client.py benchmark
|
||||
✅ SSB query successful
|
||||
4️⃣ Getting table structure...
|
||||
✅ Table structure retrieved successfully
|
||||
5️⃣ Column data analysis...
|
||||
✅ Column analysis successful
|
||||
|
||||
✅ HTTP mode testing completed!
|
||||
```
|
||||
@@ -255,12 +249,6 @@ async def comprehensive_example():
|
||||
# Get table schema
|
||||
schema_result = await client.get_table_schema("lineorder", "ssb")
|
||||
print(f"Table schema: {schema_result}")
|
||||
|
||||
# Column analysis
|
||||
analysis_result = await client.analyze_column(
|
||||
"lineorder", "lo_orderkey", "basic"
|
||||
)
|
||||
print(f"Column analysis: {analysis_result}")
|
||||
|
||||
await client.connect_and_run(demo_operations)
|
||||
|
||||
|
||||
@@ -323,7 +323,7 @@ class DorisUnifiedClient:
|
||||
async with streamablehttp_client(
|
||||
self.config.server_url,
|
||||
timeout=timedelta(seconds=self.config.timeout)
|
||||
) as (read, write):
|
||||
) as (read, write, _):
|
||||
async with ClientSession(read, write) as session:
|
||||
self.session = session
|
||||
self._init_sub_clients()
|
||||
@@ -422,18 +422,14 @@ class DorisUnifiedClient:
|
||||
|
||||
return await self.call_tool(tool_name, kwargs)
|
||||
|
||||
async def analyze_column(self, table_name: str, column_name: str, analysis_type: str = "basic", **kwargs) -> dict[str, Any]:
|
||||
"""Analyze column"""
|
||||
tool_name = await self._find_tool_by_pattern(["column_analysis", "analyze_column", "column"])
|
||||
async def get_memory_stats(self, tracker_type: str = "overview", include_details: bool = True, **kwargs) -> dict[str, Any]:
|
||||
"""Get memory statistics"""
|
||||
tool_name = await self._find_tool_by_pattern(["memory", "realtime_memory"])
|
||||
if not tool_name:
|
||||
return {"success": False, "error": "Column analysis tool not found"}
|
||||
|
||||
arguments = {
|
||||
"table_name": table_name,
|
||||
"column_name": column_name,
|
||||
"analysis_type": analysis_type,
|
||||
**kwargs
|
||||
}
|
||||
return {"success": False, "error": "Memory stats tool not found"}
|
||||
|
||||
arguments = {"tracker_type": tracker_type, "include_details": include_details}
|
||||
arguments.update(kwargs)
|
||||
return await self.call_tool(tool_name, arguments)
|
||||
|
||||
async def call_tool_by_function(self, function_description: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
@@ -467,7 +463,7 @@ async def create_http_client(server_url: str, timeout: int = 60) -> DorisUnified
|
||||
# Example usage
|
||||
async def example_stdio():
|
||||
"""stdio mode example"""
|
||||
client = await create_stdio_client("python", ["doris_mcp_server/main.py"])
|
||||
client = await create_stdio_client("python", ["-m", "doris_mcp_server.main", "--transport", "stdio"])
|
||||
|
||||
async def test_client(client: DorisUnifiedClient):
|
||||
# Get server capabilities
|
||||
@@ -510,4 +506,4 @@ if __name__ == "__main__":
|
||||
asyncio.run(example_stdio())
|
||||
|
||||
# Run HTTP example
|
||||
# asyncio.run(example_http())
|
||||
# asyncio.run(example_http())
|
||||
|
||||
56
doris_mcp_server/auth/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Doris MCP Server Authentication Module
|
||||
Provides JWT-based, Token-based, and OAuth 2.0/OIDC authentication and authorization services
|
||||
"""
|
||||
|
||||
from .jwt_manager import JWTManager
|
||||
from .key_manager import KeyManager
|
||||
from .token_validators import TokenValidator, TokenBlacklist
|
||||
from .auth_middleware import AuthMiddleware
|
||||
from .token_manager import TokenManager, TokenInfo, TokenValidationResult
|
||||
from .token_handlers import TokenHandlers
|
||||
from .oauth_client import OAuthClient, OAuthStateManager
|
||||
from .oauth_provider import OAuthAuthenticationProvider
|
||||
from .oauth_types import (
|
||||
OAuthProvider, OAuthState, OAuthTokens, OAuthUserInfo,
|
||||
OIDCDiscovery, OAuthError, OAuthProviderConfig
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"JWTManager",
|
||||
"KeyManager",
|
||||
"TokenValidator",
|
||||
"TokenBlacklist",
|
||||
"AuthMiddleware",
|
||||
"TokenManager",
|
||||
"TokenInfo",
|
||||
"TokenValidationResult",
|
||||
"TokenHandlers",
|
||||
"OAuthClient",
|
||||
"OAuthStateManager",
|
||||
"OAuthAuthenticationProvider",
|
||||
"OAuthProvider",
|
||||
"OAuthState",
|
||||
"OAuthTokens",
|
||||
"OAuthUserInfo",
|
||||
"OIDCDiscovery",
|
||||
"OAuthError",
|
||||
"OAuthProviderConfig"
|
||||
]
|
||||
271
doris_mcp_server/auth/auth_middleware.py
Normal file
@@ -0,0 +1,271 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Authentication Middleware Module
|
||||
Provides middleware for JWT authentication in HTTP and MCP contexts
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any, Callable, Awaitable
|
||||
from datetime import datetime
|
||||
|
||||
from .jwt_manager import JWTManager
|
||||
from ..utils.security import AuthContext, SecurityLevel
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AuthMiddleware:
|
||||
"""Authentication Middleware
|
||||
|
||||
Provides JWT authentication functionality for HTTP and MCP requests
|
||||
"""
|
||||
|
||||
def __init__(self, jwt_manager: JWTManager):
|
||||
"""Initialize authentication middleware
|
||||
|
||||
Args:
|
||||
jwt_manager: JWT manager instance
|
||||
"""
|
||||
self.jwt_manager = jwt_manager
|
||||
logger.info("AuthMiddleware initialized")
|
||||
|
||||
def extract_token_from_header(self, authorization: str) -> Optional[str]:
|
||||
"""Extract JWT token from Authorization header
|
||||
|
||||
Args:
|
||||
authorization: Authorization header value
|
||||
|
||||
Returns:
|
||||
JWT token string, or None if not found
|
||||
"""
|
||||
if not authorization:
|
||||
return None
|
||||
|
||||
# Support Bearer format
|
||||
if authorization.startswith('Bearer '):
|
||||
return authorization[7:] # Remove "Bearer " prefix
|
||||
|
||||
# Support direct token format
|
||||
if not authorization.startswith('Basic '):
|
||||
return authorization
|
||||
|
||||
return None
|
||||
|
||||
async def authenticate_request(self, auth_info: Dict[str, Any]) -> AuthContext:
|
||||
"""Authenticate request and return authentication context
|
||||
|
||||
Args:
|
||||
auth_info: Authentication information dictionary
|
||||
|
||||
Returns:
|
||||
AuthContext authentication context
|
||||
|
||||
Raises:
|
||||
ValueError: Authentication failed
|
||||
"""
|
||||
try:
|
||||
auth_type = auth_info.get("type", "jwt")
|
||||
|
||||
if auth_type == "jwt" or auth_type == "token":
|
||||
return await self._authenticate_jwt(auth_info)
|
||||
else:
|
||||
raise ValueError(f"Unsupported authentication type: {auth_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Request authentication failed: {e}")
|
||||
raise
|
||||
|
||||
async def _authenticate_jwt(self, auth_info: Dict[str, Any]) -> AuthContext:
|
||||
"""JWT authentication processing
|
||||
|
||||
Args:
|
||||
auth_info: Authentication information containing JWT token
|
||||
|
||||
Returns:
|
||||
AuthContext authentication context
|
||||
"""
|
||||
# Get token
|
||||
token = auth_info.get("token")
|
||||
if not token:
|
||||
# Try to get from Authorization header
|
||||
authorization = auth_info.get("authorization")
|
||||
token = self.extract_token_from_header(authorization)
|
||||
|
||||
if not token:
|
||||
raise ValueError("Missing JWT token")
|
||||
|
||||
try:
|
||||
# Validate token
|
||||
validation_result = await self.jwt_manager.validate_token(token, 'access')
|
||||
payload = validation_result['payload']
|
||||
|
||||
# Build authentication context
|
||||
auth_context = AuthContext(
|
||||
token_id=payload.get('jti', ''),
|
||||
user_id=payload.get('sub'),
|
||||
roles=payload.get('roles', []),
|
||||
permissions=payload.get('permissions', []),
|
||||
security_level=SecurityLevel(payload.get('security_level', 'internal')),
|
||||
session_id=payload.get('jti'), # Use JWT ID as session ID
|
||||
login_time=datetime.fromtimestamp(payload.get('iat', 0)),
|
||||
last_activity=datetime.utcnow(),
|
||||
token=token # Store raw token for token-bound database configuration
|
||||
)
|
||||
|
||||
logger.info(f"JWT authentication successful for user: {auth_context.user_id}")
|
||||
return auth_context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"JWT authentication failed: {e}")
|
||||
raise ValueError(f"JWT authentication failed: {str(e)}")
|
||||
|
||||
async def create_auth_response_headers(self, auth_context: AuthContext) -> Dict[str, str]:
|
||||
"""Create authentication response headers
|
||||
|
||||
Args:
|
||||
auth_context: Authentication context
|
||||
|
||||
Returns:
|
||||
Response headers dictionary
|
||||
"""
|
||||
return {
|
||||
'X-Auth-User': auth_context.user_id,
|
||||
'X-Auth-Roles': ','.join(auth_context.roles),
|
||||
'X-Auth-Session': auth_context.session_id,
|
||||
'X-Auth-Security-Level': auth_context.security_level.value
|
||||
}
|
||||
|
||||
def create_http_middleware(self, skip_paths: Optional[list] = None):
|
||||
"""Create HTTP middleware function
|
||||
|
||||
Args:
|
||||
skip_paths: List of paths to skip authentication
|
||||
|
||||
Returns:
|
||||
ASGI middleware function
|
||||
"""
|
||||
skip_paths = skip_paths or ['/health', '/docs', '/openapi.json']
|
||||
|
||||
async def middleware(scope, receive, send):
|
||||
"""HTTP authentication middleware"""
|
||||
if scope['type'] != 'http':
|
||||
# Pass through non-HTTP requests directly
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get('path', '')
|
||||
|
||||
# Check if authentication should be skipped
|
||||
if any(path.startswith(skip) for skip in skip_paths):
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
# Extract authentication information
|
||||
headers = dict(scope.get('headers', []))
|
||||
authorization = headers.get(b'authorization', b'').decode()
|
||||
|
||||
try:
|
||||
# Perform authentication
|
||||
auth_info = {
|
||||
'type': 'jwt',
|
||||
'authorization': authorization
|
||||
}
|
||||
auth_context = await self.authenticate_request(auth_info)
|
||||
|
||||
# Add authentication context to scope
|
||||
scope['auth_context'] = auth_context
|
||||
|
||||
# Create response wrapper to add authentication headers
|
||||
async def send_wrapper(message):
|
||||
if message['type'] == 'http.response.start':
|
||||
headers = dict(message.get('headers', []))
|
||||
auth_headers = await self.create_auth_response_headers(auth_context)
|
||||
|
||||
for key, value in auth_headers.items():
|
||||
headers[key.encode()] = value.encode()
|
||||
|
||||
message['headers'] = list(headers.items())
|
||||
|
||||
await send(message)
|
||||
|
||||
return await self.app(scope, receive, send_wrapper)
|
||||
|
||||
except Exception as e:
|
||||
# Authentication failed, return 401 error
|
||||
response_body = f'{{"error": "Authentication failed", "message": "{str(e)}"}}'
|
||||
|
||||
await send({
|
||||
'type': 'http.response.start',
|
||||
'status': 401,
|
||||
'headers': [
|
||||
(b'content-type', b'application/json'),
|
||||
(b'www-authenticate', b'Bearer')
|
||||
]
|
||||
})
|
||||
await send({
|
||||
'type': 'http.response.body',
|
||||
'body': response_body.encode()
|
||||
})
|
||||
|
||||
return middleware
|
||||
|
||||
async def authenticate_mcp_request(self, headers: Dict[str, str]) -> AuthContext:
|
||||
"""Authenticate MCP request
|
||||
|
||||
Args:
|
||||
headers: MCP request headers
|
||||
|
||||
Returns:
|
||||
AuthContext authentication context
|
||||
"""
|
||||
try:
|
||||
# Extract authentication information from multiple possible header fields
|
||||
authorization = (
|
||||
headers.get('Authorization') or
|
||||
headers.get('authorization') or
|
||||
headers.get('X-Auth-Token') or
|
||||
headers.get('x-auth-token')
|
||||
)
|
||||
|
||||
auth_info = {
|
||||
'type': 'jwt',
|
||||
'authorization': authorization
|
||||
}
|
||||
|
||||
return await self.authenticate_request(auth_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP request authentication failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
"""Authentication error exception"""
|
||||
|
||||
def __init__(self, message: str, error_code: str = "AUTH_FAILED"):
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AuthorizationError(Exception):
|
||||
"""Authorization error exception"""
|
||||
|
||||
def __init__(self, message: str, error_code: str = "ACCESS_DENIED"):
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
super().__init__(message)
|
||||
471
doris_mcp_server/auth/jwt_manager.py
Normal file
@@ -0,0 +1,471 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
JWT Manager Module
|
||||
Provides comprehensive JWT token management including generation, validation, refresh and revocation
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
try:
|
||||
import jwt
|
||||
except ImportError:
|
||||
raise ImportError("PyJWT is required for JWT functionality. Install with: pip install PyJWT[crypto]")
|
||||
|
||||
from .key_manager import KeyManager
|
||||
from .token_validators import TokenValidator, TokenBlacklist
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class JWTManager:
|
||||
"""JWT Token Manager
|
||||
|
||||
Provides comprehensive JWT token lifecycle management, including:
|
||||
- Token generation and signing
|
||||
- Token validation and parsing
|
||||
- Token refresh mechanism
|
||||
- Token revocation and blacklist
|
||||
- Automatic key rotation
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize JWT manager
|
||||
|
||||
Args:
|
||||
config: DorisConfig configuration object (with security attribute)
|
||||
"""
|
||||
self.config = config
|
||||
# Access JWT settings through the security configuration
|
||||
if hasattr(config, 'security'):
|
||||
security_config = config.security
|
||||
else:
|
||||
# Fallback if config is passed directly as SecurityConfig
|
||||
security_config = config
|
||||
|
||||
self.algorithm = security_config.jwt_algorithm
|
||||
self.issuer = security_config.jwt_issuer
|
||||
self.audience = security_config.jwt_audience
|
||||
self.access_token_expiry = security_config.jwt_access_token_expiry
|
||||
self.refresh_token_expiry = security_config.jwt_refresh_token_expiry
|
||||
self.enable_refresh = security_config.enable_token_refresh
|
||||
self.enable_revocation = security_config.enable_token_revocation
|
||||
|
||||
# Initialize components
|
||||
self.key_manager = KeyManager(config)
|
||||
self.token_blacklist = TokenBlacklist()
|
||||
self.validator = TokenValidator(config, self.token_blacklist)
|
||||
|
||||
# Automatic key rotation task
|
||||
self._key_rotation_task = None
|
||||
|
||||
logger.info(f"JWTManager initialized with algorithm: {self.algorithm}")
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize JWT manager"""
|
||||
try:
|
||||
# Initialize key manager
|
||||
if not await self.key_manager.initialize():
|
||||
logger.error("Failed to initialize key manager")
|
||||
return False
|
||||
|
||||
# Start token validator
|
||||
await self.validator.start()
|
||||
|
||||
# Start automatic key rotation
|
||||
if self.key_manager.key_rotation_interval > 0:
|
||||
self._key_rotation_task = asyncio.create_task(self._auto_key_rotation())
|
||||
|
||||
logger.info("JWTManager initialization completed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize JWTManager: {e}")
|
||||
return False
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown JWT manager"""
|
||||
try:
|
||||
# Stop key rotation task
|
||||
if self._key_rotation_task:
|
||||
self._key_rotation_task.cancel()
|
||||
try:
|
||||
await self._key_rotation_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Stop validator
|
||||
await self.validator.stop()
|
||||
|
||||
logger.info("JWTManager shutdown completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during JWTManager shutdown: {e}")
|
||||
|
||||
async def generate_tokens(self, user_info: Dict[str, Any],
|
||||
custom_claims: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""Generate access token and refresh token
|
||||
|
||||
Args:
|
||||
user_info: User information dictionary, containing user_id, roles, permissions, etc.
|
||||
custom_claims: Custom claims
|
||||
|
||||
Returns:
|
||||
Dictionary containing access_token and refresh_token
|
||||
"""
|
||||
try:
|
||||
current_time = int(time.time())
|
||||
jti = str(uuid.uuid4())
|
||||
|
||||
# Build base payload
|
||||
base_payload = {
|
||||
'iss': self.issuer,
|
||||
'aud': self.audience,
|
||||
'iat': current_time,
|
||||
'jti': jti,
|
||||
'sub': user_info.get('user_id'),
|
||||
'roles': user_info.get('roles', []),
|
||||
'permissions': user_info.get('permissions', []),
|
||||
'security_level': user_info.get('security_level', 'internal')
|
||||
}
|
||||
|
||||
# Add custom claims
|
||||
if custom_claims:
|
||||
base_payload.update(custom_claims)
|
||||
|
||||
# Generate access token
|
||||
access_payload = base_payload.copy()
|
||||
access_payload.update({
|
||||
'exp': current_time + self.access_token_expiry,
|
||||
'token_type': 'access'
|
||||
})
|
||||
|
||||
access_token = await self._sign_token(access_payload)
|
||||
|
||||
result = {
|
||||
'access_token': access_token,
|
||||
'token_type': 'Bearer',
|
||||
'expires_in': self.access_token_expiry,
|
||||
'user_id': user_info.get('user_id'),
|
||||
'issued_at': current_time
|
||||
}
|
||||
|
||||
# Generate refresh token (if enabled)
|
||||
if self.enable_refresh:
|
||||
refresh_jti = str(uuid.uuid4())
|
||||
refresh_payload = {
|
||||
'iss': self.issuer,
|
||||
'aud': self.audience,
|
||||
'iat': current_time,
|
||||
'exp': current_time + self.refresh_token_expiry,
|
||||
'jti': refresh_jti,
|
||||
'sub': user_info.get('user_id'),
|
||||
'token_type': 'refresh',
|
||||
'access_jti': jti # Associated access token ID
|
||||
}
|
||||
|
||||
refresh_token = await self._sign_token(refresh_payload)
|
||||
result.update({
|
||||
'refresh_token': refresh_token,
|
||||
'refresh_expires_in': self.refresh_token_expiry
|
||||
})
|
||||
|
||||
logger.info(f"Generated tokens for user: {user_info.get('user_id')}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate tokens: {e}")
|
||||
raise
|
||||
|
||||
async def _sign_token(self, payload: Dict[str, Any]) -> str:
|
||||
"""Sign JWT token
|
||||
|
||||
Args:
|
||||
payload: JWT payload
|
||||
|
||||
Returns:
|
||||
Signed JWT token
|
||||
"""
|
||||
try:
|
||||
signing_key = self.key_manager.get_private_key()
|
||||
|
||||
if self.algorithm == "HS256":
|
||||
# Symmetric key signing
|
||||
token = jwt.encode(payload, signing_key, algorithm=self.algorithm)
|
||||
else:
|
||||
# Asymmetric key signing
|
||||
token = jwt.encode(payload, signing_key, algorithm=self.algorithm)
|
||||
|
||||
return token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sign token: {e}")
|
||||
raise
|
||||
|
||||
async def validate_token(self, token: str, token_type: str = 'access') -> Dict[str, Any]:
|
||||
"""Validate JWT token
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
token_type: Token type ('access' or 'refresh')
|
||||
|
||||
Returns:
|
||||
Validation result and user information
|
||||
|
||||
Raises:
|
||||
ValueError: Token validation failed
|
||||
"""
|
||||
try:
|
||||
# Decode token
|
||||
verification_key = self.key_manager.get_public_key()
|
||||
|
||||
# Get security configuration
|
||||
if hasattr(self.config, 'security'):
|
||||
security_config = self.config.security
|
||||
else:
|
||||
security_config = self.config
|
||||
|
||||
# JWT decoding options
|
||||
options = {
|
||||
'verify_signature': security_config.jwt_verify_signature,
|
||||
'verify_exp': security_config.jwt_require_exp,
|
||||
'verify_iat': security_config.jwt_require_iat,
|
||||
'verify_nbf': security_config.jwt_require_nbf,
|
||||
'verify_aud': security_config.jwt_verify_audience,
|
||||
'verify_iss': security_config.jwt_verify_issuer,
|
||||
}
|
||||
|
||||
# Decode JWT
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
verification_key,
|
||||
algorithms=[self.algorithm],
|
||||
audience=self.audience if security_config.jwt_verify_audience else None,
|
||||
issuer=self.issuer if security_config.jwt_verify_issuer else None,
|
||||
leeway=security_config.jwt_leeway,
|
||||
options=options
|
||||
)
|
||||
|
||||
# Check token type
|
||||
if payload.get('token_type') != token_type:
|
||||
raise ValueError(f"Invalid token type: expected {token_type}")
|
||||
|
||||
# Use validator for additional checks
|
||||
validation_result = await self.validator.validate_claims(payload)
|
||||
|
||||
logger.info(f"Token validation successful for user: {payload.get('sub')}")
|
||||
return validation_result
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise ValueError("Token has expired")
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise ValueError(f"Invalid token: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Token validation failed: {e}")
|
||||
raise ValueError(f"Token validation failed: {str(e)}")
|
||||
|
||||
async def refresh_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||
"""Refresh access token
|
||||
|
||||
Args:
|
||||
refresh_token: Refresh token
|
||||
|
||||
Returns:
|
||||
New token pair
|
||||
"""
|
||||
if not self.enable_refresh:
|
||||
raise ValueError("Token refresh is disabled")
|
||||
|
||||
try:
|
||||
# Validate refresh token
|
||||
refresh_result = await self.validate_token(refresh_token, 'refresh')
|
||||
refresh_payload = refresh_result['payload']
|
||||
|
||||
# Revoke associated access token (if revocation is enabled)
|
||||
if self.enable_revocation:
|
||||
access_jti = refresh_payload.get('access_jti')
|
||||
if access_jti:
|
||||
# Should revoke old access token here, but since we don't know its expiration time,
|
||||
# in practice might need to store more information or use different strategy
|
||||
pass
|
||||
|
||||
# Build new user information
|
||||
user_info = {
|
||||
'user_id': refresh_payload.get('sub'),
|
||||
'roles': refresh_payload.get('roles', []),
|
||||
'permissions': refresh_payload.get('permissions', []),
|
||||
'security_level': refresh_payload.get('security_level', 'internal')
|
||||
}
|
||||
|
||||
# Generate new token pair
|
||||
new_tokens = await self.generate_tokens(user_info)
|
||||
|
||||
logger.info(f"Token refreshed for user: {user_info['user_id']}")
|
||||
return new_tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token refresh failed: {e}")
|
||||
raise
|
||||
|
||||
async def revoke_token(self, token: str) -> bool:
|
||||
"""Revoke token
|
||||
|
||||
Args:
|
||||
token: Token to revoke
|
||||
|
||||
Returns:
|
||||
Whether revocation was successful
|
||||
"""
|
||||
if not self.enable_revocation:
|
||||
logger.warning("Token revocation is disabled")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Decode token to get JTI and expiration time
|
||||
verification_key = self.key_manager.get_public_key()
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
verification_key,
|
||||
algorithms=[self.algorithm],
|
||||
options={'verify_exp': False} # Allow decoding expired tokens
|
||||
)
|
||||
|
||||
jti = payload.get('jti')
|
||||
exp = payload.get('exp')
|
||||
|
||||
if not jti or not exp:
|
||||
logger.error("Token missing required claims for revocation")
|
||||
return False
|
||||
|
||||
# Add to blacklist
|
||||
await self.validator.revoke_token(jti, exp)
|
||||
|
||||
logger.info(f"Token {jti} revoked successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token revocation failed: {e}")
|
||||
return False
|
||||
|
||||
async def decode_token_unsafe(self, token: str) -> Dict[str, Any]:
|
||||
"""Decode token without verifying signature (for debugging only)
|
||||
|
||||
Args:
|
||||
token: JWT token
|
||||
|
||||
Returns:
|
||||
Token payload
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, options={'verify_signature': False})
|
||||
return payload
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decode token: {e}")
|
||||
raise
|
||||
|
||||
async def get_token_info(self, token: str) -> Dict[str, Any]:
|
||||
"""Get token information (without verifying signature)
|
||||
|
||||
Args:
|
||||
token: JWT token
|
||||
|
||||
Returns:
|
||||
Token information
|
||||
"""
|
||||
try:
|
||||
payload = await self.decode_token_unsafe(token)
|
||||
|
||||
return {
|
||||
'jti': payload.get('jti'),
|
||||
'sub': payload.get('sub'),
|
||||
'iss': payload.get('iss'),
|
||||
'aud': payload.get('aud'),
|
||||
'iat': payload.get('iat'),
|
||||
'exp': payload.get('exp'),
|
||||
'token_type': payload.get('token_type'),
|
||||
'roles': payload.get('roles'),
|
||||
'permissions': payload.get('permissions'),
|
||||
'security_level': payload.get('security_level'),
|
||||
'is_expired': payload.get('exp', 0) < time.time() if payload.get('exp') else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get token info: {e}")
|
||||
raise
|
||||
|
||||
async def _auto_key_rotation(self):
|
||||
"""Automatic key rotation task"""
|
||||
while True:
|
||||
try:
|
||||
# Check if key rotation is needed
|
||||
if await self.key_manager.is_key_expired():
|
||||
logger.info("Key rotation needed, rotating keys...")
|
||||
await self.key_manager.rotate_keys()
|
||||
|
||||
# Wait until next check
|
||||
await asyncio.sleep(3600) # Check every hour
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in auto key rotation: {e}")
|
||||
# Wait longer before retry after error
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
async def get_public_key_info(self) -> Dict[str, Any]:
|
||||
"""Get public key information (for client verification)
|
||||
|
||||
Returns:
|
||||
Public key information
|
||||
"""
|
||||
key_info = await self.key_manager.get_key_info()
|
||||
public_key_pem = await self.key_manager.export_public_key_pem()
|
||||
|
||||
return {
|
||||
'algorithm': self.algorithm,
|
||||
'public_key_pem': public_key_pem,
|
||||
'key_info': key_info
|
||||
}
|
||||
|
||||
async def get_manager_stats(self) -> Dict[str, Any]:
|
||||
"""Get manager statistics
|
||||
|
||||
Returns:
|
||||
Statistics information
|
||||
"""
|
||||
key_info = await self.key_manager.get_key_info()
|
||||
validation_stats = await self.validator.get_validation_stats()
|
||||
|
||||
return {
|
||||
'jwt_config': {
|
||||
'algorithm': self.algorithm,
|
||||
'issuer': self.issuer,
|
||||
'audience': self.audience,
|
||||
'access_token_expiry': self.access_token_expiry,
|
||||
'refresh_token_expiry': self.refresh_token_expiry,
|
||||
'enable_refresh': self.enable_refresh,
|
||||
'enable_revocation': self.enable_revocation
|
||||
},
|
||||
'key_manager': key_info,
|
||||
'validator': validation_stats
|
||||
}
|
||||
343
doris_mcp_server/auth/key_manager.py
Normal file
@@ -0,0 +1,343 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
JWT Key Management Module
|
||||
Provides secure key generation, loading, rotation and management for JWT tokens
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
from datetime import datetime, timedelta
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa, ec
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class KeyManager:
|
||||
"""JWT Key Manager
|
||||
|
||||
Responsible for generating, loading, rotating and securely storing JWT signing keys
|
||||
Supports RSA and EC algorithms, provides automatic key rotation functionality
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize key manager
|
||||
|
||||
Args:
|
||||
config: DorisConfig configuration object (with security attribute)
|
||||
"""
|
||||
self.config = config
|
||||
# Access JWT settings through the security configuration
|
||||
if hasattr(config, 'security'):
|
||||
security_config = config.security
|
||||
else:
|
||||
# Fallback if config is passed directly as SecurityConfig
|
||||
security_config = config
|
||||
|
||||
self.algorithm = security_config.jwt_algorithm
|
||||
self.key_rotation_interval = security_config.key_rotation_interval
|
||||
self.private_key_path = security_config.jwt_private_key_path
|
||||
self.public_key_path = security_config.jwt_public_key_path
|
||||
self.secret_key = security_config.jwt_secret_key
|
||||
|
||||
# Key storage
|
||||
self._private_key = None
|
||||
self._public_key = None
|
||||
self._secret_key = None
|
||||
self._key_generated_at = None
|
||||
|
||||
logger.info(f"KeyManager initialized with algorithm: {self.algorithm}")
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize key manager, load or generate keys"""
|
||||
try:
|
||||
if self.algorithm == "HS256":
|
||||
await self._initialize_symmetric_key()
|
||||
else:
|
||||
await self._initialize_asymmetric_keys()
|
||||
|
||||
logger.info("KeyManager initialization completed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize KeyManager: {e}")
|
||||
return False
|
||||
|
||||
async def _initialize_symmetric_key(self):
|
||||
"""Initialize symmetric key (HS256)"""
|
||||
if self.secret_key:
|
||||
# Use configured key
|
||||
self._secret_key = self.secret_key.encode()
|
||||
logger.info("Loaded symmetric key from configuration")
|
||||
else:
|
||||
# Generate new key
|
||||
self._secret_key = await self.generate_symmetric_key()
|
||||
logger.info("Generated new symmetric key")
|
||||
|
||||
self._key_generated_at = datetime.utcnow()
|
||||
|
||||
async def _initialize_asymmetric_keys(self):
|
||||
"""Initialize asymmetric key pair (RS256/ES256)"""
|
||||
# Try to load keys from files
|
||||
if await self._load_keys_from_files():
|
||||
logger.info("Loaded asymmetric keys from files")
|
||||
return
|
||||
|
||||
# Try to load from environment variables
|
||||
if await self._load_keys_from_env():
|
||||
logger.info("Loaded asymmetric keys from environment")
|
||||
return
|
||||
|
||||
# Generate new key pair
|
||||
await self.generate_key_pair()
|
||||
logger.info("Generated new asymmetric key pair")
|
||||
|
||||
async def _load_keys_from_files(self) -> bool:
|
||||
"""Load keys from files"""
|
||||
try:
|
||||
if not self.private_key_path or not self.public_key_path:
|
||||
return False
|
||||
|
||||
private_path = Path(self.private_key_path)
|
||||
public_path = Path(self.public_key_path)
|
||||
|
||||
if not (private_path.exists() and public_path.exists()):
|
||||
return False
|
||||
|
||||
# Read private key
|
||||
with open(private_path, 'rb') as f:
|
||||
private_key_data = f.read()
|
||||
self._private_key = serialization.load_pem_private_key(
|
||||
private_key_data, password=None, backend=default_backend()
|
||||
)
|
||||
|
||||
# Read public key
|
||||
with open(public_path, 'rb') as f:
|
||||
public_key_data = f.read()
|
||||
self._public_key = serialization.load_pem_public_key(
|
||||
public_key_data, backend=default_backend()
|
||||
)
|
||||
|
||||
# Get key generation time (using file modification time)
|
||||
self._key_generated_at = datetime.fromtimestamp(private_path.stat().st_mtime)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load keys from files: {e}")
|
||||
return False
|
||||
|
||||
async def _load_keys_from_env(self) -> bool:
|
||||
"""Load keys from environment variables"""
|
||||
try:
|
||||
private_key_env = os.getenv('JWT_PRIVATE_KEY')
|
||||
public_key_env = os.getenv('JWT_PUBLIC_KEY')
|
||||
|
||||
if not (private_key_env and public_key_env):
|
||||
return False
|
||||
|
||||
# Parse private key
|
||||
self._private_key = serialization.load_pem_private_key(
|
||||
private_key_env.encode(), password=None, backend=default_backend()
|
||||
)
|
||||
|
||||
# Parse public key
|
||||
self._public_key = serialization.load_pem_public_key(
|
||||
public_key_env.encode(), backend=default_backend()
|
||||
)
|
||||
|
||||
self._key_generated_at = datetime.utcnow()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load keys from environment: {e}")
|
||||
return False
|
||||
|
||||
async def generate_symmetric_key(self, length: int = 32) -> bytes:
|
||||
"""Generate symmetric key
|
||||
|
||||
Args:
|
||||
length: Key length (bytes), default 32 bytes (256 bits)
|
||||
|
||||
Returns:
|
||||
Generated key
|
||||
"""
|
||||
return secrets.token_bytes(length)
|
||||
|
||||
async def generate_key_pair(self) -> Tuple[bytes, bytes]:
|
||||
"""Generate asymmetric key pair
|
||||
|
||||
Returns:
|
||||
(private key PEM, public key PEM) tuple
|
||||
"""
|
||||
try:
|
||||
if self.algorithm == "RS256":
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048,
|
||||
backend=default_backend()
|
||||
)
|
||||
elif self.algorithm == "ES256":
|
||||
private_key = ec.generate_private_key(
|
||||
ec.SECP256R1(), backend=default_backend()
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm for key generation: {self.algorithm}")
|
||||
|
||||
# Get public key
|
||||
public_key = private_key.public_key()
|
||||
|
||||
# Serialize private key
|
||||
private_pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption()
|
||||
)
|
||||
|
||||
# Serialize public key
|
||||
public_pem = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
)
|
||||
|
||||
# Store keys
|
||||
self._private_key = private_key
|
||||
self._public_key = public_key
|
||||
self._key_generated_at = datetime.utcnow()
|
||||
|
||||
# If file paths are configured, save to files
|
||||
if self.private_key_path and self.public_key_path:
|
||||
await self._save_keys_to_files(private_pem, public_pem)
|
||||
|
||||
logger.info(f"Generated new {self.algorithm} key pair")
|
||||
return private_pem, public_pem
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate key pair: {e}")
|
||||
raise
|
||||
|
||||
async def _save_keys_to_files(self, private_pem: bytes, public_pem: bytes):
|
||||
"""Save keys to files"""
|
||||
try:
|
||||
# Ensure directories exist
|
||||
private_path = Path(self.private_key_path)
|
||||
public_path = Path(self.public_key_path)
|
||||
|
||||
private_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
public_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save private key (set secure permissions)
|
||||
with open(private_path, 'wb') as f:
|
||||
f.write(private_pem)
|
||||
os.chmod(private_path, 0o600) # Only owner can read/write
|
||||
|
||||
# Save public key
|
||||
with open(public_path, 'wb') as f:
|
||||
f.write(public_pem)
|
||||
os.chmod(public_path, 0o644) # Owner read/write, others read only
|
||||
|
||||
logger.info(f"Saved keys to files: {private_path}, {public_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save keys to files: {e}")
|
||||
raise
|
||||
|
||||
def get_private_key(self):
|
||||
"""Get private key for signing"""
|
||||
if self.algorithm == "HS256":
|
||||
return self._secret_key
|
||||
else:
|
||||
return self._private_key
|
||||
|
||||
def get_public_key(self):
|
||||
"""Get public key for verification"""
|
||||
if self.algorithm == "HS256":
|
||||
return self._secret_key
|
||||
else:
|
||||
return self._public_key
|
||||
|
||||
def get_algorithm(self) -> str:
|
||||
"""Get signing algorithm"""
|
||||
return self.algorithm
|
||||
|
||||
async def is_key_expired(self) -> bool:
|
||||
"""Check if key is expired"""
|
||||
if not self._key_generated_at:
|
||||
return True
|
||||
|
||||
expiry_time = self._key_generated_at + timedelta(seconds=self.key_rotation_interval)
|
||||
return datetime.utcnow() > expiry_time
|
||||
|
||||
async def rotate_keys(self) -> bool:
|
||||
"""Rotate keys"""
|
||||
try:
|
||||
logger.info("Starting key rotation")
|
||||
|
||||
if self.algorithm == "HS256":
|
||||
# Generate new symmetric key
|
||||
self._secret_key = await self.generate_symmetric_key()
|
||||
self._key_generated_at = datetime.utcnow()
|
||||
else:
|
||||
# Generate new asymmetric key pair
|
||||
await self.generate_key_pair()
|
||||
|
||||
logger.info("Key rotation completed successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Key rotation failed: {e}")
|
||||
return False
|
||||
|
||||
async def get_key_info(self) -> dict:
|
||||
"""Get key information"""
|
||||
return {
|
||||
"algorithm": self.algorithm,
|
||||
"key_generated_at": self._key_generated_at.isoformat() if self._key_generated_at else None,
|
||||
"key_expires_at": (
|
||||
self._key_generated_at + timedelta(seconds=self.key_rotation_interval)
|
||||
).isoformat() if self._key_generated_at else None,
|
||||
"is_expired": await self.is_key_expired(),
|
||||
"has_private_key": self._private_key is not None or self._secret_key is not None,
|
||||
"has_public_key": self._public_key is not None or self._secret_key is not None
|
||||
}
|
||||
|
||||
async def export_public_key_pem(self) -> Optional[str]:
|
||||
"""Export public key in PEM format"""
|
||||
if self.algorithm == "HS256":
|
||||
return None # Symmetric key not exported
|
||||
|
||||
if not self._public_key:
|
||||
return None
|
||||
|
||||
try:
|
||||
public_pem = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
)
|
||||
return public_pem.decode()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to export public key: {e}")
|
||||
return None
|
||||
536
doris_mcp_server/auth/oauth_client.py
Normal file
@@ -0,0 +1,536 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
OAuth 2.0/OIDC Client Manager
|
||||
Provides OAuth authentication client implementation with PKCE and OIDC support
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional, Any, Tuple
|
||||
from urllib.parse import urlencode, parse_qs, urlparse
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError:
|
||||
raise ImportError("aiohttp is required for OAuth functionality. Install with: pip install aiohttp")
|
||||
|
||||
from .oauth_types import (
|
||||
OAuthProvider, OAuthState, OAuthTokens, OAuthUserInfo,
|
||||
OIDCDiscovery, OAuthError, OAuthProviderConfig, OAUTH_PROVIDERS
|
||||
)
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OAuthStateManager:
|
||||
"""Manages OAuth state parameters for CSRF protection"""
|
||||
|
||||
def __init__(self, state_expiry: int = 600):
|
||||
"""Initialize state manager
|
||||
|
||||
Args:
|
||||
state_expiry: State expiry time in seconds
|
||||
"""
|
||||
self.state_expiry = state_expiry
|
||||
self._states: Dict[str, OAuthState] = {}
|
||||
self._cleanup_task = None
|
||||
|
||||
logger.info("OAuthStateManager initialized")
|
||||
|
||||
async def start(self):
|
||||
"""Start periodic cleanup task"""
|
||||
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||
logger.info("OAuth state manager started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop periodic cleanup task"""
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("OAuth state manager stopped")
|
||||
|
||||
def create_state(self, redirect_uri: str, pkce_enabled: bool = True,
|
||||
nonce_enabled: bool = True) -> OAuthState:
|
||||
"""Create new OAuth state
|
||||
|
||||
Args:
|
||||
redirect_uri: OAuth redirect URI
|
||||
pkce_enabled: Whether to enable PKCE
|
||||
nonce_enabled: Whether to enable nonce (for OIDC)
|
||||
|
||||
Returns:
|
||||
OAuth state object
|
||||
"""
|
||||
state = secrets.token_urlsafe(32)
|
||||
nonce = secrets.token_urlsafe(32) if nonce_enabled else None
|
||||
|
||||
pkce_verifier = None
|
||||
pkce_challenge = None
|
||||
if pkce_enabled:
|
||||
pkce_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=')
|
||||
challenge_bytes = hashlib.sha256(pkce_verifier.encode()).digest()
|
||||
pkce_challenge = base64.urlsafe_b64encode(challenge_bytes).decode('utf-8').rstrip('=')
|
||||
|
||||
oauth_state = OAuthState(
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
pkce_verifier=pkce_verifier,
|
||||
pkce_challenge=pkce_challenge,
|
||||
redirect_uri=redirect_uri,
|
||||
created_at=datetime.utcnow(),
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=self.state_expiry)
|
||||
)
|
||||
|
||||
self._states[state] = oauth_state
|
||||
logger.debug(f"Created OAuth state: {state}")
|
||||
return oauth_state
|
||||
|
||||
def get_state(self, state: str) -> Optional[OAuthState]:
|
||||
"""Get OAuth state by state parameter
|
||||
|
||||
Args:
|
||||
state: State parameter
|
||||
|
||||
Returns:
|
||||
OAuth state object or None if not found/expired
|
||||
"""
|
||||
oauth_state = self._states.get(state)
|
||||
if oauth_state and oauth_state.expires_at > datetime.utcnow():
|
||||
return oauth_state
|
||||
elif oauth_state:
|
||||
# Remove expired state
|
||||
del self._states[state]
|
||||
logger.debug(f"Removed expired OAuth state: {state}")
|
||||
return None
|
||||
|
||||
def consume_state(self, state: str) -> Optional[OAuthState]:
|
||||
"""Get and remove OAuth state
|
||||
|
||||
Args:
|
||||
state: State parameter
|
||||
|
||||
Returns:
|
||||
OAuth state object or None if not found/expired
|
||||
"""
|
||||
oauth_state = self.get_state(state)
|
||||
if oauth_state:
|
||||
del self._states[state]
|
||||
logger.debug(f"Consumed OAuth state: {state}")
|
||||
return oauth_state
|
||||
|
||||
async def _periodic_cleanup(self):
|
||||
"""Periodic cleanup of expired states"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(300) # Clean up every 5 minutes
|
||||
current_time = datetime.utcnow()
|
||||
expired_states = [
|
||||
state for state, oauth_state in self._states.items()
|
||||
if oauth_state.expires_at <= current_time
|
||||
]
|
||||
|
||||
for state in expired_states:
|
||||
del self._states[state]
|
||||
|
||||
if expired_states:
|
||||
logger.info(f"Cleaned up {len(expired_states)} expired OAuth states")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error during OAuth state cleanup: {e}")
|
||||
|
||||
|
||||
class OAuthClient:
|
||||
"""OAuth 2.0/OIDC Client implementation"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize OAuth client
|
||||
|
||||
Args:
|
||||
config: DorisConfig with OAuth configuration
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
# Access OAuth settings through security configuration
|
||||
if hasattr(config, 'security'):
|
||||
security_config = config.security
|
||||
else:
|
||||
security_config = config
|
||||
|
||||
self.enabled = security_config.oauth_enabled
|
||||
if not self.enabled:
|
||||
logger.info("OAuth client disabled by configuration")
|
||||
return
|
||||
|
||||
# Build provider configuration
|
||||
self.provider_config = self._build_provider_config(security_config)
|
||||
self.state_manager = OAuthStateManager(security_config.oauth_state_expiry)
|
||||
|
||||
# HTTP client session
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
# Discovery cache
|
||||
self._discovery_cache: Optional[OIDCDiscovery] = None
|
||||
self._discovery_cache_time: Optional[datetime] = None
|
||||
|
||||
logger.info(f"OAuthClient initialized for provider: {self.provider_config.provider.value}")
|
||||
|
||||
def _build_provider_config(self, security_config) -> OAuthProviderConfig:
|
||||
"""Build OAuth provider configuration
|
||||
|
||||
Args:
|
||||
security_config: Security configuration object
|
||||
|
||||
Returns:
|
||||
OAuth provider configuration
|
||||
"""
|
||||
try:
|
||||
provider = OAuthProvider(security_config.oauth_provider)
|
||||
except ValueError:
|
||||
provider = OAuthProvider.CUSTOM
|
||||
|
||||
# Get default configuration for known providers
|
||||
defaults = OAUTH_PROVIDERS.get(provider, {})
|
||||
|
||||
return OAuthProviderConfig(
|
||||
provider=provider,
|
||||
client_id=security_config.oauth_client_id,
|
||||
client_secret=security_config.oauth_client_secret,
|
||||
redirect_uri=security_config.oauth_redirect_uri,
|
||||
scopes=security_config.oauth_scopes or defaults.get("scopes", ["openid", "email", "profile"]),
|
||||
|
||||
# Endpoints (use configured or defaults)
|
||||
authorization_endpoint=security_config.oauth_authorization_endpoint or defaults.get("authorization_endpoint", ""),
|
||||
token_endpoint=security_config.oauth_token_endpoint or defaults.get("token_endpoint", ""),
|
||||
userinfo_endpoint=security_config.oauth_userinfo_endpoint or defaults.get("userinfo_endpoint"),
|
||||
jwks_uri=security_config.oauth_jwks_uri or defaults.get("jwks_uri"),
|
||||
|
||||
# Discovery
|
||||
discovery_url=security_config.oidc_discovery_url or defaults.get("discovery_url"),
|
||||
|
||||
# Settings
|
||||
pkce_enabled=security_config.oauth_pkce_enabled,
|
||||
nonce_enabled=security_config.oauth_nonce_enabled,
|
||||
|
||||
# User mapping
|
||||
user_id_claim=security_config.oauth_user_id_claim or defaults.get("user_id_claim", "sub"),
|
||||
email_claim=security_config.oauth_email_claim or defaults.get("email_claim", "email"),
|
||||
name_claim=security_config.oauth_name_claim or defaults.get("name_claim", "name"),
|
||||
roles_claim=security_config.oauth_roles_claim,
|
||||
default_roles=security_config.oauth_default_roles
|
||||
)
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize OAuth client
|
||||
|
||||
Returns:
|
||||
True if initialization successful
|
||||
"""
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
try:
|
||||
# Create HTTP session
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# Start state manager
|
||||
await self.state_manager.start()
|
||||
|
||||
# Perform OIDC discovery if configured
|
||||
if self.provider_config.discovery_url:
|
||||
await self._discover_oidc_endpoints()
|
||||
|
||||
logger.info("OAuth client initialization completed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize OAuth client: {e}")
|
||||
return False
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown OAuth client"""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
try:
|
||||
# Stop state manager
|
||||
await self.state_manager.stop()
|
||||
|
||||
# Close HTTP session
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
logger.info("OAuth client shutdown completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during OAuth client shutdown: {e}")
|
||||
|
||||
async def _discover_oidc_endpoints(self):
|
||||
"""Discover OIDC endpoints using discovery URL"""
|
||||
try:
|
||||
# Check cache first
|
||||
if (self._discovery_cache and self._discovery_cache_time and
|
||||
datetime.utcnow() - self._discovery_cache_time < timedelta(hours=1)):
|
||||
return self._discovery_cache
|
||||
|
||||
logger.info(f"Discovering OIDC endpoints: {self.provider_config.discovery_url}")
|
||||
|
||||
async with self._session.get(self.provider_config.discovery_url) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
|
||||
discovery = OIDCDiscovery(
|
||||
issuer=data["issuer"],
|
||||
authorization_endpoint=data["authorization_endpoint"],
|
||||
token_endpoint=data["token_endpoint"],
|
||||
userinfo_endpoint=data.get("userinfo_endpoint"),
|
||||
jwks_uri=data.get("jwks_uri"),
|
||||
scopes_supported=data.get("scopes_supported"),
|
||||
response_types_supported=data.get("response_types_supported"),
|
||||
subject_types_supported=data.get("subject_types_supported"),
|
||||
id_token_signing_alg_values_supported=data.get("id_token_signing_alg_values_supported")
|
||||
)
|
||||
|
||||
# Update provider configuration with discovered endpoints
|
||||
if not self.provider_config.authorization_endpoint:
|
||||
self.provider_config.authorization_endpoint = discovery.authorization_endpoint
|
||||
if not self.provider_config.token_endpoint:
|
||||
self.provider_config.token_endpoint = discovery.token_endpoint
|
||||
if not self.provider_config.userinfo_endpoint:
|
||||
self.provider_config.userinfo_endpoint = discovery.userinfo_endpoint
|
||||
if not self.provider_config.jwks_uri:
|
||||
self.provider_config.jwks_uri = discovery.jwks_uri
|
||||
|
||||
# Cache discovery result
|
||||
self._discovery_cache = discovery
|
||||
self._discovery_cache_time = datetime.utcnow()
|
||||
|
||||
logger.info("OIDC endpoint discovery completed successfully")
|
||||
return discovery
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OIDC endpoint discovery failed: {e}")
|
||||
raise
|
||||
|
||||
def build_authorization_url(self) -> Tuple[str, OAuthState]:
|
||||
"""Build OAuth authorization URL
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, oauth_state)
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth client is not enabled")
|
||||
|
||||
# Create state for CSRF protection
|
||||
oauth_state = self.state_manager.create_state(
|
||||
redirect_uri=self.provider_config.redirect_uri,
|
||||
pkce_enabled=self.provider_config.pkce_enabled,
|
||||
nonce_enabled=self.provider_config.nonce_enabled
|
||||
)
|
||||
|
||||
# Build authorization parameters
|
||||
params = {
|
||||
'response_type': 'code',
|
||||
'client_id': self.provider_config.client_id,
|
||||
'redirect_uri': self.provider_config.redirect_uri,
|
||||
'scope': ' '.join(self.provider_config.scopes),
|
||||
'state': oauth_state.state
|
||||
}
|
||||
|
||||
# Add PKCE challenge
|
||||
if oauth_state.pkce_challenge:
|
||||
params['code_challenge'] = oauth_state.pkce_challenge
|
||||
params['code_challenge_method'] = 'S256'
|
||||
|
||||
# Add nonce for OIDC
|
||||
if oauth_state.nonce:
|
||||
params['nonce'] = oauth_state.nonce
|
||||
|
||||
# Build URL
|
||||
authorization_url = f"{self.provider_config.authorization_endpoint}?{urlencode(params)}"
|
||||
|
||||
logger.info(f"Built OAuth authorization URL for state: {oauth_state.state}")
|
||||
return authorization_url, oauth_state
|
||||
|
||||
async def exchange_code_for_tokens(self, code: str, state: str) -> Tuple[OAuthTokens, OAuthState]:
|
||||
"""Exchange authorization code for tokens
|
||||
|
||||
Args:
|
||||
code: Authorization code
|
||||
state: State parameter
|
||||
|
||||
Returns:
|
||||
Tuple of (OAuth tokens, OAuth state)
|
||||
|
||||
Raises:
|
||||
ValueError: If state is invalid or exchange fails
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth client is not enabled")
|
||||
|
||||
# Validate and consume state
|
||||
oauth_state = self.state_manager.consume_state(state)
|
||||
if not oauth_state:
|
||||
raise ValueError("Invalid or expired state parameter")
|
||||
|
||||
try:
|
||||
# Prepare token request
|
||||
data = {
|
||||
'grant_type': 'authorization_code',
|
||||
'client_id': self.provider_config.client_id,
|
||||
'client_secret': self.provider_config.client_secret,
|
||||
'code': code,
|
||||
'redirect_uri': oauth_state.redirect_uri
|
||||
}
|
||||
|
||||
# Add PKCE verifier
|
||||
if oauth_state.pkce_verifier:
|
||||
data['code_verifier'] = oauth_state.pkce_verifier
|
||||
|
||||
# Make token request
|
||||
async with self._session.post(
|
||||
self.provider_config.token_endpoint,
|
||||
data=data,
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded'}
|
||||
) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
if response.status != 200:
|
||||
error_msg = response_data.get('error_description', response_data.get('error', 'Token exchange failed'))
|
||||
raise ValueError(f"Token exchange failed: {error_msg}")
|
||||
|
||||
tokens = OAuthTokens(
|
||||
access_token=response_data['access_token'],
|
||||
token_type=response_data.get('token_type', 'Bearer'),
|
||||
expires_in=response_data.get('expires_in'),
|
||||
refresh_token=response_data.get('refresh_token'),
|
||||
scope=response_data.get('scope'),
|
||||
id_token=response_data.get('id_token')
|
||||
)
|
||||
|
||||
logger.info("Successfully exchanged authorization code for tokens")
|
||||
return tokens, oauth_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token exchange failed: {e}")
|
||||
raise ValueError(f"Token exchange failed: {str(e)}")
|
||||
|
||||
async def get_user_info(self, tokens: OAuthTokens) -> OAuthUserInfo:
|
||||
"""Get user information from OAuth provider
|
||||
|
||||
Args:
|
||||
tokens: OAuth tokens
|
||||
|
||||
Returns:
|
||||
OAuth user information
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth client is not enabled")
|
||||
|
||||
if not self.provider_config.userinfo_endpoint:
|
||||
raise ValueError("Userinfo endpoint not configured")
|
||||
|
||||
try:
|
||||
# Make userinfo request
|
||||
headers = {'Authorization': f'{tokens.token_type} {tokens.access_token}'}
|
||||
|
||||
async with self._session.get(
|
||||
self.provider_config.userinfo_endpoint,
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
user_data = await response.json()
|
||||
|
||||
# Extract user information using configured claims
|
||||
user_info = OAuthUserInfo(
|
||||
sub=str(user_data.get(self.provider_config.user_id_claim, '')),
|
||||
email=user_data.get(self.provider_config.email_claim),
|
||||
name=user_data.get(self.provider_config.name_claim),
|
||||
given_name=user_data.get('given_name'),
|
||||
family_name=user_data.get('family_name'),
|
||||
picture=user_data.get('picture'),
|
||||
locale=user_data.get('locale'),
|
||||
email_verified=user_data.get('email_verified'),
|
||||
roles=user_data.get(self.provider_config.roles_claim, self.provider_config.default_roles.copy()),
|
||||
raw_claims=user_data
|
||||
)
|
||||
|
||||
logger.info(f"Retrieved user info for user: {user_info.sub}")
|
||||
return user_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get user info: {e}")
|
||||
raise ValueError(f"Failed to get user info: {str(e)}")
|
||||
|
||||
async def refresh_tokens(self, refresh_token: str) -> OAuthTokens:
|
||||
"""Refresh OAuth tokens
|
||||
|
||||
Args:
|
||||
refresh_token: Refresh token
|
||||
|
||||
Returns:
|
||||
New OAuth tokens
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth client is not enabled")
|
||||
|
||||
try:
|
||||
data = {
|
||||
'grant_type': 'refresh_token',
|
||||
'client_id': self.provider_config.client_id,
|
||||
'client_secret': self.provider_config.client_secret,
|
||||
'refresh_token': refresh_token
|
||||
}
|
||||
|
||||
async with self._session.post(
|
||||
self.provider_config.token_endpoint,
|
||||
data=data,
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded'}
|
||||
) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
if response.status != 200:
|
||||
error_msg = response_data.get('error_description', response_data.get('error', 'Token refresh failed'))
|
||||
raise ValueError(f"Token refresh failed: {error_msg}")
|
||||
|
||||
tokens = OAuthTokens(
|
||||
access_token=response_data['access_token'],
|
||||
token_type=response_data.get('token_type', 'Bearer'),
|
||||
expires_in=response_data.get('expires_in'),
|
||||
refresh_token=response_data.get('refresh_token', refresh_token), # Keep old if not provided
|
||||
scope=response_data.get('scope'),
|
||||
id_token=response_data.get('id_token')
|
||||
)
|
||||
|
||||
logger.info("Successfully refreshed OAuth tokens")
|
||||
return tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token refresh failed: {e}")
|
||||
raise ValueError(f"Token refresh failed: {str(e)}")
|
||||
312
doris_mcp_server/auth/oauth_handlers.py
Normal file
@@ -0,0 +1,312 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
OAuth HTTP Handlers
|
||||
Provides HTTP endpoints for OAuth authentication flow
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
import json
|
||||
|
||||
from starlette.responses import JSONResponse, RedirectResponse, HTMLResponse
|
||||
from starlette.requests import Request
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OAuthHandlers:
|
||||
"""OAuth HTTP request handlers"""
|
||||
|
||||
def __init__(self, security_manager):
|
||||
"""Initialize OAuth handlers
|
||||
|
||||
Args:
|
||||
security_manager: DorisSecurityManager instance
|
||||
"""
|
||||
self.security_manager = security_manager
|
||||
logger.info("OAuth handlers initialized")
|
||||
|
||||
async def handle_login(self, request: Request) -> JSONResponse:
|
||||
"""Handle OAuth login initiation
|
||||
|
||||
Returns JSON with authorization URL and state
|
||||
"""
|
||||
try:
|
||||
# Check if OAuth is enabled
|
||||
oauth_info = self.security_manager.get_oauth_provider_info()
|
||||
if not oauth_info.get("enabled"):
|
||||
return JSONResponse(
|
||||
{"error": "OAuth authentication is not enabled"},
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# Get authorization URL
|
||||
authorization_url, state = self.security_manager.get_oauth_authorization_url()
|
||||
|
||||
return JSONResponse({
|
||||
"authorization_url": authorization_url,
|
||||
"state": state,
|
||||
"provider": oauth_info.get("provider"),
|
||||
"message": "Navigate to authorization_url to complete OAuth login"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth login initiation failed: {e}")
|
||||
return JSONResponse(
|
||||
{"error": f"OAuth login failed: {str(e)}"},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
async def handle_callback(self, request: Request) -> JSONResponse:
|
||||
"""Handle OAuth callback
|
||||
|
||||
Processes the OAuth callback and returns authentication result
|
||||
"""
|
||||
try:
|
||||
# Get query parameters
|
||||
query_params = dict(request.query_params)
|
||||
|
||||
# Check for error in callback
|
||||
if "error" in query_params:
|
||||
error_description = query_params.get("error_description", "Unknown error")
|
||||
logger.warning(f"OAuth callback error: {query_params['error']} - {error_description}")
|
||||
return JSONResponse(
|
||||
{
|
||||
"error": query_params["error"],
|
||||
"error_description": error_description,
|
||||
"error_uri": query_params.get("error_uri")
|
||||
},
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# Extract required parameters
|
||||
code = query_params.get("code")
|
||||
state = query_params.get("state")
|
||||
|
||||
if not code or not state:
|
||||
return JSONResponse(
|
||||
{"error": "Missing required parameters: code and state"},
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# Handle OAuth callback
|
||||
auth_context = await self.security_manager.handle_oauth_callback(code, state)
|
||||
|
||||
# Return successful authentication response
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"user_id": auth_context.user_id,
|
||||
"roles": auth_context.roles,
|
||||
"permissions": auth_context.permissions,
|
||||
"security_level": auth_context.security_level.value,
|
||||
"session_id": auth_context.session_id,
|
||||
"message": "OAuth authentication successful"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth callback handling failed: {e}")
|
||||
return JSONResponse(
|
||||
{"error": f"OAuth callback failed: {str(e)}"},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
async def handle_provider_info(self, request: Request) -> JSONResponse:
|
||||
"""Handle OAuth provider information request
|
||||
|
||||
Returns information about the configured OAuth provider
|
||||
"""
|
||||
try:
|
||||
provider_info = self.security_manager.get_oauth_provider_info()
|
||||
return JSONResponse(provider_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get OAuth provider info: {e}")
|
||||
return JSONResponse(
|
||||
{"error": f"Failed to get provider info: {str(e)}"},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
async def handle_demo_page(self, request: Request) -> HTMLResponse:
|
||||
"""Handle OAuth demo page
|
||||
|
||||
Returns a simple HTML page for testing OAuth flow
|
||||
"""
|
||||
oauth_info = self.security_manager.get_oauth_provider_info()
|
||||
if not oauth_info.get("enabled"):
|
||||
return HTMLResponse("""
|
||||
<html>
|
||||
<head><title>OAuth Demo</title></head>
|
||||
<body>
|
||||
<h1>OAuth Demo</h1>
|
||||
<p style="color: red;">OAuth authentication is not enabled.</p>
|
||||
<p>Please configure OAuth settings in your security configuration.</p>
|
||||
</body>
|
||||
</html>
|
||||
""")
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Doris MCP Server - OAuth Demo</title>
|
||||
<style>
|
||||
body {{
|
||||
font-family: Arial, sans-serif;
|
||||
max-width: 800px;
|
||||
margin: 0 auto;
|
||||
padding: 20px;
|
||||
}}
|
||||
.info {{
|
||||
background-color: #f0f8ff;
|
||||
padding: 15px;
|
||||
border-left: 4px solid #0066cc;
|
||||
margin: 20px 0;
|
||||
}}
|
||||
.error {{
|
||||
background-color: #ffe6e6;
|
||||
padding: 15px;
|
||||
border-left: 4px solid #cc0000;
|
||||
margin: 20px 0;
|
||||
}}
|
||||
.success {{
|
||||
background-color: #e6ffe6;
|
||||
padding: 15px;
|
||||
border-left: 4px solid #00cc00;
|
||||
margin: 20px 0;
|
||||
}}
|
||||
button {{
|
||||
background-color: #0066cc;
|
||||
color: white;
|
||||
padding: 10px 20px;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 16px;
|
||||
}}
|
||||
button:hover {{
|
||||
background-color: #0052a3;
|
||||
}}
|
||||
pre {{
|
||||
background-color: #f5f5f5;
|
||||
padding: 10px;
|
||||
border-radius: 4px;
|
||||
overflow-x: auto;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Doris MCP Server - OAuth Demo</h1>
|
||||
|
||||
<div class="info">
|
||||
<h3>OAuth Configuration</h3>
|
||||
<p><strong>Provider:</strong> {oauth_info.get('provider', 'N/A')}</p>
|
||||
<p><strong>Client ID:</strong> {oauth_info.get('client_id', 'N/A')}</p>
|
||||
<p><strong>Scopes:</strong> {', '.join(oauth_info.get('scopes', []))}</p>
|
||||
<p><strong>PKCE Enabled:</strong> {oauth_info.get('pkce_enabled', False)}</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<h3>OAuth Authentication Test</h3>
|
||||
<p>Click the button below to start OAuth authentication flow:</p>
|
||||
<button onclick="startOAuthFlow()">Start OAuth Login</button>
|
||||
</div>
|
||||
|
||||
<div id="result" style="margin-top: 20px;"></div>
|
||||
|
||||
<div>
|
||||
<h3>API Endpoints</h3>
|
||||
<ul>
|
||||
<li><code>GET /auth/login</code> - Initiate OAuth login</li>
|
||||
<li><code>GET /auth/callback</code> - OAuth callback handler</li>
|
||||
<li><code>GET /auth/provider</code> - Provider information</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
async function startOAuthFlow() {{
|
||||
const resultDiv = document.getElementById('result');
|
||||
resultDiv.innerHTML = '<div class="info">Initiating OAuth flow...</div>';
|
||||
|
||||
try {{
|
||||
const response = await fetch('/auth/login');
|
||||
const data = await response.json();
|
||||
|
||||
if (response.ok) {{
|
||||
resultDiv.innerHTML = `
|
||||
<div class="success">
|
||||
<h4>OAuth URL Generated Successfully</h4>
|
||||
<p><strong>State:</strong> ${{data.state}}</p>
|
||||
<p><strong>Provider:</strong> ${{data.provider}}</p>
|
||||
<p><a href="${{data.authorization_url}}" target="_blank">Click here to authenticate</a></p>
|
||||
<p><em>Note: After authentication, you will be redirected to the callback URL.</em></p>
|
||||
</div>
|
||||
`;
|
||||
|
||||
// Automatically redirect to OAuth provider
|
||||
// window.open(data.authorization_url, '_blank');
|
||||
}} else {{
|
||||
resultDiv.innerHTML = `
|
||||
<div class="error">
|
||||
<h4>Error</h4>
|
||||
<p>${{data.error}}</p>
|
||||
</div>
|
||||
`;
|
||||
}}
|
||||
}} catch (error) {{
|
||||
resultDiv.innerHTML = `
|
||||
<div class="error">
|
||||
<h4>Network Error</h4>
|
||||
<p>${{error.message}}</p>
|
||||
</div>
|
||||
`;
|
||||
}}
|
||||
}}
|
||||
|
||||
// Handle OAuth callback result if present in URL
|
||||
window.addEventListener('load', function() {{
|
||||
const urlParams = new URLSearchParams(window.location.search);
|
||||
if (urlParams.has('code') && urlParams.has('state')) {{
|
||||
const resultDiv = document.getElementById('result');
|
||||
resultDiv.innerHTML = `
|
||||
<div class="success">
|
||||
<h4>OAuth Callback Received</h4>
|
||||
<p>Code: ${{urlParams.get('code')}}</p>
|
||||
<p>State: ${{urlParams.get('state')}}</p>
|
||||
<p>The authentication was successful!</p>
|
||||
</div>
|
||||
`;
|
||||
}} else if (urlParams.has('error')) {{
|
||||
const resultDiv = document.getElementById('result');
|
||||
resultDiv.innerHTML = `
|
||||
<div class="error">
|
||||
<h4>OAuth Error</h4>
|
||||
<p>Error: ${{urlParams.get('error')}}</p>
|
||||
<p>Description: ${{urlParams.get('error_description') || 'No description'}}</p>
|
||||
</div>
|
||||
`;
|
||||
}}
|
||||
}});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return HTMLResponse(html_content)
|
||||
289
doris_mcp_server/auth/oauth_provider.py
Normal file
@@ -0,0 +1,289 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
OAuth Authentication Provider
|
||||
Integrates OAuth 2.0/OIDC authentication with the existing authentication framework
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from .oauth_client import OAuthClient
|
||||
from .oauth_types import OAuthTokens, OAuthUserInfo, OAuthState
|
||||
from ..utils.security import AuthContext, SecurityLevel
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OAuthAuthenticationProvider:
|
||||
"""OAuth authentication provider for Doris MCP Server"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize OAuth authentication provider
|
||||
|
||||
Args:
|
||||
config: DorisConfig with OAuth configuration
|
||||
"""
|
||||
self.config = config
|
||||
self.oauth_client = OAuthClient(config)
|
||||
self.enabled = self.oauth_client.enabled
|
||||
|
||||
logger.info(f"OAuthAuthenticationProvider initialized (enabled: {self.enabled})")
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize OAuth authentication provider
|
||||
|
||||
Returns:
|
||||
True if initialization successful
|
||||
"""
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
success = await self.oauth_client.initialize()
|
||||
if success:
|
||||
logger.info("OAuth authentication provider initialized successfully")
|
||||
else:
|
||||
logger.error("Failed to initialize OAuth authentication provider")
|
||||
return success
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown OAuth authentication provider"""
|
||||
if self.enabled:
|
||||
await self.oauth_client.shutdown()
|
||||
logger.info("OAuth authentication provider shutdown completed")
|
||||
|
||||
def get_authorization_url(self) -> Tuple[str, str]:
|
||||
"""Get OAuth authorization URL
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, state)
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth authentication is not enabled")
|
||||
|
||||
authorization_url, oauth_state = self.oauth_client.build_authorization_url()
|
||||
return authorization_url, oauth_state.state
|
||||
|
||||
async def handle_callback(self, code: str, state: str) -> AuthContext:
|
||||
"""Handle OAuth callback and create authentication context
|
||||
|
||||
Args:
|
||||
code: Authorization code from OAuth provider
|
||||
state: State parameter for CSRF protection
|
||||
|
||||
Returns:
|
||||
AuthContext for the authenticated user
|
||||
|
||||
Raises:
|
||||
ValueError: If authentication fails
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth authentication is not enabled")
|
||||
|
||||
try:
|
||||
# Exchange code for tokens
|
||||
tokens, oauth_state = await self.oauth_client.exchange_code_for_tokens(code, state)
|
||||
|
||||
# Get user information
|
||||
user_info = await self.oauth_client.get_user_info(tokens)
|
||||
|
||||
# Create authentication context
|
||||
auth_context = await self._create_auth_context(user_info, tokens)
|
||||
|
||||
logger.info(f"OAuth authentication successful for user: {auth_context.user_id}")
|
||||
return auth_context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth callback handling failed: {e}")
|
||||
raise ValueError(f"OAuth authentication failed: {str(e)}")
|
||||
|
||||
async def authenticate_with_token(self, access_token: str) -> AuthContext:
|
||||
"""Authenticate using OAuth access token
|
||||
|
||||
Args:
|
||||
access_token: OAuth access token
|
||||
|
||||
Returns:
|
||||
AuthContext for the authenticated user
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth authentication is not enabled")
|
||||
|
||||
try:
|
||||
# Create token object
|
||||
tokens = OAuthTokens(access_token=access_token)
|
||||
|
||||
# Get user information
|
||||
user_info = await self.oauth_client.get_user_info(tokens)
|
||||
|
||||
# Create authentication context
|
||||
auth_context = await self._create_auth_context(user_info, tokens)
|
||||
|
||||
logger.info(f"OAuth token authentication successful for user: {auth_context.user_id}")
|
||||
return auth_context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth token authentication failed: {e}")
|
||||
raise ValueError(f"OAuth token authentication failed: {str(e)}")
|
||||
|
||||
async def refresh_authentication(self, refresh_token: str) -> Tuple[AuthContext, str]:
|
||||
"""Refresh OAuth authentication
|
||||
|
||||
Args:
|
||||
refresh_token: OAuth refresh token
|
||||
|
||||
Returns:
|
||||
Tuple of (AuthContext, new_access_token)
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth authentication is not enabled")
|
||||
|
||||
try:
|
||||
# Refresh tokens
|
||||
tokens = await self.oauth_client.refresh_tokens(refresh_token)
|
||||
|
||||
# Get updated user information
|
||||
user_info = await self.oauth_client.get_user_info(tokens)
|
||||
|
||||
# Create authentication context
|
||||
auth_context = await self._create_auth_context(user_info, tokens)
|
||||
|
||||
logger.info(f"OAuth refresh successful for user: {auth_context.user_id}")
|
||||
return auth_context, tokens.access_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth refresh failed: {e}")
|
||||
raise ValueError(f"OAuth refresh failed: {str(e)}")
|
||||
|
||||
async def _create_auth_context(self, user_info: OAuthUserInfo, tokens: OAuthTokens) -> AuthContext:
|
||||
"""Create authentication context from OAuth user info
|
||||
|
||||
Args:
|
||||
user_info: OAuth user information
|
||||
tokens: OAuth tokens
|
||||
|
||||
Returns:
|
||||
AuthContext for the user
|
||||
"""
|
||||
# Determine security level based on roles or email domain
|
||||
security_level = await self._determine_security_level(user_info)
|
||||
|
||||
# Map OAuth roles to application permissions
|
||||
permissions = await self._map_permissions(user_info.roles)
|
||||
|
||||
# Generate session ID
|
||||
session_id = f"oauth_{user_info.sub}_{datetime.utcnow().timestamp()}"
|
||||
|
||||
return AuthContext(
|
||||
token_id=f"oauth_{user_info.sub}",
|
||||
user_id=user_info.sub,
|
||||
roles=user_info.roles,
|
||||
permissions=permissions,
|
||||
security_level=security_level,
|
||||
session_id=session_id,
|
||||
login_time=datetime.utcnow(),
|
||||
last_activity=datetime.utcnow(),
|
||||
token="" # OAuth doesn't have raw token, use empty string
|
||||
)
|
||||
|
||||
async def _determine_security_level(self, user_info: OAuthUserInfo) -> SecurityLevel:
|
||||
"""Determine security level for OAuth user
|
||||
|
||||
Args:
|
||||
user_info: OAuth user information
|
||||
|
||||
Returns:
|
||||
SecurityLevel for the user
|
||||
"""
|
||||
# Check if user has admin roles
|
||||
admin_roles = {"admin", "administrator", "data_admin", "super_admin"}
|
||||
if any(role.lower() in admin_roles for role in user_info.roles):
|
||||
return SecurityLevel.SECRET
|
||||
|
||||
# Check email domain for internal users
|
||||
if user_info.email:
|
||||
# You can configure trusted domains for internal access
|
||||
trusted_domains = ["yourcompany.com", "internal.org"] # Configure as needed
|
||||
email_domain = user_info.email.split("@")[-1].lower()
|
||||
if email_domain in trusted_domains:
|
||||
return SecurityLevel.CONFIDENTIAL
|
||||
|
||||
# Check for special roles
|
||||
elevated_roles = {"data_analyst", "developer", "manager"}
|
||||
if any(role.lower() in elevated_roles for role in user_info.roles):
|
||||
return SecurityLevel.CONFIDENTIAL
|
||||
|
||||
# Default to internal level for OAuth users
|
||||
return SecurityLevel.INTERNAL
|
||||
|
||||
async def _map_permissions(self, roles: list[str]) -> list[str]:
|
||||
"""Map OAuth roles to application permissions
|
||||
|
||||
Args:
|
||||
roles: OAuth user roles
|
||||
|
||||
Returns:
|
||||
List of application permissions
|
||||
"""
|
||||
permissions = set()
|
||||
|
||||
# Role to permission mapping
|
||||
role_permissions = {
|
||||
"admin": ["admin", "read_data", "write_data", "manage_users"],
|
||||
"administrator": ["admin", "read_data", "write_data", "manage_users"],
|
||||
"data_admin": ["admin", "read_data", "write_data"],
|
||||
"super_admin": ["admin", "read_data", "write_data", "manage_users", "system_admin"],
|
||||
"data_analyst": ["read_data", "query_database"],
|
||||
"developer": ["read_data", "query_database", "debug"],
|
||||
"viewer": ["read_data"],
|
||||
"user": ["read_data"],
|
||||
"oauth_user": ["read_data"] # Default OAuth user permission
|
||||
}
|
||||
|
||||
# Map roles to permissions
|
||||
for role in roles:
|
||||
role_lower = role.lower()
|
||||
if role_lower in role_permissions:
|
||||
permissions.update(role_permissions[role_lower])
|
||||
|
||||
# Ensure OAuth users have at least basic permissions
|
||||
if not permissions:
|
||||
permissions.add("read_data")
|
||||
|
||||
return list(permissions)
|
||||
|
||||
def get_provider_info(self) -> Dict[str, Any]:
|
||||
"""Get OAuth provider information
|
||||
|
||||
Returns:
|
||||
Provider information dictionary
|
||||
"""
|
||||
if not self.enabled:
|
||||
return {"enabled": False}
|
||||
|
||||
config = self.oauth_client.provider_config
|
||||
return {
|
||||
"enabled": True,
|
||||
"provider": config.provider.value,
|
||||
"client_id": config.client_id,
|
||||
"scopes": config.scopes,
|
||||
"redirect_uri": config.redirect_uri,
|
||||
"pkce_enabled": config.pkce_enabled,
|
||||
"nonce_enabled": config.nonce_enabled
|
||||
}
|
||||
196
doris_mcp_server/auth/oauth_types.py
Normal file
@@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
OAuth 2.0/OIDC Type Definitions
|
||||
Provides data types and models for OAuth authentication flow
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
|
||||
class OAuthProvider(Enum):
|
||||
"""OAuth provider enumeration"""
|
||||
GOOGLE = "google"
|
||||
MICROSOFT = "microsoft"
|
||||
GITHUB = "github"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class OAuthGrantType(Enum):
|
||||
"""OAuth grant type enumeration"""
|
||||
AUTHORIZATION_CODE = "authorization_code"
|
||||
REFRESH_TOKEN = "refresh_token"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthState:
|
||||
"""OAuth state parameter for CSRF protection"""
|
||||
state: str
|
||||
nonce: Optional[str] = None
|
||||
pkce_verifier: Optional[str] = None
|
||||
pkce_challenge: Optional[str] = None
|
||||
redirect_uri: str = ""
|
||||
created_at: datetime = None
|
||||
expires_at: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthTokens:
|
||||
"""OAuth token response"""
|
||||
access_token: str
|
||||
token_type: str = "Bearer"
|
||||
expires_in: Optional[int] = None
|
||||
refresh_token: Optional[str] = None
|
||||
scope: Optional[str] = None
|
||||
id_token: Optional[str] = None # OIDC ID token
|
||||
created_at: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthUserInfo:
|
||||
"""OAuth/OIDC user information"""
|
||||
sub: str # Subject identifier
|
||||
email: Optional[str] = None
|
||||
email_verified: Optional[bool] = None
|
||||
name: Optional[str] = None
|
||||
given_name: Optional[str] = None
|
||||
family_name: Optional[str] = None
|
||||
picture: Optional[str] = None
|
||||
locale: Optional[str] = None
|
||||
roles: List[str] = None
|
||||
raw_claims: Dict[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.roles is None:
|
||||
self.roles = []
|
||||
if self.raw_claims is None:
|
||||
self.raw_claims = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCDiscovery:
|
||||
"""OIDC Discovery document"""
|
||||
issuer: str
|
||||
authorization_endpoint: str
|
||||
token_endpoint: str
|
||||
userinfo_endpoint: Optional[str] = None
|
||||
jwks_uri: Optional[str] = None
|
||||
scopes_supported: List[str] = None
|
||||
response_types_supported: List[str] = None
|
||||
subject_types_supported: List[str] = None
|
||||
id_token_signing_alg_values_supported: List[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.scopes_supported is None:
|
||||
self.scopes_supported = ["openid"]
|
||||
if self.response_types_supported is None:
|
||||
self.response_types_supported = ["code"]
|
||||
if self.subject_types_supported is None:
|
||||
self.subject_types_supported = ["public"]
|
||||
if self.id_token_signing_alg_values_supported is None:
|
||||
self.id_token_signing_alg_values_supported = ["RS256"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthError:
|
||||
"""OAuth error response"""
|
||||
error: str
|
||||
error_description: Optional[str] = None
|
||||
error_uri: Optional[str] = None
|
||||
state: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthProviderConfig:
|
||||
"""OAuth provider configuration"""
|
||||
provider: OAuthProvider
|
||||
client_id: str
|
||||
client_secret: str
|
||||
redirect_uri: str
|
||||
scopes: List[str]
|
||||
|
||||
# Endpoints
|
||||
authorization_endpoint: str
|
||||
token_endpoint: str
|
||||
userinfo_endpoint: Optional[str] = None
|
||||
jwks_uri: Optional[str] = None
|
||||
|
||||
# Discovery
|
||||
discovery_url: Optional[str] = None
|
||||
|
||||
# Settings
|
||||
pkce_enabled: bool = True
|
||||
nonce_enabled: bool = True
|
||||
|
||||
# User mapping
|
||||
user_id_claim: str = "sub"
|
||||
email_claim: str = "email"
|
||||
name_claim: str = "name"
|
||||
roles_claim: str = "roles"
|
||||
default_roles: List[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.default_roles is None:
|
||||
self.default_roles = ["oauth_user"]
|
||||
|
||||
|
||||
# Pre-defined provider configurations
|
||||
OAUTH_PROVIDERS = {
|
||||
OAuthProvider.GOOGLE: {
|
||||
"authorization_endpoint": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_endpoint": "https://oauth2.googleapis.com/token",
|
||||
"userinfo_endpoint": "https://openidconnect.googleapis.com/v1/userinfo",
|
||||
"jwks_uri": "https://www.googleapis.com/oauth2/v3/certs",
|
||||
"discovery_url": "https://accounts.google.com/.well-known/openid_configuration",
|
||||
"scopes": ["openid", "email", "profile"],
|
||||
"user_id_claim": "sub",
|
||||
"email_claim": "email",
|
||||
"name_claim": "name"
|
||||
},
|
||||
OAuthProvider.MICROSOFT: {
|
||||
"authorization_endpoint": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
|
||||
"token_endpoint": "https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||
"userinfo_endpoint": "https://graph.microsoft.com/v1.0/me",
|
||||
"jwks_uri": "https://login.microsoftonline.com/common/discovery/v2.0/keys",
|
||||
"discovery_url": "https://login.microsoftonline.com/common/v2.0/.well-known/openid_configuration",
|
||||
"scopes": ["openid", "profile", "email", "User.Read"],
|
||||
"user_id_claim": "sub",
|
||||
"email_claim": "email",
|
||||
"name_claim": "name"
|
||||
},
|
||||
OAuthProvider.GITHUB: {
|
||||
"authorization_endpoint": "https://github.com/login/oauth/authorize",
|
||||
"token_endpoint": "https://github.com/login/oauth/access_token",
|
||||
"userinfo_endpoint": "https://api.github.com/user",
|
||||
"scopes": ["user:email", "read:user"],
|
||||
"user_id_claim": "id",
|
||||
"email_claim": "email",
|
||||
"name_claim": "name"
|
||||
}
|
||||
}
|
||||
677
doris_mcp_server/auth/token_handlers.py
Normal file
@@ -0,0 +1,677 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Token Authentication HTTP Handlers
|
||||
|
||||
Provides HTTP endpoints for token management including creation, revocation,
|
||||
listing, and statistics. Used for administrative token management in HTTP mode.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, HTMLResponse
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.security import SecurityLevel
|
||||
from ..utils.config import DatabaseConfig
|
||||
from .token_security_middleware import TokenSecurityMiddleware
|
||||
|
||||
|
||||
class TokenHandlers:
|
||||
"""Token Authentication HTTP Handlers"""
|
||||
|
||||
def __init__(self, security_manager, config=None):
|
||||
self.security_manager = security_manager
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Initialize security middleware if config is provided
|
||||
if config:
|
||||
self.security_middleware = TokenSecurityMiddleware(config)
|
||||
else:
|
||||
self.security_middleware = None
|
||||
self.logger.warning("Token handlers initialized without security middleware - access control disabled")
|
||||
|
||||
async def handle_create_token(self, request: Request) -> JSONResponse:
|
||||
"""Handle token creation request"""
|
||||
# Apply security checks
|
||||
if self.security_middleware:
|
||||
security_response = await self.security_middleware.check_token_management_access(request)
|
||||
if security_response:
|
||||
return security_response
|
||||
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
return JSONResponse({
|
||||
"error": "Token authentication is not enabled"
|
||||
}, status_code=503)
|
||||
|
||||
# Parse request data
|
||||
if request.method == "GET":
|
||||
# GET request with query parameters
|
||||
query_params = dict(request.query_params)
|
||||
token_id = query_params.get("token_id")
|
||||
expires_hours_str = query_params.get("expires_hours")
|
||||
description = query_params.get("description", "")
|
||||
custom_token = query_params.get("custom_token")
|
||||
# Database configuration from query params
|
||||
db_config = None
|
||||
if query_params.get("db_host"):
|
||||
db_config = DatabaseConfig(
|
||||
host=query_params.get("db_host", "localhost"),
|
||||
port=int(query_params.get("db_port", "9030")),
|
||||
user=query_params.get("db_user", "root"),
|
||||
password=query_params.get("db_password", ""),
|
||||
database=query_params.get("db_database", "information_schema"),
|
||||
fe_http_port=int(query_params.get("db_fe_http_port", "8030"))
|
||||
)
|
||||
else:
|
||||
# POST request with JSON body
|
||||
try:
|
||||
body = await request.json()
|
||||
except:
|
||||
return JSONResponse({
|
||||
"error": "Invalid JSON body"
|
||||
}, status_code=400)
|
||||
|
||||
token_id = body.get("token_id")
|
||||
expires_hours_str = body.get("expires_hours")
|
||||
description = body.get("description", "")
|
||||
custom_token = body.get("custom_token")
|
||||
# Database configuration from JSON body
|
||||
db_config = None
|
||||
if body.get("database_config"):
|
||||
db_data = body["database_config"]
|
||||
try:
|
||||
db_config = DatabaseConfig(
|
||||
host=db_data.get("host", "localhost"),
|
||||
port=int(db_data.get("port", 9030)),
|
||||
user=db_data.get("user", "root"),
|
||||
password=db_data.get("password", ""),
|
||||
database=db_data.get("database", "information_schema"),
|
||||
fe_http_port=int(db_data.get("fe_http_port", 8030))
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
return JSONResponse({
|
||||
"error": f"Invalid database configuration: {str(e)}"
|
||||
}, status_code=400)
|
||||
|
||||
# Validate required fields
|
||||
if not token_id:
|
||||
return JSONResponse({
|
||||
"error": "token_id is required"
|
||||
}, status_code=400)
|
||||
|
||||
# Parse expires_hours
|
||||
expires_hours = None
|
||||
if expires_hours_str:
|
||||
try:
|
||||
expires_hours = int(expires_hours_str)
|
||||
except ValueError:
|
||||
return JSONResponse({
|
||||
"error": "expires_hours must be an integer"
|
||||
}, status_code=400)
|
||||
|
||||
# Create token using the actual API
|
||||
try:
|
||||
token = await self.security_manager.create_token(
|
||||
token_id=token_id,
|
||||
expires_hours=expires_hours,
|
||||
description=description,
|
||||
custom_token=custom_token,
|
||||
database_config=db_config
|
||||
)
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"token_id": token_id,
|
||||
"token": token,
|
||||
"expires_hours": expires_hours,
|
||||
"description": description,
|
||||
"message": "Token created successfully"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Token creation failed: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Token creation failed: {str(e)}"
|
||||
}, status_code=400)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_create_token: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
async def handle_revoke_token(self, request: Request) -> JSONResponse:
|
||||
"""Handle token revocation request"""
|
||||
# Apply security checks
|
||||
if self.security_middleware:
|
||||
security_response = await self.security_middleware.check_token_management_access(request)
|
||||
if security_response:
|
||||
return security_response
|
||||
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
return JSONResponse({
|
||||
"error": "Token authentication is not enabled"
|
||||
}, status_code=503)
|
||||
|
||||
# Get token_id from query parameters or path
|
||||
token_id = request.query_params.get("token_id")
|
||||
if not token_id and request.method == "DELETE":
|
||||
# Try to get from path: /token/revoke/{token_id}
|
||||
path_parts = str(request.url.path).split("/")
|
||||
if len(path_parts) >= 4:
|
||||
token_id = path_parts[-1]
|
||||
|
||||
if not token_id:
|
||||
return JSONResponse({
|
||||
"error": "token_id is required"
|
||||
}, status_code=400)
|
||||
|
||||
# Revoke token
|
||||
success = await self.security_manager.revoke_token(token_id)
|
||||
|
||||
if success:
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"token_id": token_id,
|
||||
"message": "Token revoked successfully"
|
||||
})
|
||||
else:
|
||||
return JSONResponse({
|
||||
"success": False,
|
||||
"token_id": token_id,
|
||||
"message": "Token not found or already revoked"
|
||||
}, status_code=404)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_revoke_token: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
async def handle_list_tokens(self, request: Request) -> JSONResponse:
|
||||
"""Handle token listing request"""
|
||||
# Apply security checks
|
||||
if self.security_middleware:
|
||||
security_response = await self.security_middleware.check_token_management_access(request)
|
||||
if security_response:
|
||||
return security_response
|
||||
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
return JSONResponse({
|
||||
"error": "Token authentication is not enabled"
|
||||
}, status_code=503)
|
||||
|
||||
# Get tokens list
|
||||
tokens = await self.security_manager.list_tokens()
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"count": len(tokens),
|
||||
"tokens": tokens
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_list_tokens: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
async def handle_token_stats(self, request: Request) -> JSONResponse:
|
||||
"""Handle token statistics request"""
|
||||
# Apply security checks
|
||||
if self.security_middleware:
|
||||
security_response = await self.security_middleware.check_token_management_access(request)
|
||||
if security_response:
|
||||
return security_response
|
||||
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
return JSONResponse({
|
||||
"error": "Token authentication is not enabled"
|
||||
}, status_code=503)
|
||||
|
||||
# Get token statistics
|
||||
stats = self.security_manager.get_token_stats()
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"stats": stats
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_token_stats: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
async def handle_cleanup_tokens(self, request: Request) -> JSONResponse:
|
||||
"""Handle expired tokens cleanup request"""
|
||||
# Apply security checks
|
||||
if self.security_middleware:
|
||||
security_response = await self.security_middleware.check_token_management_access(request)
|
||||
if security_response:
|
||||
return security_response
|
||||
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
return JSONResponse({
|
||||
"error": "Token authentication is not enabled"
|
||||
}, status_code=503)
|
||||
|
||||
# Cleanup expired tokens
|
||||
cleaned_count = await self.security_manager.cleanup_expired_tokens()
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"cleaned_count": cleaned_count,
|
||||
"message": f"Cleaned up {cleaned_count} expired tokens"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_cleanup_tokens: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
async def handle_management_page(self, request: Request) -> HTMLResponse:
|
||||
"""Handle token management demo page"""
|
||||
# Apply security checks
|
||||
if self.security_middleware:
|
||||
security_response = await self.security_middleware.check_token_management_access(request)
|
||||
if security_response:
|
||||
# Convert JSON response to HTML for demo page
|
||||
error_data = security_response.body.decode('utf-8') if hasattr(security_response, 'body') else '{"error": "Access denied"}'
|
||||
try:
|
||||
error_info = json.loads(error_data)
|
||||
except:
|
||||
error_info = {"error": "Access denied"}
|
||||
|
||||
error_html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Access Denied - Token Management</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; margin: 50px; background: #f5f5f5; }}
|
||||
.container {{ max-width: 600px; margin: 0 auto; background: white; padding: 30px; border-radius: 8px; }}
|
||||
.error {{ color: #dc3545; background: #f8d7da; border: 1px solid #f5c6cb; padding: 15px; border-radius: 5px; }}
|
||||
.security-info {{ background: #d1ecf1; border: 1px solid #bee5eb; padding: 15px; border-radius: 5px; margin-top: 20px; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔐 Token Management - Access Denied</h1>
|
||||
<div class="error">
|
||||
<h3>Access Denied</h3>
|
||||
<p><strong>Error:</strong> {error_info.get('error', 'Access denied')}</p>
|
||||
<p><strong>Message:</strong> {error_info.get('message', 'Token management access is restricted')}</p>
|
||||
{'<p><strong>Your IP:</strong> ' + str(error_info.get('client_ip', 'Unknown')) + '</p>' if 'client_ip' in error_info else ''}
|
||||
</div>
|
||||
|
||||
<div class="security-info">
|
||||
<h3>🛡️ Security Information</h3>
|
||||
<p>Token management endpoints are protected by the following security measures:</p>
|
||||
<ul>
|
||||
<li><strong>IP Restrictions:</strong> Only localhost/127.0.0.1 access allowed</li>
|
||||
<li><strong>Admin Authentication:</strong> Valid admin token required</li>
|
||||
<li><strong>Configuration Control:</strong> Must be explicitly enabled</li>
|
||||
</ul>
|
||||
<p>If you need access, please:</p>
|
||||
<ol>
|
||||
<li>Access from the server host (127.0.0.1)</li>
|
||||
<li>Ensure HTTP token management is enabled in configuration</li>
|
||||
<li>Provide valid admin authentication</li>
|
||||
</ol>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return HTMLResponse(error_html, status_code=security_response.status_code)
|
||||
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
html_content = """
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Token Management - Not Available</title>
|
||||
<style>
|
||||
body { font-family: Arial, sans-serif; margin: 50px; }
|
||||
.error { color: red; font-size: 18px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Token Management</h1>
|
||||
<div class="error">Token authentication is not enabled on this server.</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return HTMLResponse(html_content)
|
||||
|
||||
# Get current stats for demo
|
||||
stats = self.security_manager.get_token_stats()
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Doris MCP Server - Token Management</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; margin: 50px; background: #f5f5f5; }}
|
||||
.container {{ max-width: 1200px; margin: 0 auto; background: white; padding: 30px; border-radius: 8px; }}
|
||||
h1 {{ color: #333; }}
|
||||
.section {{ margin: 30px 0; padding: 20px; border: 1px solid #ddd; border-radius: 5px; }}
|
||||
.stats {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 15px; }}
|
||||
.stat-item {{ padding: 15px; background: #f8f9fa; border-radius: 5px; text-align: center; }}
|
||||
.stat-value {{ font-size: 24px; font-weight: bold; color: #007bff; }}
|
||||
.form-group {{ margin: 15px 0; }}
|
||||
.form-group label {{ display: block; margin-bottom: 5px; font-weight: bold; }}
|
||||
.form-group input, .form-group textarea {{ width: 100%; padding: 8px; border: 1px solid #ddd; border-radius: 4px; }}
|
||||
button {{ padding: 10px 20px; margin: 5px; border: none; border-radius: 4px; cursor: pointer; }}
|
||||
.btn-primary {{ background: #007bff; color: white; }}
|
||||
.btn-danger {{ background: #dc3545; color: white; }}
|
||||
.btn-success {{ background: #28a745; color: white; }}
|
||||
.response {{ margin: 15px 0; padding: 15px; border-radius: 5px; }}
|
||||
.response.success {{ background: #d4edda; border: 1px solid #c3e6cb; }}
|
||||
.response.error {{ background: #f8d7da; border: 1px solid #f5c6cb; }}
|
||||
.token-list {{ margin: 15px 0; }}
|
||||
.token-item {{ padding: 10px; margin: 5px 0; background: #f8f9fa; border-radius: 4px; }}
|
||||
pre {{ background: #f8f9fa; padding: 10px; border-radius: 4px; overflow-x: auto; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔐 Doris MCP Server - Token Management</h1>
|
||||
|
||||
<div class="section">
|
||||
<h2>📊 Token Statistics</h2>
|
||||
<div class="stats">
|
||||
<div class="stat-item">
|
||||
<div class="stat-value">{stats.get('total_tokens', 0)}</div>
|
||||
<div>Total Tokens</div>
|
||||
</div>
|
||||
<div class="stat-item">
|
||||
<div class="stat-value">{stats.get('active_tokens', 0)}</div>
|
||||
<div>Active Tokens</div>
|
||||
</div>
|
||||
<div class="stat-item">
|
||||
<div class="stat-value">{stats.get('expired_tokens', 0)}</div>
|
||||
<div>Expired Tokens</div>
|
||||
</div>
|
||||
</div>
|
||||
<p><strong>Token Expiry:</strong> {'Enabled' if stats.get('expiry_enabled') else 'Disabled'}</p>
|
||||
<p><strong>Default Expiry:</strong> {stats.get('default_expiry_hours', 0)} hours</p>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>➕ Create New Token</h2>
|
||||
<form id="createTokenForm">
|
||||
<div class="form-group">
|
||||
<label for="token_id">Token ID (required):</label>
|
||||
<input type="text" id="token_id" name="token_id" placeholder="e.g., my-app-token" required>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="expires_hours">Expires Hours (optional):</label>
|
||||
<input type="number" id="expires_hours" name="expires_hours" placeholder="e.g., 720 (30 days), leave empty for default">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="description">Description (optional):</label>
|
||||
<textarea id="description" name="description" placeholder="Token description"></textarea>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="custom_token">Custom Token (optional):</label>
|
||||
<input type="text" id="custom_token" name="custom_token" placeholder="Leave empty to auto-generate">
|
||||
<small style="color: #666; display: block; margin-top: 5px;">If not provided, a secure token will be generated automatically</small>
|
||||
</div>
|
||||
|
||||
<div class="section" style="margin: 20px 0; padding: 15px; background: #f8f9fa; border-radius: 5px;">
|
||||
<h3>🗄️ Database Configuration (Optional)</h3>
|
||||
<p style="color: #666; font-size: 14px; margin-bottom: 15px;">Configure database connection for this token. Leave empty to use system defaults.</p>
|
||||
|
||||
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 10px;">
|
||||
<div class="form-group">
|
||||
<label for="db_host">Host:</label>
|
||||
<input type="text" id="db_host" name="db_host" placeholder="localhost">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="db_port">Port:</label>
|
||||
<input type="number" id="db_port" name="db_port" placeholder="9030">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="db_user">User:</label>
|
||||
<input type="text" id="db_user" name="db_user" placeholder="root">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="db_password">Password:</label>
|
||||
<input type="password" id="db_password" name="db_password" placeholder="(optional)">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="db_database">Database:</label>
|
||||
<input type="text" id="db_database" name="db_database" placeholder="information_schema">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="db_fe_http_port">FE HTTP Port:</label>
|
||||
<input type="number" id="db_fe_http_port" name="db_fe_http_port" placeholder="8030">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="btn-primary">Create Token</button>
|
||||
</form>
|
||||
<div id="createTokenResponse"></div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>📋 Token Management</h2>
|
||||
<button id="listTokensBtn" class="btn-success">Refresh Token List</button>
|
||||
<button id="cleanupTokensBtn" class="btn-primary">Cleanup Expired Tokens</button>
|
||||
<div id="tokenListResponse"></div>
|
||||
|
||||
<h3>Revoke Token</h3>
|
||||
<div class="form-group">
|
||||
<input type="text" id="revokeTokenId" placeholder="Enter token ID to revoke">
|
||||
<button id="revokeTokenBtn" class="btn-danger">Revoke Token</button>
|
||||
</div>
|
||||
<div id="revokeTokenResponse"></div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>🔧 API Endpoints</h2>
|
||||
<p>Use these endpoints for programmatic token management:</p>
|
||||
<ul>
|
||||
<li><strong>POST /token/create</strong> - Create new token</li>
|
||||
<li><strong>DELETE /token/revoke?token_id=...</strong> - Revoke token</li>
|
||||
<li><strong>GET /token/list</strong> - List all tokens</li>
|
||||
<li><strong>GET /token/stats</strong> - Get token statistics</li>
|
||||
<li><strong>POST /token/cleanup</strong> - Cleanup expired tokens</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// Get admin token from URL parameters
|
||||
const urlParams = new URLSearchParams(window.location.search);
|
||||
const adminToken = urlParams.get('admin_token');
|
||||
|
||||
// Create request headers with admin token
|
||||
function getAuthHeaders() {{
|
||||
if (adminToken) {{
|
||||
return {{
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${{adminToken}}`
|
||||
}};
|
||||
}} else {{
|
||||
return {{'Content-Type': 'application/json'}};
|
||||
}}
|
||||
}}
|
||||
|
||||
// Create URL with admin token parameter
|
||||
function getAuthURL(baseUrl) {{
|
||||
if (adminToken) {{
|
||||
const separator = baseUrl.includes('?') ? '&' : '?';
|
||||
return `${{baseUrl}}${{separator}}admin_token=${{encodeURIComponent(adminToken)}}`;
|
||||
}}
|
||||
return baseUrl;
|
||||
}}
|
||||
|
||||
function showResponse(elementId, data, isSuccess = true) {{
|
||||
const element = document.getElementById(elementId);
|
||||
element.innerHTML = '<pre>' + JSON.stringify(data, null, 2) + '</pre>';
|
||||
element.className = 'response ' + (isSuccess ? 'success' : 'error');
|
||||
}}
|
||||
|
||||
// Create token form - updated to match actual API
|
||||
document.getElementById('createTokenForm').addEventListener('submit', async (e) => {{
|
||||
e.preventDefault();
|
||||
const formData = new FormData(e.target);
|
||||
const data = Object.fromEntries(formData.entries());
|
||||
|
||||
// Remove empty fields for optional parameters
|
||||
if (!data.expires_hours) delete data.expires_hours;
|
||||
if (!data.description) delete data.description;
|
||||
if (!data.custom_token) delete data.custom_token;
|
||||
|
||||
// Handle database configuration
|
||||
if (data.db_host) {{
|
||||
data.database_config = {{
|
||||
host: data.db_host,
|
||||
port: data.db_port ? parseInt(data.db_port) : 9030,
|
||||
user: data.db_user || 'root',
|
||||
password: data.db_password || '',
|
||||
database: data.db_database || 'information_schema',
|
||||
fe_http_port: data.db_fe_http_port ? parseInt(data.db_fe_http_port) : 8030
|
||||
}};
|
||||
}}
|
||||
|
||||
// Remove individual database fields from data
|
||||
delete data.db_host;
|
||||
delete data.db_port;
|
||||
delete data.db_user;
|
||||
delete data.db_password;
|
||||
delete data.db_database;
|
||||
delete data.db_fe_http_port;
|
||||
|
||||
try {{
|
||||
const response = await fetch(getAuthURL('/token/create'), {{
|
||||
method: 'POST',
|
||||
headers: getAuthHeaders(),
|
||||
body: JSON.stringify(data)
|
||||
}});
|
||||
const result = await response.json();
|
||||
showResponse('createTokenResponse', result, response.ok);
|
||||
|
||||
// Refresh token list if creation was successful
|
||||
if (response.ok) {{
|
||||
document.getElementById('listTokensBtn').click();
|
||||
}}
|
||||
}} catch (error) {{
|
||||
showResponse('createTokenResponse', {{error: error.message}}, false);
|
||||
}}
|
||||
}});
|
||||
|
||||
// List tokens
|
||||
document.getElementById('listTokensBtn').addEventListener('click', async () => {{
|
||||
try {{
|
||||
const response = await fetch(getAuthURL('/token/list'), {{
|
||||
headers: getAuthHeaders()
|
||||
}});
|
||||
const result = await response.json();
|
||||
showResponse('tokenListResponse', result, response.ok);
|
||||
}} catch (error) {{
|
||||
showResponse('tokenListResponse', {{error: error.message}}, false);
|
||||
}}
|
||||
}});
|
||||
|
||||
// Cleanup tokens
|
||||
document.getElementById('cleanupTokensBtn').addEventListener('click', async () => {{
|
||||
try {{
|
||||
const response = await fetch(getAuthURL('/token/cleanup'), {{
|
||||
method: 'POST',
|
||||
headers: getAuthHeaders()
|
||||
}});
|
||||
const result = await response.json();
|
||||
showResponse('tokenListResponse', result, response.ok);
|
||||
}} catch (error) {{
|
||||
showResponse('tokenListResponse', {{error: error.message}}, false);
|
||||
}}
|
||||
}});
|
||||
|
||||
// Revoke token
|
||||
document.getElementById('revokeTokenBtn').addEventListener('click', async () => {{
|
||||
const tokenId = document.getElementById('revokeTokenId').value;
|
||||
if (!tokenId) {{
|
||||
showResponse('revokeTokenResponse', {{error: 'Token ID is required'}}, false);
|
||||
return;
|
||||
}}
|
||||
|
||||
try {{
|
||||
const response = await fetch(getAuthURL(`/token/revoke?token_id=${{encodeURIComponent(tokenId)}}`), {{
|
||||
method: 'DELETE',
|
||||
headers: getAuthHeaders()
|
||||
}});
|
||||
const result = await response.json();
|
||||
showResponse('revokeTokenResponse', result, response.ok);
|
||||
|
||||
// Refresh token list if revocation was successful
|
||||
if (response.ok) {{
|
||||
document.getElementById('listTokensBtn').click();
|
||||
document.getElementById('revokeTokenId').value = '';
|
||||
}}
|
||||
}} catch (error) {{
|
||||
showResponse('revokeTokenResponse', {{error: error.message}}, false);
|
||||
}}
|
||||
}});
|
||||
|
||||
// Load token list on page load
|
||||
document.getElementById('listTokensBtn').click();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return HTMLResponse(html_content)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_demo_page: {e}")
|
||||
error_html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Token Management Error</title>
|
||||
<style>body {{ font-family: Arial, sans-serif; margin: 50px; }}</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Token Management Error</h1>
|
||||
<p>Error loading token management page: {str(e)}</p>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return HTMLResponse(error_html, status_code=500)
|
||||
827
doris_mcp_server/auth/token_manager.py
Normal file
@@ -0,0 +1,827 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Token Authentication Management Module
|
||||
|
||||
Provides enterprise-grade token authentication system with configurable tokens,
|
||||
expiration management, role-based access control and secure token storage.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from pathlib import Path
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.security import SecurityLevel
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""Database connection configuration for token binding"""
|
||||
|
||||
host: str
|
||||
port: int = 9030
|
||||
user: str = ""
|
||||
password: str = ""
|
||||
database: str = "information_schema"
|
||||
charset: str = "UTF8"
|
||||
fe_http_port: int = 8030
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenInfo:
|
||||
"""Token information structure with optional database binding"""
|
||||
|
||||
token_id: str # Unique token identifier for audit and management
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
expires_at: Optional[datetime] = None
|
||||
last_used: Optional[datetime] = None
|
||||
description: str = "" # Optional description for token purpose
|
||||
is_active: bool = True
|
||||
database_config: Optional[DatabaseConfig] = None # Optional database binding
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenValidationResult:
|
||||
"""Token validation result"""
|
||||
|
||||
is_valid: bool
|
||||
token_info: Optional[TokenInfo] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class TokenManager:
|
||||
"""Enterprise Token Authentication Manager
|
||||
|
||||
Features:
|
||||
- Configurable token storage (file-based or environment variables)
|
||||
- Token expiration management
|
||||
- Secure token hashing
|
||||
- Role-based access control
|
||||
- Token lifecycle management
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Token storage
|
||||
self._tokens: Dict[str, TokenInfo] = {} # token_hash -> TokenInfo
|
||||
self._token_ids: Dict[str, str] = {} # token_id -> token_hash
|
||||
|
||||
# Configuration
|
||||
self.token_file_path = getattr(config.security, 'token_file_path', 'tokens.json')
|
||||
self.enable_token_expiry = getattr(config.security, 'enable_token_expiry', True)
|
||||
self.default_token_expiry_hours = getattr(config.security, 'default_token_expiry_hours', 24 * 30) # 30 days
|
||||
self.token_hash_algorithm = getattr(config.security, 'token_hash_algorithm', 'sha256')
|
||||
|
||||
# Hot reload configuration
|
||||
self.enable_hot_reload = True
|
||||
self.hot_reload_interval = 10 # Check every 10 seconds
|
||||
self._file_last_modified = 0
|
||||
self._hot_reload_task = None
|
||||
|
||||
# Initialize with default tokens if none exist
|
||||
self._initialize_default_tokens()
|
||||
|
||||
# Load tokens from configuration
|
||||
self._load_tokens()
|
||||
|
||||
# Start hot reload monitoring
|
||||
if self.enable_hot_reload:
|
||||
self._start_hot_reload()
|
||||
|
||||
self.logger.info(f"TokenManager initialized with {len(self._tokens)} tokens, hot reload: {self.enable_hot_reload}")
|
||||
|
||||
def _initialize_default_tokens(self):
|
||||
"""Initialize default tokens for basic authentication (configurable via environment)"""
|
||||
# Default token configurations (can be overridden by environment variables)
|
||||
default_tokens = [
|
||||
{
|
||||
'token_id': 'admin-token',
|
||||
'token': os.getenv('DEFAULT_ADMIN_TOKEN', 'doris_admin_token_123456'),
|
||||
'description': os.getenv('DEFAULT_ADMIN_DESCRIPTION', 'Default admin API access token'),
|
||||
'expires_hours': None # Never expires
|
||||
},
|
||||
{
|
||||
'token_id': 'analyst-token',
|
||||
'token': os.getenv('DEFAULT_ANALYST_TOKEN', 'doris_analyst_token_123456'),
|
||||
'description': os.getenv('DEFAULT_ANALYST_DESCRIPTION', 'Default data analysis API access token'),
|
||||
'expires_hours': None # Never expires
|
||||
},
|
||||
{
|
||||
'token_id': 'readonly-token',
|
||||
'token': os.getenv('DEFAULT_READONLY_TOKEN', 'doris_readonly_token_123456'),
|
||||
'description': os.getenv('DEFAULT_READONLY_DESCRIPTION', 'Default read-only API access token'),
|
||||
'expires_hours': None # Never expires
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# Only add default tokens if no custom tokens are defined via environment variables
|
||||
# Check if any TOKEN_* environment variables exist (excluding system and legacy configs)
|
||||
excluded_prefixes = ('DEFAULT_', 'TOKEN_FILE_PATH', 'TOKEN_HASH_')
|
||||
excluded_vars = {'TOKEN_SECRET', 'TOKEN_EXPIRY'}
|
||||
|
||||
custom_tokens_exist = any(
|
||||
key.startswith('TOKEN_') and
|
||||
not key.startswith(excluded_prefixes) and
|
||||
not key.endswith(('_EXPIRES_HOURS', '_DESCRIPTION')) and
|
||||
key not in excluded_vars
|
||||
for key in os.environ.keys()
|
||||
)
|
||||
|
||||
# Also check if token file exists and has content
|
||||
token_file_exists = False
|
||||
if os.path.exists(self.token_file_path):
|
||||
try:
|
||||
with open(self.token_file_path, 'r') as f:
|
||||
content = f.read().strip()
|
||||
if content and content != '{}':
|
||||
token_file_exists = True
|
||||
except:
|
||||
pass
|
||||
|
||||
# Add default tokens only if no custom configuration exists
|
||||
if not custom_tokens_exist and not token_file_exists:
|
||||
for token_config in default_tokens:
|
||||
self._add_token_from_config(token_config)
|
||||
|
||||
self.logger.info(f"Initialized {len(default_tokens)} default tokens (no custom config found)")
|
||||
else:
|
||||
self.logger.info("Skipped default tokens initialization (custom tokens detected)")
|
||||
|
||||
def _add_token_from_config(self, token_config: Dict[str, Any]):
|
||||
"""Add token from configuration with optional database binding"""
|
||||
try:
|
||||
# Calculate expiration time
|
||||
expires_at = None
|
||||
if self.enable_token_expiry:
|
||||
expires_hours = token_config.get('expires_hours', self.default_token_expiry_hours)
|
||||
if expires_hours is not None:
|
||||
expires_at = datetime.utcnow() + timedelta(hours=expires_hours)
|
||||
|
||||
# Parse database configuration if provided
|
||||
database_config = None
|
||||
if 'database_config' in token_config:
|
||||
db_config = token_config['database_config']
|
||||
database_config = DatabaseConfig(
|
||||
host=db_config.get('host', 'localhost'),
|
||||
port=db_config.get('port', 9030),
|
||||
user=db_config.get('user', 'root'),
|
||||
password=db_config.get('password', ''),
|
||||
database=db_config.get('database', 'information_schema'),
|
||||
charset=db_config.get('charset', 'UTF8'),
|
||||
fe_http_port=db_config.get('fe_http_port', 8030)
|
||||
)
|
||||
|
||||
# Create token info
|
||||
token_info = TokenInfo(
|
||||
token_id=token_config['token_id'],
|
||||
expires_at=expires_at,
|
||||
description=token_config.get('description', ''),
|
||||
is_active=token_config.get('is_active', True),
|
||||
database_config=database_config
|
||||
)
|
||||
|
||||
# Hash the token
|
||||
raw_token = token_config['token']
|
||||
token_hash = self._hash_token(raw_token)
|
||||
|
||||
# Store token
|
||||
self._tokens[token_hash] = token_info
|
||||
self._token_ids[token_info.token_id] = token_hash
|
||||
|
||||
db_info = f" with DB binding ({database_config.host})" if database_config else ""
|
||||
self.logger.debug(f"Added token '{token_info.token_id}'{db_info}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to add token from config: {e}")
|
||||
raise
|
||||
|
||||
def _load_tokens(self):
|
||||
"""Load tokens from configuration sources"""
|
||||
# 1. Load from environment variables
|
||||
self._load_tokens_from_env()
|
||||
|
||||
# 2. Load from token file if exists
|
||||
if os.path.exists(self.token_file_path):
|
||||
self._load_tokens_from_file()
|
||||
|
||||
self.logger.info(f"Token loading completed, total tokens: {len(self._tokens)}")
|
||||
|
||||
def _load_tokens_from_env(self):
|
||||
"""Load tokens from environment variables
|
||||
|
||||
Simplified format:
|
||||
TOKEN_<ID>=<token>
|
||||
TOKEN_<ID>_EXPIRES_HOURS=<hours>
|
||||
TOKEN_<ID>_DESCRIPTION=<description>
|
||||
"""
|
||||
token_prefixes = set()
|
||||
|
||||
# Find all TOKEN_ environment variables (exclude legacy and system variables)
|
||||
excluded_token_vars = {
|
||||
'TOKEN_SECRET', # Legacy token secret
|
||||
'TOKEN_EXPIRY', # Legacy token expiry
|
||||
'TOKEN_FILE_PATH', # System config
|
||||
'TOKEN_HASH_ALGORITHM' # System config
|
||||
}
|
||||
|
||||
for key in os.environ:
|
||||
if (key.startswith('TOKEN_') and
|
||||
not key.endswith(('_EXPIRES_HOURS', '_DESCRIPTION')) and
|
||||
key not in excluded_token_vars):
|
||||
token_id = key[6:] # Remove 'TOKEN_' prefix
|
||||
token_prefixes.add(token_id)
|
||||
|
||||
# Load each token
|
||||
for token_id in token_prefixes:
|
||||
try:
|
||||
token = os.environ.get(f'TOKEN_{token_id}')
|
||||
if not token:
|
||||
continue
|
||||
|
||||
expires_hours_str = os.environ.get(f'TOKEN_{token_id}_EXPIRES_HOURS', str(self.default_token_expiry_hours))
|
||||
description = os.environ.get(f'TOKEN_{token_id}_DESCRIPTION', f'Environment token {token_id}')
|
||||
|
||||
expires_hours = None
|
||||
try:
|
||||
if expires_hours_str and expires_hours_str.lower() != 'none':
|
||||
expires_hours = int(expires_hours_str)
|
||||
except ValueError:
|
||||
expires_hours = self.default_token_expiry_hours
|
||||
|
||||
# Add token
|
||||
token_config = {
|
||||
'token_id': token_id.lower(),
|
||||
'token': token,
|
||||
'expires_hours': expires_hours,
|
||||
'description': description
|
||||
}
|
||||
|
||||
self._add_token_from_config(token_config)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load token {token_id} from environment: {e}")
|
||||
|
||||
def _load_tokens_from_file(self):
|
||||
"""Load tokens from JSON file"""
|
||||
try:
|
||||
with open(self.token_file_path, 'r', encoding='utf-8') as f:
|
||||
tokens_data = json.load(f)
|
||||
|
||||
if isinstance(tokens_data, dict) and 'tokens' in tokens_data:
|
||||
tokens_list = tokens_data['tokens']
|
||||
elif isinstance(tokens_data, list):
|
||||
tokens_list = tokens_data
|
||||
else:
|
||||
self.logger.error(f"Invalid token file format: {self.token_file_path}")
|
||||
return
|
||||
|
||||
for token_config in tokens_list:
|
||||
self._add_token_from_config(token_config)
|
||||
|
||||
self.logger.info(f"Loaded {len(tokens_list)} tokens from file: {self.token_file_path}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load tokens from file {self.token_file_path}: {e}")
|
||||
|
||||
def _hash_token(self, token: str) -> str:
|
||||
"""Hash token for secure storage"""
|
||||
if self.token_hash_algorithm == 'sha256':
|
||||
return hashlib.sha256(token.encode('utf-8')).hexdigest()
|
||||
elif self.token_hash_algorithm == 'sha512':
|
||||
return hashlib.sha512(token.encode('utf-8')).hexdigest()
|
||||
else:
|
||||
# Fallback to sha256
|
||||
return hashlib.sha256(token.encode('utf-8')).hexdigest()
|
||||
|
||||
async def validate_token(self, token: str) -> TokenValidationResult:
|
||||
"""Validate token and return user information"""
|
||||
try:
|
||||
# Hash the provided token
|
||||
token_hash = self._hash_token(token)
|
||||
|
||||
# Find token info
|
||||
token_info = self._tokens.get(token_hash)
|
||||
if not token_info:
|
||||
return TokenValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Invalid token"
|
||||
)
|
||||
|
||||
# Check if token is active
|
||||
if not token_info.is_active:
|
||||
return TokenValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Token is inactive"
|
||||
)
|
||||
|
||||
# Check expiration
|
||||
if token_info.expires_at and datetime.utcnow() > token_info.expires_at:
|
||||
return TokenValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Token has expired"
|
||||
)
|
||||
|
||||
# Update last used time
|
||||
token_info.last_used = datetime.utcnow()
|
||||
|
||||
return TokenValidationResult(
|
||||
is_valid=True,
|
||||
token_info=token_info
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Token validation error: {e}")
|
||||
return TokenValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"Token validation failed: {str(e)}"
|
||||
)
|
||||
|
||||
def generate_token(self, length: int = 32) -> str:
|
||||
"""Generate a cryptographically secure random token"""
|
||||
return secrets.token_urlsafe(length)
|
||||
|
||||
async def create_token(
|
||||
self,
|
||||
token_id: str,
|
||||
expires_hours: Optional[int] = None,
|
||||
description: str = "",
|
||||
custom_token: Optional[str] = None,
|
||||
database_config: Optional[DatabaseConfig] = None
|
||||
) -> str:
|
||||
"""Create a new token"""
|
||||
try:
|
||||
# Check if token_id already exists
|
||||
if token_id in self._token_ids:
|
||||
raise ValueError(f"Token ID '{token_id}' already exists")
|
||||
|
||||
# Generate or use provided token
|
||||
if custom_token:
|
||||
raw_token = custom_token
|
||||
else:
|
||||
raw_token = self.generate_token()
|
||||
|
||||
# Calculate expiration
|
||||
expires_at = None
|
||||
if expires_hours is not None:
|
||||
expires_at = datetime.utcnow() + timedelta(hours=expires_hours)
|
||||
elif self.enable_token_expiry:
|
||||
expires_at = datetime.utcnow() + timedelta(hours=self.default_token_expiry_hours)
|
||||
|
||||
# Create token info
|
||||
token_info = TokenInfo(
|
||||
token_id=token_id,
|
||||
expires_at=expires_at,
|
||||
description=description,
|
||||
database_config=database_config
|
||||
)
|
||||
|
||||
# Hash and store token
|
||||
token_hash = self._hash_token(raw_token)
|
||||
self._tokens[token_hash] = token_info
|
||||
self._token_ids[token_id] = token_hash
|
||||
|
||||
self.logger.info(f"Created new token '{token_id}'")
|
||||
|
||||
# Save token to file
|
||||
self._save_token_to_file(token_id, raw_token, token_info)
|
||||
|
||||
return raw_token
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to create token: {e}")
|
||||
raise
|
||||
|
||||
async def revoke_token(self, token_id: str) -> bool:
|
||||
"""Revoke a token by token ID"""
|
||||
try:
|
||||
if token_id not in self._token_ids:
|
||||
self.logger.warning(f"Token ID '{token_id}' not found")
|
||||
return False
|
||||
|
||||
# Get token hash and remove from storage
|
||||
token_hash = self._token_ids[token_id]
|
||||
if token_hash in self._tokens:
|
||||
del self._tokens[token_hash]
|
||||
del self._token_ids[token_id]
|
||||
|
||||
self.logger.info(f"Revoked token '{token_id}'")
|
||||
|
||||
# Save updated tokens to file
|
||||
self._remove_token_from_file(token_id)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to revoke token '{token_id}': {e}")
|
||||
return False
|
||||
|
||||
def _save_tokens_to_file(self):
|
||||
"""Save current tokens to JSON file"""
|
||||
try:
|
||||
# Convert current tokens to file format
|
||||
tokens_list = []
|
||||
|
||||
for token_hash, token_info in self._tokens.items():
|
||||
# Find the raw token for this token_info
|
||||
raw_token = None
|
||||
for tid, thash in self._token_ids.items():
|
||||
if thash == token_hash and tid == token_info.token_id:
|
||||
# We can't recover the original token from hash,
|
||||
# so we'll create a placeholder for existing tokens
|
||||
raw_token = f"<existing_token_hash_{token_hash[:8]}>"
|
||||
break
|
||||
|
||||
if raw_token is None:
|
||||
continue
|
||||
|
||||
token_config = {
|
||||
"token_id": token_info.token_id,
|
||||
"token": raw_token,
|
||||
"description": token_info.description,
|
||||
"expires_hours": None,
|
||||
"is_active": token_info.is_active
|
||||
}
|
||||
|
||||
# Add expiration info
|
||||
if token_info.expires_at:
|
||||
# Calculate remaining hours from now
|
||||
remaining = token_info.expires_at - datetime.utcnow()
|
||||
if remaining.total_seconds() > 0:
|
||||
token_config["expires_hours"] = int(remaining.total_seconds() / 3600)
|
||||
else:
|
||||
token_config["expires_hours"] = 0
|
||||
|
||||
# Add database config if present
|
||||
if token_info.database_config:
|
||||
token_config["database_config"] = {
|
||||
"host": token_info.database_config.host,
|
||||
"port": token_info.database_config.port,
|
||||
"user": token_info.database_config.user,
|
||||
"password": token_info.database_config.password,
|
||||
"database": token_info.database_config.database,
|
||||
"charset": token_info.database_config.charset,
|
||||
"fe_http_port": token_info.database_config.fe_http_port
|
||||
}
|
||||
|
||||
tokens_list.append(token_config)
|
||||
|
||||
# Create file content
|
||||
file_content = {
|
||||
"version": "1.0",
|
||||
"description": "Doris MCP Server Token configuration file",
|
||||
"created_at": datetime.utcnow().isoformat() + "Z",
|
||||
"tokens": tokens_list,
|
||||
"notes": [
|
||||
"This file is automatically updated when tokens are created or revoked",
|
||||
"Please backup this file before making manual changes",
|
||||
"Tokens with hash placeholders were loaded from previous configuration"
|
||||
]
|
||||
}
|
||||
|
||||
# Save to file
|
||||
with open(self.token_file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(file_content, f, indent=2, ensure_ascii=False)
|
||||
|
||||
self.logger.info(f"Saved {len(tokens_list)} tokens to file: {self.token_file_path}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to save tokens to file {self.token_file_path}: {e}")
|
||||
|
||||
def _save_token_to_file(self, token_id: str, raw_token: str, token_info: TokenInfo):
|
||||
"""Save a single new token to file (for newly created tokens only)"""
|
||||
try:
|
||||
# Load existing file
|
||||
existing_data = {"tokens": []}
|
||||
if os.path.exists(self.token_file_path):
|
||||
try:
|
||||
with open(self.token_file_path, 'r', encoding='utf-8') as f:
|
||||
existing_data = json.load(f)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Could not load existing token file: {e}")
|
||||
|
||||
# Ensure tokens list exists
|
||||
if 'tokens' not in existing_data or not isinstance(existing_data['tokens'], list):
|
||||
existing_data['tokens'] = []
|
||||
|
||||
# Check if token already exists in file
|
||||
token_exists = False
|
||||
for i, token_config in enumerate(existing_data['tokens']):
|
||||
if token_config.get('token_id') == token_id:
|
||||
# Update existing token
|
||||
existing_data['tokens'][i] = self._token_info_to_config(token_id, raw_token, token_info)
|
||||
token_exists = True
|
||||
break
|
||||
|
||||
# Add new token if it doesn't exist
|
||||
if not token_exists:
|
||||
new_token_config = self._token_info_to_config(token_id, raw_token, token_info)
|
||||
existing_data['tokens'].append(new_token_config)
|
||||
|
||||
# Update metadata
|
||||
existing_data.update({
|
||||
"version": "1.0",
|
||||
"description": "Doris MCP Server Token configuration file",
|
||||
"updated_at": datetime.utcnow().isoformat() + "Z"
|
||||
})
|
||||
|
||||
# Save to file
|
||||
with open(self.token_file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
self.logger.info(f"Saved token '{token_id}' to file: {self.token_file_path}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to save token '{token_id}' to file: {e}")
|
||||
|
||||
def _token_info_to_config(self, token_id: str, raw_token: str, token_info: TokenInfo) -> dict:
|
||||
"""Convert TokenInfo to file configuration format"""
|
||||
token_config = {
|
||||
"token_id": token_id,
|
||||
"token": raw_token,
|
||||
"description": token_info.description,
|
||||
"expires_hours": None,
|
||||
"is_active": token_info.is_active
|
||||
}
|
||||
|
||||
# Add expiration info
|
||||
if token_info.expires_at:
|
||||
# Calculate remaining hours from creation time
|
||||
remaining = token_info.expires_at - token_info.created_at
|
||||
token_config["expires_hours"] = int(remaining.total_seconds() / 3600) if remaining.total_seconds() > 0 else None
|
||||
|
||||
# Add database config if present
|
||||
if token_info.database_config:
|
||||
token_config["database_config"] = {
|
||||
"host": token_info.database_config.host,
|
||||
"port": token_info.database_config.port,
|
||||
"user": token_info.database_config.user,
|
||||
"password": token_info.database_config.password,
|
||||
"database": token_info.database_config.database,
|
||||
"charset": token_info.database_config.charset,
|
||||
"fe_http_port": token_info.database_config.fe_http_port
|
||||
}
|
||||
|
||||
return token_config
|
||||
|
||||
def _remove_token_from_file(self, token_id: str):
|
||||
"""Remove a token from the JSON file"""
|
||||
try:
|
||||
if not os.path.exists(self.token_file_path):
|
||||
return
|
||||
|
||||
# Load existing file
|
||||
with open(self.token_file_path, 'r', encoding='utf-8') as f:
|
||||
existing_data = json.load(f)
|
||||
|
||||
if 'tokens' not in existing_data or not isinstance(existing_data['tokens'], list):
|
||||
return
|
||||
|
||||
# Remove the token
|
||||
original_count = len(existing_data['tokens'])
|
||||
existing_data['tokens'] = [
|
||||
token for token in existing_data['tokens']
|
||||
if token.get('token_id') != token_id
|
||||
]
|
||||
|
||||
if len(existing_data['tokens']) < original_count:
|
||||
# Update metadata
|
||||
existing_data.update({
|
||||
"version": "1.0",
|
||||
"description": "Doris MCP Server Token configuration file",
|
||||
"updated_at": datetime.utcnow().isoformat() + "Z"
|
||||
})
|
||||
|
||||
# Save to file
|
||||
with open(self.token_file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
self.logger.info(f"Removed token '{token_id}' from file: {self.token_file_path}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to remove token '{token_id}' from file: {e}")
|
||||
|
||||
async def list_tokens(self) -> List[Dict[str, Any]]:
|
||||
"""List all tokens (without sensitive data)"""
|
||||
tokens = []
|
||||
|
||||
for token_hash, token_info in self._tokens.items():
|
||||
token_data = {
|
||||
'token_id': token_info.token_id,
|
||||
'created_at': token_info.created_at.isoformat(),
|
||||
'expires_at': token_info.expires_at.isoformat() if token_info.expires_at else None,
|
||||
'last_used': token_info.last_used.isoformat() if token_info.last_used else None,
|
||||
'is_active': token_info.is_active,
|
||||
'description': token_info.description,
|
||||
'is_expired': token_info.expires_at and datetime.utcnow() > token_info.expires_at if token_info.expires_at else False
|
||||
}
|
||||
|
||||
# Add database binding info (without sensitive data)
|
||||
if token_info.database_config:
|
||||
token_data['database_binding'] = {
|
||||
'host': token_info.database_config.host,
|
||||
'port': token_info.database_config.port,
|
||||
'user': token_info.database_config.user,
|
||||
'database': token_info.database_config.database,
|
||||
'has_password': bool(token_info.database_config.password)
|
||||
}
|
||||
else:
|
||||
token_data['database_binding'] = None
|
||||
|
||||
tokens.append(token_data)
|
||||
|
||||
# Sort by creation time
|
||||
tokens.sort(key=lambda x: x['created_at'], reverse=True)
|
||||
|
||||
return tokens
|
||||
|
||||
async def cleanup_expired_tokens(self) -> int:
|
||||
"""Remove expired tokens and return count"""
|
||||
if not self.enable_token_expiry:
|
||||
return 0
|
||||
|
||||
now = datetime.utcnow()
|
||||
expired_tokens = []
|
||||
|
||||
# Find expired tokens
|
||||
for token_hash, token_info in self._tokens.items():
|
||||
if token_info.expires_at and now > token_info.expires_at:
|
||||
expired_tokens.append((token_hash, token_info.token_id))
|
||||
|
||||
# Remove expired tokens
|
||||
for token_hash, token_id in expired_tokens:
|
||||
del self._tokens[token_hash]
|
||||
if token_id in self._token_ids:
|
||||
del self._token_ids[token_id]
|
||||
|
||||
if expired_tokens:
|
||||
self.logger.info(f"Cleaned up {len(expired_tokens)} expired tokens")
|
||||
|
||||
return len(expired_tokens)
|
||||
|
||||
async def save_tokens_to_file(self, file_path: Optional[str] = None) -> bool:
|
||||
"""Save current tokens to JSON file"""
|
||||
try:
|
||||
file_path = file_path or self.token_file_path
|
||||
tokens_list = await self.list_tokens()
|
||||
|
||||
tokens_data = {
|
||||
'version': '1.0',
|
||||
'created_at': datetime.utcnow().isoformat(),
|
||||
'tokens': tokens_list
|
||||
}
|
||||
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(tokens_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
self.logger.info(f"Saved {len(tokens_list)} tokens to file: {file_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to save tokens to file: {e}")
|
||||
return False
|
||||
|
||||
def get_database_config_by_token(self, token: str) -> Optional[DatabaseConfig]:
|
||||
"""Get database configuration bound to a token
|
||||
|
||||
Args:
|
||||
token: The raw token string
|
||||
|
||||
Returns:
|
||||
DatabaseConfig if token exists and has database binding, None otherwise
|
||||
"""
|
||||
try:
|
||||
token_hash = self._hash_token(token)
|
||||
token_info = self._tokens.get(token_hash)
|
||||
|
||||
if not token_info or not token_info.is_active:
|
||||
return None
|
||||
|
||||
# Check expiration
|
||||
if token_info.expires_at and datetime.utcnow() > token_info.expires_at:
|
||||
return None
|
||||
|
||||
return token_info.database_config
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get database config for token: {e}")
|
||||
return None
|
||||
|
||||
def get_token_stats(self) -> Dict[str, Any]:
|
||||
"""Get token statistics"""
|
||||
now = datetime.utcnow()
|
||||
total_tokens = len(self._tokens)
|
||||
active_tokens = sum(1 for info in self._tokens.values() if info.is_active)
|
||||
expired_tokens = sum(1 for info in self._tokens.values()
|
||||
if info.expires_at and now > info.expires_at)
|
||||
tokens_with_db = sum(1 for info in self._tokens.values()
|
||||
if info.database_config is not None)
|
||||
|
||||
return {
|
||||
'total_tokens': total_tokens,
|
||||
'active_tokens': active_tokens,
|
||||
'expired_tokens': expired_tokens,
|
||||
'tokens_with_database_binding': tokens_with_db,
|
||||
'expiry_enabled': self.enable_token_expiry,
|
||||
'default_expiry_hours': self.default_token_expiry_hours,
|
||||
'hot_reload_enabled': self.enable_hot_reload,
|
||||
'last_file_check': datetime.fromtimestamp(self._file_last_modified).isoformat() if self._file_last_modified else None
|
||||
}
|
||||
|
||||
def _start_hot_reload(self):
|
||||
"""Start hot reload monitoring task"""
|
||||
if self._hot_reload_task:
|
||||
return # Already running
|
||||
|
||||
# Update initial file modification time
|
||||
self._update_file_modified_time()
|
||||
|
||||
# Start monitoring task
|
||||
self._hot_reload_task = asyncio.create_task(self._hot_reload_monitor())
|
||||
self.logger.info(f"Started hot reload monitoring for {self.token_file_path}")
|
||||
|
||||
def stop_hot_reload(self):
|
||||
"""Stop hot reload monitoring"""
|
||||
if self._hot_reload_task:
|
||||
self._hot_reload_task.cancel()
|
||||
self._hot_reload_task = None
|
||||
self.logger.info("Stopped hot reload monitoring")
|
||||
|
||||
def _update_file_modified_time(self):
|
||||
"""Update the last modified time of tokens file"""
|
||||
try:
|
||||
if os.path.exists(self.token_file_path):
|
||||
self._file_last_modified = os.path.getmtime(self.token_file_path)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Failed to get file modification time: {e}")
|
||||
|
||||
async def _hot_reload_monitor(self):
|
||||
"""Background task to monitor tokens.json file changes"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.hot_reload_interval)
|
||||
|
||||
if not os.path.exists(self.token_file_path):
|
||||
continue
|
||||
|
||||
# Check if file was modified
|
||||
current_mtime = os.path.getmtime(self.token_file_path)
|
||||
if current_mtime > self._file_last_modified:
|
||||
self.logger.info(f"Detected changes in {self.token_file_path}, reloading tokens...")
|
||||
|
||||
try:
|
||||
# Backup current tokens
|
||||
old_tokens = self._tokens.copy()
|
||||
old_token_ids = self._token_ids.copy()
|
||||
|
||||
# Clear and reload
|
||||
self._tokens.clear()
|
||||
self._token_ids.clear()
|
||||
|
||||
# Reinitialize default tokens
|
||||
self._initialize_default_tokens()
|
||||
|
||||
# Load from file
|
||||
self._load_tokens_from_file()
|
||||
|
||||
# Update modification time
|
||||
self._file_last_modified = current_mtime
|
||||
|
||||
self.logger.info(f"Hot reload completed, {len(self._tokens)} tokens loaded")
|
||||
|
||||
except Exception as reload_error:
|
||||
# Restore backup on failure
|
||||
self.logger.error(f"Hot reload failed, restoring previous tokens: {reload_error}")
|
||||
self._tokens = old_tokens
|
||||
self._token_ids = old_token_ids
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self.logger.info("Hot reload monitor stopped")
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in hot reload monitor: {e}")
|
||||
227
doris_mcp_server/auth/token_security_middleware.py
Normal file
@@ -0,0 +1,227 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Token Management Security Middleware
|
||||
|
||||
Provides comprehensive security controls for token management endpoints including
|
||||
IP restrictions, admin authentication, and configuration-based access control.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import ipaddress
|
||||
import secrets
|
||||
import time
|
||||
from typing import Optional, List, Dict, Any
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.config import DorisConfig
|
||||
|
||||
|
||||
class TokenSecurityMiddleware:
|
||||
"""Security middleware for token management endpoints"""
|
||||
|
||||
def __init__(self, config: DorisConfig):
|
||||
self.config = config
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Initialize admin token hash if provided
|
||||
self._admin_token_hash = None
|
||||
if config.security.token_management_admin_token:
|
||||
self._admin_token_hash = self._hash_token(config.security.token_management_admin_token)
|
||||
|
||||
# Normalize allowed IPs
|
||||
self._allowed_networks = self._parse_allowed_networks(
|
||||
config.security.token_management_allowed_ips
|
||||
)
|
||||
|
||||
self.logger.info(f"Token management security initialized: "
|
||||
f"HTTP endpoints {'enabled' if config.security.enable_http_token_management else 'disabled'}, "
|
||||
f"Admin auth {'required' if config.security.require_admin_auth else 'optional'}, "
|
||||
f"Allowed networks: {len(self._allowed_networks)}")
|
||||
|
||||
def _hash_token(self, token: str) -> str:
|
||||
"""Hash token using SHA-256"""
|
||||
return hashlib.sha256(token.encode('utf-8')).hexdigest()
|
||||
|
||||
def _parse_allowed_networks(self, allowed_ips: List[str]) -> List[ipaddress.IPv4Network | ipaddress.IPv6Network]:
|
||||
"""Parse allowed IP addresses and networks"""
|
||||
networks = []
|
||||
for ip_str in allowed_ips:
|
||||
try:
|
||||
# Try to parse as network (CIDR notation)
|
||||
if '/' in ip_str:
|
||||
networks.append(ipaddress.ip_network(ip_str, strict=False))
|
||||
else:
|
||||
# Parse as single IP and convert to /32 network
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
if isinstance(ip, ipaddress.IPv4Address):
|
||||
networks.append(ipaddress.IPv4Network(f"{ip}/32"))
|
||||
else:
|
||||
networks.append(ipaddress.IPv6Network(f"{ip}/128"))
|
||||
except ValueError as e:
|
||||
self.logger.warning(f"Invalid IP/network '{ip_str}': {e}")
|
||||
|
||||
return networks
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract client IP from request, considering proxies"""
|
||||
# Check X-Forwarded-For header first (for proxy setups)
|
||||
forwarded_for = request.headers.get('X-Forwarded-For')
|
||||
if forwarded_for:
|
||||
# Take the first IP (original client)
|
||||
client_ip = forwarded_for.split(',')[0].strip()
|
||||
elif request.headers.get('X-Real-IP'):
|
||||
client_ip = request.headers.get('X-Real-IP')
|
||||
else:
|
||||
# Direct connection
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
return client_ip
|
||||
|
||||
def _is_ip_allowed(self, client_ip: str) -> bool:
|
||||
"""Check if client IP is in allowed networks"""
|
||||
try:
|
||||
client_addr = ipaddress.ip_address(client_ip)
|
||||
|
||||
for network in self._allowed_networks:
|
||||
if client_addr in network:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except ValueError:
|
||||
self.logger.warning(f"Invalid client IP format: {client_ip}")
|
||||
return False
|
||||
|
||||
def _extract_admin_token(self, request: Request) -> Optional[str]:
|
||||
"""Extract admin token from request headers"""
|
||||
# Try Authorization header first
|
||||
auth_header = request.headers.get('Authorization', '')
|
||||
if auth_header.startswith('Bearer '):
|
||||
return auth_header[7:]
|
||||
elif auth_header.startswith('Token '):
|
||||
return auth_header[6:]
|
||||
|
||||
# Try X-Admin-Token header
|
||||
admin_token = request.headers.get('X-Admin-Token', '')
|
||||
if admin_token:
|
||||
return admin_token
|
||||
|
||||
# Try query parameter as fallback (not recommended for production)
|
||||
admin_token = request.query_params.get('admin_token', '')
|
||||
if admin_token:
|
||||
self.logger.warning("Admin token passed via query parameter - this is insecure for production")
|
||||
return admin_token
|
||||
|
||||
return None
|
||||
|
||||
def _verify_admin_token(self, provided_token: str) -> bool:
|
||||
"""Verify provided admin token against configured token"""
|
||||
if not self._admin_token_hash:
|
||||
self.logger.warning("No admin token configured for token management")
|
||||
return False
|
||||
|
||||
provided_hash = self._hash_token(provided_token)
|
||||
|
||||
# Use constant-time comparison to prevent timing attacks
|
||||
return hmac.compare_digest(self._admin_token_hash, provided_hash)
|
||||
|
||||
async def check_token_management_access(self, request: Request) -> Optional[JSONResponse]:
|
||||
"""
|
||||
Check if request is authorized for token management operations
|
||||
|
||||
Returns:
|
||||
None if access is granted
|
||||
JSONResponse with error if access is denied
|
||||
"""
|
||||
|
||||
# Check if HTTP token management is enabled
|
||||
if not self.config.security.enable_http_token_management:
|
||||
self.logger.warning(f"Token management endpoint access denied - HTTP management disabled: {request.url.path}")
|
||||
return JSONResponse({
|
||||
"error": "Token management endpoints are disabled for security",
|
||||
"message": "HTTP token management is disabled. Use file-based token management instead.",
|
||||
"suggestion": "Edit tokens.json file directly or enable HTTP management with proper security configuration"
|
||||
}, status_code=403)
|
||||
|
||||
# Extract client IP
|
||||
client_ip = self._get_client_ip(request)
|
||||
|
||||
# Check IP restrictions
|
||||
if not self._is_ip_allowed(client_ip):
|
||||
self.logger.warning(f"Token management access denied for IP {client_ip}: not in allowed list")
|
||||
return JSONResponse({
|
||||
"error": "Access denied - IP not allowed",
|
||||
"client_ip": client_ip,
|
||||
"message": "Token management is restricted to specific IP addresses",
|
||||
"allowed_networks": [str(net) for net in self._allowed_networks]
|
||||
}, status_code=403)
|
||||
|
||||
# Check admin authentication if required
|
||||
if self.config.security.require_admin_auth:
|
||||
admin_token = self._extract_admin_token(request)
|
||||
|
||||
if not admin_token:
|
||||
self.logger.warning(f"Token management access denied for IP {client_ip}: missing admin token")
|
||||
return JSONResponse({
|
||||
"error": "Admin authentication required",
|
||||
"message": "Token management requires admin authentication",
|
||||
"hint": "Provide admin token in Authorization header: 'Bearer <admin_token>' or 'X-Admin-Token: <admin_token>'"
|
||||
}, status_code=401)
|
||||
|
||||
if not self._verify_admin_token(admin_token):
|
||||
self.logger.warning(f"Token management access denied for IP {client_ip}: invalid admin token")
|
||||
return JSONResponse({
|
||||
"error": "Invalid admin token",
|
||||
"message": "The provided admin token is invalid"
|
||||
}, status_code=401)
|
||||
|
||||
# Log successful access
|
||||
self.logger.info(f"Token management access granted for IP {client_ip} to {request.url.path}")
|
||||
|
||||
# Access granted
|
||||
return None
|
||||
|
||||
def get_security_info(self) -> Dict[str, Any]:
|
||||
"""Get current security configuration info (for demo/status pages)"""
|
||||
return {
|
||||
"http_token_management_enabled": self.config.security.enable_http_token_management,
|
||||
"admin_auth_required": self.config.security.require_admin_auth,
|
||||
"admin_token_configured": bool(self._admin_token_hash),
|
||||
"allowed_networks_count": len(self._allowed_networks),
|
||||
"allowed_networks": [str(net) for net in self._allowed_networks],
|
||||
"security_features": [
|
||||
"IP address restrictions",
|
||||
"Admin token authentication" if self.config.security.require_admin_auth else "Optional admin authentication",
|
||||
"Secure token hashing",
|
||||
"Request logging and auditing"
|
||||
]
|
||||
}
|
||||
|
||||
def generate_admin_token(self) -> str:
|
||||
"""Generate a secure admin token"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
# Convenience function for middleware creation
|
||||
def create_token_security_middleware(config: DorisConfig) -> TokenSecurityMiddleware:
|
||||
"""Create token security middleware with configuration"""
|
||||
return TokenSecurityMiddleware(config)
|
||||
365
doris_mcp_server/auth/token_validators.py
Normal file
@@ -0,0 +1,365 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
JWT Token Validation Module
|
||||
Provides token validation, blacklist management and security features
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, Set, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TokenBlacklist:
|
||||
"""JWT Token Blacklist Manager
|
||||
|
||||
Manages revoked tokens to prevent revoked tokens from being used again
|
||||
Supports both in-memory and persistent storage
|
||||
"""
|
||||
|
||||
def __init__(self, cleanup_interval: int = 3600):
|
||||
"""Initialize token blacklist
|
||||
|
||||
Args:
|
||||
cleanup_interval: Interval for cleaning up expired tokens (seconds)
|
||||
"""
|
||||
self.cleanup_interval = cleanup_interval
|
||||
# Storage format: {token_jti: expiry_timestamp}
|
||||
self._blacklisted_tokens: Dict[str, float] = {}
|
||||
self._cleanup_task = None
|
||||
|
||||
logger.info("TokenBlacklist initialized")
|
||||
|
||||
async def start(self):
|
||||
"""Start blacklist manager"""
|
||||
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||
logger.info("TokenBlacklist started with periodic cleanup")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop blacklist manager"""
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("TokenBlacklist stopped")
|
||||
|
||||
async def add_token(self, jti: str, exp: float):
|
||||
"""Add token to blacklist
|
||||
|
||||
Args:
|
||||
jti: JWT ID (unique identifier)
|
||||
exp: Token expiration timestamp
|
||||
"""
|
||||
self._blacklisted_tokens[jti] = exp
|
||||
logger.info(f"Token {jti} added to blacklist")
|
||||
|
||||
async def is_blacklisted(self, jti: str) -> bool:
|
||||
"""Check if token is blacklisted
|
||||
|
||||
Args:
|
||||
jti: JWT ID
|
||||
|
||||
Returns:
|
||||
True if blacklisted, False otherwise
|
||||
"""
|
||||
return jti in self._blacklisted_tokens
|
||||
|
||||
async def remove_token(self, jti: str) -> bool:
|
||||
"""Remove token from blacklist
|
||||
|
||||
Args:
|
||||
jti: JWT ID
|
||||
|
||||
Returns:
|
||||
True if removed, False if not found
|
||||
"""
|
||||
if jti in self._blacklisted_tokens:
|
||||
del self._blacklisted_tokens[jti]
|
||||
logger.info(f"Token {jti} removed from blacklist")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""Clean up expired blacklisted tokens
|
||||
|
||||
Returns:
|
||||
Number of tokens cleaned up
|
||||
"""
|
||||
current_time = time.time()
|
||||
expired_tokens = [
|
||||
jti for jti, exp in self._blacklisted_tokens.items()
|
||||
if exp <= current_time
|
||||
]
|
||||
|
||||
for jti in expired_tokens:
|
||||
del self._blacklisted_tokens[jti]
|
||||
|
||||
if expired_tokens:
|
||||
logger.info(f"Cleaned up {len(expired_tokens)} expired tokens from blacklist")
|
||||
|
||||
return len(expired_tokens)
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get blacklist statistics"""
|
||||
current_time = time.time()
|
||||
active_tokens = sum(1 for exp in self._blacklisted_tokens.values() if exp > current_time)
|
||||
|
||||
return {
|
||||
"total_blacklisted": len(self._blacklisted_tokens),
|
||||
"active_blacklisted": active_tokens,
|
||||
"expired_blacklisted": len(self._blacklisted_tokens) - active_tokens,
|
||||
"cleanup_interval": self.cleanup_interval
|
||||
}
|
||||
|
||||
async def _periodic_cleanup(self):
|
||||
"""Periodically clean up expired tokens"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.cleanup_interval)
|
||||
await self.cleanup_expired()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error during periodic cleanup: {e}")
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Token usage rate limiter"""
|
||||
|
||||
def __init__(self, max_requests: int = 100, time_window: int = 3600):
|
||||
"""Initialize rate limiter
|
||||
|
||||
Args:
|
||||
max_requests: Maximum requests within time window
|
||||
time_window: Time window in seconds
|
||||
"""
|
||||
self.max_requests = max_requests
|
||||
self.time_window = time_window
|
||||
# Storage format: {user_id: [timestamp1, timestamp2, ...]}
|
||||
self._request_history: Dict[str, list] = defaultdict(list)
|
||||
|
||||
logger.info(f"RateLimiter initialized: {max_requests} requests per {time_window} seconds")
|
||||
|
||||
async def is_allowed(self, user_id: str) -> bool:
|
||||
"""Check if user is allowed to make request
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
True if allowed, False otherwise
|
||||
"""
|
||||
current_time = time.time()
|
||||
user_requests = self._request_history[user_id]
|
||||
|
||||
# Clean up expired request records
|
||||
cutoff_time = current_time - self.time_window
|
||||
user_requests[:] = [t for t in user_requests if t > cutoff_time]
|
||||
|
||||
# Check if limit exceeded
|
||||
if len(user_requests) >= self.max_requests:
|
||||
logger.warning(f"Rate limit exceeded for user {user_id}")
|
||||
return False
|
||||
|
||||
# Record current request
|
||||
user_requests.append(current_time)
|
||||
return True
|
||||
|
||||
async def get_usage(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get user usage information
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Usage statistics
|
||||
"""
|
||||
current_time = time.time()
|
||||
user_requests = self._request_history[user_id]
|
||||
|
||||
# Clean up expired records
|
||||
cutoff_time = current_time - self.time_window
|
||||
active_requests = [t for t in user_requests if t > cutoff_time]
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"requests_in_window": len(active_requests),
|
||||
"max_requests": self.max_requests,
|
||||
"time_window": self.time_window,
|
||||
"remaining_requests": max(0, self.max_requests - len(active_requests))
|
||||
}
|
||||
|
||||
|
||||
class TokenValidator:
|
||||
"""JWT Token Validator
|
||||
|
||||
Provides comprehensive JWT token validation functionality, including signature verification,
|
||||
claim validation, blacklist checking and rate limiting
|
||||
"""
|
||||
|
||||
def __init__(self, config, blacklist: Optional[TokenBlacklist] = None):
|
||||
"""Initialize token validator
|
||||
|
||||
Args:
|
||||
config: DorisConfig configuration object (with security attribute)
|
||||
blacklist: Token blacklist manager
|
||||
"""
|
||||
self.config = config
|
||||
self.blacklist = blacklist or TokenBlacklist()
|
||||
self.rate_limiter = RateLimiter()
|
||||
|
||||
# Access JWT settings through the security configuration
|
||||
if hasattr(config, 'security'):
|
||||
security_config = config.security
|
||||
else:
|
||||
# Fallback if config is passed directly as SecurityConfig
|
||||
security_config = config
|
||||
|
||||
# Validation options
|
||||
self.verify_signature = security_config.jwt_verify_signature
|
||||
self.verify_audience = security_config.jwt_verify_audience
|
||||
self.verify_issuer = security_config.jwt_verify_issuer
|
||||
self.require_exp = security_config.jwt_require_exp
|
||||
self.require_iat = security_config.jwt_require_iat
|
||||
self.require_nbf = security_config.jwt_require_nbf
|
||||
self.leeway = security_config.jwt_leeway
|
||||
|
||||
# Expected values
|
||||
self.expected_audience = security_config.jwt_audience
|
||||
self.expected_issuer = security_config.jwt_issuer
|
||||
|
||||
logger.info("TokenValidator initialized")
|
||||
|
||||
async def validate_claims(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate JWT claims
|
||||
|
||||
Args:
|
||||
payload: JWT payload
|
||||
|
||||
Returns:
|
||||
Validation result
|
||||
|
||||
Raises:
|
||||
ValueError: Validation failed
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# Validate issuer
|
||||
if self.verify_issuer:
|
||||
if payload.get('iss') != self.expected_issuer:
|
||||
raise ValueError(f"Invalid issuer: expected {self.expected_issuer}")
|
||||
|
||||
# Validate audience
|
||||
if self.verify_audience:
|
||||
aud = payload.get('aud')
|
||||
if isinstance(aud, list):
|
||||
if self.expected_audience not in aud:
|
||||
raise ValueError(f"Invalid audience: {self.expected_audience} not in {aud}")
|
||||
elif aud != self.expected_audience:
|
||||
raise ValueError(f"Invalid audience: expected {self.expected_audience}")
|
||||
|
||||
# Validate expiration time
|
||||
if self.require_exp or 'exp' in payload:
|
||||
exp = payload.get('exp')
|
||||
if not exp:
|
||||
raise ValueError("Missing 'exp' claim")
|
||||
if current_time > exp + self.leeway:
|
||||
raise ValueError("Token has expired")
|
||||
|
||||
# Validate not before time
|
||||
if self.require_nbf or 'nbf' in payload:
|
||||
nbf = payload.get('nbf')
|
||||
if not nbf:
|
||||
raise ValueError("Missing 'nbf' claim")
|
||||
if current_time < nbf - self.leeway:
|
||||
raise ValueError("Token not yet valid")
|
||||
|
||||
# Validate issued at time
|
||||
if self.require_iat or 'iat' in payload:
|
||||
iat = payload.get('iat')
|
||||
if not iat:
|
||||
raise ValueError("Missing 'iat' claim")
|
||||
# Allow some clock skew, but cannot be future time
|
||||
if iat > current_time + self.leeway:
|
||||
raise ValueError("Token issued in the future")
|
||||
|
||||
# Check blacklist
|
||||
jti = payload.get('jti')
|
||||
if jti and await self.blacklist.is_blacklisted(jti):
|
||||
raise ValueError("Token has been revoked")
|
||||
|
||||
# Rate limit check
|
||||
user_id = payload.get('sub')
|
||||
if user_id:
|
||||
if not await self.rate_limiter.is_allowed(user_id):
|
||||
raise ValueError("Rate limit exceeded")
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"user_id": user_id,
|
||||
"payload": payload
|
||||
}
|
||||
|
||||
async def start(self):
|
||||
"""Start validator"""
|
||||
await self.blacklist.start()
|
||||
logger.info("TokenValidator started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop validator"""
|
||||
await self.blacklist.stop()
|
||||
logger.info("TokenValidator stopped")
|
||||
|
||||
async def revoke_token(self, jti: str, exp: float):
|
||||
"""Revoke token
|
||||
|
||||
Args:
|
||||
jti: JWT ID
|
||||
exp: Token expiration time
|
||||
"""
|
||||
await self.blacklist.add_token(jti, exp)
|
||||
logger.info(f"Token {jti} has been revoked")
|
||||
|
||||
async def get_validation_stats(self) -> Dict[str, Any]:
|
||||
"""Get validation statistics"""
|
||||
blacklist_stats = await self.blacklist.get_stats()
|
||||
|
||||
return {
|
||||
"blacklist": blacklist_stats,
|
||||
"validation_config": {
|
||||
"verify_signature": self.verify_signature,
|
||||
"verify_audience": self.verify_audience,
|
||||
"verify_issuer": self.verify_issuer,
|
||||
"require_exp": self.require_exp,
|
||||
"require_iat": self.require_iat,
|
||||
"require_nbf": self.require_nbf,
|
||||
"leeway": self.leeway
|
||||
}
|
||||
}
|
||||
|
||||
async def get_user_rate_limit_info(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get user rate limit information"""
|
||||
return await self.rate_limiter.get_usage(user_id)
|
||||
@@ -28,15 +28,183 @@ import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from mcp.server import Server
|
||||
from mcp.server.models import InitializationOptions
|
||||
# MCP version compatibility handling
|
||||
MCP_VERSION = 'unknown'
|
||||
Server = None
|
||||
InitializationOptions = None
|
||||
Prompt = None
|
||||
Resource = None
|
||||
TextContent = None
|
||||
Tool = None
|
||||
|
||||
from mcp.types import (
|
||||
Prompt,
|
||||
Resource,
|
||||
TextContent,
|
||||
Tool,
|
||||
)
|
||||
def _import_mcp_with_compatibility():
|
||||
"""Import MCP components with multi-version compatibility"""
|
||||
global MCP_VERSION, Server, InitializationOptions, Prompt, Resource, TextContent, Tool
|
||||
|
||||
try:
|
||||
# Strategy 1: Try direct server-only imports to avoid client-side issues
|
||||
from mcp.server import Server as _Server
|
||||
from mcp.server.models import InitializationOptions as _InitOptions
|
||||
from mcp.types import (
|
||||
Prompt as _Prompt,
|
||||
Resource as _Resource,
|
||||
TextContent as _TextContent,
|
||||
Tool as _Tool,
|
||||
)
|
||||
|
||||
# Assign to globals
|
||||
Server = _Server
|
||||
InitializationOptions = _InitOptions
|
||||
Prompt = _Prompt
|
||||
Resource = _Resource
|
||||
TextContent = _TextContent
|
||||
Tool = _Tool
|
||||
|
||||
# Try to get version safely
|
||||
try:
|
||||
import mcp
|
||||
MCP_VERSION = getattr(mcp, '__version__', None)
|
||||
if not MCP_VERSION:
|
||||
# Fallback: try to get version from package metadata
|
||||
try:
|
||||
import importlib.metadata
|
||||
MCP_VERSION = importlib.metadata.version('mcp')
|
||||
except Exception:
|
||||
# Second fallback: try pkg_resources
|
||||
try:
|
||||
import pkg_resources
|
||||
MCP_VERSION = pkg_resources.get_distribution('mcp').version
|
||||
except Exception:
|
||||
MCP_VERSION = 'detected-but-version-unknown'
|
||||
except Exception:
|
||||
# Version detection failed, but imports worked
|
||||
try:
|
||||
import importlib.metadata
|
||||
MCP_VERSION = importlib.metadata.version('mcp')
|
||||
except Exception:
|
||||
try:
|
||||
import pkg_resources
|
||||
MCP_VERSION = pkg_resources.get_distribution('mcp').version
|
||||
except Exception:
|
||||
MCP_VERSION = 'imported-successfully'
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"MCP components imported successfully, version: {MCP_VERSION}")
|
||||
return True
|
||||
|
||||
except Exception as import_error:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Strategy 2: Handle RequestContext compatibility issues in 1.9.x versions
|
||||
error_str = str(import_error).lower()
|
||||
if 'requestcontext' in error_str and 'too few arguments' in error_str:
|
||||
logger.warning(f"Detected MCP RequestContext compatibility issue: {import_error}")
|
||||
logger.info("Attempting comprehensive workaround for MCP 1.9.x RequestContext issue...")
|
||||
|
||||
try:
|
||||
# Comprehensive monkey patch approach
|
||||
import sys
|
||||
import types
|
||||
|
||||
# Create and install mock modules before any MCP imports
|
||||
if 'mcp.shared.context' not in sys.modules:
|
||||
mock_context_module = types.ModuleType('mcp.shared.context')
|
||||
|
||||
class FlexibleRequestContext:
|
||||
"""Flexible RequestContext that accepts variable arguments"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __class_getitem__(cls, params):
|
||||
# Accept any number of parameters and return cls
|
||||
return cls
|
||||
|
||||
# Add other methods that might be called
|
||||
def __getattr__(self, name):
|
||||
return lambda *args, **kwargs: None
|
||||
|
||||
mock_context_module.RequestContext = FlexibleRequestContext
|
||||
sys.modules['mcp.shared.context'] = mock_context_module
|
||||
|
||||
# Also patch the typing system to be more permissive
|
||||
original_check_generic = None
|
||||
try:
|
||||
import typing
|
||||
if hasattr(typing, '_check_generic'):
|
||||
original_check_generic = typing._check_generic
|
||||
def permissive_check_generic(cls, params, elen):
|
||||
# Don't enforce strict parameter count checking
|
||||
return
|
||||
typing._check_generic = permissive_check_generic
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Clear any cached imports that might have failed
|
||||
modules_to_clear = [k for k in sys.modules.keys() if k.startswith('mcp.')]
|
||||
for module in modules_to_clear:
|
||||
if module in sys.modules:
|
||||
del sys.modules[module]
|
||||
|
||||
# Now try importing again with the patches in place
|
||||
from mcp.server import Server as _Server
|
||||
from mcp.server.models import InitializationOptions as _InitOptions
|
||||
from mcp.types import (
|
||||
Prompt as _Prompt,
|
||||
Resource as _Resource,
|
||||
TextContent as _TextContent,
|
||||
Tool as _Tool,
|
||||
)
|
||||
|
||||
# Assign to globals
|
||||
Server = _Server
|
||||
InitializationOptions = _InitOptions
|
||||
Prompt = _Prompt
|
||||
Resource = _Resource
|
||||
TextContent = _TextContent
|
||||
Tool = _Tool
|
||||
|
||||
# Try to detect actual version even in compatibility mode
|
||||
try:
|
||||
import importlib.metadata
|
||||
actual_version = importlib.metadata.version('mcp')
|
||||
MCP_VERSION = f'compatibility-mode-{actual_version}'
|
||||
except Exception:
|
||||
try:
|
||||
import pkg_resources
|
||||
actual_version = pkg_resources.get_distribution('mcp').version
|
||||
MCP_VERSION = f'compatibility-mode-{actual_version}'
|
||||
except Exception:
|
||||
MCP_VERSION = 'compatibility-mode-1.9.x'
|
||||
|
||||
logger.info("MCP 1.9.x compatibility workaround successful!")
|
||||
|
||||
# Restore original typing function if we patched it
|
||||
if original_check_generic:
|
||||
typing._check_generic = original_check_generic
|
||||
|
||||
return True
|
||||
|
||||
except Exception as workaround_error:
|
||||
logger.error(f"MCP compatibility workaround failed: {workaround_error}")
|
||||
|
||||
# Restore original typing function if we patched it
|
||||
if original_check_generic:
|
||||
try:
|
||||
import typing
|
||||
typing._check_generic = original_check_generic
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.error(f"Failed to import MCP components: {import_error}")
|
||||
return False
|
||||
|
||||
# Perform MCP import with compatibility handling
|
||||
if not _import_mcp_with_compatibility():
|
||||
raise ImportError(
|
||||
"Failed to import MCP components. Please ensure MCP is properly installed. "
|
||||
"Supported versions: 1.8.x, 1.9.x"
|
||||
)
|
||||
|
||||
from .tools.tools_manager import DorisToolsManager
|
||||
from .tools.prompts_manager import DorisPromptsManager
|
||||
@@ -44,11 +212,16 @@ from .tools.resources_manager import DorisResourcesManager
|
||||
from .utils.config import DorisConfig
|
||||
from .utils.db import DorisConnectionManager
|
||||
from .utils.security import DorisSecurityManager
|
||||
import os
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
# Configure logging - will be properly initialized later
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create a default config instance for getting default values
|
||||
_default_config = DorisConfig()
|
||||
|
||||
|
||||
|
||||
|
||||
class DorisServer:
|
||||
"""Apache Doris MCP Server main class"""
|
||||
@@ -57,20 +230,101 @@ class DorisServer:
|
||||
self.config = config
|
||||
self.server = Server("doris-mcp-server")
|
||||
|
||||
# Initialize security manager
|
||||
# Initialize security manager (without connection_manager initially)
|
||||
self.security_manager = DorisSecurityManager(config)
|
||||
|
||||
# Initialize connection manager, pass in security manager
|
||||
self.connection_manager = DorisConnectionManager(config, self.security_manager)
|
||||
# Initialize connection manager, pass in security manager and token manager for token-bound DB config
|
||||
token_manager = self.security_manager.auth_provider.token_manager if hasattr(self.security_manager, 'auth_provider') and hasattr(self.security_manager.auth_provider, 'token_manager') else None
|
||||
self.connection_manager = DorisConnectionManager(config, self.security_manager, token_manager)
|
||||
|
||||
# Set connection manager reference in security manager for database validation
|
||||
self.security_manager.connection_manager = self.connection_manager
|
||||
|
||||
# Initialize independent managers
|
||||
self.resources_manager = DorisResourcesManager(self.connection_manager)
|
||||
self.tools_manager = DorisToolsManager(self.connection_manager)
|
||||
self.prompts_manager = DorisPromptsManager(self.connection_manager)
|
||||
|
||||
self.logger = logging.getLogger(f"{__name__}.DorisServer")
|
||||
# Import here to avoid circular imports
|
||||
from .utils.logger import get_logger
|
||||
self.logger = get_logger(f"{__name__}.DorisServer")
|
||||
self._setup_handlers()
|
||||
|
||||
async def _extract_auth_info_from_scope(self, scope, headers):
|
||||
"""Extract authentication information from ASGI scope and headers"""
|
||||
auth_info = {}
|
||||
|
||||
# Extract client IP
|
||||
client = scope.get("client")
|
||||
if client:
|
||||
auth_info["client_ip"] = client[0]
|
||||
else:
|
||||
auth_info["client_ip"] = "unknown"
|
||||
|
||||
# Extract token from Authorization header
|
||||
authorization = headers.get(b'authorization', b'').decode('utf-8')
|
||||
if authorization:
|
||||
if authorization.startswith('Bearer '):
|
||||
auth_info["token"] = authorization[7:]
|
||||
auth_info["authorization"] = authorization
|
||||
elif authorization.startswith('Token '):
|
||||
auth_info["token"] = authorization[6:]
|
||||
auth_info["authorization"] = authorization
|
||||
|
||||
# Extract token from query parameters (for compatibility)
|
||||
query_string = scope.get("query_string", b"").decode('utf-8')
|
||||
if query_string and "token=" in query_string:
|
||||
import urllib.parse
|
||||
query_params = urllib.parse.parse_qs(query_string)
|
||||
if "token" in query_params:
|
||||
auth_info["token"] = query_params["token"][0]
|
||||
|
||||
# If no token found, this will be handled by the authentication system
|
||||
# (either return anonymous context if auth disabled, or raise error if auth enabled)
|
||||
|
||||
return auth_info
|
||||
|
||||
def _get_mcp_capabilities(self):
|
||||
"""Get MCP capabilities with version compatibility"""
|
||||
try:
|
||||
# For MCP 1.9.x and newer
|
||||
from mcp.server.lowlevel.server import NotificationOptions
|
||||
|
||||
return self.server.get_capabilities(
|
||||
notification_options=NotificationOptions(
|
||||
prompts_changed=True,
|
||||
resources_changed=True,
|
||||
tools_changed=True
|
||||
),
|
||||
experimental_capabilities={}
|
||||
)
|
||||
except TypeError:
|
||||
try:
|
||||
# For MCP 1.8.x
|
||||
from mcp.server.lowlevel.server import NotificationOptions
|
||||
|
||||
return self.server.get_capabilities(
|
||||
notification_options=NotificationOptions(
|
||||
prompts_changed=True,
|
||||
resources_changed=True,
|
||||
tools_changed=True
|
||||
),
|
||||
experimental_capabilities={}
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Could not get capabilities with NotificationOptions: {e}")
|
||||
# Fallback for older versions
|
||||
try:
|
||||
return self.server.get_capabilities()
|
||||
except Exception as fallback_e:
|
||||
self.logger.error(f"Failed to get capabilities: {fallback_e}")
|
||||
# Return minimal capabilities
|
||||
return {
|
||||
"resources": {},
|
||||
"tools": {},
|
||||
"prompts": {}
|
||||
}
|
||||
|
||||
def _setup_handlers(self):
|
||||
"""Setup MCP protocol handlers"""
|
||||
|
||||
@@ -174,12 +428,24 @@ class DorisServer:
|
||||
self.logger.info("Starting Doris MCP Server (stdio mode)")
|
||||
|
||||
try:
|
||||
# Ensure connection manager is initialized
|
||||
await self.connection_manager.initialize()
|
||||
self.logger.info("Connection manager initialization completed")
|
||||
# Initialize security manager first (includes JWT setup if enabled)
|
||||
await self.security_manager.initialize()
|
||||
self.logger.info("Security manager initialization completed")
|
||||
|
||||
# For stdio mode, we must establish a working database connection
|
||||
# Use the dedicated stdio mode initialization method
|
||||
await self.connection_manager.initialize_for_stdio_mode()
|
||||
|
||||
# Start stdio server - using simpler approach
|
||||
from mcp.server.stdio import stdio_server
|
||||
# Start stdio server - using compatible import approach
|
||||
try:
|
||||
from mcp.server.stdio import stdio_server
|
||||
except ImportError:
|
||||
# Fallback for different MCP versions
|
||||
try:
|
||||
from mcp.server import stdio_server
|
||||
except ImportError as stdio_import_error:
|
||||
self.logger.error(f"Failed to import stdio_server: {stdio_import_error}")
|
||||
raise RuntimeError("stdio_server module not available in this MCP version")
|
||||
|
||||
self.logger.info("Creating stdio_server transport...")
|
||||
|
||||
@@ -189,22 +455,12 @@ class DorisServer:
|
||||
read_stream, write_stream = streams
|
||||
self.logger.info("stdio_server streams created successfully")
|
||||
|
||||
# Create initialization options
|
||||
# MCP 1.8.0 requires parameters for get_capabilities
|
||||
from mcp.server.lowlevel.server import NotificationOptions
|
||||
|
||||
capabilities = self.server.get_capabilities(
|
||||
notification_options=NotificationOptions(
|
||||
prompts_changed=True,
|
||||
resources_changed=True,
|
||||
tools_changed=True
|
||||
),
|
||||
experimental_capabilities={}
|
||||
)
|
||||
# Create initialization options with version compatibility
|
||||
capabilities = self._get_mcp_capabilities()
|
||||
|
||||
init_options = InitializationOptions(
|
||||
server_name="doris-mcp-server",
|
||||
server_version="1.0.0",
|
||||
server_version=os.getenv("SERVER_VERSION", _default_config.server_version),
|
||||
capabilities=capabilities,
|
||||
)
|
||||
self.logger.info("Initialization options created successfully")
|
||||
@@ -237,13 +493,21 @@ class DorisServer:
|
||||
|
||||
|
||||
|
||||
async def start_http(self, host: str = "localhost", port: int = 3000):
|
||||
"""Start Streamable HTTP transport mode"""
|
||||
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}")
|
||||
async def start_http(self, host: str = os.getenv("SERVER_HOST", _default_config.database.host), port: int = os.getenv("SERVER_PORT", _default_config.server_port), workers: int = 1):
|
||||
"""Start Streamable HTTP transport mode with workers support"""
|
||||
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}, workers: {workers}")
|
||||
|
||||
try:
|
||||
# Ensure connection manager is initialized
|
||||
await self.connection_manager.initialize()
|
||||
# Initialize security manager first (includes JWT setup if enabled)
|
||||
await self.security_manager.initialize()
|
||||
self.logger.info("Security manager initialization completed")
|
||||
|
||||
# For HTTP mode, try to initialize global connection pool with graceful degradation
|
||||
global_pool_created = await self.connection_manager.initialize_for_http_mode()
|
||||
if global_pool_created:
|
||||
self.logger.info("Global database connection pool available for HTTP mode")
|
||||
else:
|
||||
self.logger.info("HTTP mode running without global database pool, will use token-bound configurations")
|
||||
|
||||
# Use Starlette and StreamableHTTPSessionManager according to official example
|
||||
import uvicorn
|
||||
@@ -251,9 +515,9 @@ class DorisServer:
|
||||
from collections.abc import AsyncIterator
|
||||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Mount, Route
|
||||
from starlette.routing import Route
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
from starlette.types import Scope
|
||||
|
||||
# Create session manager
|
||||
session_manager = StreamableHTTPSessionManager(
|
||||
@@ -268,6 +532,44 @@ class DorisServer:
|
||||
async def health_check(request):
|
||||
return JSONResponse({"status": "healthy", "service": "doris-mcp-server"})
|
||||
|
||||
# OAuth endpoints
|
||||
from .auth.oauth_handlers import OAuthHandlers
|
||||
oauth_handlers = OAuthHandlers(self.security_manager)
|
||||
|
||||
async def oauth_login(request):
|
||||
return await oauth_handlers.handle_login(request)
|
||||
|
||||
async def oauth_callback(request):
|
||||
return await oauth_handlers.handle_callback(request)
|
||||
|
||||
async def oauth_provider_info(request):
|
||||
return await oauth_handlers.handle_provider_info(request)
|
||||
|
||||
async def oauth_demo(request):
|
||||
return await oauth_handlers.handle_demo_page(request)
|
||||
|
||||
# Token management endpoints
|
||||
from .auth.token_handlers import TokenHandlers
|
||||
token_handlers = TokenHandlers(self.security_manager, self.config)
|
||||
|
||||
async def token_create(request):
|
||||
return await token_handlers.handle_create_token(request)
|
||||
|
||||
async def token_revoke(request):
|
||||
return await token_handlers.handle_revoke_token(request)
|
||||
|
||||
async def token_list(request):
|
||||
return await token_handlers.handle_list_tokens(request)
|
||||
|
||||
async def token_stats(request):
|
||||
return await token_handlers.handle_token_stats(request)
|
||||
|
||||
async def token_cleanup(request):
|
||||
return await token_handlers.handle_cleanup_tokens(request)
|
||||
|
||||
async def token_management(request):
|
||||
return await token_handlers.handle_management_page(request)
|
||||
|
||||
# Lifecycle manager - simplified since we manage session_manager externally
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app: Starlette) -> AsyncIterator[None]:
|
||||
@@ -283,6 +585,18 @@ class DorisServer:
|
||||
debug=True,
|
||||
routes=[
|
||||
Route("/health", health_check, methods=["GET"]),
|
||||
# OAuth endpoints
|
||||
Route("/auth/login", oauth_login, methods=["GET"]),
|
||||
Route("/auth/callback", oauth_callback, methods=["GET"]),
|
||||
Route("/auth/provider", oauth_provider_info, methods=["GET"]),
|
||||
Route("/auth/demo", oauth_demo, methods=["GET"]),
|
||||
# Token management endpoints
|
||||
Route("/token/create", token_create, methods=["GET", "POST"]),
|
||||
Route("/token/revoke", token_revoke, methods=["GET", "DELETE"]),
|
||||
Route("/token/list", token_list, methods=["GET"]),
|
||||
Route("/token/stats", token_stats, methods=["GET"]),
|
||||
Route("/token/cleanup", token_cleanup, methods=["GET", "POST"]),
|
||||
Route("/token/management", token_management, methods=["GET"]),
|
||||
],
|
||||
lifespan=lifespan,
|
||||
)
|
||||
@@ -300,8 +614,10 @@ class DorisServer:
|
||||
self.logger.info(f"Received request for path: {path}")
|
||||
|
||||
try:
|
||||
# Handle health check
|
||||
if path.startswith("/health"):
|
||||
# Handle health check, auth, and token management endpoints
|
||||
if (path.startswith("/health") or
|
||||
path.startswith("/auth/") or
|
||||
path.startswith("/token/")):
|
||||
await starlette_app(scope, receive, send)
|
||||
return
|
||||
|
||||
@@ -314,6 +630,39 @@ class DorisServer:
|
||||
self.logger.info(f"MCP Request - Method: {method}")
|
||||
self.logger.info(f"MCP Request - Headers: {headers}")
|
||||
|
||||
# Authentication check for MCP requests
|
||||
try:
|
||||
# Extract authentication information
|
||||
auth_info = await self._extract_auth_info_from_scope(scope, headers)
|
||||
|
||||
# Authenticate the request
|
||||
auth_context = await self.security_manager.authenticate_request(auth_info)
|
||||
self.logger.info(f"MCP request authenticated: token_id={auth_context.token_id}, client_ip={auth_context.client_ip}")
|
||||
|
||||
# Store auth context in scope for potential use by tools/resources
|
||||
scope["auth_context"] = auth_context
|
||||
|
||||
# FIX for Issue #62 Bug 1: Set auth_context in context variable
|
||||
# This allows tools to access token information for token-bound database configuration
|
||||
# CRITICAL: Use the global ContextVar from security.py to ensure same instance is used everywhere
|
||||
try:
|
||||
from .utils.security import mcp_auth_context_var
|
||||
mcp_auth_context_var.set(auth_context)
|
||||
self.logger.debug(f"Set auth_context in context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
|
||||
except Exception as ctx_error:
|
||||
self.logger.warning(f"Failed to set auth_context in context variable: {ctx_error}")
|
||||
|
||||
except Exception as auth_error:
|
||||
self.logger.error(f"MCP authentication failed: {auth_error}")
|
||||
# Return 401 Unauthorized
|
||||
from starlette.responses import JSONResponse
|
||||
response = JSONResponse(
|
||||
{"error": "Authentication required", "message": str(auth_error)},
|
||||
status_code=401
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
# Handle Dify compatibility for GET requests
|
||||
if method == "GET":
|
||||
accept_header = headers.get(b'accept', b'').decode('utf-8')
|
||||
@@ -356,19 +705,35 @@ class DorisServer:
|
||||
self.logger.warning(f"Unsupported scope type: {scope['type']}")
|
||||
return
|
||||
|
||||
# Start uvicorn server with session manager lifecycle
|
||||
config = uvicorn.Config(
|
||||
app=mcp_app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="info"
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# Run session manager and server together
|
||||
async with session_manager.run():
|
||||
self.logger.info("Session manager started, now starting HTTP server")
|
||||
await server.serve()
|
||||
# Choose startup method based on worker count
|
||||
if workers > 1:
|
||||
self.logger.info(f"Using multi-process mode with {workers} workers")
|
||||
self.logger.info("Note: Multi-worker mode provides full MCP functionality with independent worker processes")
|
||||
|
||||
# Use the dedicated multiworker app module with full MCP support
|
||||
uvicorn.run(
|
||||
"doris_mcp_server.multiworker_app:app",
|
||||
host=host,
|
||||
port=port,
|
||||
workers=workers,
|
||||
log_level="info"
|
||||
)
|
||||
|
||||
else:
|
||||
self.logger.info("Using single-process mode")
|
||||
# Single worker mode, use original logic with session manager lifecycle
|
||||
config = uvicorn.Config(
|
||||
app=mcp_app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="info"
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# Run session manager and server together
|
||||
async with session_manager.run():
|
||||
self.logger.info("Session manager started, now starting HTTP server")
|
||||
await server.serve()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Streamable HTTP server startup failed: {e}")
|
||||
@@ -383,10 +748,16 @@ class DorisServer:
|
||||
self.logger.error(f" Exception {i+1}: {type(exc).__name__}: {exc}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown server"""
|
||||
self.logger.info("Shutting down Doris MCP Server")
|
||||
try:
|
||||
# Shutdown security manager first (includes JWT cleanup)
|
||||
await self.security_manager.shutdown()
|
||||
self.logger.info("Security manager shutdown completed")
|
||||
|
||||
await self.connection_manager.close()
|
||||
self.logger.info("Doris MCP Server has been shut down")
|
||||
except Exception as e:
|
||||
@@ -406,6 +777,11 @@ Transport Modes:
|
||||
Examples:
|
||||
python -m doris_mcp_server --transport stdio
|
||||
python -m doris_mcp_server --transport http --host 0.0.0.0 --port 3000
|
||||
python -m doris_mcp_server --transport stdio --doris-host localhost --doris-port 9030
|
||||
python -m doris_mcp_server --transport http --doris-user admin --doris-database test_db
|
||||
|
||||
# Backward compatibility: --db-* parameters are also supported
|
||||
python -m doris_mcp_server --transport stdio --db-host localhost --db-port 9030
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -413,91 +789,151 @@ Examples:
|
||||
"--transport",
|
||||
type=str,
|
||||
choices=["stdio", "http"],
|
||||
default="stdio",
|
||||
help="Transport protocol type: stdio (local), http (Streamable HTTP)",
|
||||
default=os.getenv("TRANSPORT", _default_config.transport),
|
||||
help=f"Transport protocol type: stdio (local), http (Streamable HTTP) (default: {_default_config.transport})",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Host address for HTTP mode (default: localhost)",
|
||||
default=os.getenv("SERVER_HOST", _default_config.server_host),
|
||||
help=f"Host address for HTTP mode (default: {_default_config.server_host})",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=3000, help="Port number for HTTP mode (default: 3000)"
|
||||
"--port",
|
||||
type=int,
|
||||
default=3000,
|
||||
help="Port number for HTTP mode (default: 3000)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--db-host",
|
||||
"--workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of worker processes for HTTP mode (default: 1, use 0 for auto-detect CPU cores)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--doris-host", "--db-host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Doris database host address (default: localhost)",
|
||||
default=os.getenv("DORIS_HOST", _default_config.database.host),
|
||||
help=f"Doris database host address (default: {_default_config.database.host})",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--db-port", type=int, default=9030, help="Doris database port number (default: 9030)"
|
||||
"--doris-port", "--db-port", type=int, default=9030, help="Doris database port number (default: 9030)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--db-user", type=str, default="root", help="Doris database username (default: root)"
|
||||
"--doris-user", "--db-user", type=str, default=os.getenv("DORIS_USER", _default_config.database.user), help=f"Doris database username (default: {_default_config.database.user})"
|
||||
)
|
||||
|
||||
parser.add_argument("--db-password", type=str, default="", help="Doris database password")
|
||||
parser.add_argument("--doris-password", "--db-password", type=str, default=os.getenv("DORIS_PASSWORD", ""), help="Doris database password")
|
||||
|
||||
parser.add_argument(
|
||||
"--db-database",
|
||||
"--doris-database", "--db-database",
|
||||
type=str,
|
||||
default="information_schema",
|
||||
help="Doris database name (default: information_schema)",
|
||||
default=os.getenv("DORIS_DATABASE", _default_config.database.database),
|
||||
help=f"Doris database name (default: {_default_config.database.database})",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
default="INFO",
|
||||
help="Log level (default: INFO)",
|
||||
default=os.getenv("LOG_LEVEL", _default_config.logging.level),
|
||||
help=f"Log level (default: {_default_config.logging.level})",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function"""
|
||||
def update_configuration(config: DorisConfig):
|
||||
"""Update doris configuration object"""
|
||||
# For some arguments, if not specified, environment variables or default configurations will be used as default values
|
||||
parser = create_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set log level
|
||||
logging.getLogger().setLevel(getattr(logging, args.log_level))
|
||||
|
||||
# Create configuration - priority: command line arguments > .env file > default values
|
||||
config = DorisConfig.from_env() # First load from .env file and environment variables
|
||||
|
||||
# Update config values
|
||||
# Command line arguments override configuration (if provided)
|
||||
if args.db_host != "localhost": # If not default value, use command line argument
|
||||
config.database.host = args.db_host
|
||||
if args.db_port != 9030:
|
||||
config.database.port = args.db_port
|
||||
if args.db_user != "root":
|
||||
config.database.user = args.db_user
|
||||
if args.db_password: # Use password if provided
|
||||
config.database.password = args.db_password
|
||||
if args.db_database != "information_schema":
|
||||
config.database.database = args.db_database
|
||||
if args.log_level != "INFO":
|
||||
# basic
|
||||
if args.transport != _default_config.transport:
|
||||
config.transport = args.transport
|
||||
if args.host != _default_config.server_host:
|
||||
config.server_host = args.host
|
||||
if args.port != _default_config.server_port:
|
||||
config.server_port = args.port
|
||||
server_name = os.getenv("SERVER_NAME")
|
||||
if server_name:
|
||||
config.server_name = server_name
|
||||
server_version = os.getenv("SERVER_VERSION")
|
||||
if server_version:
|
||||
config.server_version = server_version
|
||||
|
||||
# database
|
||||
if args.doris_host != _default_config.database.host: # If not default value, use command line argument
|
||||
config.database.host = args.doris_host
|
||||
if args.doris_port != _default_config.database.port:
|
||||
config.database.port = args.doris_port
|
||||
if args.doris_user != _default_config.database.user:
|
||||
config.database.user = args.doris_user
|
||||
if args.doris_password: # Use password if provided
|
||||
config.database.password = args.doris_password
|
||||
if args.doris_database != _default_config.database.database:
|
||||
config.database.database = args.doris_database
|
||||
|
||||
# logging
|
||||
if args.log_level != _default_config.logging.level:
|
||||
config.logging.level = args.log_level
|
||||
|
||||
# workers (add to config for HTTP mode)
|
||||
if hasattr(args, 'workers'):
|
||||
config.workers = args.workers
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function"""
|
||||
# Create configuration - priority: command line arguments > env variables > .env file > default values
|
||||
# First load from .env file and environment variables
|
||||
config = DorisConfig.from_env()
|
||||
|
||||
# Then parse the command line arguments, and update the config object.
|
||||
update_configuration(config)
|
||||
|
||||
# Initialize enhanced logging system
|
||||
from .utils.config import ConfigManager
|
||||
config_manager = ConfigManager(config)
|
||||
config_manager.setup_logging()
|
||||
|
||||
# Get logger with proper configuration
|
||||
from .utils.logger import get_logger, log_system_info
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Log system information for debugging
|
||||
log_system_info()
|
||||
|
||||
logger.info("Starting Doris MCP Server...")
|
||||
logger.info(f"Transport: {config.transport}")
|
||||
logger.info(f"Log Level: {config.logging.level}")
|
||||
|
||||
# Create server instance
|
||||
server = DorisServer(config)
|
||||
|
||||
try:
|
||||
if args.transport == "stdio":
|
||||
if config.transport == "stdio":
|
||||
await server.start_stdio()
|
||||
elif args.transport == "http":
|
||||
await server.start_http(args.host, args.port)
|
||||
elif config.transport == "http":
|
||||
# Get workers configuration with auto-detection support
|
||||
workers = getattr(config, 'workers', 1)
|
||||
if workers == 0:
|
||||
import multiprocessing
|
||||
workers = multiprocessing.cpu_count()
|
||||
logger.info(f"Auto-detected {workers} CPU cores for worker processes")
|
||||
|
||||
await server.start_http(config.server_host, config.server_port, workers)
|
||||
else:
|
||||
logger.error(f"Unsupported transport protocol: {args.transport}")
|
||||
logger.error(f"Unsupported transport protocol: {config.transport}")
|
||||
await server.shutdown()
|
||||
return 1
|
||||
|
||||
@@ -517,6 +953,10 @@ async def main():
|
||||
await server.shutdown()
|
||||
except Exception as shutdown_error:
|
||||
logger.error(f"Error occurred while shutting down server: {shutdown_error}")
|
||||
|
||||
# Shutdown logging system
|
||||
from .utils.logger import shutdown_logging
|
||||
shutdown_logging()
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
627
doris_mcp_server/multiworker_app.py
Normal file
@@ -0,0 +1,627 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Multi-worker application module for doris-mcp-server
|
||||
|
||||
This module provides full MCP functionality with multi-worker support.
|
||||
Each worker process creates its own MCP server and session manager using the same
|
||||
robust architecture as the single-worker mode.
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
# Import MCP components with compatibility handling
|
||||
# Use the same import strategy as main.py for consistency
|
||||
MCP_VERSION = 'unknown'
|
||||
Server = None
|
||||
InitializationOptions = None
|
||||
Prompt = None
|
||||
Resource = None
|
||||
TextContent = None
|
||||
Tool = None
|
||||
|
||||
def _import_mcp_with_compatibility():
|
||||
"""Import MCP components with multi-version compatibility"""
|
||||
global MCP_VERSION, Server, InitializationOptions, Prompt, Resource, TextContent, Tool
|
||||
|
||||
try:
|
||||
# Strategy 1: Try direct server-only imports to avoid client-side issues
|
||||
from mcp.server import Server as _Server
|
||||
from mcp.server.models import InitializationOptions as _InitOptions
|
||||
from mcp.types import (
|
||||
Prompt as _Prompt,
|
||||
Resource as _Resource,
|
||||
TextContent as _TextContent,
|
||||
Tool as _Tool,
|
||||
)
|
||||
|
||||
# Assign to globals
|
||||
Server = _Server
|
||||
InitializationOptions = _InitOptions
|
||||
Prompt = _Prompt
|
||||
Resource = _Resource
|
||||
TextContent = _TextContent
|
||||
Tool = _Tool
|
||||
|
||||
# Try to get version safely
|
||||
try:
|
||||
import mcp
|
||||
MCP_VERSION = getattr(mcp, '__version__', None)
|
||||
if not MCP_VERSION:
|
||||
# Fallback: try to get version from package metadata
|
||||
try:
|
||||
import importlib.metadata
|
||||
MCP_VERSION = importlib.metadata.version('mcp')
|
||||
except Exception:
|
||||
# Second fallback: try pkg_resources
|
||||
try:
|
||||
import pkg_resources
|
||||
MCP_VERSION = pkg_resources.get_distribution('mcp').version
|
||||
except Exception:
|
||||
MCP_VERSION = 'detected-but-version-unknown'
|
||||
except Exception:
|
||||
# Version detection failed, but imports worked
|
||||
try:
|
||||
import importlib.metadata
|
||||
MCP_VERSION = importlib.metadata.version('mcp')
|
||||
except Exception:
|
||||
try:
|
||||
import pkg_resources
|
||||
MCP_VERSION = pkg_resources.get_distribution('mcp').version
|
||||
except Exception:
|
||||
MCP_VERSION = 'imported-successfully'
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"MCP components imported successfully in multiworker, version: {MCP_VERSION}")
|
||||
return True
|
||||
|
||||
except Exception as import_error:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Strategy 2: Handle RequestContext compatibility issues in 1.9.x versions
|
||||
error_str = str(import_error).lower()
|
||||
if 'requestcontext' in error_str and 'too few arguments' in error_str:
|
||||
logger.warning(f"Detected MCP RequestContext compatibility issue: {import_error}")
|
||||
logger.info("Attempting comprehensive workaround for MCP 1.9.x RequestContext issue...")
|
||||
|
||||
try:
|
||||
# Comprehensive monkey patch approach
|
||||
import sys
|
||||
import types
|
||||
|
||||
# Create and install mock modules before any MCP imports
|
||||
if 'mcp.shared.context' not in sys.modules:
|
||||
mock_context_module = types.ModuleType('mcp.shared.context')
|
||||
|
||||
class FlexibleRequestContext:
|
||||
"""Flexible RequestContext that accepts variable arguments"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __class_getitem__(cls, params):
|
||||
# Accept any number of parameters and return cls
|
||||
return cls
|
||||
|
||||
# Add other methods that might be called
|
||||
def __getattr__(self, name):
|
||||
return lambda *args, **kwargs: None
|
||||
|
||||
mock_context_module.RequestContext = FlexibleRequestContext
|
||||
sys.modules['mcp.shared.context'] = mock_context_module
|
||||
|
||||
# Also patch the typing system to be more permissive
|
||||
original_check_generic = None
|
||||
try:
|
||||
import typing
|
||||
if hasattr(typing, '_check_generic'):
|
||||
original_check_generic = typing._check_generic
|
||||
def permissive_check_generic(cls, params, elen):
|
||||
# Don't enforce strict parameter count checking
|
||||
return
|
||||
typing._check_generic = permissive_check_generic
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Clear any cached imports that might have failed
|
||||
modules_to_clear = [k for k in sys.modules.keys() if k.startswith('mcp.')]
|
||||
for module in modules_to_clear:
|
||||
if module in sys.modules:
|
||||
del sys.modules[module]
|
||||
|
||||
# Now try importing again with the patches in place
|
||||
from mcp.server import Server as _Server
|
||||
from mcp.server.models import InitializationOptions as _InitOptions
|
||||
from mcp.types import (
|
||||
Prompt as _Prompt,
|
||||
Resource as _Resource,
|
||||
TextContent as _TextContent,
|
||||
Tool as _Tool,
|
||||
)
|
||||
|
||||
# Assign to globals
|
||||
Server = _Server
|
||||
InitializationOptions = _InitOptions
|
||||
Prompt = _Prompt
|
||||
Resource = _Resource
|
||||
TextContent = _TextContent
|
||||
Tool = _Tool
|
||||
|
||||
# Try to detect actual version even in compatibility mode
|
||||
try:
|
||||
import importlib.metadata
|
||||
actual_version = importlib.metadata.version('mcp')
|
||||
MCP_VERSION = f'compatibility-mode-{actual_version}'
|
||||
except Exception:
|
||||
try:
|
||||
import pkg_resources
|
||||
actual_version = pkg_resources.get_distribution('mcp').version
|
||||
MCP_VERSION = f'compatibility-mode-{actual_version}'
|
||||
except Exception:
|
||||
MCP_VERSION = 'compatibility-mode-1.9.x'
|
||||
|
||||
logger.info("MCP 1.9.x compatibility workaround successful in multiworker!")
|
||||
|
||||
# Restore original typing function if we patched it
|
||||
if original_check_generic:
|
||||
typing._check_generic = original_check_generic
|
||||
|
||||
return True
|
||||
|
||||
except Exception as workaround_error:
|
||||
logger.error(f"MCP compatibility workaround failed in multiworker: {workaround_error}")
|
||||
|
||||
# Restore original typing function if we patched it
|
||||
if original_check_generic:
|
||||
try:
|
||||
import typing
|
||||
typing._check_generic = original_check_generic
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.error(f"Failed to import MCP components in multiworker: {import_error}")
|
||||
return False
|
||||
|
||||
# Perform MCP import with compatibility handling
|
||||
if not _import_mcp_with_compatibility():
|
||||
raise ImportError(
|
||||
"Failed to import MCP components in multiworker. Please ensure MCP is properly installed. "
|
||||
"Supported versions: 1.8.x, 1.9.x"
|
||||
)
|
||||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Route
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
# Import Doris MCP components
|
||||
from .tools.tools_manager import DorisToolsManager
|
||||
from .tools.prompts_manager import DorisPromptsManager
|
||||
from .tools.resources_manager import DorisResourcesManager
|
||||
from .utils.config import DorisConfig
|
||||
from .utils.db import DorisConnectionManager
|
||||
from .utils.security import DorisSecurityManager
|
||||
|
||||
# Global variables for worker-specific instances
|
||||
_worker_server = None
|
||||
_worker_session_manager = None
|
||||
_worker_connection_manager = None
|
||||
_worker_security_manager = None
|
||||
_worker_session_manager_context = None
|
||||
_worker_initialized = False
|
||||
|
||||
def get_mcp_capabilities():
|
||||
"""Get MCP capabilities for worker - use the same logic as main.py"""
|
||||
try:
|
||||
# For MCP 1.9.x and newer
|
||||
from mcp.server.lowlevel.server import NotificationOptions
|
||||
|
||||
capabilities = {
|
||||
"resources": {},
|
||||
"tools": {},
|
||||
"prompts": {},
|
||||
"notification_options": {
|
||||
"prompts_changed": True,
|
||||
"resources_changed": True,
|
||||
"tools_changed": True
|
||||
}
|
||||
}
|
||||
return capabilities
|
||||
except Exception as e:
|
||||
# Import logger properly
|
||||
from .utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
logger.warning(f"Failed to get full capabilities in multiworker: {e}")
|
||||
return {
|
||||
"resources": {},
|
||||
"tools": {},
|
||||
"prompts": {}
|
||||
}
|
||||
|
||||
async def initialize_worker():
|
||||
"""Initialize MCP server and managers for this worker process"""
|
||||
global _worker_server, _worker_session_manager, _worker_connection_manager, _worker_security_manager, _worker_session_manager_context, _worker_initialized, _oauth_handlers, _token_handlers
|
||||
|
||||
if _worker_initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
# Import logger properly
|
||||
from .utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.info(f"Initializing MCP worker process {os.getpid()}")
|
||||
|
||||
# Create configuration
|
||||
config = DorisConfig.from_env()
|
||||
|
||||
# Initialize enhanced logging system
|
||||
from .utils.config import ConfigManager
|
||||
config_manager = ConfigManager(config)
|
||||
config_manager.setup_logging()
|
||||
|
||||
# Create security manager
|
||||
_worker_security_manager = DorisSecurityManager(config)
|
||||
|
||||
# Initialize security manager first (includes JWT setup if enabled)
|
||||
await _worker_security_manager.initialize()
|
||||
logger.info(f"Worker {os.getpid()} security manager initialization completed")
|
||||
|
||||
# Create connection manager with token manager for token-bound DB config
|
||||
token_manager = _worker_security_manager.auth_provider.token_manager if hasattr(_worker_security_manager, 'auth_provider') and hasattr(_worker_security_manager.auth_provider, 'token_manager') else None
|
||||
_worker_connection_manager = DorisConnectionManager(config, _worker_security_manager, token_manager)
|
||||
|
||||
# Set connection manager reference in security manager for database validation
|
||||
_worker_security_manager.connection_manager = _worker_connection_manager
|
||||
|
||||
await _worker_connection_manager.initialize()
|
||||
|
||||
# Create MCP server
|
||||
_worker_server = Server("doris-mcp-server")
|
||||
|
||||
# Create managers
|
||||
resources_manager = DorisResourcesManager(_worker_connection_manager)
|
||||
tools_manager = DorisToolsManager(_worker_connection_manager)
|
||||
prompts_manager = DorisPromptsManager(_worker_connection_manager)
|
||||
|
||||
# Setup MCP handlers
|
||||
@_worker_server.list_resources()
|
||||
async def handle_list_resources() -> list[Resource]:
|
||||
"""Handle resource list request"""
|
||||
try:
|
||||
logger.info("Handling resource list request in worker")
|
||||
resources = await resources_manager.list_resources()
|
||||
logger.info(f"Returning {len(resources)} resources from worker")
|
||||
return resources
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle resource list request in worker: {e}")
|
||||
return []
|
||||
|
||||
@_worker_server.read_resource()
|
||||
async def handle_read_resource(uri: str) -> str:
|
||||
"""Handle resource read request"""
|
||||
try:
|
||||
logger.info(f"Handling resource read request in worker: {uri}")
|
||||
content = await resources_manager.read_resource(uri)
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle resource read request in worker: {e}")
|
||||
return json.dumps(
|
||||
{"error": f"Failed to read resource: {str(e)}", "uri": uri},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
@_worker_server.list_tools()
|
||||
async def handle_list_tools() -> list[Tool]:
|
||||
"""Handle tool list request"""
|
||||
try:
|
||||
logger.info("Handling tool list request in worker")
|
||||
tools = await tools_manager.list_tools()
|
||||
logger.info(f"Returning {len(tools)} tools from worker")
|
||||
return tools
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle tool list request in worker: {e}")
|
||||
return []
|
||||
|
||||
@_worker_server.call_tool()
|
||||
async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
|
||||
"""Handle tool call request"""
|
||||
try:
|
||||
logger.info(f"Handling tool call request in worker: {name}")
|
||||
result = await tools_manager.call_tool(name, arguments)
|
||||
return [TextContent(type="text", text=result)]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle tool call request in worker: {e}")
|
||||
error_result = json.dumps(
|
||||
{
|
||||
"error": f"Tool call failed: {str(e)}",
|
||||
"tool_name": name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
return [TextContent(type="text", text=error_result)]
|
||||
|
||||
@_worker_server.list_prompts()
|
||||
async def handle_list_prompts() -> list[Prompt]:
|
||||
"""Handle prompt list request"""
|
||||
try:
|
||||
logger.info("Handling prompt list request in worker")
|
||||
prompts = await prompts_manager.list_prompts()
|
||||
logger.info(f"Returning {len(prompts)} prompts from worker")
|
||||
return prompts
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle prompt list request in worker: {e}")
|
||||
return []
|
||||
|
||||
@_worker_server.get_prompt()
|
||||
async def handle_get_prompt(name: str, arguments: dict[str, Any]) -> str:
|
||||
"""Handle prompt get request"""
|
||||
try:
|
||||
logger.info(f"Handling prompt get request in worker: {name}")
|
||||
result = await prompts_manager.get_prompt(name, arguments)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle prompt get request in worker: {e}")
|
||||
error_result = json.dumps(
|
||||
{
|
||||
"error": f"Failed to get prompt: {str(e)}",
|
||||
"prompt_name": name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
return error_result
|
||||
|
||||
# Create session manager for this worker
|
||||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||
|
||||
_worker_session_manager = StreamableHTTPSessionManager(
|
||||
app=_worker_server,
|
||||
json_response=True,
|
||||
stateless=True # Use stateless mode for multi-worker compatibility
|
||||
)
|
||||
|
||||
# Start the session manager context
|
||||
_worker_session_manager_context = _worker_session_manager.run()
|
||||
await _worker_session_manager_context.__aenter__()
|
||||
|
||||
# Initialize OAuth and Token handlers
|
||||
from .auth.oauth_handlers import OAuthHandlers
|
||||
from .auth.token_handlers import TokenHandlers
|
||||
_oauth_handlers = OAuthHandlers(_worker_security_manager)
|
||||
_token_handlers = TokenHandlers(_worker_security_manager, config)
|
||||
|
||||
_worker_initialized = True
|
||||
logger.info(f"Worker {os.getpid()} MCP initialization completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
from .utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
logger.error(f"Failed to initialize worker {os.getpid()}: {e}")
|
||||
import traceback
|
||||
logger.error("Complete error stack:")
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
async def health_check(request):
|
||||
"""Health check endpoint that shows worker PID"""
|
||||
return JSONResponse({
|
||||
"status": "healthy",
|
||||
"service": "doris-mcp-server",
|
||||
"worker_pid": os.getpid(),
|
||||
"worker_mode": "multi-process-full-mcp",
|
||||
"mcp_initialized": _worker_initialized,
|
||||
"mcp_version": MCP_VERSION
|
||||
})
|
||||
|
||||
# OAuth and Token handlers (initialize after worker setup)
|
||||
_oauth_handlers = None
|
||||
_token_handlers = None
|
||||
|
||||
async def oauth_login(request):
|
||||
"""OAuth login endpoint"""
|
||||
if not _oauth_handlers:
|
||||
return JSONResponse({"error": "OAuth not initialized"}, status_code=503)
|
||||
return await _oauth_handlers.handle_login(request)
|
||||
|
||||
async def oauth_callback(request):
|
||||
"""OAuth callback endpoint"""
|
||||
if not _oauth_handlers:
|
||||
return JSONResponse({"error": "OAuth not initialized"}, status_code=503)
|
||||
return await _oauth_handlers.handle_callback(request)
|
||||
|
||||
async def oauth_provider_info(request):
|
||||
"""OAuth provider info endpoint"""
|
||||
if not _oauth_handlers:
|
||||
return JSONResponse({"error": "OAuth not initialized"}, status_code=503)
|
||||
return await _oauth_handlers.handle_provider_info(request)
|
||||
|
||||
async def oauth_demo(request):
|
||||
"""OAuth demo page endpoint"""
|
||||
if not _oauth_handlers:
|
||||
from starlette.responses import HTMLResponse
|
||||
return HTMLResponse("<h1>OAuth not initialized</h1>")
|
||||
return await _oauth_handlers.handle_demo_page(request)
|
||||
|
||||
# Token management endpoints
|
||||
async def token_create(request):
|
||||
"""Token creation endpoint"""
|
||||
if not _token_handlers:
|
||||
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
|
||||
return await _token_handlers.handle_create_token(request)
|
||||
|
||||
async def token_revoke(request):
|
||||
"""Token revocation endpoint"""
|
||||
if not _token_handlers:
|
||||
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
|
||||
return await _token_handlers.handle_revoke_token(request)
|
||||
|
||||
async def token_list(request):
|
||||
"""Token listing endpoint"""
|
||||
if not _token_handlers:
|
||||
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
|
||||
return await _token_handlers.handle_list_tokens(request)
|
||||
|
||||
async def token_stats(request):
|
||||
"""Token statistics endpoint"""
|
||||
if not _token_handlers:
|
||||
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
|
||||
return await _token_handlers.handle_token_stats(request)
|
||||
|
||||
async def token_cleanup(request):
|
||||
"""Token cleanup endpoint"""
|
||||
if not _token_handlers:
|
||||
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
|
||||
return await _token_handlers.handle_cleanup_tokens(request)
|
||||
|
||||
async def token_management(request):
|
||||
"""Token management page endpoint"""
|
||||
if not _token_handlers:
|
||||
from starlette.responses import HTMLResponse
|
||||
return HTMLResponse("<h1>Token handlers not initialized</h1>")
|
||||
return await _token_handlers.handle_management_page(request)
|
||||
|
||||
async def root_info(request):
|
||||
"""Root endpoint"""
|
||||
return JSONResponse({
|
||||
"service": "doris-mcp-server",
|
||||
"mode": "multi-worker-full-mcp",
|
||||
"worker_pid": os.getpid(),
|
||||
"mcp_initialized": _worker_initialized,
|
||||
"mcp_version": MCP_VERSION,
|
||||
"endpoints": {
|
||||
"health": "/health",
|
||||
"mcp": "/mcp"
|
||||
}
|
||||
})
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app):
|
||||
"""Application lifespan manager"""
|
||||
# Startup
|
||||
try:
|
||||
await initialize_worker()
|
||||
# Import logger properly
|
||||
from .utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
logger.info(f"Worker {os.getpid()} startup completed")
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
# Shutdown
|
||||
from .utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Close session manager context
|
||||
if _worker_session_manager_context:
|
||||
try:
|
||||
await _worker_session_manager_context.__aexit__(None, None, None)
|
||||
logger.info(f"Worker {os.getpid()} session manager context closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing worker session manager context: {e}")
|
||||
|
||||
if _worker_connection_manager:
|
||||
try:
|
||||
await _worker_connection_manager.close()
|
||||
logger.info(f"Worker {os.getpid()} connection manager closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing worker connection manager: {e}")
|
||||
|
||||
if _worker_security_manager:
|
||||
try:
|
||||
await _worker_security_manager.shutdown()
|
||||
logger.info(f"Worker {os.getpid()} security manager shutdown completed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error shutting down worker security manager: {e}")
|
||||
|
||||
# Shutdown logging system
|
||||
try:
|
||||
from .utils.logger import shutdown_logging
|
||||
shutdown_logging()
|
||||
except Exception as e:
|
||||
logger.error(f"Error shutting down logging system: {e}")
|
||||
|
||||
async def mcp_asgi_app(scope, receive, send):
|
||||
"""ASGI app that handles MCP requests"""
|
||||
if not _worker_initialized:
|
||||
# Send error response if worker not initialized
|
||||
await send({
|
||||
'type': 'http.response.start',
|
||||
'status': 503,
|
||||
'headers': [(b'content-type', b'application/json')]
|
||||
})
|
||||
await send({
|
||||
'type': 'http.response.body',
|
||||
'body': b'{"error": "Worker not initialized"}'
|
||||
})
|
||||
return
|
||||
|
||||
# Import logger properly
|
||||
from .utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Get request path for logging
|
||||
path = scope.get('path', '')
|
||||
method = scope.get('method', 'UNKNOWN')
|
||||
logger.debug(f"Worker {os.getpid()} handling MCP request: {method} {path}")
|
||||
|
||||
# Handle the request directly without nested run context
|
||||
await _worker_session_manager.handle_request(scope, receive, send)
|
||||
|
||||
# Create Starlette app with basic routes
|
||||
basic_app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route("/", root_info, methods=["GET"]),
|
||||
Route("/health", health_check, methods=["GET"]),
|
||||
# OAuth endpoints
|
||||
Route("/auth/login", oauth_login, methods=["GET"]),
|
||||
Route("/auth/callback", oauth_callback, methods=["GET"]),
|
||||
Route("/auth/provider", oauth_provider_info, methods=["GET"]),
|
||||
Route("/auth/demo", oauth_demo, methods=["GET"]),
|
||||
# Token management endpoints
|
||||
Route("/token/create", token_create, methods=["GET", "POST"]),
|
||||
Route("/token/revoke", token_revoke, methods=["GET", "DELETE"]),
|
||||
Route("/token/list", token_list, methods=["GET"]),
|
||||
Route("/token/stats", token_stats, methods=["GET"]),
|
||||
Route("/token/cleanup", token_cleanup, methods=["GET", "POST"]),
|
||||
Route("/token/management", token_management, methods=["GET"]),
|
||||
],
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Create main ASGI app that routes between basic app and MCP
|
||||
async def app(scope, receive, send):
|
||||
"""Main ASGI app that routes requests"""
|
||||
path = scope.get('path', '/')
|
||||
|
||||
if path == "/mcp" or path.startswith('/mcp/'):
|
||||
# Handle MCP requests with session manager
|
||||
await mcp_asgi_app(scope, receive, send)
|
||||
else:
|
||||
# Handle other requests with basic Starlette app (includes auth endpoints)
|
||||
await basic_app(scope, receive, send)
|
||||
@@ -31,6 +31,7 @@ from mcp.types import (
|
||||
)
|
||||
|
||||
from ..utils.db import DorisConnectionManager
|
||||
from ..utils.sql_security_utils import get_auth_context
|
||||
|
||||
|
||||
class PromptTemplate:
|
||||
@@ -422,7 +423,8 @@ Please generate accurate and efficient SQL queries based on the above requiremen
|
||||
AND table_type = 'BASE TABLE'
|
||||
"""
|
||||
|
||||
db_result = await connection.execute(db_info_sql)
|
||||
auth_context = get_auth_context()
|
||||
db_result = await connection.execute(db_info_sql, auth_context=auth_context)
|
||||
db_info = db_result.data[0] if db_result.data else {}
|
||||
|
||||
# Get main table list
|
||||
@@ -438,7 +440,7 @@ Please generate accurate and efficient SQL queries based on the above requiremen
|
||||
LIMIT 10
|
||||
"""
|
||||
|
||||
tables_result = await connection.execute(tables_sql)
|
||||
tables_result = await connection.execute(tables_sql, auth_context=auth_context)
|
||||
|
||||
context = f"""Current database statistics:
|
||||
- Total number of tables: {db_info.get("table_count", 0)}
|
||||
|
||||
@@ -26,6 +26,7 @@ from typing import Any
|
||||
from mcp.types import Resource
|
||||
|
||||
from ..utils.db import DorisConnectionManager
|
||||
from ..utils.sql_security_utils import get_auth_context
|
||||
|
||||
|
||||
class TableMetadata:
|
||||
@@ -169,7 +170,8 @@ class DorisResourcesManager:
|
||||
ORDER BY table_name
|
||||
"""
|
||||
|
||||
result = await connection.execute(tables_query)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(tables_query, auth_context=auth_context)
|
||||
tables = []
|
||||
|
||||
for row in result.data:
|
||||
@@ -204,7 +206,8 @@ class DorisResourcesManager:
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
result = await connection.execute(columns_query, (table_name,))
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(columns_query, params=(table_name,), auth_context=auth_context)
|
||||
return [dict(row) for row in result.data]
|
||||
|
||||
async def _get_view_metadata(self) -> list[ViewMetadata]:
|
||||
@@ -226,7 +229,8 @@ class DorisResourcesManager:
|
||||
ORDER BY table_name
|
||||
"""
|
||||
|
||||
result = await connection.execute(views_query)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(views_query, auth_context=auth_context)
|
||||
views = []
|
||||
|
||||
for row in result.data:
|
||||
@@ -257,7 +261,8 @@ class DorisResourcesManager:
|
||||
AND table_name = %s
|
||||
"""
|
||||
|
||||
table_result = await connection.execute(table_info_query, (table_name,))
|
||||
auth_context = get_auth_context()
|
||||
table_result = await connection.execute(table_info_query, params=(table_name,), auth_context=auth_context)
|
||||
if not table_result.data:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
|
||||
@@ -295,7 +300,8 @@ class DorisResourcesManager:
|
||||
ORDER BY index_name, seq_in_index
|
||||
"""
|
||||
|
||||
result = await connection.execute(indexes_query, (table_name,))
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(indexes_query, params=(table_name,), auth_context=auth_context)
|
||||
return [dict(row) for row in result.data]
|
||||
|
||||
async def _get_view_definition(self, view_name: str) -> str:
|
||||
@@ -312,7 +318,8 @@ class DorisResourcesManager:
|
||||
AND table_name = %s
|
||||
"""
|
||||
|
||||
result = await connection.execute(view_query, (view_name,))
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(view_query, params=(view_name,), auth_context=auth_context)
|
||||
if not result.data:
|
||||
raise ValueError(f"View {view_name} does not exist")
|
||||
|
||||
@@ -340,7 +347,8 @@ class DorisResourcesManager:
|
||||
AND table_type = 'BASE TABLE'
|
||||
"""
|
||||
|
||||
table_result = await connection.execute(table_stats_query)
|
||||
auth_context = get_auth_context()
|
||||
table_result = await connection.execute(table_stats_query, auth_context=auth_context)
|
||||
table_stats = table_result.data[0] if table_result.data else {}
|
||||
|
||||
# Get view statistics
|
||||
@@ -350,7 +358,7 @@ class DorisResourcesManager:
|
||||
WHERE table_schema = DATABASE()
|
||||
"""
|
||||
|
||||
view_result = await connection.execute(view_stats_query)
|
||||
view_result = await connection.execute(view_stats_query, auth_context=auth_context)
|
||||
view_stats = view_result.data[0] if view_result.data else {}
|
||||
|
||||
stats_info = {
|
||||
|
||||
542
doris_mcp_server/utils/adbc_query_tools.py
Normal file
@@ -0,0 +1,542 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Apache Doris ADBC Query Tools
|
||||
High-performance data querying using Apache Arrow Flight SQL protocol
|
||||
"""
|
||||
|
||||
import os
|
||||
import socket
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.db import DorisConnectionManager
|
||||
from ..utils.sql_security_utils import get_auth_context
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _convert_numpy_types(obj):
|
||||
"""Convert numpy types to native Python types for JSON serialization"""
|
||||
try:
|
||||
# Import numpy only when needed
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
if isinstance(obj, np.integer):
|
||||
return int(obj)
|
||||
elif isinstance(obj, np.floating):
|
||||
return float(obj)
|
||||
elif isinstance(obj, np.bool_):
|
||||
return bool(obj)
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
elif isinstance(obj, (pd.Timestamp, pd.NaT.__class__)):
|
||||
return str(obj)
|
||||
elif pd.isna(obj):
|
||||
return None
|
||||
else:
|
||||
return obj
|
||||
except ImportError:
|
||||
# If numpy/pandas not available, return as-is
|
||||
return obj
|
||||
|
||||
|
||||
def _convert_dataframe_to_json_serializable(df):
|
||||
"""Convert DataFrame to JSON serializable format"""
|
||||
try:
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
# Convert DataFrame to records
|
||||
records = df.to_dict('records')
|
||||
|
||||
# Convert each record's values
|
||||
converted_records = []
|
||||
for record in records:
|
||||
converted_record = {}
|
||||
for key, value in record.items():
|
||||
converted_record[key] = _convert_numpy_types(value)
|
||||
converted_records.append(converted_record)
|
||||
|
||||
return converted_records
|
||||
except ImportError:
|
||||
# Fallback to basic dict conversion
|
||||
return df.to_dict('records')
|
||||
|
||||
|
||||
class DorisADBCQueryTools:
|
||||
"""ADBC Query Tools for high-performance data transfer using Arrow Flight SQL"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
self.adbc_client = None
|
||||
self.flight_sql_module = None
|
||||
self.adbc_manager_module = None
|
||||
|
||||
async def exec_adbc_query(
|
||||
self,
|
||||
sql: str,
|
||||
max_rows: int | None = None,
|
||||
timeout: int | None = None,
|
||||
return_format: str | None = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute SQL query using ADBC (Arrow Flight SQL) protocol
|
||||
|
||||
Args:
|
||||
sql: SQL statement to execute
|
||||
max_rows: Maximum number of rows to return (uses config default if None)
|
||||
timeout: Query timeout in seconds (uses config default if None)
|
||||
return_format: Format for returned data ("arrow", "pandas", "dict", uses config default if None)
|
||||
|
||||
Returns:
|
||||
Query results in specified format with metadata
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Use configuration defaults if parameters not specified
|
||||
adbc_config = self.connection_manager.config.adbc
|
||||
max_rows = max_rows if max_rows is not None else adbc_config.default_max_rows
|
||||
timeout = timeout if timeout is not None else adbc_config.default_timeout
|
||||
return_format = return_format if return_format is not None else adbc_config.default_return_format
|
||||
|
||||
# Step 1: Check environment variables and port availability
|
||||
port_check_result = await self._check_arrow_flight_ports()
|
||||
if not port_check_result["success"]:
|
||||
return port_check_result
|
||||
|
||||
# Step 2: Import required ADBC modules
|
||||
import_result = await self._import_adbc_modules()
|
||||
if not import_result["success"]:
|
||||
return import_result
|
||||
|
||||
# Step 3: Create ADBC connection
|
||||
connection_result = await self._create_adbc_connection()
|
||||
if not connection_result["success"]:
|
||||
return connection_result
|
||||
|
||||
# Step 4: Execute query using ADBC
|
||||
query_result = await self._execute_query_with_adbc(
|
||||
sql, max_rows, timeout, return_format
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
if query_result["success"]:
|
||||
query_result["execution_time"] = round(execution_time, 3)
|
||||
query_result["protocol"] = "ADBC_Arrow_Flight_SQL"
|
||||
query_result["timestamp"] = datetime.now().isoformat()
|
||||
|
||||
return query_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ADBC query execution failed: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"ADBC query execution failed: {str(e)}",
|
||||
"error_type": "execution_error",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
async def _check_arrow_flight_ports(self) -> Dict[str, Any]:
|
||||
"""Check Arrow Flight SQL port configuration and availability"""
|
||||
try:
|
||||
# Check environment variables
|
||||
fe_port = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
|
||||
be_port = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
|
||||
|
||||
if not fe_port:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Missing environment variable FE_ARROW_FLIGHT_SQL_PORT, please configure Arrow Flight SQL FE port in .env file",
|
||||
"error_type": "missing_fe_port_config"
|
||||
}
|
||||
|
||||
if not be_port:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Missing environment variable BE_ARROW_FLIGHT_SQL_PORT, please configure Arrow Flight SQL BE port in .env file",
|
||||
"error_type": "missing_be_port_config"
|
||||
}
|
||||
|
||||
# Convert to integer and validate
|
||||
try:
|
||||
fe_port = int(fe_port)
|
||||
be_port = int(be_port)
|
||||
except ValueError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Invalid Arrow Flight SQL port configuration, please ensure FE_ARROW_FLIGHT_SQL_PORT and BE_ARROW_FLIGHT_SQL_PORT are valid numbers",
|
||||
"error_type": "invalid_port_format"
|
||||
}
|
||||
|
||||
# Get host address
|
||||
db_config = self.connection_manager.config.database
|
||||
fe_host = db_config.host
|
||||
|
||||
# Check FE Arrow Flight SQL port availability
|
||||
fe_available = self._check_port_connectivity(fe_host, fe_port)
|
||||
if not fe_available:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Cannot connect to FE Arrow Flight SQL port {fe_host}:{fe_port}, please check if service is running",
|
||||
"error_type": "fe_port_unavailable",
|
||||
"fe_host": fe_host,
|
||||
"fe_port": fe_port
|
||||
}
|
||||
|
||||
# Get BE host list
|
||||
be_hosts = await self._get_be_hosts()
|
||||
if not be_hosts:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Cannot get BE node information, please check cluster status",
|
||||
"error_type": "no_be_hosts"
|
||||
}
|
||||
|
||||
# Check at least one BE Arrow Flight SQL port availability
|
||||
be_available_count = 0
|
||||
be_check_results = []
|
||||
|
||||
for be_host in be_hosts[:3]: # Check first 3 BE nodes
|
||||
be_available = self._check_port_connectivity(be_host, be_port)
|
||||
be_check_results.append({
|
||||
"host": be_host,
|
||||
"port": be_port,
|
||||
"available": be_available
|
||||
})
|
||||
if be_available:
|
||||
be_available_count += 1
|
||||
|
||||
if be_available_count == 0:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Cannot connect to any BE Arrow Flight SQL port (port: {be_port}), please check if BE services are running",
|
||||
"error_type": "no_be_ports_available",
|
||||
"be_check_results": be_check_results
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"fe_host": fe_host,
|
||||
"fe_port": fe_port,
|
||||
"be_port": be_port,
|
||||
"be_hosts": be_hosts,
|
||||
"be_available_count": be_available_count,
|
||||
"be_check_results": be_check_results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Arrow Flight port check failed: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Arrow Flight port check failed: {str(e)}",
|
||||
"error_type": "port_check_error"
|
||||
}
|
||||
|
||||
def _check_port_connectivity(self, host: str, port: int, timeout: int | None = None) -> bool:
|
||||
"""Check port connectivity"""
|
||||
try:
|
||||
# Use config timeout if not specified
|
||||
if timeout is None:
|
||||
timeout = self.connection_manager.config.adbc.connection_timeout
|
||||
|
||||
with socket.create_connection((host, port), timeout=timeout):
|
||||
return True
|
||||
except (socket.timeout, socket.error, OSError):
|
||||
return False
|
||||
|
||||
async def _get_be_hosts(self) -> List[str]:
|
||||
"""Get BE host list"""
|
||||
try:
|
||||
db_config = self.connection_manager.config.database
|
||||
|
||||
# Use configured BE hosts first
|
||||
if db_config.be_hosts:
|
||||
logger.info(f"Using configured BE hosts: {db_config.be_hosts}")
|
||||
return db_config.be_hosts
|
||||
|
||||
# Get BE nodes via SHOW BACKENDS
|
||||
logger.info("No BE hosts configured, getting BE node information via SHOW BACKENDS")
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute("SHOW BACKENDS", auth_context=auth_context)
|
||||
|
||||
be_hosts = []
|
||||
for row in result.data:
|
||||
host = row.get("Host")
|
||||
alive = row.get("Alive", "").lower()
|
||||
if host and alive == "true":
|
||||
be_hosts.append(host)
|
||||
|
||||
logger.info(f"Got {len(be_hosts)} active BE nodes from SHOW BACKENDS")
|
||||
return be_hosts
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get BE hosts: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _import_adbc_modules(self) -> Dict[str, Any]:
|
||||
"""Import ADBC related modules"""
|
||||
try:
|
||||
# Import ADBC Driver Manager
|
||||
try:
|
||||
import adbc_driver_manager
|
||||
self.adbc_manager_module = adbc_driver_manager
|
||||
except ImportError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Missing adbc_driver_manager module, please install: pip install adbc_driver_manager",
|
||||
"error_type": "missing_adbc_manager"
|
||||
}
|
||||
|
||||
# Import ADBC Flight SQL Driver
|
||||
try:
|
||||
import adbc_driver_flightsql.dbapi as flight_sql
|
||||
self.flight_sql_module = flight_sql
|
||||
except ImportError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Missing adbc_driver_flightsql module, please install: pip install adbc_driver_flightsql",
|
||||
"error_type": "missing_flight_sql_driver"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"adbc_manager_version": getattr(adbc_driver_manager, '__version__', 'unknown'),
|
||||
"flight_sql_version": getattr(flight_sql, '__version__', 'unknown')
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ADBC module import failed: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"ADBC module import failed: {str(e)}",
|
||||
"error_type": "import_error"
|
||||
}
|
||||
|
||||
async def _create_adbc_connection(self) -> Dict[str, Any]:
|
||||
"""Create ADBC connection"""
|
||||
try:
|
||||
db_config = self.connection_manager.config.database
|
||||
fe_port = int(os.getenv("FE_ARROW_FLIGHT_SQL_PORT"))
|
||||
|
||||
# Build connection URI
|
||||
uri = f"grpc://{db_config.host}:{fe_port}"
|
||||
|
||||
# Create database connection parameters
|
||||
db_kwargs = {
|
||||
self.adbc_manager_module.DatabaseOptions.USERNAME.value: db_config.user,
|
||||
self.adbc_manager_module.DatabaseOptions.PASSWORD.value: db_config.password,
|
||||
}
|
||||
|
||||
# Create connection
|
||||
self.adbc_client = self.flight_sql_module.connect(
|
||||
uri=uri,
|
||||
db_kwargs=db_kwargs
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"uri": uri,
|
||||
"connection_established": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create ADBC connection: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed to create ADBC connection: {str(e)}",
|
||||
"error_type": "connection_error"
|
||||
}
|
||||
|
||||
async def _execute_query_with_adbc(
|
||||
self,
|
||||
sql: str,
|
||||
max_rows: int,
|
||||
timeout: int,
|
||||
return_format: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute query using ADBC"""
|
||||
try:
|
||||
if not self.adbc_client:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "ADBC connection not established",
|
||||
"error_type": "no_connection"
|
||||
}
|
||||
|
||||
# SECURITY FIX: Perform SQL security validation before executing
|
||||
auth_context = get_auth_context()
|
||||
if self.connection_manager.security_manager:
|
||||
# Always perform security validation, even without auth_context
|
||||
# Use a default context for basic SQL security checks
|
||||
validation_result = await self.connection_manager.security_manager.validate_sql_security(sql, auth_context)
|
||||
if not validation_result.is_valid:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"SQL security validation failed: {validation_result.error_message}",
|
||||
"error_type": "security_violation",
|
||||
"risk_level": validation_result.risk_level
|
||||
}
|
||||
|
||||
cursor = self.adbc_client.cursor()
|
||||
start_time = time.time()
|
||||
|
||||
# Execute query
|
||||
cursor.execute(sql)
|
||||
|
||||
# Get results based on return format
|
||||
if return_format == "arrow":
|
||||
# Return Arrow format
|
||||
arrow_data = cursor.fetchallarrow()
|
||||
|
||||
# Limit rows
|
||||
if len(arrow_data) > max_rows:
|
||||
arrow_data = arrow_data.slice(0, max_rows)
|
||||
|
||||
# Convert Arrow data to serializable format
|
||||
preview_df = arrow_data.to_pandas().head(10) if len(arrow_data) > 0 else None
|
||||
result_data = {
|
||||
"format": "arrow",
|
||||
"num_rows": len(arrow_data),
|
||||
"num_columns": len(arrow_data.schema),
|
||||
"column_names": arrow_data.schema.names,
|
||||
"column_types": [str(field.type) for field in arrow_data.schema],
|
||||
"data_preview": _convert_dataframe_to_json_serializable(preview_df) if preview_df is not None else [],
|
||||
"total_bytes": arrow_data.nbytes if hasattr(arrow_data, 'nbytes') else 0
|
||||
}
|
||||
|
||||
elif return_format == "pandas":
|
||||
# Return Pandas DataFrame
|
||||
df = cursor.fetch_df()
|
||||
|
||||
# Limit rows
|
||||
if len(df) > max_rows:
|
||||
df = df.head(max_rows)
|
||||
|
||||
result_data = {
|
||||
"format": "pandas",
|
||||
"num_rows": len(df),
|
||||
"num_columns": len(df.columns),
|
||||
"column_names": df.columns.tolist(),
|
||||
"column_types": df.dtypes.astype(str).tolist(),
|
||||
"data": _convert_dataframe_to_json_serializable(df),
|
||||
"memory_usage": int(df.memory_usage(deep=True).sum())
|
||||
}
|
||||
|
||||
else: # return_format == "dict"
|
||||
# Return dictionary format
|
||||
arrow_data = cursor.fetchallarrow()
|
||||
df = arrow_data.to_pandas()
|
||||
|
||||
# Limit rows
|
||||
if len(df) > max_rows:
|
||||
df = df.head(max_rows)
|
||||
|
||||
result_data = {
|
||||
"format": "dict",
|
||||
"num_rows": len(df),
|
||||
"num_columns": len(df.columns),
|
||||
"column_names": df.columns.tolist(),
|
||||
"column_types": df.dtypes.astype(str).tolist(),
|
||||
"data": _convert_dataframe_to_json_serializable(df)
|
||||
}
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
cursor.close()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"result": result_data,
|
||||
"execution_time": round(execution_time, 3),
|
||||
"sql": sql,
|
||||
"max_rows_applied": len(result_data.get("data", [])) >= max_rows
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ADBC query execution failed: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"ADBC query execution failed: {str(e)}",
|
||||
"error_type": "query_execution_error",
|
||||
"sql": sql
|
||||
}
|
||||
|
||||
async def get_adbc_connection_info(self) -> Dict[str, Any]:
|
||||
"""Get ADBC connection information and status"""
|
||||
try:
|
||||
# Check port status
|
||||
port_status = await self._check_arrow_flight_ports()
|
||||
|
||||
# Check module status
|
||||
module_status = await self._import_adbc_modules()
|
||||
|
||||
# Get configuration information
|
||||
db_config = self.connection_manager.config.database
|
||||
fe_port = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
|
||||
be_port = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
|
||||
|
||||
connection_info = {
|
||||
"adbc_available": module_status["success"],
|
||||
"ports_available": port_status["success"],
|
||||
"configuration": {
|
||||
"fe_host": db_config.host,
|
||||
"fe_arrow_flight_port": fe_port,
|
||||
"be_arrow_flight_port": be_port,
|
||||
"user": db_config.user
|
||||
},
|
||||
"port_status": port_status,
|
||||
"module_status": module_status,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
if port_status["success"] and module_status["success"]:
|
||||
connection_info["status"] = "ready"
|
||||
connection_info["message"] = "ADBC Arrow Flight SQL connection ready"
|
||||
else:
|
||||
connection_info["status"] = "not_ready"
|
||||
errors = []
|
||||
if not port_status["success"]:
|
||||
errors.append(port_status["error"])
|
||||
if not module_status["success"]:
|
||||
errors.append(module_status["error"])
|
||||
connection_info["message"] = "; ".join(errors)
|
||||
|
||||
return connection_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get ADBC connection information: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": f"Failed to get ADBC connection information: {str(e)}",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup resources"""
|
||||
try:
|
||||
if self.adbc_client:
|
||||
self.adbc_client.close()
|
||||
except:
|
||||
pass
|
||||
@@ -32,6 +32,8 @@ try:
|
||||
except ImportError:
|
||||
load_dotenv = None
|
||||
|
||||
from .logger import get_logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
@@ -41,38 +43,105 @@ class DatabaseConfig:
|
||||
port: int = 9030
|
||||
user: str = "root"
|
||||
password: str = ""
|
||||
database: str = "test"
|
||||
charset: str = "utf8mb4"
|
||||
database: str = "information_schema"
|
||||
charset: str = "UTF8"
|
||||
|
||||
# FE HTTP API port for profile and other HTTP APIs
|
||||
fe_http_port: int = 8030
|
||||
|
||||
# BE nodes configuration for external access
|
||||
# If be_hosts is empty, will use "show backends" to get BE nodes
|
||||
be_hosts: list[str] = field(default_factory=list)
|
||||
be_webserver_port: int = 8040
|
||||
|
||||
# Arrow Flight SQL Configuration (Required for ADBC tools)
|
||||
fe_arrow_flight_sql_port: int | None = None
|
||||
be_arrow_flight_sql_port: int | None = None
|
||||
|
||||
# Connection pool configuration
|
||||
min_connections: int = 5
|
||||
# Note: min_connections is fixed at 0 to avoid at_eof connection issues
|
||||
# This prevents pre-creation of connections which can cause state problems
|
||||
_min_connections: int = field(default=0, init=False) # Internal use only, always 0
|
||||
max_connections: int = 20
|
||||
connection_timeout: int = 30
|
||||
health_check_interval: int = 60
|
||||
max_connection_age: int = 3600
|
||||
|
||||
@property
|
||||
def min_connections(self) -> int:
|
||||
"""Minimum connections is always 0 to prevent at_eof issues"""
|
||||
return self._min_connections
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityConfig:
|
||||
"""Security configuration"""
|
||||
|
||||
# Authentication configuration
|
||||
auth_type: str = "token" # token, basic, oauth
|
||||
token_secret: str = "default_secret"
|
||||
# Independent authentication switches - any one enabled allows that method
|
||||
enable_token_auth: bool = False # Enable token-based authentication (default: disabled)
|
||||
enable_jwt_auth: bool = False # Enable JWT authentication (default: disabled)
|
||||
enable_oauth_auth: bool = False # Enable OAuth 2.0/OIDC authentication (default: disabled)
|
||||
|
||||
# Legacy configuration (kept for backward compatibility)
|
||||
auth_type: str = "token" # jwt, token, basic, oauth (deprecated: use individual switches)
|
||||
token_secret: str = "default_secret" # Legacy token secret for backward compatibility
|
||||
token_expiry: int = 3600
|
||||
|
||||
# Enhanced Token Authentication Configuration
|
||||
token_file_path: str = "tokens.json" # Path to token configuration file
|
||||
enable_token_expiry: bool = True # Enable token expiration
|
||||
default_token_expiry_hours: int = 24 * 30 # Default expiry: 30 days
|
||||
token_hash_algorithm: str = "sha256" # Token hashing algorithm: sha256, sha512
|
||||
|
||||
# Token Management Security (New in v0.6.0)
|
||||
enable_http_token_management: bool = False # Enable HTTP token management endpoints (default: disabled for security)
|
||||
token_management_admin_token: str = "" # Admin token for token management endpoints (required if HTTP management enabled)
|
||||
token_management_allowed_ips: list[str] = field(default_factory=lambda: ["127.0.0.1", "::1", "localhost"]) # Allowed IPs for token management
|
||||
require_admin_auth: bool = True # Require admin authentication for token management (default: true)
|
||||
|
||||
# JWT Configuration
|
||||
jwt_algorithm: str = "RS256" # RS256, ES256, HS256
|
||||
jwt_issuer: str = "doris-mcp-server"
|
||||
jwt_audience: str = "doris-mcp-client"
|
||||
jwt_private_key_path: str = ""
|
||||
jwt_public_key_path: str = ""
|
||||
jwt_secret_key: str = "" # Only used for HS256 algorithm
|
||||
jwt_access_token_expiry: int = 3600 # 1 hour
|
||||
jwt_refresh_token_expiry: int = 86400 # 24 hours
|
||||
enable_token_refresh: bool = True
|
||||
enable_token_revocation: bool = True
|
||||
key_rotation_interval: int = 30 * 24 * 3600 # 30 days in seconds
|
||||
|
||||
# JWT Security Features
|
||||
jwt_require_iat: bool = True # Require "issued at" claim
|
||||
jwt_require_exp: bool = True # Require "expires at" claim
|
||||
jwt_require_nbf: bool = False # Require "not before" claim
|
||||
jwt_leeway: int = 10 # Clock skew tolerance in seconds
|
||||
jwt_verify_signature: bool = True # Verify JWT signature
|
||||
jwt_verify_audience: bool = True # Verify audience claim
|
||||
jwt_verify_issuer: bool = True # Verify issuer claim
|
||||
|
||||
# SQL security configuration
|
||||
enable_security_check: bool = True # Main switch: whether to enable SQL security check
|
||||
blocked_keywords: list[str] = field(
|
||||
default_factory=lambda: [
|
||||
# DDL Operations (Data Definition Language)
|
||||
"DROP",
|
||||
"DELETE",
|
||||
"TRUNCATE",
|
||||
"CREATE",
|
||||
"ALTER",
|
||||
"CREATE",
|
||||
"TRUNCATE",
|
||||
# DML Operations (Data Manipulation Language)
|
||||
"DELETE",
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
# DCL Operations (Data Control Language)
|
||||
"GRANT",
|
||||
"REVOKE",
|
||||
# System Operations
|
||||
"EXEC",
|
||||
"EXECUTE",
|
||||
"SHUTDOWN",
|
||||
"KILL",
|
||||
]
|
||||
)
|
||||
max_query_complexity: int = 100
|
||||
@@ -85,6 +154,45 @@ class SecurityConfig:
|
||||
enable_masking: bool = True
|
||||
masking_rules: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
# OAuth 2.0/OIDC Configuration
|
||||
oauth_enabled: bool = False
|
||||
oauth_provider: str = "" # 'google', 'microsoft', 'github', 'custom'
|
||||
oauth_client_id: str = ""
|
||||
oauth_client_secret: str = ""
|
||||
oauth_redirect_uri: str = "http://localhost:3000/auth/callback"
|
||||
|
||||
# OIDC Discovery
|
||||
oidc_discovery_url: str = "" # e.g., https://accounts.google.com/.well-known/openid_configuration
|
||||
oauth_authorization_endpoint: str = ""
|
||||
oauth_token_endpoint: str = ""
|
||||
oauth_userinfo_endpoint: str = ""
|
||||
oauth_jwks_uri: str = ""
|
||||
|
||||
# OAuth Scopes and Settings
|
||||
oauth_scopes: list[str] = field(default_factory=list)
|
||||
oauth_state_expiry: int = 600 # State parameter expiry in seconds (10 minutes)
|
||||
oauth_pkce_enabled: bool = True # Enable PKCE for better security
|
||||
oauth_nonce_enabled: bool = True # Enable nonce for OIDC
|
||||
|
||||
# User Mapping Configuration
|
||||
oauth_user_id_claim: str = "sub" # JWT claim for user ID
|
||||
oauth_email_claim: str = "email"
|
||||
oauth_name_claim: str = "name"
|
||||
oauth_roles_claim: str = "roles" # Custom claim for roles
|
||||
oauth_default_roles: list[str] = field(default_factory=lambda: ["oauth_user"])
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize default OAuth scopes based on provider"""
|
||||
if not self.oauth_scopes and self.oauth_provider:
|
||||
if self.oauth_provider == "google":
|
||||
self.oauth_scopes = ["openid", "email", "profile"]
|
||||
elif self.oauth_provider == "microsoft":
|
||||
self.oauth_scopes = ["openid", "profile", "email", "User.Read"]
|
||||
elif self.oauth_provider == "github":
|
||||
self.oauth_scopes = ["user:email", "read:user"]
|
||||
else:
|
||||
self.oauth_scopes = ["openid", "email", "profile"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerformanceConfig:
|
||||
@@ -102,6 +210,52 @@ class PerformanceConfig:
|
||||
# Connection pool optimization configuration
|
||||
connection_pool_size: int = 20
|
||||
idle_timeout: int = 1800
|
||||
|
||||
# Response content size limit (characters)
|
||||
max_response_content_size: int = 4096
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataQualityConfig:
|
||||
"""Data quality analysis configuration"""
|
||||
|
||||
# Column analysis configuration
|
||||
max_columns_per_batch: int = 20 # Maximum columns to analyze in a single batch
|
||||
default_sample_size: int = 100000 # Default sample size for analysis
|
||||
|
||||
# Sampling strategy configuration
|
||||
small_table_threshold: int = 100000 # Tables smaller than this use full table analysis
|
||||
medium_table_threshold: int = 1000000 # Tables smaller than this use simple LIMIT sampling
|
||||
# Tables larger than medium_table_threshold use systematic sampling
|
||||
|
||||
# Performance optimization
|
||||
enable_batch_analysis: bool = True # Enable batch analysis for multiple columns
|
||||
batch_timeout: int = 300 # Timeout for batch analysis in seconds
|
||||
|
||||
# Accuracy vs Performance trade-off
|
||||
enable_fast_mode: bool = False # Use approximate algorithms for faster results
|
||||
fast_mode_sample_size: int = 10000 # Sample size for fast mode
|
||||
|
||||
# Statistical analysis configuration
|
||||
enable_distribution_analysis: bool = True # Enable distribution analysis
|
||||
histogram_bins: int = 20 # Number of bins for histogram analysis
|
||||
percentile_levels: list[float] = field(default_factory=lambda: [0.25, 0.5, 0.75, 0.95, 0.99]) # Percentile levels to calculate
|
||||
|
||||
|
||||
@dataclass
|
||||
class ADBCConfig:
|
||||
"""ADBC (Arrow Flight SQL) configuration"""
|
||||
|
||||
# Default query parameters
|
||||
default_max_rows: int = 100000
|
||||
default_timeout: int = 60
|
||||
default_return_format: str = "arrow" # "arrow", "pandas", "dict"
|
||||
|
||||
# Connection timeout for ADBC
|
||||
connection_timeout: int = 30
|
||||
|
||||
# Whether to enable ADBC tools
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -117,6 +271,11 @@ class LoggingConfig:
|
||||
# Audit log configuration
|
||||
enable_audit: bool = True
|
||||
audit_file_path: str | None = None
|
||||
|
||||
# Log cleanup configuration
|
||||
enable_cleanup: bool = True
|
||||
max_age_days: int = 30
|
||||
cleanup_interval_hours: int = 24
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -125,11 +284,11 @@ class MonitoringConfig:
|
||||
|
||||
# Metrics collection configuration
|
||||
enable_metrics: bool = True
|
||||
metrics_port: int = 8081
|
||||
metrics_port: int = 3001
|
||||
metrics_path: str = "/metrics"
|
||||
|
||||
# Health check configuration
|
||||
health_check_port: int = 8082
|
||||
health_check_port: int = 3002
|
||||
health_check_path: str = "/health"
|
||||
|
||||
# Alert configuration
|
||||
@@ -143,15 +302,22 @@ class DorisConfig:
|
||||
|
||||
# Basic configuration
|
||||
server_name: str = "doris-mcp-server"
|
||||
server_version: str = "1.0.0"
|
||||
server_port: int = 8080
|
||||
server_version: str = "0.4.1"
|
||||
server_host: str = "localhost"
|
||||
server_port: int = 3000
|
||||
transport: str = "stdio"
|
||||
|
||||
# Temporary files configuration
|
||||
temp_files_dir: str = "tmp" # Temporary files directory for Explain and Profile outputs
|
||||
|
||||
# Sub-configuration modules
|
||||
database: DatabaseConfig = field(default_factory=DatabaseConfig)
|
||||
security: SecurityConfig = field(default_factory=SecurityConfig)
|
||||
performance: PerformanceConfig = field(default_factory=PerformanceConfig)
|
||||
data_quality: DataQualityConfig = field(default_factory=DataQualityConfig)
|
||||
logging: LoggingConfig = field(default_factory=LoggingConfig)
|
||||
monitoring: MonitoringConfig = field(default_factory=MonitoringConfig)
|
||||
adbc: ADBCConfig = field(default_factory=ADBCConfig)
|
||||
|
||||
# Custom configuration
|
||||
custom_config: dict[str, Any] = field(default_factory=dict)
|
||||
@@ -180,6 +346,9 @@ class DorisConfig:
|
||||
@classmethod
|
||||
def from_env(cls, env_file: str | None = None) -> "DorisConfig":
|
||||
"""Load configuration from environment variables
|
||||
|
||||
The kv pairs in the. env file will be loaded as environment variables,
|
||||
but the existing environment variables will not be overridden.
|
||||
|
||||
Args:
|
||||
env_file: .env file path, if None, search in the following order:
|
||||
@@ -199,7 +368,7 @@ class DorisConfig:
|
||||
env_files = [".env", ".env.local", ".env.production", ".env.development"]
|
||||
for env_path in env_files:
|
||||
if Path(env_path).exists():
|
||||
load_dotenv(env_path)
|
||||
load_dotenv(env_path, override=False)
|
||||
logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_path}")
|
||||
break
|
||||
else:
|
||||
@@ -209,17 +378,45 @@ class DorisConfig:
|
||||
|
||||
config = cls()
|
||||
|
||||
# Database configuration
|
||||
config.database.host = os.getenv("DORIS_HOST", config.database.host)
|
||||
config.database.port = int(os.getenv("DORIS_PORT", str(config.database.port)))
|
||||
config.database.user = os.getenv("DORIS_USER", config.database.user)
|
||||
config.database.password = os.getenv("DORIS_PASSWORD", config.database.password)
|
||||
config.database.database = os.getenv("DORIS_DATABASE", config.database.database)
|
||||
# Database configuration - handle empty strings properly
|
||||
doris_host = os.getenv("DORIS_HOST", "").strip()
|
||||
config.database.host = doris_host if doris_host else config.database.host
|
||||
|
||||
doris_port = os.getenv("DORIS_PORT", "").strip()
|
||||
if doris_port and doris_port.isdigit():
|
||||
config.database.port = int(doris_port)
|
||||
|
||||
doris_user = os.getenv("DORIS_USER", "").strip()
|
||||
config.database.user = doris_user if doris_user else config.database.user
|
||||
|
||||
doris_password = os.getenv("DORIS_PASSWORD", "")
|
||||
config.database.password = doris_password if doris_password else config.database.password
|
||||
|
||||
doris_database = os.getenv("DORIS_DATABASE", "").strip()
|
||||
config.database.database = doris_database if doris_database else config.database.database
|
||||
|
||||
doris_fe_http_port = os.getenv("DORIS_FE_HTTP_PORT", "").strip()
|
||||
if doris_fe_http_port and doris_fe_http_port.isdigit():
|
||||
config.database.fe_http_port = int(doris_fe_http_port)
|
||||
|
||||
# BE nodes configuration
|
||||
be_hosts_env = os.getenv("DORIS_BE_HOSTS", "")
|
||||
if be_hosts_env:
|
||||
config.database.be_hosts = [host.strip() for host in be_hosts_env.split(",") if host.strip()]
|
||||
be_webserver_port = os.getenv("DORIS_BE_WEBSERVER_PORT", "").strip()
|
||||
if be_webserver_port and be_webserver_port.isdigit():
|
||||
config.database.be_webserver_port = int(be_webserver_port)
|
||||
|
||||
# Arrow Flight SQL Configuration
|
||||
fe_arrow_port_env = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
|
||||
if fe_arrow_port_env:
|
||||
config.database.fe_arrow_flight_sql_port = int(fe_arrow_port_env)
|
||||
|
||||
be_arrow_port_env = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
|
||||
if be_arrow_port_env:
|
||||
config.database.be_arrow_flight_sql_port = int(be_arrow_port_env)
|
||||
|
||||
# Connection pool configuration
|
||||
config.database.min_connections = int(
|
||||
os.getenv("DORIS_MIN_CONNECTIONS", str(config.database.min_connections))
|
||||
)
|
||||
config.database.max_connections = int(
|
||||
os.getenv("DORIS_MAX_CONNECTIONS", str(config.database.max_connections))
|
||||
)
|
||||
@@ -234,6 +431,10 @@ class DorisConfig:
|
||||
)
|
||||
|
||||
# Security configuration
|
||||
# Independent authentication switches
|
||||
config.security.enable_token_auth = os.getenv("ENABLE_TOKEN_AUTH", str(config.security.enable_token_auth)).lower() == "true"
|
||||
config.security.enable_jwt_auth = os.getenv("ENABLE_JWT_AUTH", str(config.security.enable_jwt_auth)).lower() == "true"
|
||||
config.security.enable_oauth_auth = os.getenv("ENABLE_OAUTH_AUTH", str(config.security.enable_oauth_auth)).lower() == "true"
|
||||
config.security.auth_type = os.getenv("AUTH_TYPE", config.security.auth_type)
|
||||
config.security.token_secret = os.getenv("TOKEN_SECRET", config.security.token_secret)
|
||||
config.security.token_expiry = int(
|
||||
@@ -245,9 +446,50 @@ class DorisConfig:
|
||||
config.security.max_query_complexity = int(
|
||||
os.getenv("MAX_QUERY_COMPLEXITY", str(config.security.max_query_complexity))
|
||||
)
|
||||
config.security.enable_security_check = (
|
||||
os.getenv("ENABLE_SECURITY_CHECK", str(config.security.enable_security_check).lower()).lower() == "true"
|
||||
)
|
||||
|
||||
# Handle blocked keywords environment variable configuration
|
||||
# Format: BLOCKED_KEYWORDS="DROP,DELETE,TRUNCATE,ALTER,CREATE,INSERT,UPDATE,GRANT,REVOKE"
|
||||
blocked_keywords_env = os.getenv("BLOCKED_KEYWORDS", "")
|
||||
if blocked_keywords_env:
|
||||
# If environment variable is provided, use keywords list from environment variable
|
||||
config.security.blocked_keywords = [
|
||||
keyword.strip().upper()
|
||||
for keyword in blocked_keywords_env.split(",")
|
||||
if keyword.strip()
|
||||
]
|
||||
# If environment variable is empty, keep default configuration unchanged
|
||||
|
||||
config.security.enable_masking = (
|
||||
os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true"
|
||||
)
|
||||
|
||||
# Enhanced Token Authentication configuration
|
||||
config.security.token_file_path = os.getenv("TOKEN_FILE_PATH", config.security.token_file_path)
|
||||
config.security.enable_token_expiry = (
|
||||
os.getenv("ENABLE_TOKEN_EXPIRY", str(config.security.enable_token_expiry).lower()).lower() == "true"
|
||||
)
|
||||
config.security.default_token_expiry_hours = int(
|
||||
os.getenv("DEFAULT_TOKEN_EXPIRY_HOURS", str(config.security.default_token_expiry_hours))
|
||||
)
|
||||
config.security.token_hash_algorithm = os.getenv("TOKEN_HASH_ALGORITHM", config.security.token_hash_algorithm)
|
||||
|
||||
# Token Management Security Configuration (New in v0.6.0)
|
||||
config.security.enable_http_token_management = (
|
||||
os.getenv("ENABLE_HTTP_TOKEN_MANAGEMENT", str(config.security.enable_http_token_management).lower()).lower() == "true"
|
||||
)
|
||||
config.security.token_management_admin_token = os.getenv("TOKEN_MANAGEMENT_ADMIN_TOKEN", config.security.token_management_admin_token)
|
||||
|
||||
# Parse allowed IPs from comma-separated string
|
||||
allowed_ips_str = os.getenv("TOKEN_MANAGEMENT_ALLOWED_IPS", "")
|
||||
if allowed_ips_str:
|
||||
config.security.token_management_allowed_ips = [ip.strip() for ip in allowed_ips_str.split(",") if ip.strip()]
|
||||
|
||||
config.security.require_admin_auth = (
|
||||
os.getenv("REQUIRE_ADMIN_AUTH", str(config.security.require_admin_auth).lower()).lower() == "true"
|
||||
)
|
||||
|
||||
# Performance configuration
|
||||
config.performance.enable_query_cache = (
|
||||
@@ -265,6 +507,9 @@ class DorisConfig:
|
||||
config.performance.query_timeout = int(
|
||||
os.getenv("QUERY_TIMEOUT", str(config.performance.query_timeout))
|
||||
)
|
||||
config.performance.max_response_content_size = int(
|
||||
os.getenv("MAX_RESPONSE_CONTENT_SIZE", str(config.performance.max_response_content_size))
|
||||
)
|
||||
|
||||
# Logging configuration
|
||||
config.logging.level = os.getenv("LOG_LEVEL", config.logging.level)
|
||||
@@ -273,6 +518,15 @@ class DorisConfig:
|
||||
os.getenv("ENABLE_AUDIT", str(config.logging.enable_audit).lower()).lower() == "true"
|
||||
)
|
||||
config.logging.audit_file_path = os.getenv("AUDIT_FILE_PATH", config.logging.audit_file_path)
|
||||
config.logging.enable_cleanup = (
|
||||
os.getenv("ENABLE_LOG_CLEANUP", str(config.logging.enable_cleanup).lower()).lower() == "true"
|
||||
)
|
||||
config.logging.max_age_days = int(
|
||||
os.getenv("LOG_MAX_AGE_DAYS", str(config.logging.max_age_days))
|
||||
)
|
||||
config.logging.cleanup_interval_hours = int(
|
||||
os.getenv("LOG_CLEANUP_INTERVAL_HOURS", str(config.logging.cleanup_interval_hours))
|
||||
)
|
||||
|
||||
# Monitoring configuration
|
||||
config.monitoring.enable_metrics = (
|
||||
@@ -289,10 +543,60 @@ class DorisConfig:
|
||||
)
|
||||
config.monitoring.alert_webhook_url = os.getenv("ALERT_WEBHOOK_URL", config.monitoring.alert_webhook_url)
|
||||
|
||||
# ADBC configuration
|
||||
config.adbc.default_max_rows = int(
|
||||
os.getenv("ADBC_DEFAULT_MAX_ROWS", str(config.adbc.default_max_rows))
|
||||
)
|
||||
config.adbc.default_timeout = int(
|
||||
os.getenv("ADBC_DEFAULT_TIMEOUT", str(config.adbc.default_timeout))
|
||||
)
|
||||
config.adbc.default_return_format = os.getenv("ADBC_DEFAULT_RETURN_FORMAT", config.adbc.default_return_format)
|
||||
config.adbc.connection_timeout = int(
|
||||
os.getenv("ADBC_CONNECTION_TIMEOUT", str(config.adbc.connection_timeout))
|
||||
)
|
||||
config.adbc.enabled = (
|
||||
os.getenv("ADBC_ENABLED", str(config.adbc.enabled).lower()).lower() == "true"
|
||||
)
|
||||
|
||||
# Data quality configuration
|
||||
config.data_quality.max_columns_per_batch = int(
|
||||
os.getenv("DATA_QUALITY_MAX_COLUMNS_PER_BATCH", str(config.data_quality.max_columns_per_batch))
|
||||
)
|
||||
config.data_quality.default_sample_size = int(
|
||||
os.getenv("DATA_QUALITY_DEFAULT_SAMPLE_SIZE", str(config.data_quality.default_sample_size))
|
||||
)
|
||||
config.data_quality.small_table_threshold = int(
|
||||
os.getenv("DATA_QUALITY_SMALL_TABLE_THRESHOLD", str(config.data_quality.small_table_threshold))
|
||||
)
|
||||
config.data_quality.medium_table_threshold = int(
|
||||
os.getenv("DATA_QUALITY_MEDIUM_TABLE_THRESHOLD", str(config.data_quality.medium_table_threshold))
|
||||
)
|
||||
config.data_quality.enable_batch_analysis = (
|
||||
os.getenv("DATA_QUALITY_ENABLE_BATCH_ANALYSIS", str(config.data_quality.enable_batch_analysis).lower()).lower() == "true"
|
||||
)
|
||||
config.data_quality.batch_timeout = int(
|
||||
os.getenv("DATA_QUALITY_BATCH_TIMEOUT", str(config.data_quality.batch_timeout))
|
||||
)
|
||||
config.data_quality.enable_fast_mode = (
|
||||
os.getenv("DATA_QUALITY_ENABLE_FAST_MODE", str(config.data_quality.enable_fast_mode).lower()).lower() == "true"
|
||||
)
|
||||
config.data_quality.fast_mode_sample_size = int(
|
||||
os.getenv("DATA_QUALITY_FAST_MODE_SAMPLE_SIZE", str(config.data_quality.fast_mode_sample_size))
|
||||
)
|
||||
config.data_quality.enable_distribution_analysis = (
|
||||
os.getenv("DATA_QUALITY_ENABLE_DISTRIBUTION_ANALYSIS", str(config.data_quality.enable_distribution_analysis).lower()).lower() == "true"
|
||||
)
|
||||
config.data_quality.histogram_bins = int(
|
||||
os.getenv("DATA_QUALITY_HISTOGRAM_BINS", str(config.data_quality.histogram_bins))
|
||||
)
|
||||
|
||||
# Server configuration
|
||||
config.server_name = os.getenv("SERVER_NAME", config.server_name)
|
||||
config.server_version = os.getenv("SERVER_VERSION", config.server_version)
|
||||
config.server_port = int(os.getenv("SERVER_PORT", str(config.server_port)))
|
||||
server_port = os.getenv("SERVER_PORT", "").strip()
|
||||
if server_port and server_port.isdigit():
|
||||
config.server_port = int(server_port)
|
||||
config.temp_files_dir = os.getenv("TEMP_FILES_DIR", config.temp_files_dir)
|
||||
|
||||
return config
|
||||
|
||||
@@ -302,7 +606,7 @@ class DorisConfig:
|
||||
config = cls()
|
||||
|
||||
# Update basic configuration
|
||||
for key in ["server_name", "server_version", "server_port"]:
|
||||
for key in ["server_name", "server_version", "server_port", "temp_files_dir"]:
|
||||
if key in config_data:
|
||||
setattr(config, key, config_data[key])
|
||||
|
||||
@@ -327,6 +631,13 @@ class DorisConfig:
|
||||
if hasattr(config.performance, key):
|
||||
setattr(config.performance, key, value)
|
||||
|
||||
# Update data quality configuration
|
||||
if "data_quality" in config_data:
|
||||
dq_config = config_data["data_quality"]
|
||||
for key, value in dq_config.items():
|
||||
if hasattr(config.data_quality, key):
|
||||
setattr(config.data_quality, key, value)
|
||||
|
||||
# Update logging configuration
|
||||
if "logging" in config_data:
|
||||
log_config = config_data["logging"]
|
||||
@@ -341,6 +652,13 @@ class DorisConfig:
|
||||
if hasattr(config.monitoring, key):
|
||||
setattr(config.monitoring, key, value)
|
||||
|
||||
# Update ADBC configuration
|
||||
if "adbc" in config_data:
|
||||
adbc_config = config_data["adbc"]
|
||||
for key, value in adbc_config.items():
|
||||
if hasattr(config.adbc, key):
|
||||
setattr(config.adbc, key, value)
|
||||
|
||||
# Custom configuration
|
||||
config.custom_config = config_data.get("custom", {})
|
||||
|
||||
@@ -352,6 +670,7 @@ class DorisConfig:
|
||||
"server_name": self.server_name,
|
||||
"server_version": self.server_version,
|
||||
"server_port": self.server_port,
|
||||
"temp_files_dir": self.temp_files_dir,
|
||||
"database": {
|
||||
"host": self.database.host,
|
||||
"port": self.database.port,
|
||||
@@ -359,7 +678,12 @@ class DorisConfig:
|
||||
"password": "***", # Hide password
|
||||
"database": self.database.database,
|
||||
"charset": self.database.charset,
|
||||
"min_connections": self.database.min_connections,
|
||||
"fe_http_port": self.database.fe_http_port,
|
||||
"be_hosts": self.database.be_hosts,
|
||||
"be_webserver_port": self.database.be_webserver_port,
|
||||
"fe_arrow_flight_sql_port": self.database.fe_arrow_flight_sql_port,
|
||||
"be_arrow_flight_sql_port": self.database.be_arrow_flight_sql_port,
|
||||
"min_connections": self.database.min_connections, # Always 0, shown for reference
|
||||
"max_connections": self.database.max_connections,
|
||||
"connection_timeout": self.database.connection_timeout,
|
||||
"health_check_interval": self.database.health_check_interval,
|
||||
@@ -369,6 +693,7 @@ class DorisConfig:
|
||||
"auth_type": self.security.auth_type,
|
||||
"token_secret": "***", # Hide secret key
|
||||
"token_expiry": self.security.token_expiry,
|
||||
"enable_security_check": self.security.enable_security_check,
|
||||
"blocked_keywords": self.security.blocked_keywords,
|
||||
"max_query_complexity": self.security.max_query_complexity,
|
||||
"max_result_rows": self.security.max_result_rows,
|
||||
@@ -384,6 +709,20 @@ class DorisConfig:
|
||||
"query_timeout": self.performance.query_timeout,
|
||||
"connection_pool_size": self.performance.connection_pool_size,
|
||||
"idle_timeout": self.performance.idle_timeout,
|
||||
"max_response_content_size": self.performance.max_response_content_size,
|
||||
},
|
||||
"data_quality": {
|
||||
"max_columns_per_batch": self.data_quality.max_columns_per_batch,
|
||||
"default_sample_size": self.data_quality.default_sample_size,
|
||||
"small_table_threshold": self.data_quality.small_table_threshold,
|
||||
"medium_table_threshold": self.data_quality.medium_table_threshold,
|
||||
"enable_batch_analysis": self.data_quality.enable_batch_analysis,
|
||||
"batch_timeout": self.data_quality.batch_timeout,
|
||||
"enable_fast_mode": self.data_quality.enable_fast_mode,
|
||||
"fast_mode_sample_size": self.data_quality.fast_mode_sample_size,
|
||||
"enable_distribution_analysis": self.data_quality.enable_distribution_analysis,
|
||||
"histogram_bins": self.data_quality.histogram_bins,
|
||||
"percentile_levels": self.data_quality.percentile_levels,
|
||||
},
|
||||
"logging": {
|
||||
"level": self.logging.level,
|
||||
@@ -393,6 +732,9 @@ class DorisConfig:
|
||||
"backup_count": self.logging.backup_count,
|
||||
"enable_audit": self.logging.enable_audit,
|
||||
"audit_file_path": self.logging.audit_file_path,
|
||||
"enable_cleanup": self.logging.enable_cleanup,
|
||||
"max_age_days": self.logging.max_age_days,
|
||||
"cleanup_interval_hours": self.logging.cleanup_interval_hours,
|
||||
},
|
||||
"monitoring": {
|
||||
"enable_metrics": self.monitoring.enable_metrics,
|
||||
@@ -403,6 +745,13 @@ class DorisConfig:
|
||||
"enable_alerts": self.monitoring.enable_alerts,
|
||||
"alert_webhook_url": self.monitoring.alert_webhook_url,
|
||||
},
|
||||
"adbc": {
|
||||
"default_max_rows": self.adbc.default_max_rows,
|
||||
"default_timeout": self.adbc.default_timeout,
|
||||
"default_return_format": self.adbc.default_return_format,
|
||||
"connection_timeout": self.adbc.connection_timeout,
|
||||
"enabled": self.adbc.enabled,
|
||||
},
|
||||
"custom": self.custom_config,
|
||||
}
|
||||
|
||||
@@ -435,11 +784,8 @@ class DorisConfig:
|
||||
if not self.database.user:
|
||||
errors.append("Database username cannot be empty")
|
||||
|
||||
if self.database.min_connections <= 0:
|
||||
errors.append("Minimum connections must be greater than 0")
|
||||
|
||||
if self.database.max_connections <= self.database.min_connections:
|
||||
errors.append("Maximum connections must be greater than minimum connections")
|
||||
if self.database.max_connections <= 0:
|
||||
errors.append("Maximum connections must be greater than 0")
|
||||
|
||||
# Validate security configuration
|
||||
if self.security.auth_type not in ["token", "basic", "oauth"]:
|
||||
@@ -464,6 +810,31 @@ class DorisConfig:
|
||||
if self.performance.query_timeout <= 0:
|
||||
errors.append("Query timeout must be greater than 0")
|
||||
|
||||
# Validate data quality configuration
|
||||
if self.data_quality.max_columns_per_batch <= 0:
|
||||
errors.append("Max columns per batch must be greater than 0")
|
||||
|
||||
if self.data_quality.default_sample_size <= 0:
|
||||
errors.append("Default sample size must be greater than 0")
|
||||
|
||||
if self.data_quality.small_table_threshold <= 0:
|
||||
errors.append("Small table threshold must be greater than 0")
|
||||
|
||||
if self.data_quality.medium_table_threshold <= 0:
|
||||
errors.append("Medium table threshold must be greater than 0")
|
||||
|
||||
if self.data_quality.small_table_threshold >= self.data_quality.medium_table_threshold:
|
||||
errors.append("Small table threshold must be less than medium table threshold")
|
||||
|
||||
if self.data_quality.batch_timeout <= 0:
|
||||
errors.append("Batch timeout must be greater than 0")
|
||||
|
||||
if self.data_quality.fast_mode_sample_size <= 0:
|
||||
errors.append("Fast mode sample size must be greater than 0")
|
||||
|
||||
if self.data_quality.histogram_bins <= 0:
|
||||
errors.append("Histogram bins must be greater than 0")
|
||||
|
||||
# Validate logging configuration
|
||||
if self.logging.level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
|
||||
errors.append("Log level must be one of DEBUG, INFO, WARNING, ERROR, or CRITICAL")
|
||||
@@ -473,6 +844,12 @@ class DorisConfig:
|
||||
|
||||
if self.logging.backup_count < 0:
|
||||
errors.append("Log backup count cannot be negative")
|
||||
|
||||
if self.logging.max_age_days <= 0:
|
||||
errors.append("Log max age days must be greater than 0")
|
||||
|
||||
if self.logging.cleanup_interval_hours <= 0:
|
||||
errors.append("Log cleanup interval hours must be greater than 0")
|
||||
|
||||
# Validate monitoring configuration
|
||||
if not (1 <= self.monitoring.metrics_port <= 65535):
|
||||
@@ -481,6 +858,19 @@ class DorisConfig:
|
||||
if not (1 <= self.monitoring.health_check_port <= 65535):
|
||||
errors.append("Health check port must be in the range 1-65535")
|
||||
|
||||
# Validate ADBC configuration
|
||||
if self.adbc.default_max_rows <= 0:
|
||||
errors.append("ADBC default max rows must be greater than 0")
|
||||
|
||||
if self.adbc.default_timeout <= 0:
|
||||
errors.append("ADBC default timeout must be greater than 0")
|
||||
|
||||
if self.adbc.default_return_format not in ["arrow", "pandas", "dict"]:
|
||||
errors.append("ADBC default return format must be one of arrow, pandas, or dict")
|
||||
|
||||
if self.adbc.connection_timeout <= 0:
|
||||
errors.append("ADBC connection timeout must be greater than 0")
|
||||
|
||||
return errors
|
||||
|
||||
def get_connection_string(self) -> str:
|
||||
@@ -492,7 +882,7 @@ class DorisConfig:
|
||||
return {
|
||||
"server": f"{self.server_name} v{self.server_version}",
|
||||
"database": f"{self.database.host}:{self.database.port}/{self.database.database}",
|
||||
"connection_pool": f"{self.database.min_connections}-{self.database.max_connections}",
|
||||
"connection_pool": f"0-{self.database.max_connections} (min fixed at 0 for stability)",
|
||||
"security": {
|
||||
"auth_type": self.security.auth_type,
|
||||
"masking_enabled": self.security.enable_masking,
|
||||
@@ -518,56 +908,50 @@ class ConfigManager:
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def setup_logging(self):
|
||||
"""Setup logging configuration"""
|
||||
# Configure root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(getattr(logging, self.config.logging.level.upper()))
|
||||
|
||||
# Clear existing handlers
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter(self.config.logging.format)
|
||||
|
||||
# Console handler
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# File handler (if configured)
|
||||
"""Setup logging configuration using enhanced logger"""
|
||||
from .logger import setup_logging, get_logger
|
||||
import sys
|
||||
|
||||
# Determine log directory
|
||||
log_dir = "logs"
|
||||
if self.config.logging.file_path:
|
||||
try:
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
file_handler = RotatingFileHandler(
|
||||
self.config.logging.file_path,
|
||||
maxBytes=self.config.logging.max_file_size,
|
||||
backupCount=self.config.logging.backup_count,
|
||||
encoding="utf-8",
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(file_handler)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to setup file logging: {e}")
|
||||
|
||||
# Audit log handler (if configured)
|
||||
if self.config.logging.enable_audit and self.config.logging.audit_file_path:
|
||||
try:
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
audit_logger = logging.getLogger("audit")
|
||||
audit_handler = RotatingFileHandler(
|
||||
self.config.logging.audit_file_path,
|
||||
maxBytes=self.config.logging.max_file_size,
|
||||
backupCount=self.config.logging.backup_count,
|
||||
encoding="utf-8",
|
||||
)
|
||||
audit_handler.setFormatter(formatter)
|
||||
audit_logger.addHandler(audit_handler)
|
||||
audit_logger.setLevel(logging.INFO)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to setup audit logging: {e}")
|
||||
# Extract directory from file path if provided
|
||||
from pathlib import Path
|
||||
log_dir = str(Path(self.config.logging.file_path).parent)
|
||||
|
||||
# Detect if we're in stdio mode by checking if this is likely MCP stdio communication
|
||||
# In stdio mode, we shouldn't output to console as it interferes with JSON protocol
|
||||
is_stdio_mode = (
|
||||
self.config.transport == "stdio" or
|
||||
"--transport" in sys.argv and "stdio" in sys.argv or
|
||||
not sys.stdout.isatty() # Not a terminal (likely piped/redirected)
|
||||
)
|
||||
|
||||
# Setup enhanced logging with cleanup functionality
|
||||
setup_logging(
|
||||
level=self.config.logging.level,
|
||||
log_dir=log_dir,
|
||||
enable_console=not is_stdio_mode, # Disable console logging in stdio mode
|
||||
enable_file=True,
|
||||
enable_audit=self.config.logging.enable_audit,
|
||||
audit_file=self.config.logging.audit_file_path,
|
||||
max_file_size=self.config.logging.max_file_size,
|
||||
backup_count=self.config.logging.backup_count,
|
||||
enable_cleanup=self.config.logging.enable_cleanup,
|
||||
max_age_days=self.config.logging.max_age_days,
|
||||
cleanup_interval_hours=self.config.logging.cleanup_interval_hours
|
||||
)
|
||||
|
||||
# Update logger to use new system
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
self.logger.info("Enhanced logging system with cleanup initialized successfully")
|
||||
self.logger.info(f"Log directory: {log_dir}")
|
||||
self.logger.info(f"Log level: {self.config.logging.level}")
|
||||
self.logger.info(f"Audit logging: {'Enabled' if self.config.logging.enable_audit else 'Disabled'}")
|
||||
self.logger.info(f"Log cleanup: {'Enabled' if self.config.logging.enable_cleanup else 'Disabled'}")
|
||||
if self.config.logging.enable_cleanup:
|
||||
self.logger.info(f"Cleanup config: Max age {self.config.logging.max_age_days} days, interval {self.config.logging.cleanup_interval_hours}h")
|
||||
|
||||
def validate_config(self) -> bool:
|
||||
"""Validate configuration"""
|
||||
|
||||
771
doris_mcp_server/utils/data_exploration_tools.py
Normal file
@@ -0,0 +1,771 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Data Exploration Tools Module
|
||||
Provides table data distribution analysis and exploration capabilities
|
||||
"""
|
||||
|
||||
import time
|
||||
import math
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import (
|
||||
SQLSecurityError,
|
||||
validate_identifier,
|
||||
quote_identifier,
|
||||
build_table_reference,
|
||||
get_auth_context
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DataExplorationTools:
|
||||
"""Data exploration tools for table distribution analysis"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
logger.info("DataExplorationTools initialized")
|
||||
|
||||
|
||||
|
||||
# ==================== Private Helper Methods ====================
|
||||
|
||||
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
|
||||
"""Build full table name with catalog and database using three-part naming convention"""
|
||||
# SECURITY FIX: Use build_table_reference for safe identifier handling
|
||||
effective_catalog = catalog_name if catalog_name else "internal"
|
||||
|
||||
if db_name:
|
||||
return build_table_reference(table_name, db_name, effective_catalog)
|
||||
else:
|
||||
return build_table_reference(table_name, catalog_name=effective_catalog)
|
||||
|
||||
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
|
||||
"""Get basic table information including row count"""
|
||||
try:
|
||||
# SECURITY FIX: Get auth_context for security validation
|
||||
# table_name should already be validated by _build_full_table_name
|
||||
auth_context = get_auth_context()
|
||||
|
||||
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
|
||||
result = await connection.execute(count_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
return {"row_count": result.data[0]["row_count"]}
|
||||
return None
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed for table {table_name}: {str(e)}")
|
||||
return {"row_count": 0}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
|
||||
return {"row_count": 0}
|
||||
|
||||
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
|
||||
"""Get detailed column information"""
|
||||
try:
|
||||
# SECURITY FIX: Validate identifiers and use parameterized query
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if db_name:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
# Build parameterized query
|
||||
params = [table_name]
|
||||
where_conditions = ["table_name = %s"]
|
||||
|
||||
if db_name:
|
||||
where_conditions.append("table_schema = %s")
|
||||
params.append(db_name)
|
||||
else:
|
||||
where_conditions.append("table_schema = DATABASE()")
|
||||
|
||||
columns_sql = f"""
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
is_nullable,
|
||||
column_comment,
|
||||
ordinal_position
|
||||
FROM information_schema.columns
|
||||
WHERE {' AND '.join(where_conditions)}
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
result = await connection.execute(columns_sql, params=tuple(params), auth_context=auth_context)
|
||||
return result.data if result.data else []
|
||||
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed: {str(e)}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _determine_sampling_strategy(self, connection, table_name: str, total_rows: int, sample_size: int) -> Dict[str, Any]:
|
||||
"""Determine optimal sampling strategy based on table size"""
|
||||
if total_rows <= sample_size:
|
||||
# Use all data if table is small enough
|
||||
return {
|
||||
"total_rows": total_rows,
|
||||
"sample_size": total_rows,
|
||||
"sampling_method": "full_scan",
|
||||
"sampling_ratio": 1.0,
|
||||
"use_sampling": False,
|
||||
"sample_table_expression": table_name
|
||||
}
|
||||
else:
|
||||
# Use random sampling for large tables
|
||||
sampling_ratio = sample_size / total_rows
|
||||
return {
|
||||
"total_rows": total_rows,
|
||||
"sample_size": sample_size,
|
||||
"sampling_method": "random_sample",
|
||||
"sampling_ratio": round(sampling_ratio, 4),
|
||||
"use_sampling": True,
|
||||
"sample_table_expression": f"(SELECT * FROM {table_name} ORDER BY RAND() LIMIT {sample_size}) as sample_table"
|
||||
}
|
||||
|
||||
def _select_analysis_columns(self, columns_info: List[Dict], include_all: bool) -> List[Dict]:
|
||||
"""Select columns for analysis based on strategy"""
|
||||
if include_all:
|
||||
return columns_info
|
||||
|
||||
# If not analyzing all columns, prioritize key columns
|
||||
priority_keywords = ['id', 'key', 'code', 'status', 'type', 'amount', 'count', 'date', 'time']
|
||||
|
||||
priority_columns = []
|
||||
other_columns = []
|
||||
|
||||
for col in columns_info:
|
||||
col_name_lower = col["column_name"].lower()
|
||||
if any(keyword in col_name_lower for keyword in priority_keywords):
|
||||
priority_columns.append(col)
|
||||
else:
|
||||
other_columns.append(col)
|
||||
|
||||
# Return priority columns plus first 10 other columns
|
||||
return priority_columns + other_columns[:10]
|
||||
|
||||
def _is_numeric_type(self, data_type: str) -> bool:
|
||||
"""Check if column type is numeric"""
|
||||
numeric_types = [
|
||||
'tinyint', 'smallint', 'int', 'bigint', 'largeint',
|
||||
'float', 'double', 'decimal', 'numeric'
|
||||
]
|
||||
return any(num_type in data_type.lower() for num_type in numeric_types)
|
||||
|
||||
def _is_categorical_type(self, data_type: str) -> bool:
|
||||
"""Check if column type is categorical"""
|
||||
categorical_types = ['varchar', 'char', 'string', 'text', 'enum']
|
||||
return any(cat_type in data_type.lower() for cat_type in categorical_types)
|
||||
|
||||
def _is_temporal_type(self, data_type: str) -> bool:
|
||||
"""Check if column type is temporal"""
|
||||
temporal_types = ['date', 'datetime', 'timestamp', 'time']
|
||||
return any(temp_type in data_type.lower() for temp_type in temporal_types)
|
||||
|
||||
async def _analyze_numeric_distributions(self, connection, table_name: str, numeric_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Analyze distribution patterns for numeric columns"""
|
||||
numeric_analysis = {}
|
||||
|
||||
for column in numeric_columns:
|
||||
col_name = column["column_name"]
|
||||
try:
|
||||
# Basic statistics
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
stats_sql = f"""
|
||||
SELECT
|
||||
COUNT({col_name}) as count,
|
||||
MIN({col_name}) as min_value,
|
||||
MAX({col_name}) as max_value,
|
||||
AVG({col_name}) as mean_value,
|
||||
STDDEV({col_name}) as std_dev
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
auth_context = get_auth_context()
|
||||
stats_result = await connection.execute(stats_sql, auth_context=auth_context)
|
||||
|
||||
if stats_result.data and stats_result.data[0]["count"] > 0:
|
||||
stats = stats_result.data[0]
|
||||
|
||||
# Percentiles calculation
|
||||
percentiles = await self._calculate_percentiles(connection, table_name, col_name, sampling_info)
|
||||
|
||||
# Outlier detection
|
||||
outliers = await self._detect_numeric_outliers(connection, table_name, col_name, percentiles, sampling_info)
|
||||
|
||||
# Distribution shape analysis
|
||||
distribution_shape = await self._analyze_distribution_shape(
|
||||
connection, table_name, col_name, stats, percentiles, sampling_info
|
||||
)
|
||||
|
||||
numeric_analysis[col_name] = {
|
||||
"data_type": column["data_type"],
|
||||
"statistics": {
|
||||
"count": stats["count"],
|
||||
"mean": round(float(stats["mean_value"]), 4) if stats["mean_value"] else None,
|
||||
"std": round(float(stats["std_dev"]), 4) if stats["std_dev"] else None,
|
||||
"min": float(stats["min_value"]) if stats["min_value"] else None,
|
||||
"max": float(stats["max_value"]) if stats["max_value"] else None,
|
||||
**percentiles
|
||||
},
|
||||
"distribution_shape": distribution_shape,
|
||||
"outliers": outliers
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze numeric column {col_name}: {str(e)}")
|
||||
numeric_analysis[col_name] = {"error": str(e)}
|
||||
|
||||
return numeric_analysis
|
||||
|
||||
async def _calculate_percentiles(self, connection, table_name: str, col_name: str, sampling_info: Dict) -> Dict[str, float]:
|
||||
"""Calculate percentiles for numeric column"""
|
||||
try:
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
percentile_sql = f"""
|
||||
SELECT
|
||||
PERCENTILE({col_name}, 0.25) as p25,
|
||||
PERCENTILE({col_name}, 0.50) as p50,
|
||||
PERCENTILE({col_name}, 0.75) as p75,
|
||||
PERCENTILE({col_name}, 0.90) as p90,
|
||||
PERCENTILE({col_name}, 0.95) as p95,
|
||||
PERCENTILE({col_name}, 0.99) as p99
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(percentile_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
data = result.data[0]
|
||||
return {
|
||||
"25%": round(float(data["p25"]), 4) if data["p25"] else None,
|
||||
"50%": round(float(data["p50"]), 4) if data["p50"] else None,
|
||||
"75%": round(float(data["p75"]), 4) if data["p75"] else None,
|
||||
"90%": round(float(data["p90"]), 4) if data["p90"] else None,
|
||||
"95%": round(float(data["p95"]), 4) if data["p95"] else None,
|
||||
"99%": round(float(data["p99"]), 4) if data["p99"] else None
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to calculate percentiles for {col_name}: {str(e)}")
|
||||
|
||||
return {}
|
||||
|
||||
async def _detect_numeric_outliers(self, connection, table_name: str, col_name: str, percentiles: Dict, sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Detect outliers using IQR method"""
|
||||
try:
|
||||
if "25%" not in percentiles or "75%" not in percentiles:
|
||||
return {"outlier_count": 0, "outlier_rate": 0.0}
|
||||
|
||||
q1 = percentiles["25%"]
|
||||
q3 = percentiles["75%"]
|
||||
iqr = q3 - q1
|
||||
|
||||
lower_bound = q1 - 1.5 * iqr
|
||||
upper_bound = q3 + 1.5 * iqr
|
||||
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
outlier_sql = f"""
|
||||
SELECT
|
||||
COUNT(*) as total_count,
|
||||
SUM(CASE WHEN {col_name} < {lower_bound} OR {col_name} > {upper_bound} THEN 1 ELSE 0 END) as outlier_count
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(outlier_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
data = result.data[0]
|
||||
total_count = data["total_count"]
|
||||
outlier_count = data["outlier_count"]
|
||||
outlier_rate = outlier_count / total_count if total_count > 0 else 0
|
||||
|
||||
return {
|
||||
"outlier_count": outlier_count,
|
||||
"outlier_rate": round(outlier_rate, 4),
|
||||
"outlier_threshold_lower": round(lower_bound, 4),
|
||||
"outlier_threshold_upper": round(upper_bound, 4),
|
||||
"iqr": round(iqr, 4)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to detect outliers for {col_name}: {str(e)}")
|
||||
|
||||
return {"outlier_count": 0, "outlier_rate": 0.0}
|
||||
|
||||
async def _analyze_distribution_shape(self, connection, table_name: str, col_name: str, stats: Dict, percentiles: Dict, sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Analyze the shape of data distribution"""
|
||||
try:
|
||||
mean = stats.get("mean_value", 0)
|
||||
median = percentiles.get("50%", 0)
|
||||
|
||||
if mean is None or median is None:
|
||||
return {"distribution_type": "unknown"}
|
||||
|
||||
# Calculate skewness indicator
|
||||
if abs(mean - median) < 0.01:
|
||||
skew_indicator = "symmetric"
|
||||
elif mean > median:
|
||||
skew_indicator = "right_skewed"
|
||||
else:
|
||||
skew_indicator = "left_skewed"
|
||||
|
||||
# Estimate kurtosis based on percentile spread
|
||||
if "25%" in percentiles and "75%" in percentiles:
|
||||
iqr = percentiles["75%"] - percentiles["25%"]
|
||||
range_90 = percentiles.get("90%", percentiles["75%"]) - percentiles.get("10%", percentiles["25%"])
|
||||
|
||||
if iqr > 0:
|
||||
kurtosis_indicator = "normal" if 2.5 <= range_90/iqr <= 3.5 else ("heavy_tailed" if range_90/iqr > 3.5 else "light_tailed")
|
||||
else:
|
||||
kurtosis_indicator = "unknown"
|
||||
else:
|
||||
kurtosis_indicator = "unknown"
|
||||
|
||||
return {
|
||||
"skewness_indicator": skew_indicator,
|
||||
"kurtosis_indicator": kurtosis_indicator,
|
||||
"distribution_type": self._classify_distribution_type(skew_indicator, kurtosis_indicator),
|
||||
"mean_median_ratio": round(mean / median, 4) if median != 0 else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze distribution shape for {col_name}: {str(e)}")
|
||||
return {"distribution_type": "unknown"}
|
||||
|
||||
def _classify_distribution_type(self, skew: str, kurtosis: str) -> str:
|
||||
"""Classify distribution type based on skewness and kurtosis"""
|
||||
if skew == "symmetric" and kurtosis == "normal":
|
||||
return "approximately_normal"
|
||||
elif skew == "right_skewed":
|
||||
return "right_skewed"
|
||||
elif skew == "left_skewed":
|
||||
return "left_skewed"
|
||||
elif kurtosis == "heavy_tailed":
|
||||
return "heavy_tailed"
|
||||
else:
|
||||
return "non_normal"
|
||||
|
||||
async def _analyze_categorical_distributions(self, connection, table_name: str, categorical_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Analyze distribution patterns for categorical columns"""
|
||||
categorical_analysis = {}
|
||||
|
||||
for column in categorical_columns:
|
||||
col_name = column["column_name"]
|
||||
try:
|
||||
# Basic cardinality and distribution
|
||||
cardinality_sql = f"""
|
||||
SELECT
|
||||
COUNT(DISTINCT {col_name}) as cardinality,
|
||||
COUNT({col_name}) as non_null_count
|
||||
FROM {table_name}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
{sampling_info.get('sample_query_suffix', '')}
|
||||
"""
|
||||
|
||||
auth_context = get_auth_context()
|
||||
cardinality_result = await connection.execute(cardinality_sql, auth_context=auth_context)
|
||||
|
||||
if cardinality_result.data:
|
||||
cardinality_data = cardinality_result.data[0]
|
||||
cardinality = cardinality_data["cardinality"]
|
||||
non_null_count = cardinality_data["non_null_count"]
|
||||
|
||||
# Value distribution (top values)
|
||||
value_distribution = await self._get_categorical_value_distribution(
|
||||
connection, table_name, col_name, sampling_info, non_null_count
|
||||
)
|
||||
|
||||
# Calculate entropy and concentration
|
||||
entropy = self._calculate_entropy(value_distribution)
|
||||
concentration_ratio = value_distribution[0]["percentage"] if value_distribution else 0
|
||||
|
||||
categorical_analysis[col_name] = {
|
||||
"data_type": column["data_type"],
|
||||
"cardinality": cardinality,
|
||||
"non_null_count": non_null_count,
|
||||
"value_distribution": value_distribution,
|
||||
"entropy": round(entropy, 3),
|
||||
"concentration_ratio": round(concentration_ratio, 4),
|
||||
"diversity_score": round(cardinality / non_null_count, 4) if non_null_count > 0 else 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze categorical column {col_name}: {str(e)}")
|
||||
categorical_analysis[col_name] = {"error": str(e)}
|
||||
|
||||
return categorical_analysis
|
||||
|
||||
async def _get_categorical_value_distribution(self, connection, table_name: str, col_name: str, sampling_info: Dict, total_count: int) -> List[Dict]:
|
||||
"""Get value distribution for categorical column"""
|
||||
try:
|
||||
# Use sample table expression if sampling is enabled
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
|
||||
distribution_sql = f"""
|
||||
SELECT
|
||||
{col_name} as value,
|
||||
COUNT(*) as count
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
GROUP BY {col_name}
|
||||
ORDER BY COUNT(*) DESC
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(distribution_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
distribution = []
|
||||
for row in result.data:
|
||||
count = row["count"]
|
||||
percentage = count / total_count if total_count > 0 else 0
|
||||
distribution.append({
|
||||
"value": str(row["value"]),
|
||||
"count": count,
|
||||
"percentage": round(percentage, 4)
|
||||
})
|
||||
return distribution
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get value distribution for {col_name}: {str(e)}")
|
||||
|
||||
return []
|
||||
|
||||
def _calculate_entropy(self, value_distribution: List[Dict]) -> float:
|
||||
"""Calculate Shannon entropy for categorical distribution"""
|
||||
if not value_distribution:
|
||||
return 0.0
|
||||
|
||||
entropy = 0.0
|
||||
for item in value_distribution:
|
||||
p = item["percentage"]
|
||||
if p > 0:
|
||||
entropy -= p * math.log2(p)
|
||||
|
||||
return entropy
|
||||
|
||||
async def _analyze_temporal_distributions(self, connection, table_name: str, temporal_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Analyze distribution patterns for temporal columns"""
|
||||
temporal_analysis = {}
|
||||
|
||||
for column in temporal_columns:
|
||||
col_name = column["column_name"]
|
||||
try:
|
||||
# Date range analysis
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
range_sql = f"""
|
||||
SELECT
|
||||
MIN({col_name}) as earliest,
|
||||
MAX({col_name}) as latest,
|
||||
COUNT({col_name}) as non_null_count
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
auth_context = get_auth_context()
|
||||
range_result = await connection.execute(range_sql, auth_context=auth_context)
|
||||
|
||||
if range_result.data and range_result.data[0]["non_null_count"] > 0:
|
||||
range_data = range_result.data[0]
|
||||
earliest = range_data["earliest"]
|
||||
latest = range_data["latest"]
|
||||
|
||||
# Calculate span
|
||||
date_span_info = self._calculate_date_span(earliest, latest)
|
||||
|
||||
# Temporal patterns analysis
|
||||
temporal_patterns = await self._analyze_temporal_patterns(
|
||||
connection, table_name, col_name, sampling_info
|
||||
)
|
||||
|
||||
temporal_analysis[col_name] = {
|
||||
"data_type": column["data_type"],
|
||||
"non_null_count": range_data["non_null_count"],
|
||||
"date_range": {
|
||||
"earliest": str(earliest),
|
||||
"latest": str(latest),
|
||||
**date_span_info
|
||||
},
|
||||
"temporal_patterns": temporal_patterns
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze temporal column {col_name}: {str(e)}")
|
||||
temporal_analysis[col_name] = {"error": str(e)}
|
||||
|
||||
return temporal_analysis
|
||||
|
||||
def _calculate_date_span(self, earliest, latest) -> Dict[str, Any]:
|
||||
"""Calculate date span information"""
|
||||
try:
|
||||
if isinstance(earliest, str):
|
||||
earliest = datetime.fromisoformat(earliest.replace('Z', '+00:00'))
|
||||
if isinstance(latest, str):
|
||||
latest = datetime.fromisoformat(latest.replace('Z', '+00:00'))
|
||||
|
||||
span = latest - earliest
|
||||
span_days = span.days
|
||||
|
||||
return {
|
||||
"span_days": span_days,
|
||||
"span_years": round(span_days / 365.25, 2),
|
||||
"span_description": self._describe_time_span(span_days)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to calculate date span: {str(e)}")
|
||||
return {"span_days": 0}
|
||||
|
||||
def _describe_time_span(self, days: int) -> str:
|
||||
"""Describe time span in human readable format"""
|
||||
if days < 1:
|
||||
return "less_than_day"
|
||||
elif days < 7:
|
||||
return "days"
|
||||
elif days < 30:
|
||||
return "weeks"
|
||||
elif days < 365:
|
||||
return "months"
|
||||
else:
|
||||
return "years"
|
||||
|
||||
async def _analyze_temporal_patterns(self, connection, table_name: str, col_name: str, sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Analyze temporal patterns like seasonality and trends"""
|
||||
try:
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
# Weekly pattern analysis
|
||||
weekly_pattern_sql = f"""
|
||||
SELECT
|
||||
DAYOFWEEK({col_name}) as day_of_week,
|
||||
COUNT(*) as count
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
GROUP BY DAYOFWEEK({col_name})
|
||||
ORDER BY day_of_week
|
||||
"""
|
||||
|
||||
auth_context = get_auth_context()
|
||||
weekly_result = await connection.execute(weekly_pattern_sql, auth_context=auth_context)
|
||||
|
||||
weekly_pattern = []
|
||||
if weekly_result.data:
|
||||
total_records = sum(row["count"] for row in weekly_result.data)
|
||||
for row in weekly_result.data:
|
||||
percentage = row["count"] / total_records if total_records > 0 else 0
|
||||
weekly_pattern.append(round(percentage, 3))
|
||||
|
||||
# Monthly trend analysis (simplified)
|
||||
monthly_trend_sql = f"""
|
||||
SELECT
|
||||
YEAR({col_name}) as year,
|
||||
MONTH({col_name}) as month,
|
||||
COUNT(*) as count
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
GROUP BY YEAR({col_name}), MONTH({col_name})
|
||||
ORDER BY year, month
|
||||
LIMIT 12
|
||||
"""
|
||||
|
||||
monthly_result = await connection.execute(monthly_trend_sql, auth_context=auth_context)
|
||||
monthly_trend = "stable" # Simplified trend analysis
|
||||
|
||||
if monthly_result.data and len(monthly_result.data) > 3:
|
||||
counts = [row["count"] for row in monthly_result.data]
|
||||
if len(counts) > 1:
|
||||
trend_direction = "increasing" if counts[-1] > counts[0] else "decreasing"
|
||||
monthly_trend = trend_direction
|
||||
|
||||
return {
|
||||
"weekly_pattern": weekly_pattern,
|
||||
"monthly_trend": monthly_trend,
|
||||
"seasonal_component": self._estimate_seasonality(weekly_pattern)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze temporal patterns for {col_name}: {str(e)}")
|
||||
return {"weekly_pattern": [], "monthly_trend": "unknown"}
|
||||
|
||||
def _estimate_seasonality(self, weekly_pattern: List[float]) -> float:
|
||||
"""Estimate seasonality strength based on weekly pattern variance"""
|
||||
if len(weekly_pattern) < 7:
|
||||
return 0.0
|
||||
|
||||
mean_percentage = sum(weekly_pattern) / len(weekly_pattern)
|
||||
variance = sum((x - mean_percentage) ** 2 for x in weekly_pattern) / len(weekly_pattern)
|
||||
|
||||
# Normalize variance to 0-1 scale as seasonality indicator
|
||||
seasonality = min(variance * 10, 1.0) # Scaling factor
|
||||
return round(seasonality, 3)
|
||||
|
||||
async def _generate_data_quality_insights(self, connection, table_name: str, columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Generate overall data quality insights"""
|
||||
try:
|
||||
total_columns = len(columns)
|
||||
|
||||
# Calculate null rates across all columns
|
||||
null_analysis = await self._analyze_overall_null_rates(connection, table_name, columns, sampling_info)
|
||||
|
||||
# Identify potential data quality issues
|
||||
quality_issues = []
|
||||
|
||||
# High null rate columns
|
||||
high_null_columns = [col for col, rate in null_analysis["column_null_rates"].items() if rate > 0.2]
|
||||
if high_null_columns:
|
||||
quality_issues.append({
|
||||
"issue_type": "high_null_rates",
|
||||
"severity": "medium",
|
||||
"affected_columns": high_null_columns,
|
||||
"description": f"{len(high_null_columns)} columns have null rates > 20%"
|
||||
})
|
||||
|
||||
# Calculate overall data quality score
|
||||
avg_null_rate = sum(null_analysis["column_null_rates"].values()) / len(null_analysis["column_null_rates"]) if null_analysis["column_null_rates"] else 0
|
||||
data_quality_score = max(0, 1 - avg_null_rate)
|
||||
|
||||
return {
|
||||
"total_columns_analyzed": total_columns,
|
||||
"null_analysis": null_analysis,
|
||||
"data_quality_score": round(data_quality_score, 3),
|
||||
"quality_issues": quality_issues,
|
||||
"recommendations": self._generate_quality_recommendations(quality_issues, null_analysis)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate data quality insights: {str(e)}")
|
||||
return {"data_quality_score": 0.0, "error": str(e)}
|
||||
|
||||
async def _analyze_overall_null_rates(self, connection, table_name: str, columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Analyze null rates across all columns"""
|
||||
column_null_rates = {}
|
||||
total_null_count = 0
|
||||
total_cell_count = 0
|
||||
|
||||
for column in columns:
|
||||
col_name = column["column_name"]
|
||||
try:
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
null_sql = f"""
|
||||
SELECT
|
||||
COUNT(*) as total_count,
|
||||
COUNT({col_name}) as non_null_count
|
||||
FROM {table_expr}
|
||||
"""
|
||||
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(null_sql, auth_context=auth_context)
|
||||
if result.data:
|
||||
data = result.data[0]
|
||||
total_count = data["total_count"]
|
||||
non_null_count = data["non_null_count"]
|
||||
null_count = total_count - non_null_count
|
||||
null_rate = null_count / total_count if total_count > 0 else 0
|
||||
|
||||
column_null_rates[col_name] = round(null_rate, 4)
|
||||
total_null_count += null_count
|
||||
total_cell_count += total_count
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze null rate for column {col_name}: {str(e)}")
|
||||
column_null_rates[col_name] = 0.0
|
||||
|
||||
overall_null_rate = total_null_count / total_cell_count if total_cell_count > 0 else 0
|
||||
|
||||
return {
|
||||
"column_null_rates": column_null_rates,
|
||||
"overall_null_rate": round(overall_null_rate, 4),
|
||||
"columns_with_nulls": len([rate for rate in column_null_rates.values() if rate > 0])
|
||||
}
|
||||
|
||||
def _generate_quality_recommendations(self, quality_issues: List[Dict], null_analysis: Dict) -> List[Dict]:
|
||||
"""Generate data quality improvement recommendations"""
|
||||
recommendations = []
|
||||
|
||||
# Recommendations based on null analysis
|
||||
overall_null_rate = null_analysis.get("overall_null_rate", 0)
|
||||
if overall_null_rate > 0.1:
|
||||
recommendations.append({
|
||||
"type": "data_completeness",
|
||||
"priority": "high" if overall_null_rate > 0.3 else "medium",
|
||||
"description": f"Overall null rate is {overall_null_rate:.1%}",
|
||||
"action": "Review data collection and validation processes"
|
||||
})
|
||||
|
||||
# Recommendations based on quality issues
|
||||
for issue in quality_issues:
|
||||
if issue["issue_type"] == "high_null_rates":
|
||||
recommendations.append({
|
||||
"type": "column_completeness",
|
||||
"priority": issue["severity"],
|
||||
"description": issue["description"],
|
||||
"action": f"Focus on improving data completeness for: {', '.join(issue['affected_columns'][:3])}"
|
||||
})
|
||||
|
||||
return recommendations
|
||||
|
||||
def _generate_analysis_summary(self, distribution_analysis: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Generate high-level summary of distribution analysis"""
|
||||
summary = {
|
||||
"numeric_columns_count": len(distribution_analysis.get("numeric_columns", {})),
|
||||
"categorical_columns_count": len(distribution_analysis.get("categorical_columns", {})),
|
||||
"temporal_columns_count": len(distribution_analysis.get("temporal_columns", {}))
|
||||
}
|
||||
|
||||
# Identify interesting patterns
|
||||
patterns = []
|
||||
|
||||
# Check for highly skewed numeric columns
|
||||
numeric_cols = distribution_analysis.get("numeric_columns", {})
|
||||
skewed_cols = [
|
||||
col for col, info in numeric_cols.items()
|
||||
if isinstance(info, dict) and
|
||||
info.get("distribution_shape", {}).get("skewness_indicator") in ["right_skewed", "left_skewed"]
|
||||
]
|
||||
|
||||
if skewed_cols:
|
||||
patterns.append(f"Found {len(skewed_cols)} skewed numeric columns")
|
||||
|
||||
# Check for high cardinality categorical columns
|
||||
categorical_cols = distribution_analysis.get("categorical_columns", {})
|
||||
high_cardinality_cols = [
|
||||
col for col, info in categorical_cols.items()
|
||||
if isinstance(info, dict) and info.get("cardinality", 0) > 1000
|
||||
]
|
||||
|
||||
if high_cardinality_cols:
|
||||
patterns.append(f"Found {len(high_cardinality_cols)} high cardinality categorical columns")
|
||||
|
||||
summary["notable_patterns"] = patterns
|
||||
|
||||
return summary
|
||||
1022
doris_mcp_server/utils/data_governance_tools.py
Normal file
1173
doris_mcp_server/utils/data_quality_tools.py
Normal file
1025
doris_mcp_server/utils/dependency_analysis_tools.py
Normal file
@@ -15,77 +15,573 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Logging configuration for Doris MCP Server.
|
||||
Enhanced Logging configuration for Doris MCP Server.
|
||||
Features:
|
||||
- Log level-based file separation
|
||||
- Timestamped log entries
|
||||
- Automatic log rotation
|
||||
- Comprehensive logging coverage
|
||||
"""
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
import logging.handlers
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import threading
|
||||
|
||||
|
||||
def setup_logging(
|
||||
level: str = "INFO",
|
||||
log_file: str | None = None,
|
||||
log_format: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Setup logging configuration.
|
||||
class TimestampedFormatter(logging.Formatter):
|
||||
"""Custom formatter with enhanced timestamp and structured format"""
|
||||
|
||||
def __init__(self, fmt=None, datefmt=None, style='%'):
|
||||
if fmt is None:
|
||||
fmt = "%(asctime)s.%(msecs)03d %(level_aligned)s %(name)s:%(lineno)d - %(message)s"
|
||||
if datefmt is None:
|
||||
datefmt = "%Y-%m-%d %H:%M:%S"
|
||||
super().__init__(fmt, datefmt, style)
|
||||
|
||||
def format(self, record):
|
||||
"""Format log record with enhanced information and proper alignment"""
|
||||
# Add process info if available
|
||||
if hasattr(record, 'process') and record.process:
|
||||
record.process_info = f"[PID:{record.process}]"
|
||||
else:
|
||||
record.process_info = ""
|
||||
|
||||
# Add thread info if available
|
||||
if hasattr(record, 'thread') and record.thread:
|
||||
record.thread_info = f"[TID:{record.thread}]"
|
||||
else:
|
||||
record.thread_info = ""
|
||||
|
||||
# Format with proper alignment after the level name
|
||||
# Calculate padding needed for alignment
|
||||
level_name = record.levelname
|
||||
max_level_length = 8 # Length of "CRITICAL"
|
||||
padding = max_level_length - len(level_name)
|
||||
record.level_aligned = f"[{level_name}]{' ' * padding}"
|
||||
|
||||
return super().format(record)
|
||||
|
||||
Args:
|
||||
level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
log_file: Optional log file path
|
||||
log_format: Optional custom log format
|
||||
"""
|
||||
if log_format is None:
|
||||
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
|
||||
# Base configuration
|
||||
config: dict[str, Any] = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"default": {"format": log_format, "datefmt": "%Y-%m-%d %H:%M:%S"}
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": level,
|
||||
"formatter": "default",
|
||||
"stream": sys.stdout,
|
||||
}
|
||||
},
|
||||
"root": {"level": level, "handlers": ["console"]},
|
||||
"loggers": {
|
||||
"doris_mcp_server": {
|
||||
"level": level,
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Add file handler if log_file is specified
|
||||
if log_file:
|
||||
class LevelBasedFileHandler(logging.Handler):
|
||||
"""Custom handler that writes different log levels to different files"""
|
||||
|
||||
def __init__(self, log_dir: str, base_name: str = "doris_mcp_server",
|
||||
max_bytes: int = 10*1024*1024, backup_count: int = 5):
|
||||
super().__init__()
|
||||
self.log_dir = Path(log_dir)
|
||||
self.base_name = base_name
|
||||
self.max_bytes = max_bytes
|
||||
self.backup_count = backup_count
|
||||
|
||||
# Ensure log directory exists
|
||||
log_path = Path(log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
config["handlers"]["file"] = {
|
||||
"class": "logging.handlers.RotatingFileHandler",
|
||||
"level": level,
|
||||
"formatter": "default",
|
||||
"filename": log_file,
|
||||
"maxBytes": 10485760, # 10MB
|
||||
"backupCount": 5,
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create handlers for different log levels
|
||||
self.handlers = {}
|
||||
self._setup_level_handlers()
|
||||
|
||||
def _setup_level_handlers(self):
|
||||
"""Setup rotating file handlers for different log levels"""
|
||||
level_files = {
|
||||
'DEBUG': 'debug.log',
|
||||
'INFO': 'info.log',
|
||||
'WARNING': 'warning.log',
|
||||
'ERROR': 'error.log',
|
||||
'CRITICAL': 'critical.log'
|
||||
}
|
||||
|
||||
formatter = TimestampedFormatter()
|
||||
|
||||
for level, filename in level_files.items():
|
||||
file_path = self.log_dir / f"{self.base_name}_{filename}"
|
||||
handler = logging.handlers.RotatingFileHandler(
|
||||
file_path,
|
||||
maxBytes=self.max_bytes,
|
||||
backupCount=self.backup_count,
|
||||
encoding='utf-8'
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
handler.setLevel(getattr(logging, level))
|
||||
self.handlers[level] = handler
|
||||
|
||||
def emit(self, record):
|
||||
"""Emit log record to appropriate level-based file"""
|
||||
level_name = record.levelname
|
||||
if level_name in self.handlers:
|
||||
try:
|
||||
self.handlers[level_name].emit(record)
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
def close(self):
|
||||
"""Close all handlers"""
|
||||
for handler in self.handlers.values():
|
||||
handler.close()
|
||||
super().close()
|
||||
|
||||
# Add file handler to root and package loggers
|
||||
config["root"]["handlers"].append("file")
|
||||
config["loggers"]["doris_mcp_server"]["handlers"].append("file")
|
||||
|
||||
logging.config.dictConfig(config)
|
||||
class LogCleanupManager:
|
||||
"""Log file cleanup manager for automatic maintenance"""
|
||||
|
||||
def __init__(self, log_dir: str, max_age_days: int = 30, cleanup_interval_hours: int = 24):
|
||||
"""
|
||||
Initialize log cleanup manager.
|
||||
|
||||
Args:
|
||||
log_dir: Directory containing log files
|
||||
max_age_days: Maximum age of log files in days (default: 30 days)
|
||||
cleanup_interval_hours: Cleanup interval in hours (default: 24 hours)
|
||||
"""
|
||||
self.log_dir = Path(log_dir)
|
||||
self.max_age_days = max_age_days
|
||||
self.cleanup_interval_hours = cleanup_interval_hours
|
||||
self.cleanup_thread = None
|
||||
self.stop_event = threading.Event()
|
||||
self.logger = None
|
||||
|
||||
def start_cleanup_scheduler(self):
|
||||
"""Start the cleanup scheduler in a background thread"""
|
||||
if self.cleanup_thread and self.cleanup_thread.is_alive():
|
||||
return
|
||||
|
||||
self.stop_event.clear()
|
||||
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
|
||||
self.cleanup_thread.start()
|
||||
|
||||
# Get logger for this class
|
||||
if not self.logger:
|
||||
self.logger = logging.getLogger("doris_mcp_server.log_cleanup")
|
||||
|
||||
self.logger.info(f"Log cleanup scheduler started - cleanup every {self.cleanup_interval_hours}h, max age {self.max_age_days} days")
|
||||
|
||||
def stop_cleanup_scheduler(self):
|
||||
"""Stop the cleanup scheduler"""
|
||||
if self.cleanup_thread and self.cleanup_thread.is_alive():
|
||||
self.stop_event.set()
|
||||
self.cleanup_thread.join(timeout=5)
|
||||
if self.logger:
|
||||
self.logger.info("Log cleanup scheduler stopped")
|
||||
|
||||
def _cleanup_loop(self):
|
||||
"""Background loop for periodic cleanup"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
self.cleanup_old_logs()
|
||||
# Sleep for the specified interval, but check stop event every 60 seconds
|
||||
for _ in range(self.cleanup_interval_hours * 60): # Convert hours to minutes
|
||||
if self.stop_event.wait(60): # Wait 60 seconds or until stop event
|
||||
break
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"Error in log cleanup loop: {e}")
|
||||
# Sleep for 5 minutes before retrying
|
||||
self.stop_event.wait(300)
|
||||
|
||||
def cleanup_old_logs(self):
|
||||
"""Clean up old log files based on age"""
|
||||
if not self.log_dir.exists():
|
||||
return
|
||||
|
||||
current_time = datetime.now()
|
||||
cutoff_time = current_time - timedelta(days=self.max_age_days)
|
||||
|
||||
cleaned_files = []
|
||||
cleaned_size = 0
|
||||
|
||||
# Pattern for log files (including backup files)
|
||||
log_patterns = [
|
||||
"doris_mcp_server_*.log",
|
||||
"doris_mcp_server_*.log.*" # Backup files
|
||||
]
|
||||
|
||||
for pattern in log_patterns:
|
||||
for log_file in self.log_dir.glob(pattern):
|
||||
try:
|
||||
# Get file modification time
|
||||
file_mtime = datetime.fromtimestamp(log_file.stat().st_mtime)
|
||||
|
||||
if file_mtime < cutoff_time:
|
||||
file_size = log_file.stat().st_size
|
||||
log_file.unlink() # Delete the file
|
||||
cleaned_files.append(log_file.name)
|
||||
cleaned_size += file_size
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.warning(f"Failed to cleanup log file {log_file}: {e}")
|
||||
|
||||
if cleaned_files and self.logger:
|
||||
size_mb = cleaned_size / (1024 * 1024)
|
||||
self.logger.info(f"Cleaned up {len(cleaned_files)} old log files, freed {size_mb:.2f} MB")
|
||||
self.logger.debug(f"Cleaned files: {', '.join(cleaned_files)}")
|
||||
|
||||
def get_cleanup_stats(self) -> dict:
|
||||
"""Get statistics about log files and cleanup status"""
|
||||
if not self.log_dir.exists():
|
||||
return {"error": "Log directory does not exist"}
|
||||
|
||||
stats = {
|
||||
"log_directory": str(self.log_dir.absolute()),
|
||||
"max_age_days": self.max_age_days,
|
||||
"cleanup_interval_hours": self.cleanup_interval_hours,
|
||||
"scheduler_running": self.cleanup_thread and self.cleanup_thread.is_alive(),
|
||||
"total_files": 0,
|
||||
"total_size_mb": 0,
|
||||
"files_by_age": {"recent": 0, "old": 0},
|
||||
"oldest_file": None,
|
||||
"newest_file": None
|
||||
}
|
||||
|
||||
current_time = datetime.now()
|
||||
cutoff_time = current_time - timedelta(days=self.max_age_days)
|
||||
oldest_time = None
|
||||
newest_time = None
|
||||
|
||||
log_patterns = ["doris_mcp_server_*.log", "doris_mcp_server_*.log.*"]
|
||||
|
||||
for pattern in log_patterns:
|
||||
for log_file in self.log_dir.glob(pattern):
|
||||
try:
|
||||
file_stat = log_file.stat()
|
||||
file_mtime = datetime.fromtimestamp(file_stat.st_mtime)
|
||||
|
||||
stats["total_files"] += 1
|
||||
stats["total_size_mb"] += file_stat.st_size / (1024 * 1024)
|
||||
|
||||
if file_mtime < cutoff_time:
|
||||
stats["files_by_age"]["old"] += 1
|
||||
else:
|
||||
stats["files_by_age"]["recent"] += 1
|
||||
|
||||
if oldest_time is None or file_mtime < oldest_time:
|
||||
oldest_time = file_mtime
|
||||
stats["oldest_file"] = {"name": log_file.name, "age_days": (current_time - file_mtime).days}
|
||||
|
||||
if newest_time is None or file_mtime > newest_time:
|
||||
newest_time = file_mtime
|
||||
stats["newest_file"] = {"name": log_file.name, "age_days": (current_time - file_mtime).days}
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
stats["total_size_mb"] = round(stats["total_size_mb"], 2)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
class DorisLoggerManager:
|
||||
"""Centralized logger manager for Doris MCP Server"""
|
||||
|
||||
def __init__(self):
|
||||
self.is_initialized = False
|
||||
self.log_dir = None
|
||||
self.config = None
|
||||
self.loggers = {}
|
||||
self.cleanup_manager = None
|
||||
|
||||
def setup_logging(self,
|
||||
level: str = "INFO",
|
||||
log_dir: str = "logs",
|
||||
enable_console: bool = True,
|
||||
enable_file: bool = True,
|
||||
enable_audit: bool = True,
|
||||
audit_file: Optional[str] = None,
|
||||
max_file_size: int = 10*1024*1024,
|
||||
backup_count: int = 5,
|
||||
enable_cleanup: bool = True,
|
||||
max_age_days: int = 30,
|
||||
cleanup_interval_hours: int = 24) -> None:
|
||||
"""
|
||||
Setup comprehensive logging configuration.
|
||||
|
||||
Args:
|
||||
level: Base logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
log_dir: Directory for log files
|
||||
enable_console: Enable console output
|
||||
enable_file: Enable file logging
|
||||
enable_audit: Enable audit logging
|
||||
audit_file: Custom audit log file path
|
||||
max_file_size: Maximum size per log file (bytes)
|
||||
backup_count: Number of backup files to keep
|
||||
enable_cleanup: Enable automatic log cleanup
|
||||
max_age_days: Maximum age of log files in days (default: 30)
|
||||
cleanup_interval_hours: Cleanup interval in hours (default: 24)
|
||||
"""
|
||||
if self.is_initialized:
|
||||
return
|
||||
|
||||
self.log_dir = Path(log_dir)
|
||||
log_dir_writable = True # Initialize the variable
|
||||
|
||||
# Try to create log directory, fallback to console-only if fails
|
||||
try:
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
except (OSError, PermissionError) as e:
|
||||
# If we can't create log directory (e.g., read-only filesystem in stdio mode),
|
||||
# fall back to console-only logging
|
||||
log_dir_writable = False
|
||||
enable_file = False
|
||||
enable_audit = False
|
||||
enable_cleanup = False
|
||||
# Don't use print() in stdio mode as it interferes with MCP JSON protocol
|
||||
# Log the warning through the logging system instead, which will be handled after setup
|
||||
|
||||
# Clear existing handlers
|
||||
root_logger = logging.getLogger()
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
|
||||
# Set root logger level
|
||||
root_logger.setLevel(logging.DEBUG) # Allow all levels, handlers will filter
|
||||
|
||||
handlers = []
|
||||
|
||||
# Console handler
|
||||
if enable_console:
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(getattr(logging, level.upper()))
|
||||
console_formatter = TimestampedFormatter(
|
||||
fmt="%(asctime)s.%(msecs)03d %(level_aligned)s %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
handlers.append(console_handler)
|
||||
|
||||
# Level-based file handlers
|
||||
if enable_file:
|
||||
level_handler = LevelBasedFileHandler(
|
||||
log_dir=str(self.log_dir),
|
||||
base_name="doris_mcp_server",
|
||||
max_bytes=max_file_size,
|
||||
backup_count=backup_count
|
||||
)
|
||||
level_handler.setLevel(logging.DEBUG) # Accept all levels
|
||||
handlers.append(level_handler)
|
||||
|
||||
# Combined application log (all levels in one file)
|
||||
if enable_file:
|
||||
app_log_file = self.log_dir / "doris_mcp_server_all.log"
|
||||
app_handler = logging.handlers.RotatingFileHandler(
|
||||
app_log_file,
|
||||
maxBytes=max_file_size,
|
||||
backupCount=backup_count,
|
||||
encoding='utf-8'
|
||||
)
|
||||
app_handler.setLevel(getattr(logging, level.upper()))
|
||||
app_formatter = TimestampedFormatter()
|
||||
app_handler.setFormatter(app_formatter)
|
||||
handlers.append(app_handler)
|
||||
|
||||
# Audit logger (separate from main logging)
|
||||
if enable_audit:
|
||||
audit_file_path = audit_file or str(self.log_dir / "doris_mcp_server_audit.log")
|
||||
audit_logger = logging.getLogger("audit")
|
||||
audit_logger.setLevel(logging.INFO)
|
||||
|
||||
# Clear existing audit handlers
|
||||
for handler in audit_logger.handlers[:]:
|
||||
audit_logger.removeHandler(handler)
|
||||
|
||||
audit_handler = logging.handlers.RotatingFileHandler(
|
||||
audit_file_path,
|
||||
maxBytes=max_file_size,
|
||||
backupCount=backup_count,
|
||||
encoding='utf-8'
|
||||
)
|
||||
audit_formatter = TimestampedFormatter(
|
||||
fmt="%(asctime)s.%(msecs)03d [AUDIT] %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
audit_handler.setFormatter(audit_formatter)
|
||||
audit_logger.addHandler(audit_handler)
|
||||
audit_logger.propagate = False # Don't propagate to root logger
|
||||
|
||||
# Add all handlers to root logger
|
||||
for handler in handlers:
|
||||
root_logger.addHandler(handler)
|
||||
|
||||
# Setup package-specific loggers
|
||||
self._setup_package_loggers(level)
|
||||
|
||||
# Setup log cleanup manager
|
||||
if enable_cleanup and enable_file:
|
||||
self.cleanup_manager = LogCleanupManager(
|
||||
log_dir=str(self.log_dir),
|
||||
max_age_days=max_age_days,
|
||||
cleanup_interval_hours=cleanup_interval_hours
|
||||
)
|
||||
self.cleanup_manager.start_cleanup_scheduler()
|
||||
|
||||
self.is_initialized = True
|
||||
|
||||
# Log initialization message
|
||||
logger = self.get_logger("doris_mcp_server.logger")
|
||||
logger.info("=" * 80)
|
||||
logger.info("Doris MCP Server Logging System Initialized")
|
||||
logger.info(f"Log Level: {level}")
|
||||
if log_dir_writable:
|
||||
logger.info(f"Log Directory: {self.log_dir.absolute()}")
|
||||
else:
|
||||
logger.info("Log Directory: Not available (console-only mode)")
|
||||
logger.info(f"Console Logging: {'Enabled' if enable_console else 'Disabled'}")
|
||||
logger.info(f"File Logging: {'Enabled' if enable_file else 'Disabled (fallback mode)'}")
|
||||
logger.info(f"Audit Logging: {'Enabled' if enable_audit else 'Disabled (fallback mode)'}")
|
||||
logger.info(f"Log Cleanup: {'Enabled' if enable_cleanup and enable_file else 'Disabled (fallback mode)'}")
|
||||
if enable_cleanup and enable_file:
|
||||
logger.info(f"Cleanup Settings: Max age {max_age_days} days, interval {cleanup_interval_hours}h")
|
||||
if not log_dir_writable:
|
||||
logger.warning("Running in console-only logging mode due to filesystem permissions")
|
||||
logger.warning(f"Could not create log directory '{log_dir}' - stdio mode fallback enabled")
|
||||
logger.info("=" * 80)
|
||||
|
||||
def _setup_package_loggers(self, level: str):
|
||||
"""Setup specific loggers for different modules"""
|
||||
package_loggers = [
|
||||
"doris_mcp_server",
|
||||
"doris_mcp_server.main",
|
||||
"doris_mcp_server.utils",
|
||||
"doris_mcp_server.tools",
|
||||
"doris_mcp_client"
|
||||
]
|
||||
|
||||
for logger_name in package_loggers:
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(getattr(logging, level.upper()))
|
||||
# Don't add handlers here - they inherit from root logger
|
||||
|
||||
def get_logger(self, name: str) -> logging.Logger:
|
||||
"""
|
||||
Get a logger instance with proper configuration.
|
||||
|
||||
Args:
|
||||
name: Logger name (usually __name__)
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
if name not in self.loggers:
|
||||
logger = logging.getLogger(name)
|
||||
self.loggers[name] = logger
|
||||
|
||||
return self.loggers[name]
|
||||
|
||||
def get_audit_logger(self) -> logging.Logger:
|
||||
"""Get the audit logger"""
|
||||
return logging.getLogger("audit")
|
||||
|
||||
def log_system_info(self):
|
||||
"""Log system information for debugging"""
|
||||
logger = self.get_logger("doris_mcp_server.system")
|
||||
logger.info("System Information:")
|
||||
logger.info(f"Python Version: {sys.version}")
|
||||
logger.info(f"Platform: {sys.platform}")
|
||||
logger.info(f"Working Directory: {os.getcwd()}")
|
||||
logger.info(f"Process ID: {os.getpid()}")
|
||||
|
||||
# Log environment variables (filtered)
|
||||
env_vars = ["LOG_LEVEL", "LOG_FILE_PATH", "ENABLE_AUDIT", "AUDIT_FILE_PATH"]
|
||||
for var in env_vars:
|
||||
value = os.getenv(var, "Not Set")
|
||||
logger.info(f"Environment {var}: {value}")
|
||||
|
||||
def get_cleanup_stats(self) -> dict:
|
||||
"""Get log cleanup statistics"""
|
||||
if self.cleanup_manager:
|
||||
return self.cleanup_manager.get_cleanup_stats()
|
||||
else:
|
||||
return {"error": "Log cleanup is not enabled"}
|
||||
|
||||
def manual_cleanup(self) -> dict:
|
||||
"""Manually trigger log cleanup and return statistics"""
|
||||
if self.cleanup_manager:
|
||||
self.cleanup_manager.cleanup_old_logs()
|
||||
return self.cleanup_manager.get_cleanup_stats()
|
||||
else:
|
||||
return {"error": "Log cleanup is not enabled"}
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown logging system"""
|
||||
if not self.is_initialized:
|
||||
return
|
||||
|
||||
logger = self.get_logger("doris_mcp_server.logger")
|
||||
logger.info("Shutting down logging system...")
|
||||
|
||||
# Stop cleanup manager
|
||||
if self.cleanup_manager:
|
||||
self.cleanup_manager.stop_cleanup_scheduler()
|
||||
|
||||
# Close all handlers
|
||||
root_logger = logging.getLogger()
|
||||
for handler in root_logger.handlers[:]:
|
||||
try:
|
||||
handler.close()
|
||||
except Exception as e:
|
||||
print(f"Error closing handler: {e}")
|
||||
|
||||
# Close audit logger handlers
|
||||
audit_logger = logging.getLogger("audit")
|
||||
for handler in audit_logger.handlers[:]:
|
||||
try:
|
||||
handler.close()
|
||||
except Exception as e:
|
||||
print(f"Error closing audit handler: {e}")
|
||||
|
||||
self.is_initialized = False
|
||||
|
||||
|
||||
# Global logger manager instance
|
||||
_logger_manager = DorisLoggerManager()
|
||||
|
||||
|
||||
def setup_logging(level: str = "INFO",
|
||||
log_dir: str = "logs",
|
||||
enable_console: bool = True,
|
||||
enable_file: bool = True,
|
||||
enable_audit: bool = True,
|
||||
audit_file: Optional[str] = None,
|
||||
max_file_size: int = 10*1024*1024,
|
||||
backup_count: int = 5,
|
||||
enable_cleanup: bool = True,
|
||||
max_age_days: int = 30,
|
||||
cleanup_interval_hours: int = 24) -> None:
|
||||
"""
|
||||
Setup logging configuration (convenience function).
|
||||
|
||||
Args:
|
||||
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
log_dir: Directory for log files
|
||||
enable_console: Enable console output
|
||||
enable_file: Enable file logging
|
||||
enable_audit: Enable audit logging
|
||||
audit_file: Custom audit log file path
|
||||
max_file_size: Maximum size per log file (bytes)
|
||||
backup_count: Number of backup files to keep
|
||||
enable_cleanup: Enable automatic log cleanup
|
||||
max_age_days: Maximum age of log files in days (default: 30)
|
||||
cleanup_interval_hours: Cleanup interval in hours (default: 24)
|
||||
"""
|
||||
_logger_manager.setup_logging(
|
||||
level=level,
|
||||
log_dir=log_dir,
|
||||
enable_console=enable_console,
|
||||
enable_file=enable_file,
|
||||
enable_audit=enable_audit,
|
||||
audit_file=audit_file,
|
||||
max_file_size=max_file_size,
|
||||
backup_count=backup_count,
|
||||
enable_cleanup=enable_cleanup,
|
||||
max_age_days=max_age_days,
|
||||
cleanup_interval_hours=cleanup_interval_hours
|
||||
)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
@@ -93,9 +589,60 @@ def get_logger(name: str) -> logging.Logger:
|
||||
Get a logger instance.
|
||||
|
||||
Args:
|
||||
name: Logger name
|
||||
name: Logger name (usually __name__)
|
||||
|
||||
Returns:
|
||||
Logger instance
|
||||
Configured logger instance
|
||||
"""
|
||||
return logging.getLogger(name)
|
||||
return _logger_manager.get_logger(name)
|
||||
|
||||
|
||||
def get_audit_logger() -> logging.Logger:
|
||||
"""Get the audit logger"""
|
||||
return _logger_manager.get_audit_logger()
|
||||
|
||||
|
||||
def log_system_info():
|
||||
"""Log system information for debugging"""
|
||||
_logger_manager.log_system_info()
|
||||
|
||||
|
||||
def get_cleanup_stats() -> dict:
|
||||
"""Get log cleanup statistics"""
|
||||
return _logger_manager.get_cleanup_stats()
|
||||
|
||||
|
||||
def manual_cleanup() -> dict:
|
||||
"""Manually trigger log cleanup and return statistics"""
|
||||
return _logger_manager.manual_cleanup()
|
||||
|
||||
|
||||
def shutdown_logging():
|
||||
"""Shutdown logging system"""
|
||||
_logger_manager.shutdown()
|
||||
|
||||
|
||||
# Compatibility function for existing code
|
||||
def setup_logging_old(level: str = "INFO",
|
||||
log_file: str | None = None,
|
||||
log_format: str | None = None) -> None:
|
||||
"""
|
||||
Legacy setup function for backward compatibility.
|
||||
|
||||
Args:
|
||||
level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
log_file: Optional log file path (deprecated - use log_dir instead)
|
||||
log_format: Optional custom log format (deprecated)
|
||||
"""
|
||||
# Extract directory from log_file if provided
|
||||
log_dir = "logs"
|
||||
if log_file:
|
||||
log_dir = str(Path(log_file).parent)
|
||||
|
||||
setup_logging(
|
||||
level=level,
|
||||
log_dir=log_dir,
|
||||
enable_console=True,
|
||||
enable_file=True,
|
||||
enable_audit=True
|
||||
)
|
||||
|
||||
1611
doris_mcp_server/utils/monitoring_tools.py
Normal file
1810
doris_mcp_server/utils/performance_analytics_tools.py
Normal file
@@ -33,7 +33,11 @@ from datetime import datetime, timedelta, date
|
||||
from typing import Any, Dict
|
||||
from decimal import Decimal
|
||||
|
||||
import sqlparse
|
||||
|
||||
from .db import DorisConnectionManager, QueryResult
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import get_auth_context
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -92,7 +96,7 @@ class QueryCache:
|
||||
self.max_size = max_size
|
||||
self.default_ttl = default_ttl
|
||||
self.cache: dict[str, CachedQuery] = {}
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def _generate_cache_key(
|
||||
self, sql: str, parameters: dict[str, Any] | None = None
|
||||
@@ -194,7 +198,7 @@ class QueryOptimizer:
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger = get_logger(__name__)
|
||||
self.optimization_rules = self._load_optimization_rules()
|
||||
|
||||
def _load_optimization_rules(self) -> list[dict[str, Any]]:
|
||||
@@ -318,7 +322,7 @@ class DorisQueryExecutor:
|
||||
def __init__(self, connection_manager: DorisConnectionManager, config=None):
|
||||
self.connection_manager = connection_manager
|
||||
self.config = config or self._create_default_config()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Initialize components
|
||||
cache_config = getattr(self.config, 'performance', None)
|
||||
@@ -425,27 +429,27 @@ class DorisQueryExecutor:
|
||||
self, query_request: QueryRequest, auth_context
|
||||
) -> QueryResult:
|
||||
"""Internal query execution"""
|
||||
|
||||
# Database configuration should already be handled during authentication
|
||||
# No need to configure again during query execution
|
||||
|
||||
# Optimize query
|
||||
optimized_sql = await self.query_optimizer.optimize_query(
|
||||
query_request.sql, {"user_roles": getattr(auth_context, 'roles', [])}
|
||||
)
|
||||
|
||||
# Execute query
|
||||
connection = await self.connection_manager.get_connection(
|
||||
query_request.session_id
|
||||
)
|
||||
|
||||
# Set timeout if specified
|
||||
if query_request.timeout:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
connection.execute(optimized_sql, query_request.parameters, auth_context),
|
||||
self.connection_manager.execute_query(query_request.session_id, optimized_sql, query_request.parameters, auth_context),
|
||||
timeout=query_request.timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(f"Query timeout after {query_request.timeout} seconds")
|
||||
else:
|
||||
result = await connection.execute(optimized_sql, query_request.parameters, auth_context)
|
||||
result = await self.connection_manager.execute_query(query_request.session_id, optimized_sql, query_request.parameters, auth_context)
|
||||
|
||||
return result
|
||||
|
||||
@@ -466,6 +470,51 @@ class DorisQueryExecutor:
|
||||
f"Slow query detected: {execution_time:.2f}s (threshold: {self.slow_query_threshold}s)"
|
||||
)
|
||||
|
||||
async def execute_batch_sqls_for_mcp(
|
||||
self, sqls: list[str],
|
||||
timeout: int = 30,
|
||||
session_id: str = "mcp_session",
|
||||
user_id: str = "mcp_user",
|
||||
auth_context=None
|
||||
) -> dict[str, Any]:
|
||||
"""Execute multiple sqls in batch"""
|
||||
if not sqls:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "SQL query is required",
|
||||
"data": None
|
||||
}
|
||||
query_requests = [
|
||||
QueryRequest(
|
||||
sql=sql,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
timeout=timeout,
|
||||
cache_enabled=False
|
||||
)
|
||||
for sql in sqls
|
||||
]
|
||||
query_results = await self.execute_batch_queries(query_requests, auth_context)
|
||||
# Serialize data for JSON response
|
||||
results = [
|
||||
{
|
||||
"data": [self._serialize_row_data(data) for data in result.data],
|
||||
"row_count": result.row_count,
|
||||
"execution_time": result.execution_time,
|
||||
"metadata": {
|
||||
"columns": result.metadata.get("columns", []),
|
||||
"query": result.sql
|
||||
}
|
||||
}
|
||||
for result in query_results
|
||||
]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"multiple_results": True,
|
||||
"results": results
|
||||
}
|
||||
|
||||
async def execute_batch_queries(
|
||||
self, query_requests: list[QueryRequest], auth_context=None
|
||||
) -> list[QueryResult]:
|
||||
@@ -483,20 +532,24 @@ class DorisQueryExecutor:
|
||||
self.execute_query(request, auth_context) for request in query_requests
|
||||
]
|
||||
|
||||
try:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Batch query execution failed: {e}")
|
||||
raise
|
||||
query_results = []
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
self.logger.error(f"Batch query execution failed: {result}")
|
||||
raise result
|
||||
else:
|
||||
query_results.append(result)
|
||||
|
||||
return results
|
||||
return query_results
|
||||
|
||||
async def explain_query(self, sql: str, session_id: str) -> dict[str, Any]:
|
||||
"""Get query execution plan"""
|
||||
explain_sql = f"EXPLAIN {sql}"
|
||||
|
||||
connection = await self.connection_manager.get_connection(session_id)
|
||||
result = await connection.execute(explain_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(explain_sql, auth_context=auth_context)
|
||||
|
||||
return {
|
||||
"query": sql,
|
||||
@@ -540,87 +593,192 @@ class DorisQueryExecutor:
|
||||
await self.query_cache.clear_all()
|
||||
|
||||
async def execute_sql_for_mcp(
|
||||
self,
|
||||
sql: str,
|
||||
limit: int = 1000,
|
||||
self,
|
||||
sql: str,
|
||||
limit: int = 1000,
|
||||
timeout: int = 30,
|
||||
session_id: str = "mcp_session",
|
||||
user_id: str = "mcp_user"
|
||||
user_id: str = "mcp_user",
|
||||
auth_context = None # FIX for Issue #62 Bug 1: Accept auth_context with token
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute SQL query for MCP interface - unified method"""
|
||||
try:
|
||||
if not sql:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "SQL query is required",
|
||||
"data": None
|
||||
}
|
||||
"""Execute SQL query for MCP interface - unified method
|
||||
|
||||
# Add LIMIT if not present and it's a SELECT query
|
||||
if sql.upper().startswith("SELECT") and "LIMIT" not in sql.upper():
|
||||
if sql.endswith(";"):
|
||||
sql = sql[:-1]
|
||||
sql = f"{sql} LIMIT {limit}"
|
||||
FIX for Issue #62 Bug 1: Now accepts auth_context parameter to support token-bound database configuration
|
||||
"""
|
||||
max_retries = 2
|
||||
retry_count = 0
|
||||
|
||||
# Create auth context for MCP calls
|
||||
class MockAuthContext:
|
||||
def __init__(self):
|
||||
self.user_id = user_id
|
||||
self.roles = ["data_analyst"]
|
||||
self.permissions = ["read_data", "execute_query"]
|
||||
self.session_id = session_id
|
||||
self.security_level = "internal"
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
if not sql:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "SQL query is required",
|
||||
"data": None
|
||||
}
|
||||
|
||||
auth_context = MockAuthContext()
|
||||
|
||||
# Create query request
|
||||
query_request = QueryRequest(
|
||||
sql=sql,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
timeout=timeout,
|
||||
cache_enabled=True
|
||||
)
|
||||
|
||||
# Execute query
|
||||
result = await self.execute_query(query_request, auth_context)
|
||||
|
||||
# Process results
|
||||
processed_data = []
|
||||
if result.data:
|
||||
# Import required security modules
|
||||
from .security import DorisSecurityManager, AuthContext, SecurityLevel
|
||||
|
||||
# FIX: Use provided auth_context if available (contains token for DB config)
|
||||
# Otherwise create default auth context for backward compatibility
|
||||
if auth_context is None:
|
||||
auth_context = AuthContext(
|
||||
user_id=user_id,
|
||||
roles=["read_only_user"], # Restrictive role for MCP interface
|
||||
permissions=["read_data"], # Only read permissions
|
||||
session_id=session_id,
|
||||
security_level=SecurityLevel.INTERNAL,
|
||||
token="" # No token in default context
|
||||
)
|
||||
else:
|
||||
# Use provided auth_context (may contain token for database configuration)
|
||||
self.logger.debug(f"Using provided auth_context with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
|
||||
|
||||
# Perform SQL security validation if enabled
|
||||
if hasattr(self.connection_manager, 'config') and hasattr(self.connection_manager.config, 'security'):
|
||||
if self.connection_manager.config.security.enable_security_check:
|
||||
try:
|
||||
# 🔧 FIX: Use existing security_manager to avoid creating multiple TokenManager instances
|
||||
# Creating new DorisSecurityManager each time causes multiple hot reload monitors
|
||||
security_manager = getattr(self.connection_manager, 'security_manager', None)
|
||||
if not security_manager:
|
||||
# Fallback: create new one only if not available (should rarely happen)
|
||||
self.logger.warning("No existing security_manager, creating new instance")
|
||||
security_manager = DorisSecurityManager(self.connection_manager.config)
|
||||
validation_result = await security_manager.validate_sql_security(sql, auth_context)
|
||||
|
||||
if not validation_result.is_valid:
|
||||
self.logger.warning(f"SQL security validation failed for query: {sql[:100]}...")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"SQL security validation failed: {validation_result.error_message}",
|
||||
"error_type": "security_violation",
|
||||
"blocked_operations": validation_result.blocked_operations,
|
||||
"risk_level": validation_result.risk_level,
|
||||
"data": None,
|
||||
"metadata": {
|
||||
"query": sql,
|
||||
"validation_details": {
|
||||
"blocked_operations": validation_result.blocked_operations,
|
||||
"risk_level": validation_result.risk_level
|
||||
}
|
||||
}
|
||||
}
|
||||
else:
|
||||
self.logger.debug(f"SQL security validation passed for query: {sql[:100]}...")
|
||||
except Exception as security_error:
|
||||
self.logger.error(f"Security validation error: {str(security_error)}")
|
||||
# In case of security validation error, fail safe
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Security validation system error: {str(security_error)}",
|
||||
"error_type": "security_system_error",
|
||||
"data": None,
|
||||
"metadata": {
|
||||
"query": sql,
|
||||
"security_error": str(security_error)
|
||||
}
|
||||
}
|
||||
else:
|
||||
self.logger.info("SQL security check is disabled in configuration")
|
||||
else:
|
||||
self.logger.warning("Security configuration not found, proceeding without validation")
|
||||
|
||||
# Add LIMIT if not present and it's a SELECT query
|
||||
if sql.upper().startswith("SELECT") and "LIMIT" not in sql.upper():
|
||||
if sql.endswith(";"):
|
||||
sql = sql[:-1]
|
||||
sql = f"{sql} LIMIT {limit}"
|
||||
|
||||
all_statements = [
|
||||
s.strip()
|
||||
for s in sqlparse.split(sql)
|
||||
if s.strip()
|
||||
]
|
||||
if len(all_statements) > 1:
|
||||
return await self.execute_batch_sqls_for_mcp(sqls=all_statements, timeout=timeout,
|
||||
session_id=session_id, user_id=user_id,
|
||||
auth_context=auth_context)
|
||||
# Create query request
|
||||
query_request = QueryRequest(
|
||||
sql=sql,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
timeout=timeout,
|
||||
cache_enabled=False # Disable cache for MCP calls to ensure fresh data
|
||||
)
|
||||
|
||||
# Execute query with retry logic
|
||||
result = await self.execute_query(query_request, auth_context)
|
||||
|
||||
# Serialize data for JSON response
|
||||
serialized_data = []
|
||||
for row in result.data:
|
||||
processed_row = self._serialize_row_data(row)
|
||||
processed_data.append(processed_row)
|
||||
serialized_data.append(self._serialize_row_data(row))
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": processed_data,
|
||||
"metadata": {
|
||||
return {
|
||||
"success": True,
|
||||
"data": serialized_data,
|
||||
"row_count": result.row_count,
|
||||
"execution_time": result.execution_time,
|
||||
"columns": result.metadata.get("columns", []),
|
||||
"query": sql
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
self.logger.error(f"SQL execution error: {error_msg}")
|
||||
|
||||
# Analyze error for better user feedback
|
||||
error_analysis = self._analyze_error(error_msg)
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_analysis.get("user_message", error_msg),
|
||||
"error_type": error_analysis.get("error_type", "execution_error"),
|
||||
"data": None,
|
||||
"metadata": {
|
||||
"query": sql,
|
||||
"error_details": error_msg
|
||||
"metadata": {
|
||||
"columns": result.metadata.get("columns", []),
|
||||
"query": sql
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
error_str = error_msg.lower()
|
||||
|
||||
# Check if it's a connection-related error that we should retry
|
||||
connection_errors = [
|
||||
"at_eof", "connection", "closed", "nonetype",
|
||||
"transport", "reader", "broken pipe", "connection reset"
|
||||
]
|
||||
|
||||
is_connection_error = any(err in error_str for err in connection_errors)
|
||||
|
||||
if is_connection_error and retry_count < max_retries:
|
||||
retry_count += 1
|
||||
self.logger.warning(f"Connection error detected, retrying ({retry_count}/{max_retries}): {e}")
|
||||
|
||||
# Release the problematic connection
|
||||
try:
|
||||
await self.connection_manager.release_connection(session_id)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
# Wait a bit before retry
|
||||
await asyncio.sleep(0.5 * retry_count)
|
||||
continue
|
||||
else:
|
||||
# If we've exhausted retries or it's not a connection error, return error
|
||||
error_analysis = self._analyze_error(error_msg)
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_analysis.get("user_message", error_msg),
|
||||
"error_type": error_analysis.get("error_type", "general_error"),
|
||||
"data": None,
|
||||
"metadata": {
|
||||
"query": sql,
|
||||
"error_details": error_msg,
|
||||
"retry_count": retry_count
|
||||
}
|
||||
}
|
||||
|
||||
# This should never be reached, but just in case
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Maximum retries exceeded",
|
||||
"data": None,
|
||||
"metadata": {
|
||||
"query": sql,
|
||||
"retry_count": retry_count
|
||||
}
|
||||
}
|
||||
|
||||
def _serialize_row_data(self, row_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Serialize row data for JSON response"""
|
||||
@@ -649,7 +807,12 @@ class DorisQueryExecutor:
|
||||
"""Analyze error message and provide user-friendly feedback"""
|
||||
error_msg_lower = error_message.lower()
|
||||
|
||||
if "table" in error_msg_lower and "doesn't exist" in error_msg_lower:
|
||||
if "at_eof" in error_msg_lower or "nonetype" in error_msg_lower and "at_eof" in error_msg_lower:
|
||||
return {
|
||||
"error_type": "connection_lost",
|
||||
"user_message": "Database connection was lost. The query has been automatically retried. If this persists, please restart the server."
|
||||
}
|
||||
elif "table" in error_msg_lower and "doesn't exist" in error_msg_lower:
|
||||
return {
|
||||
"error_type": "table_not_found",
|
||||
"user_message": "The specified table does not exist. Please check the table name and database."
|
||||
@@ -674,6 +837,11 @@ class DorisQueryExecutor:
|
||||
"error_type": "timeout",
|
||||
"user_message": "Query execution timed out. Try simplifying your query or adding more specific filters."
|
||||
}
|
||||
elif "connection" in error_msg_lower and ("closed" in error_msg_lower or "reset" in error_msg_lower):
|
||||
return {
|
||||
"error_type": "connection_error",
|
||||
"user_message": "Database connection was interrupted. The query has been automatically retried."
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"error_type": "general_error",
|
||||
@@ -701,7 +869,7 @@ class QueryPerformanceMonitor:
|
||||
|
||||
def __init__(self, query_executor: DorisQueryExecutor):
|
||||
self.query_executor = query_executor
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger = get_logger(__name__)
|
||||
self.performance_records = []
|
||||
|
||||
async def record_query_performance(
|
||||
@@ -785,32 +953,51 @@ class QueryPerformanceMonitor:
|
||||
|
||||
# Unified convenience function for MCP integration
|
||||
async def execute_sql_query(sql: str, connection_manager: DorisConnectionManager, **kwargs) -> Dict[str, Any]:
|
||||
"""Execute SQL query - unified convenience function for MCP tools"""
|
||||
"""Execute SQL query - unified convenience function for MCP tools
|
||||
|
||||
This function now includes security validation to ensure safe query execution.
|
||||
All queries are validated against the configured security policies before execution.
|
||||
|
||||
FIX for Issue #62 Bug 1: Now supports auth_context parameter for token-bound database configuration
|
||||
FIX for Issue #58 Problem 2: Removed executor.close() to prevent ClosedResourceError in multi-worker mode
|
||||
"""
|
||||
try:
|
||||
# Create query executor
|
||||
# Create query executor with the connection manager's configuration
|
||||
executor = DorisQueryExecutor(connection_manager)
|
||||
|
||||
try:
|
||||
# Extract parameters from kwargs or use defaults
|
||||
limit = kwargs.get("limit", 1000)
|
||||
timeout = kwargs.get("timeout", 30)
|
||||
session_id = kwargs.get("session_id", "mcp_session")
|
||||
user_id = kwargs.get("user_id", "mcp_user")
|
||||
|
||||
result = await executor.execute_sql_for_mcp(
|
||||
sql=sql,
|
||||
limit=limit,
|
||||
timeout=timeout,
|
||||
session_id=session_id,
|
||||
user_id=user_id
|
||||
)
|
||||
return result
|
||||
finally:
|
||||
await executor.close()
|
||||
|
||||
|
||||
# Extract parameters from kwargs or use defaults
|
||||
limit = kwargs.get("limit", 1000)
|
||||
timeout = kwargs.get("timeout", 30)
|
||||
session_id = kwargs.get("session_id", "mcp_session")
|
||||
user_id = kwargs.get("user_id", "mcp_user")
|
||||
auth_context = kwargs.get("auth_context", None) # FIX: Extract auth_context
|
||||
|
||||
# The execute_sql_for_mcp method now includes security validation
|
||||
result = await executor.execute_sql_for_mcp(
|
||||
sql=sql,
|
||||
limit=limit,
|
||||
timeout=timeout,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
auth_context=auth_context # FIX: Pass auth_context with token
|
||||
)
|
||||
|
||||
# FIX for Issue #58 Problem 2: Do NOT close executor here
|
||||
# In multi-worker mode, closing here causes ClosedResourceError
|
||||
# The executor's resources (cache, background tasks) will be managed
|
||||
# by the connection_manager lifecycle and Python's garbage collection
|
||||
# This prevents premature cleanup while MCP session manager is still processing
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Query execution failed: {str(e)}",
|
||||
"data": None
|
||||
"error_type": "execution_error",
|
||||
"data": None,
|
||||
"metadata": {
|
||||
"query": sql,
|
||||
"execution_error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,18 +20,25 @@ Doris Security Management Module
|
||||
Implements enterprise-level authentication, authorization, SQL security validation and data masking functionality
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import sqlparse
|
||||
from sqlparse.sql import Statement
|
||||
from sqlparse.tokens import Keyword, Name
|
||||
|
||||
from .logger import get_logger
|
||||
from .config import DatabaseConfig
|
||||
|
||||
# Global ContextVar for auth_context - must be a single instance shared across all modules
|
||||
# This allows token-bound database configuration to work correctly in concurrent requests
|
||||
mcp_auth_context_var: ContextVar['AuthContext'] = ContextVar('mcp_auth_context', default=None)
|
||||
|
||||
|
||||
class SecurityLevel(Enum):
|
||||
"""Security level enumeration"""
|
||||
@@ -44,15 +51,18 @@ class SecurityLevel(Enum):
|
||||
|
||||
@dataclass
|
||||
class AuthContext:
|
||||
"""Authentication context"""
|
||||
"""Authentication context for audit and session tracking"""
|
||||
|
||||
user_id: str
|
||||
roles: list[str]
|
||||
permissions: list[str]
|
||||
session_id: str
|
||||
login_time: datetime | None = None
|
||||
token_id: str = "" # Token identifier for audit logging
|
||||
user_id: str = "" # User identifier
|
||||
roles: list[str] = field(default_factory=list) # User roles
|
||||
permissions: list[str] = field(default_factory=list) # User permissions
|
||||
security_level: 'SecurityLevel' = field(default_factory=lambda: SecurityLevel.INTERNAL) # Security level
|
||||
client_ip: str = "unknown" # Client IP address
|
||||
session_id: str = "" # Session identifier
|
||||
login_time: datetime = field(default_factory=datetime.utcnow)
|
||||
last_activity: datetime | None = None
|
||||
security_level: SecurityLevel = SecurityLevel.INTERNAL
|
||||
token: str = "" # Raw token for token-bound database configuration
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -85,12 +95,13 @@ class DorisSecurityManager:
|
||||
Provides complete security control functionality, including authentication, authorization, SQL security validation and data masking
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, connection_manager=None):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger = get_logger(__name__)
|
||||
self.connection_manager = connection_manager
|
||||
|
||||
# Initialize security components
|
||||
self.auth_provider = AuthenticationProvider(config)
|
||||
self.auth_provider = AuthenticationProvider(config, self)
|
||||
self.authz_provider = AuthorizationProvider(config)
|
||||
self.sql_validator = SQLSecurityValidator(config)
|
||||
self.masking_processor = DataMaskingProcessor(config)
|
||||
@@ -99,32 +110,56 @@ class DorisSecurityManager:
|
||||
self.blocked_keywords = self._load_blocked_keywords()
|
||||
self.sensitive_tables = self._load_sensitive_tables()
|
||||
self.masking_rules = self._load_masking_rules()
|
||||
|
||||
# Track initialization state
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize security manager components"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
# Initialize authentication provider (for JWT setup)
|
||||
await self.auth_provider.initialize()
|
||||
|
||||
self._initialized = True
|
||||
self.logger.info("DorisSecurityManager initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize DorisSecurityManager: {e}")
|
||||
raise
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown security manager components"""
|
||||
try:
|
||||
await self.auth_provider.shutdown()
|
||||
self._initialized = False
|
||||
self.logger.info("DorisSecurityManager shutdown completed")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during DorisSecurityManager shutdown: {e}")
|
||||
raise
|
||||
|
||||
def _load_blocked_keywords(self) -> set[str]:
|
||||
"""Load blocked SQL keywords"""
|
||||
default_blocked = {
|
||||
"DROP",
|
||||
"DELETE",
|
||||
"TRUNCATE",
|
||||
"ALTER",
|
||||
"CREATE",
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"GRANT",
|
||||
"REVOKE",
|
||||
"EXEC",
|
||||
"EXECUTE",
|
||||
"SHUTDOWN",
|
||||
"KILL",
|
||||
}
|
||||
|
||||
# Load custom rules from configuration file
|
||||
"""Load blocked SQL keywords from configuration"""
|
||||
# Load keywords from configuration, unified source of truth
|
||||
if hasattr(self.config, 'get'):
|
||||
custom_blocked = set(self.config.get("blocked_keywords", []))
|
||||
# Dictionary-style configuration
|
||||
blocked_keywords = self.config.get("blocked_keywords", [])
|
||||
elif hasattr(self.config, 'security') and hasattr(self.config.security, 'blocked_keywords'):
|
||||
# DorisConfig object, get through security.blocked_keywords
|
||||
blocked_keywords = self.config.security.blocked_keywords
|
||||
else:
|
||||
custom_blocked = set()
|
||||
# Fallback to default if no configuration available
|
||||
blocked_keywords = [
|
||||
"DROP", "CREATE", "ALTER", "TRUNCATE",
|
||||
"DELETE", "INSERT", "UPDATE",
|
||||
"GRANT", "REVOKE",
|
||||
"EXEC", "EXECUTE", "SHUTDOWN", "KILL"
|
||||
]
|
||||
|
||||
return default_blocked.union(custom_blocked)
|
||||
return set(blocked_keywords)
|
||||
|
||||
def _load_sensitive_tables(self) -> dict[str, SecurityLevel]:
|
||||
"""Load sensitive table configuration"""
|
||||
@@ -189,8 +224,59 @@ class DorisSecurityManager:
|
||||
return default_rules
|
||||
|
||||
async def authenticate_request(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Validate request authentication information"""
|
||||
return await self.auth_provider.authenticate(auth_info)
|
||||
"""Validate request authentication information
|
||||
|
||||
Tries authentication methods in order: Token -> JWT -> OAuth
|
||||
Any one method succeeding allows access
|
||||
If all methods are disabled, returns anonymous context
|
||||
"""
|
||||
# Check if any authentication method is enabled
|
||||
if not (self.config.security.enable_token_auth or
|
||||
self.config.security.enable_jwt_auth or
|
||||
self.config.security.enable_oauth_auth):
|
||||
self.logger.debug("All authentication methods are disabled")
|
||||
# Return anonymous context when no authentication is enabled
|
||||
return AuthContext(
|
||||
token_id="anonymous",
|
||||
user_id="anonymous",
|
||||
roles=["anonymous"],
|
||||
permissions=["read"],
|
||||
security_level=SecurityLevel.PUBLIC,
|
||||
client_ip=auth_info.get("client_ip", "unknown"),
|
||||
session_id="anonymous_session"
|
||||
)
|
||||
|
||||
# Try authentication methods in order of preference
|
||||
last_error = None
|
||||
|
||||
# 1. Try Token authentication first (most common)
|
||||
if self.config.security.enable_token_auth:
|
||||
try:
|
||||
return await self.auth_provider.authenticate_token(auth_info)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Token authentication failed: {e}")
|
||||
last_error = e
|
||||
|
||||
# 2. Try JWT authentication
|
||||
if self.config.security.enable_jwt_auth:
|
||||
try:
|
||||
return await self.auth_provider.authenticate_jwt(auth_info)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"JWT authentication failed: {e}")
|
||||
last_error = e
|
||||
|
||||
# 3. Try OAuth authentication
|
||||
if self.config.security.enable_oauth_auth:
|
||||
try:
|
||||
return await self.auth_provider.authenticate_oauth(auth_info)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"OAuth authentication failed: {e}")
|
||||
last_error = e
|
||||
|
||||
# All enabled authentication methods failed
|
||||
error_message = f"Authentication failed: {str(last_error)}" if last_error else "No authentication method succeeded"
|
||||
self.logger.warning(f"Authentication failed for client {auth_info.get('client_ip', 'unknown')}: {error_message}")
|
||||
raise ValueError(error_message)
|
||||
|
||||
async def authorize_resource_access(
|
||||
self, auth_context: AuthContext, resource_uri: str
|
||||
@@ -212,43 +298,362 @@ class DorisSecurityManager:
|
||||
"""Apply data masking processing"""
|
||||
return await self.masking_processor.process(data, auth_context)
|
||||
|
||||
# OAuth-specific methods
|
||||
def get_oauth_authorization_url(self) -> tuple[str, str]:
|
||||
"""Get OAuth authorization URL
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, state)
|
||||
"""
|
||||
if not self.auth_provider.oauth_provider:
|
||||
raise ValueError("OAuth is not enabled")
|
||||
return self.auth_provider.oauth_provider.get_authorization_url()
|
||||
|
||||
async def handle_oauth_callback(self, code: str, state: str) -> AuthContext:
|
||||
"""Handle OAuth callback
|
||||
|
||||
Args:
|
||||
code: Authorization code from OAuth provider
|
||||
state: State parameter for CSRF protection
|
||||
|
||||
Returns:
|
||||
AuthContext for authenticated user
|
||||
"""
|
||||
if not self.auth_provider.oauth_provider:
|
||||
raise ValueError("OAuth is not enabled")
|
||||
return await self.auth_provider.oauth_provider.handle_callback(code, state)
|
||||
|
||||
def get_oauth_provider_info(self) -> dict[str, Any]:
|
||||
"""Get OAuth provider information
|
||||
|
||||
Returns:
|
||||
OAuth provider information
|
||||
"""
|
||||
if not self.auth_provider.oauth_provider:
|
||||
return {"enabled": False}
|
||||
return self.auth_provider.oauth_provider.get_provider_info()
|
||||
|
||||
# Token management methods
|
||||
async def create_token(
|
||||
self,
|
||||
token_id: str,
|
||||
expires_hours: Optional[int] = None,
|
||||
description: str = "",
|
||||
custom_token: Optional[str] = None,
|
||||
database_config: Optional[DatabaseConfig] = None
|
||||
) -> str:
|
||||
"""Create a new API access token
|
||||
|
||||
Args:
|
||||
token_id: Unique token identifier for audit and management
|
||||
expires_hours: Token expiration in hours (None for no expiration)
|
||||
description: Token description for management purposes
|
||||
custom_token: Custom token string (if None, generates random token)
|
||||
database_config: Optional database configuration for this token
|
||||
|
||||
Returns:
|
||||
Generated token string
|
||||
"""
|
||||
if not self.auth_provider.token_manager:
|
||||
raise ValueError("Token manager not initialized")
|
||||
|
||||
return await self.auth_provider.token_manager.create_token(
|
||||
token_id=token_id,
|
||||
expires_hours=expires_hours,
|
||||
description=description,
|
||||
custom_token=custom_token,
|
||||
database_config=database_config
|
||||
)
|
||||
|
||||
async def revoke_token(self, token_id: str) -> bool:
|
||||
"""Revoke a token by token ID
|
||||
|
||||
Args:
|
||||
token_id: Token ID to revoke
|
||||
|
||||
Returns:
|
||||
True if token was revoked successfully
|
||||
"""
|
||||
if not self.auth_provider.token_manager:
|
||||
raise ValueError("Token manager not initialized")
|
||||
|
||||
return await self.auth_provider.token_manager.revoke_token(token_id)
|
||||
|
||||
async def list_tokens(self) -> list[dict[str, Any]]:
|
||||
"""List all tokens (without sensitive data)
|
||||
|
||||
Returns:
|
||||
List of token information
|
||||
"""
|
||||
if not self.auth_provider.token_manager:
|
||||
raise ValueError("Token manager not initialized")
|
||||
|
||||
return await self.auth_provider.token_manager.list_tokens()
|
||||
|
||||
async def cleanup_expired_tokens(self) -> int:
|
||||
"""Remove expired tokens and return count
|
||||
|
||||
Returns:
|
||||
Number of expired tokens removed
|
||||
"""
|
||||
if not self.auth_provider.token_manager:
|
||||
return 0
|
||||
|
||||
return await self.auth_provider.token_manager.cleanup_expired_tokens()
|
||||
|
||||
def get_token_stats(self) -> dict[str, Any]:
|
||||
"""Get token statistics
|
||||
|
||||
Returns:
|
||||
Token statistics dictionary
|
||||
"""
|
||||
if not self.auth_provider.token_manager:
|
||||
return {"error": "Token manager not initialized"}
|
||||
|
||||
return self.auth_provider.token_manager.get_token_stats()
|
||||
|
||||
async def _validate_token_database_config(self, token: str, token_info) -> None:
|
||||
"""Validate database configuration for token immediately during authentication
|
||||
|
||||
This ensures database connectivity issues are caught at authentication time,
|
||||
not during query execution, providing better user experience.
|
||||
|
||||
Args:
|
||||
token: Raw authentication token
|
||||
token_info: TokenInfo object from token validation
|
||||
|
||||
Raises:
|
||||
ValueError: If database configuration is invalid or connection fails
|
||||
"""
|
||||
try:
|
||||
if not self.connection_manager:
|
||||
self.logger.warning("Connection manager not available for immediate database validation")
|
||||
return
|
||||
|
||||
# Configure and test database connection for this token
|
||||
success, config_source = await self.connection_manager.configure_for_token(token)
|
||||
|
||||
if success:
|
||||
self.logger.info(f"Database configuration validated successfully for token {token_info.token_id} (source: {config_source})")
|
||||
else:
|
||||
raise ValueError("Database configuration validation failed")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Database configuration validation failed for token {token_info.token_id}: {str(e)}"
|
||||
self.logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
class AuthenticationProvider:
|
||||
"""Authentication provider"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, security_manager=None):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger = get_logger(__name__)
|
||||
self.session_cache = {}
|
||||
|
||||
async def authenticate(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Perform identity authentication"""
|
||||
auth_type = auth_info.get("type", "token")
|
||||
|
||||
if auth_type == "token":
|
||||
return await self._authenticate_token(auth_info)
|
||||
elif auth_type == "basic":
|
||||
return await self._authenticate_basic(auth_info)
|
||||
self.jwt_manager = None
|
||||
self.oauth_provider = None
|
||||
self.token_manager = None
|
||||
self.security_manager = security_manager
|
||||
|
||||
# Initialize authentication providers based on individual switches
|
||||
auth_methods_enabled = []
|
||||
|
||||
# Initialize Token manager if enabled
|
||||
if config.security.enable_token_auth:
|
||||
self._initialize_token_manager()
|
||||
auth_methods_enabled.append("Token")
|
||||
|
||||
# Initialize JWT manager if enabled
|
||||
if config.security.enable_jwt_auth:
|
||||
self._initialize_jwt_manager()
|
||||
auth_methods_enabled.append("JWT")
|
||||
|
||||
# Initialize OAuth provider if enabled
|
||||
if config.security.enable_oauth_auth or (hasattr(config.security, 'oauth_enabled') and config.security.oauth_enabled):
|
||||
self._initialize_oauth_provider()
|
||||
auth_methods_enabled.append("OAuth")
|
||||
|
||||
if auth_methods_enabled:
|
||||
self.logger.info(f"Authentication enabled with methods: {', '.join(auth_methods_enabled)}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported authentication type: {auth_type}")
|
||||
self.logger.info("All authentication methods are disabled - anonymous access allowed")
|
||||
|
||||
def _initialize_jwt_manager(self):
|
||||
"""Initialize JWT manager"""
|
||||
try:
|
||||
from ..auth.jwt_manager import JWTManager
|
||||
self.jwt_manager = JWTManager(self.config)
|
||||
self.logger.info("JWT manager initialized")
|
||||
except ImportError as e:
|
||||
self.logger.error(f"Failed to import JWT manager: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize JWT manager: {e}")
|
||||
raise
|
||||
|
||||
def _initialize_token_manager(self):
|
||||
"""Initialize Token manager"""
|
||||
try:
|
||||
from ..auth.token_manager import TokenManager
|
||||
self.token_manager = TokenManager(self.config)
|
||||
self.logger.info("Token manager initialized")
|
||||
except ImportError as e:
|
||||
self.logger.error(f"Failed to import Token manager: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize Token manager: {e}")
|
||||
raise
|
||||
|
||||
def _initialize_oauth_provider(self):
|
||||
"""Initialize OAuth provider"""
|
||||
try:
|
||||
from ..auth.oauth_provider import OAuthAuthenticationProvider
|
||||
self.oauth_provider = OAuthAuthenticationProvider(self.config)
|
||||
self.logger.info("OAuth provider initialized")
|
||||
except ImportError as e:
|
||||
self.logger.error(f"Failed to import OAuth provider: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize OAuth provider: {e}")
|
||||
raise
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize authentication provider asynchronously"""
|
||||
if self.jwt_manager:
|
||||
success = await self.jwt_manager.initialize()
|
||||
if not success:
|
||||
raise RuntimeError("Failed to initialize JWT manager")
|
||||
self.logger.info("JWT authentication provider initialized successfully")
|
||||
|
||||
if self.token_manager:
|
||||
# Token manager doesn't need async initialization, just log success
|
||||
self.logger.info("Token authentication provider initialized successfully")
|
||||
|
||||
if self.oauth_provider:
|
||||
success = await self.oauth_provider.initialize()
|
||||
if not success:
|
||||
raise RuntimeError("Failed to initialize OAuth provider")
|
||||
self.logger.info("OAuth authentication provider initialized successfully")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown authentication provider"""
|
||||
if self.jwt_manager:
|
||||
await self.jwt_manager.shutdown()
|
||||
self.logger.info("JWT authentication provider shutdown completed")
|
||||
|
||||
if self.token_manager:
|
||||
# Token manager doesn't need async shutdown, just log
|
||||
self.logger.info("Token authentication provider shutdown completed")
|
||||
|
||||
if self.oauth_provider:
|
||||
await self.oauth_provider.shutdown()
|
||||
self.logger.info("OAuth authentication provider shutdown completed")
|
||||
|
||||
async def authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Perform token authentication"""
|
||||
if not self.config.security.enable_token_auth:
|
||||
raise ValueError("Token authentication is not enabled")
|
||||
return await self._authenticate_token(auth_info)
|
||||
|
||||
async def authenticate_jwt(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Perform JWT authentication"""
|
||||
if not self.config.security.enable_jwt_auth:
|
||||
raise ValueError("JWT authentication is not enabled")
|
||||
return await self._authenticate_jwt(auth_info)
|
||||
|
||||
async def authenticate_oauth(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Perform OAuth authentication"""
|
||||
if not self.config.security.enable_oauth_auth:
|
||||
raise ValueError("OAuth authentication is not enabled")
|
||||
return await self._authenticate_oauth(auth_info)
|
||||
|
||||
async def _authenticate_jwt(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""JWT authentication"""
|
||||
if not self.jwt_manager:
|
||||
raise ValueError("JWT manager not initialized")
|
||||
|
||||
token = auth_info.get("token")
|
||||
if not token:
|
||||
# Try to extract from Authorization header
|
||||
authorization = auth_info.get("authorization")
|
||||
if authorization and authorization.startswith('Bearer '):
|
||||
token = authorization[7:]
|
||||
|
||||
if not token:
|
||||
raise ValueError("Missing JWT token")
|
||||
|
||||
try:
|
||||
# Use JWT middleware for authentication
|
||||
from ..auth.auth_middleware import AuthMiddleware
|
||||
middleware = AuthMiddleware(self.jwt_manager)
|
||||
return await middleware.authenticate_request(auth_info)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"JWT authentication failed: {e}")
|
||||
raise ValueError(f"JWT authentication failed: {str(e)}")
|
||||
|
||||
async def _authenticate_oauth(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""OAuth authentication"""
|
||||
if not self.oauth_provider:
|
||||
raise ValueError("OAuth provider not initialized")
|
||||
|
||||
# Handle different OAuth authentication scenarios
|
||||
if "access_token" in auth_info:
|
||||
# Direct OAuth access token authentication
|
||||
return await self.oauth_provider.authenticate_with_token(auth_info["access_token"])
|
||||
elif "code" in auth_info and "state" in auth_info:
|
||||
# OAuth callback authentication
|
||||
return await self.oauth_provider.handle_callback(auth_info["code"], auth_info["state"])
|
||||
else:
|
||||
raise ValueError("OAuth authentication requires either access_token or code+state")
|
||||
|
||||
async def _authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Token authentication"""
|
||||
if not self.token_manager:
|
||||
raise ValueError("Token manager not initialized")
|
||||
|
||||
token = auth_info.get("token")
|
||||
if not token:
|
||||
# Try to extract from Authorization header
|
||||
authorization = auth_info.get("authorization")
|
||||
if authorization and authorization.startswith('Bearer '):
|
||||
token = authorization[7:]
|
||||
elif authorization and authorization.startswith('Token '):
|
||||
token = authorization[6:]
|
||||
|
||||
if not token:
|
||||
raise ValueError("Missing authentication token")
|
||||
|
||||
# Validate token (simplified implementation, should validate JWT or query authentication service in practice)
|
||||
user_info = await self._validate_token(token)
|
||||
|
||||
return AuthContext(
|
||||
user_id=user_info["user_id"],
|
||||
roles=user_info["roles"],
|
||||
permissions=user_info["permissions"],
|
||||
session_id=auth_info.get("session_id", "default"),
|
||||
login_time=datetime.utcnow(),
|
||||
security_level=SecurityLevel(user_info.get("security_level", "internal")),
|
||||
)
|
||||
try:
|
||||
# Validate token using TokenManager
|
||||
validation_result = await self.token_manager.validate_token(token)
|
||||
|
||||
if not validation_result.is_valid:
|
||||
raise ValueError(f"Token validation failed: {validation_result.error_message}")
|
||||
|
||||
token_info = validation_result.token_info
|
||||
|
||||
# Immediately validate database configuration for this token
|
||||
if self.security_manager:
|
||||
await self.security_manager._validate_token_database_config(token, token_info)
|
||||
|
||||
return AuthContext(
|
||||
token_id=token_info.token_id,
|
||||
user_id=token_info.token_id, # Use token_id as user_id for token auth
|
||||
roles=["token_user"], # Default role for token users
|
||||
permissions=["read", "write"], # Default permissions for token users
|
||||
security_level=SecurityLevel.INTERNAL,
|
||||
client_ip=auth_info.get("client_ip", "unknown"),
|
||||
session_id=auth_info.get("session_id", f"session_{token_info.token_id}"),
|
||||
login_time=datetime.utcnow(),
|
||||
last_activity=token_info.last_used,
|
||||
token=token # Store raw token for token-bound database configuration
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Token authentication failed: {e}")
|
||||
raise ValueError(f"Token authentication failed: {str(e)}")
|
||||
|
||||
async def _authenticate_basic(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Basic authentication (username password)"""
|
||||
@@ -328,7 +733,7 @@ class AuthorizationProvider:
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger = get_logger(__name__)
|
||||
self.permission_cache = {}
|
||||
|
||||
# Load sensitive tables configuration
|
||||
@@ -471,43 +876,80 @@ class SQLSecurityValidator:
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Handle DorisConfig object or dictionary configuration
|
||||
if hasattr(config, 'get'):
|
||||
# Dictionary configuration
|
||||
self.blocked_keywords = set(config.get("blocked_keywords", []))
|
||||
self.max_query_complexity = config.get("max_query_complexity", 100)
|
||||
self.enable_security_check = config.get("enable_security_check", True)
|
||||
elif hasattr(config, 'security'):
|
||||
# DorisConfig object with security attribute - unified source from config
|
||||
self.blocked_keywords = set(config.security.blocked_keywords)
|
||||
self.max_query_complexity = config.security.max_query_complexity
|
||||
self.enable_security_check = getattr(config.security, 'enable_security_check', True)
|
||||
else:
|
||||
# DorisConfig object, use default values
|
||||
self.blocked_keywords = set(["DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE"])
|
||||
# Fallback to default if no configuration available
|
||||
self.blocked_keywords = set([
|
||||
"DROP", "CREATE", "ALTER", "TRUNCATE",
|
||||
"DELETE", "INSERT", "UPDATE",
|
||||
"GRANT", "REVOKE",
|
||||
"EXEC", "EXECUTE", "SHUTDOWN", "KILL"
|
||||
])
|
||||
self.max_query_complexity = 100
|
||||
self.enable_security_check = True
|
||||
|
||||
async def validate(self, sql: str, auth_context: AuthContext) -> ValidationResult:
|
||||
"""Validate SQL query security"""
|
||||
# If security check is disabled, always return valid
|
||||
if not self.enable_security_check:
|
||||
self.logger.debug("SQL security check is disabled, allowing all queries")
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
try:
|
||||
# Parse SQL statement
|
||||
parsed = sqlparse.parse(sql)[0]
|
||||
# SECURITY FIX: Parse ALL SQL statements, not just the first one
|
||||
# This prevents bypassing security checks by injecting additional statements
|
||||
all_statements = sqlparse.parse(sql)
|
||||
|
||||
# Check blocked operations first (more specific)
|
||||
keyword_result = await self._check_blocked_keywords(parsed)
|
||||
if not keyword_result.is_valid:
|
||||
return keyword_result
|
||||
if not all_statements:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Empty or invalid SQL statement",
|
||||
risk_level="medium"
|
||||
)
|
||||
|
||||
# Check SQL injection risks
|
||||
injection_result = await self._check_sql_injection(sql, parsed)
|
||||
if not injection_result.is_valid:
|
||||
return injection_result
|
||||
# SECURITY FIX: Validate each statement individually
|
||||
for idx, parsed in enumerate(all_statements):
|
||||
# Skip empty statements (e.g., from trailing semicolons)
|
||||
if not parsed.tokens or str(parsed).strip() == '':
|
||||
continue
|
||||
|
||||
# Check query complexity
|
||||
complexity_result = await self._check_query_complexity(parsed)
|
||||
if not complexity_result.is_valid:
|
||||
return complexity_result
|
||||
self.logger.debug(f"Validating SQL statement {idx + 1}/{len(all_statements)}: {str(parsed)[:100]}...")
|
||||
|
||||
# Check table access permissions
|
||||
table_result = await self._check_table_access(parsed, auth_context)
|
||||
if not table_result.is_valid:
|
||||
return table_result
|
||||
# Check blocked operations first (more specific)
|
||||
keyword_result = await self._check_blocked_keywords(parsed)
|
||||
if not keyword_result.is_valid:
|
||||
keyword_result.error_message = f"Statement {idx + 1}: {keyword_result.error_message}"
|
||||
return keyword_result
|
||||
|
||||
# Check SQL injection risks
|
||||
injection_result = await self._check_sql_injection(sql, parsed)
|
||||
if not injection_result.is_valid:
|
||||
injection_result.error_message = f"Statement {idx + 1}: {injection_result.error_message}"
|
||||
return injection_result
|
||||
|
||||
# Check query complexity
|
||||
complexity_result = await self._check_query_complexity(parsed)
|
||||
if not complexity_result.is_valid:
|
||||
complexity_result.error_message = f"Statement {idx + 1}: {complexity_result.error_message}"
|
||||
return complexity_result
|
||||
|
||||
# Check table access permissions
|
||||
table_result = await self._check_table_access(parsed, auth_context)
|
||||
if not table_result.is_valid:
|
||||
table_result.error_message = f"Statement {idx + 1}: {table_result.error_message}"
|
||||
return table_result
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
@@ -522,28 +964,69 @@ class SQLSecurityValidator:
|
||||
async def _check_sql_injection(
|
||||
self, sql: str, parsed: Statement
|
||||
) -> ValidationResult:
|
||||
"""Check SQL injection risks"""
|
||||
# Check common SQL injection patterns
|
||||
"""Check SQL injection risks with improved pattern detection
|
||||
|
||||
FIX for Issue #62 Bug 2: Improved patterns to reduce false positives
|
||||
Now better distinguishes between legitimate SQL (like BETWEEN...AND) and injection attempts
|
||||
"""
|
||||
# Improved injection patterns that are more specific and less prone to false positives
|
||||
injection_patterns = [
|
||||
r"(\s|^)(union|select|insert|update|delete|drop|create|alter)\s+.*\s+(union|select|insert|update|delete|drop|create|alter)",
|
||||
r"(\s|^)(or|and)\s+\d+\s*=\s*\d+",
|
||||
r"(\s|^)(or|and)\s+['\"].*['\"]",
|
||||
r";\s*(drop|delete|truncate|alter|create)",
|
||||
r"(exec|execute|sp_|xp_)",
|
||||
r"(script|javascript|vbscript)",
|
||||
r"(char|ascii|substring|concat)\s*\(",
|
||||
# Stacked queries with dangerous operations (true injection risk)
|
||||
r";\s*(DROP|DELETE|TRUNCATE|ALTER|CREATE|INSERT|UPDATE)\s+",
|
||||
|
||||
# UNION-based injection (but allow legitimate UNION queries)
|
||||
# Only flag if UNION is followed by suspicious patterns like SELECT with WHERE 1=1
|
||||
r"UNION\s+(ALL\s+)?SELECT\s+.*\s+(WHERE|AND|OR)\s+\d+\s*=\s*\d+",
|
||||
|
||||
# Boolean-based blind injection with comments (true injection pattern)
|
||||
r"(WHERE|AND|OR)\s+\d+\s*=\s*\d+\s*(--|#|/\*)",
|
||||
|
||||
# Quote-based injection attempts (but not in legitimate strings)
|
||||
r"(WHERE|AND|OR)\s+(['\"])[^\2]*\2\s*=\s*\2[^\2]*\2",
|
||||
|
||||
# Time-based blind injection
|
||||
r"(SLEEP|WAITFOR|BENCHMARK)\s*\(",
|
||||
|
||||
# System stored procedure injection
|
||||
r"(EXEC|EXECUTE|SP_|XP_)\s*\(",
|
||||
|
||||
# Script injection attempts
|
||||
r"<\s*(SCRIPT|JAVASCRIPT|VBSCRIPT)",
|
||||
]
|
||||
|
||||
sql_lower = sql.lower()
|
||||
# FIX: Don't flag legitimate SQL functions and keywords
|
||||
# These patterns are too broad and cause false positives:
|
||||
# - REMOVED: r"(char|ascii|substring|concat)\s*\(" - These are legitimate SQL functions
|
||||
# - REMOVED: r"(\s|^)(or|and)\s+\d+\s*=\s*\d+" - This flags BETWEEN...AND constructs
|
||||
# - REMOVED: r"(\s|^)(or|and)\s+['\"].*['\"]" - This is too broad
|
||||
|
||||
sql_upper = sql.upper()
|
||||
|
||||
# Special case: Allow BETWEEN...AND which is legitimate SQL
|
||||
# This prevents false positives like "WHERE dt BETWEEN '2025-01-01' AND '2025-01-31'"
|
||||
if "BETWEEN" in sql_upper and "AND" in sql_upper:
|
||||
# This is likely a BETWEEN clause, not injection
|
||||
# Check if AND appears in a BETWEEN context
|
||||
between_pattern = r"BETWEEN\s+[^\s]+\s+AND\s+[^\s]+"
|
||||
if re.search(between_pattern, sql_upper, re.IGNORECASE):
|
||||
# Remove BETWEEN clauses before checking other patterns
|
||||
sql_cleaned = re.sub(between_pattern, "BETWEEN_CLAUSE", sql_upper, flags=re.IGNORECASE)
|
||||
sql_to_check = sql_cleaned
|
||||
else:
|
||||
sql_to_check = sql_upper
|
||||
else:
|
||||
sql_to_check = sql_upper
|
||||
|
||||
for pattern in injection_patterns:
|
||||
if re.search(pattern, sql_lower, re.IGNORECASE):
|
||||
if re.search(pattern, sql_to_check, re.IGNORECASE):
|
||||
self.logger.warning(f"Potential SQL injection pattern detected: {pattern}")
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Potential SQL injection risk detected",
|
||||
risk_level="high",
|
||||
)
|
||||
|
||||
# Check suspicious quotes and comments
|
||||
# Check suspicious quotes and comments (with improved detection)
|
||||
if self._has_suspicious_quotes_or_comments(sql):
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
@@ -554,19 +1037,67 @@ class SQLSecurityValidator:
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
def _has_suspicious_quotes_or_comments(self, sql: str) -> bool:
|
||||
"""Check suspicious quote and comment patterns"""
|
||||
# Check unmatched quotes
|
||||
single_quotes = sql.count("'")
|
||||
double_quotes = sql.count('"')
|
||||
"""Check suspicious quote and comment patterns with improved detection
|
||||
|
||||
if single_quotes % 2 != 0 or double_quotes % 2 != 0:
|
||||
return True
|
||||
FIX for Issue #62 Bug 2: Improved detection to reduce false positives
|
||||
Now distinguishes between legitimate comments/strings and injection attempts
|
||||
"""
|
||||
try:
|
||||
# Use sqlparse to parse the SQL and distinguish between code and comments/strings
|
||||
import sqlparse
|
||||
from sqlparse.tokens import Comment, String
|
||||
|
||||
# Check SQL comments
|
||||
if "--" in sql or "/*" in sql:
|
||||
return True
|
||||
# Parse the SQL
|
||||
parsed = sqlparse.parse(sql)
|
||||
if not parsed:
|
||||
# If parsing fails, be conservative
|
||||
return True
|
||||
|
||||
return False
|
||||
statement = parsed[0]
|
||||
|
||||
# Check for unmatched quotes ONLY in non-string tokens
|
||||
# This prevents false positives from legitimate string content
|
||||
non_string_content = []
|
||||
has_string_tokens = False
|
||||
|
||||
for token in statement.flatten():
|
||||
if token.ttype in (String.Single, String.Double):
|
||||
has_string_tokens = True
|
||||
# Skip string content - quotes inside strings are legitimate
|
||||
continue
|
||||
elif token.ttype in (Comment.Single, Comment.Multi):
|
||||
# Comments are generally OK, but check for suspicious injection patterns
|
||||
comment_value = str(token).lower()
|
||||
# Check if comment contains dangerous SQL keywords
|
||||
dangerous_in_comments = ['drop', 'delete', 'insert', 'update', 'union', 'exec', 'execute']
|
||||
if any(keyword in comment_value for keyword in dangerous_in_comments):
|
||||
self.logger.warning(f"Suspicious SQL keyword in comment: {token}")
|
||||
return True
|
||||
# Normal comments are OK
|
||||
continue
|
||||
else:
|
||||
# Accumulate non-string, non-comment content
|
||||
non_string_content.append(str(token))
|
||||
|
||||
# Check for unmatched quotes in non-string content
|
||||
non_string_text = ''.join(non_string_content)
|
||||
single_quotes = non_string_text.count("'")
|
||||
double_quotes = non_string_text.count('"')
|
||||
|
||||
# Only flag if there are unmatched quotes in actual SQL code (not in strings)
|
||||
if single_quotes % 2 != 0 or double_quotes % 2 != 0:
|
||||
return True
|
||||
|
||||
# FIX: Don't flag legitimate SQL comments
|
||||
# Comments are OK as long as they don't contain dangerous patterns (already checked above)
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.debug(f"SQL parsing error in quote/comment check: {e}")
|
||||
# On parsing error, fall back to conservative check
|
||||
# But be more lenient than before
|
||||
return False # Don't flag on parse errors to reduce false positives
|
||||
|
||||
async def _check_blocked_keywords(self, parsed: Statement) -> ValidationResult:
|
||||
"""Check blocked keywords"""
|
||||
@@ -628,6 +1159,10 @@ class SQLSecurityValidator:
|
||||
self, parsed: Statement, auth_context: AuthContext
|
||||
) -> ValidationResult:
|
||||
"""Check table access permissions"""
|
||||
# If no auth_context, skip table access checks (rely on other security checks)
|
||||
if auth_context is None:
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
# Extract table names from query
|
||||
tables = self._extract_table_names(parsed)
|
||||
|
||||
@@ -676,7 +1211,7 @@ class DataMaskingProcessor:
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger = get_logger(__name__)
|
||||
self.masking_algorithms = self._init_masking_algorithms()
|
||||
self.masking_rules = self._load_masking_rules()
|
||||
|
||||
|
||||
788
doris_mcp_server/utils/security_analytics_tools.py
Normal file
@@ -0,0 +1,788 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Security Analytics Tools Module
|
||||
Provides data access analysis, user behavior monitoring, and security insights
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import get_auth_context
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SecurityAnalyticsTools:
|
||||
"""Security analytics tools for access pattern analysis and user monitoring"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
logger.info("SecurityAnalyticsTools initialized")
|
||||
|
||||
async def analyze_data_access_patterns(
|
||||
self,
|
||||
days: int = 7,
|
||||
include_system_users: bool = False,
|
||||
min_query_threshold: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyze data access patterns for users and roles
|
||||
|
||||
Args:
|
||||
days: Number of days to analyze
|
||||
include_system_users: Whether to include system/service users
|
||||
min_query_threshold: Minimum queries for a user to be included in analysis
|
||||
|
||||
Returns:
|
||||
Comprehensive access pattern analysis
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# 🚀 PROGRESS: Initialize security analysis
|
||||
logger.info("=" * 70)
|
||||
logger.info(f"🔒 Starting Data Access Pattern Analysis")
|
||||
logger.info(f"📅 Analysis period: {days} days")
|
||||
logger.info(f"👥 Include system users: {include_system_users}")
|
||||
logger.info(f"🎯 Min query threshold: {min_query_threshold}")
|
||||
logger.info("=" * 70)
|
||||
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
|
||||
# Define analysis period
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=days)
|
||||
|
||||
logger.info(f"📊 Period: {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')}")
|
||||
|
||||
# 🚀 PROGRESS: Step 1 - Get audit log data
|
||||
logger.info("📋 Step 1/5: Retrieving audit log data...")
|
||||
audit_start = time.time()
|
||||
audit_data = await self._get_audit_log_data(connection, start_date, end_date, include_system_users)
|
||||
audit_time = time.time() - audit_start
|
||||
|
||||
if not audit_data:
|
||||
logger.warning("⚠️ No audit data available for the specified period")
|
||||
return {
|
||||
"error": "No audit data available for the specified period",
|
||||
"analysis_period": {
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": end_date.isoformat(),
|
||||
"days": days
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(f"✅ Retrieved {len(audit_data)} audit records in {audit_time:.2f}s")
|
||||
|
||||
# 🚀 PROGRESS: Step 2 - Analyze user access patterns
|
||||
logger.info("👤 Step 2/5: Analyzing user access patterns...")
|
||||
user_start = time.time()
|
||||
user_access_analysis = await self._analyze_user_access_patterns(
|
||||
audit_data, min_query_threshold
|
||||
)
|
||||
user_time = time.time() - user_start
|
||||
logger.info(f"✅ Analyzed {len(user_access_analysis)} users in {user_time:.2f}s")
|
||||
|
||||
# 🚀 PROGRESS: Step 3 - Analyze role-based access
|
||||
logger.info("🎭 Step 3/5: Analyzing role-based access patterns...")
|
||||
role_start = time.time()
|
||||
role_access_analysis = await self._analyze_role_access_patterns(
|
||||
connection, user_access_analysis
|
||||
)
|
||||
role_time = time.time() - role_start
|
||||
logger.info(f"✅ Role analysis completed in {role_time:.2f}s")
|
||||
|
||||
# 🚀 PROGRESS: Step 4 - Detect security anomalies
|
||||
logger.info("🚨 Step 4/5: Detecting security anomalies...")
|
||||
anomaly_start = time.time()
|
||||
security_alerts = await self._detect_security_anomalies(
|
||||
audit_data, user_access_analysis
|
||||
)
|
||||
anomaly_time = time.time() - anomaly_start
|
||||
logger.info(f"✅ Found {len(security_alerts)} security alerts in {anomaly_time:.2f}s")
|
||||
|
||||
# Log alert summary
|
||||
if security_alerts:
|
||||
high_alerts = sum(1 for alert in security_alerts if alert.get("severity") == "high")
|
||||
medium_alerts = sum(1 for alert in security_alerts if alert.get("severity") == "medium")
|
||||
logger.info(f"🚨 Alert breakdown: {high_alerts} high, {medium_alerts} medium")
|
||||
|
||||
# 🚀 PROGRESS: Step 5 - Generate access insights
|
||||
logger.info("💡 Step 5/5: Generating access insights...")
|
||||
insights_start = time.time()
|
||||
access_insights = await self._generate_access_insights(
|
||||
user_access_analysis, role_access_analysis
|
||||
)
|
||||
insights_time = time.time() - insights_start
|
||||
logger.info(f"✅ Access insights generated in {insights_time:.2f}s")
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
"analysis_period": {
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": end_date.isoformat(),
|
||||
"days": days
|
||||
},
|
||||
"analysis_timestamp": datetime.now().isoformat(),
|
||||
"execution_time_seconds": round(execution_time, 3),
|
||||
"user_access_summary": self._generate_user_access_summary(user_access_analysis),
|
||||
"user_access_details": user_access_analysis,
|
||||
"role_analysis": role_access_analysis,
|
||||
"security_alerts": security_alerts,
|
||||
"access_insights": access_insights,
|
||||
"recommendations": self._generate_security_recommendations(security_alerts, access_insights)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Data access pattern analysis failed: {str(e)}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"analysis_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# ==================== Private Helper Methods ====================
|
||||
|
||||
async def _get_audit_log_data(self, connection, start_date: datetime, end_date: datetime, include_system_users: bool) -> List[Dict]:
|
||||
"""Retrieve audit log data for the specified period"""
|
||||
try:
|
||||
# System users filter
|
||||
system_user_filter = ""
|
||||
if not include_system_users:
|
||||
system_users = ['root', 'admin', 'system', 'doris', 'information_schema']
|
||||
user_list = ','.join([f'"{user}"' for user in system_users])
|
||||
system_user_filter = f"AND `user` NOT IN ({user_list})"
|
||||
|
||||
audit_sql = f"""
|
||||
SELECT
|
||||
`user` as user_name,
|
||||
`client_ip` as host,
|
||||
`time` as query_time,
|
||||
`stmt` as sql_statement,
|
||||
`state` as query_status,
|
||||
`scan_bytes` as scan_bytes,
|
||||
`scan_rows` as scan_rows,
|
||||
`return_rows` as return_rows,
|
||||
`query_time` as execution_time_ms
|
||||
FROM internal.__internal_schema.audit_log
|
||||
WHERE `time` >= '{start_date.strftime('%Y-%m-%d %H:%M:%S')}'
|
||||
AND `time` <= '{end_date.strftime('%Y-%m-%d %H:%M:%S')}'
|
||||
AND `stmt` IS NOT NULL
|
||||
AND `stmt` != ''
|
||||
{system_user_filter}
|
||||
ORDER BY `time` DESC
|
||||
LIMIT 10000
|
||||
"""
|
||||
|
||||
# SECURITY FIX: Pass auth_context to execute
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(audit_sql, auth_context=auth_context)
|
||||
return result.data if result.data else []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get audit log data: {str(e)}")
|
||||
# Try alternative method without detailed metrics
|
||||
try:
|
||||
simple_audit_sql = f"""
|
||||
SELECT
|
||||
`user` as user_name,
|
||||
`client_ip` as host,
|
||||
`time` as query_time,
|
||||
`stmt` as sql_statement,
|
||||
`state` as query_status
|
||||
FROM internal.__internal_schema.audit_log
|
||||
WHERE `time` >= '{start_date.strftime('%Y-%m-%d %H:%M:%S')}'
|
||||
AND `time` <= '{end_date.strftime('%Y-%m-%d %H:%M:%S')}'
|
||||
AND `stmt` IS NOT NULL
|
||||
{system_user_filter}
|
||||
ORDER BY `time` DESC
|
||||
LIMIT 10000
|
||||
"""
|
||||
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(simple_audit_sql, auth_context=auth_context)
|
||||
return result.data if result.data else []
|
||||
|
||||
except Exception as e2:
|
||||
logger.error(f"Failed to get simplified audit log data: {str(e2)}")
|
||||
return []
|
||||
|
||||
async def _analyze_user_access_patterns(self, audit_data: List[Dict], min_query_threshold: int) -> List[Dict]:
|
||||
"""Analyze access patterns for individual users"""
|
||||
user_stats = defaultdict(lambda: {
|
||||
"total_queries": 0,
|
||||
"unique_tables_accessed": set(),
|
||||
"hosts": set(),
|
||||
"query_types": Counter(),
|
||||
"query_times": [],
|
||||
"failed_queries": 0,
|
||||
"data_volume_read_bytes": 0,
|
||||
"data_volume_read_rows": 0,
|
||||
"hourly_pattern": [0] * 24,
|
||||
"daily_pattern": [0] * 7,
|
||||
"query_statements": []
|
||||
})
|
||||
|
||||
# Process audit data
|
||||
for entry in audit_data:
|
||||
user_name = entry.get("user_name", "unknown")
|
||||
query_time = entry.get("query_time")
|
||||
sql_statement = entry.get("sql_statement", "")
|
||||
query_status = entry.get("query_status", "")
|
||||
|
||||
stats = user_stats[user_name]
|
||||
stats["total_queries"] += 1
|
||||
|
||||
# Extract table names from SQL
|
||||
tables = self._extract_table_names_from_sql(sql_statement)
|
||||
stats["unique_tables_accessed"].update(tables)
|
||||
|
||||
# Host tracking
|
||||
if entry.get("host"):
|
||||
stats["hosts"].add(entry["host"])
|
||||
|
||||
# Query type analysis
|
||||
query_type = self._classify_query_type(sql_statement)
|
||||
stats["query_types"][query_type] += 1
|
||||
|
||||
# Query time patterns
|
||||
if query_time:
|
||||
try:
|
||||
if isinstance(query_time, str):
|
||||
query_dt = datetime.fromisoformat(query_time.replace('Z', '+00:00'))
|
||||
else:
|
||||
query_dt = query_time
|
||||
|
||||
stats["query_times"].append(query_dt)
|
||||
stats["hourly_pattern"][query_dt.hour] += 1
|
||||
stats["daily_pattern"][query_dt.weekday()] += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Error tracking
|
||||
if query_status and "error" in query_status.lower():
|
||||
stats["failed_queries"] += 1
|
||||
|
||||
# Data volume tracking
|
||||
if entry.get("scan_bytes"):
|
||||
try:
|
||||
stats["data_volume_read_bytes"] += int(entry["scan_bytes"])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if entry.get("scan_rows"):
|
||||
try:
|
||||
stats["data_volume_read_rows"] += int(entry["scan_rows"])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Store sample queries
|
||||
if len(stats["query_statements"]) < 10:
|
||||
stats["query_statements"].append({
|
||||
"sql": sql_statement[:200] + "..." if len(sql_statement) > 200 else sql_statement,
|
||||
"timestamp": str(query_time),
|
||||
"type": query_type
|
||||
})
|
||||
|
||||
# Convert to analysis results
|
||||
user_analysis = []
|
||||
for user_name, stats in user_stats.items():
|
||||
if stats["total_queries"] >= min_query_threshold:
|
||||
# Calculate patterns and insights
|
||||
access_pattern = self._classify_access_pattern(stats["hourly_pattern"])
|
||||
table_access_frequency = dict(Counter(
|
||||
table for entry in audit_data
|
||||
if entry.get("user_name") == user_name
|
||||
for table in self._extract_table_names_from_sql(entry.get("sql_statement", ""))
|
||||
).most_common(10))
|
||||
|
||||
user_analysis.append({
|
||||
"user_name": user_name,
|
||||
"access_stats": {
|
||||
"total_queries": stats["total_queries"],
|
||||
"unique_tables_accessed": len(stats["unique_tables_accessed"]),
|
||||
"unique_hosts": len(stats["hosts"]),
|
||||
"data_volume_read_gb": round(stats["data_volume_read_bytes"] / (1024**3), 3),
|
||||
"data_volume_read_rows": stats["data_volume_read_rows"],
|
||||
"failed_queries": stats["failed_queries"],
|
||||
"success_rate": round((stats["total_queries"] - stats["failed_queries"]) / stats["total_queries"], 3) if stats["total_queries"] > 0 else 0,
|
||||
"peak_access_hour": stats["hourly_pattern"].index(max(stats["hourly_pattern"])) if max(stats["hourly_pattern"]) > 0 else None,
|
||||
"access_pattern": access_pattern
|
||||
},
|
||||
"query_type_distribution": dict(stats["query_types"]),
|
||||
"table_access_frequency": table_access_frequency,
|
||||
"hosts_used": list(stats["hosts"]),
|
||||
"sample_queries": stats["query_statements"],
|
||||
"temporal_patterns": {
|
||||
"hourly_distribution": stats["hourly_pattern"],
|
||||
"daily_distribution": stats["daily_pattern"]
|
||||
}
|
||||
})
|
||||
|
||||
return sorted(user_analysis, key=lambda x: x["access_stats"]["total_queries"], reverse=True)
|
||||
|
||||
def _extract_table_names_from_sql(self, sql: str) -> List[str]:
|
||||
"""Extract table names from SQL statement (simplified implementation)"""
|
||||
if not sql:
|
||||
return []
|
||||
|
||||
import re
|
||||
|
||||
# Simple regex patterns to match table names
|
||||
patterns = [
|
||||
r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
|
||||
r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
|
||||
r'\bINTO\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
|
||||
r'\bUPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
|
||||
r'\bDELETE\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)'
|
||||
]
|
||||
|
||||
tables = []
|
||||
for pattern in patterns:
|
||||
matches = re.findall(pattern, sql, re.IGNORECASE)
|
||||
tables.extend(matches)
|
||||
|
||||
# Clean up table names (remove quotes, aliases, etc.)
|
||||
cleaned_tables = []
|
||||
for table in tables:
|
||||
# Remove backticks, quotes, and get just the table name
|
||||
clean_table = table.strip('`"\'').split(' ')[0]
|
||||
if clean_table and not clean_table.upper() in ['SELECT', 'WHERE', 'AND', 'OR']:
|
||||
cleaned_tables.append(clean_table)
|
||||
|
||||
return list(set(cleaned_tables))
|
||||
|
||||
def _classify_query_type(self, sql: str) -> str:
|
||||
"""Classify SQL query type"""
|
||||
if not sql:
|
||||
return "unknown"
|
||||
|
||||
sql_upper = sql.upper().strip()
|
||||
|
||||
if sql_upper.startswith('SELECT'):
|
||||
return "SELECT"
|
||||
elif sql_upper.startswith('INSERT'):
|
||||
return "INSERT"
|
||||
elif sql_upper.startswith('UPDATE'):
|
||||
return "UPDATE"
|
||||
elif sql_upper.startswith('DELETE'):
|
||||
return "DELETE"
|
||||
elif sql_upper.startswith('CREATE'):
|
||||
return "CREATE"
|
||||
elif sql_upper.startswith('ALTER'):
|
||||
return "ALTER"
|
||||
elif sql_upper.startswith('DROP'):
|
||||
return "DROP"
|
||||
elif sql_upper.startswith('SHOW'):
|
||||
return "SHOW"
|
||||
elif sql_upper.startswith('DESCRIBE') or sql_upper.startswith('DESC'):
|
||||
return "DESCRIBE"
|
||||
else:
|
||||
return "OTHER"
|
||||
|
||||
def _classify_access_pattern(self, hourly_pattern: List[int]) -> str:
|
||||
"""Classify user access pattern based on hourly distribution"""
|
||||
if not hourly_pattern or max(hourly_pattern) == 0:
|
||||
return "no_pattern"
|
||||
|
||||
# Find peak hours
|
||||
max_queries = max(hourly_pattern)
|
||||
peak_hours = [i for i, count in enumerate(hourly_pattern) if count == max_queries]
|
||||
|
||||
# Business hours: 9-17
|
||||
business_hours = set(range(9, 18))
|
||||
peak_in_business_hours = any(hour in business_hours for hour in peak_hours)
|
||||
|
||||
# Night hours: 22-6
|
||||
night_hours = set(list(range(22, 24)) + list(range(0, 7)))
|
||||
peak_in_night_hours = any(hour in night_hours for hour in peak_hours)
|
||||
|
||||
if peak_in_business_hours and not peak_in_night_hours:
|
||||
return "regular_business_hours"
|
||||
elif peak_in_night_hours:
|
||||
return "night_shift_or_batch"
|
||||
elif len(peak_hours) > 6: # Distributed throughout day
|
||||
return "distributed_access"
|
||||
else:
|
||||
return "irregular_pattern"
|
||||
|
||||
async def _analyze_role_access_patterns(self, connection, user_access_analysis: List[Dict]) -> Dict[str, Any]:
|
||||
"""Analyze access patterns by role"""
|
||||
try:
|
||||
# Get user roles information
|
||||
user_roles = await self._get_user_roles(connection)
|
||||
|
||||
# Group users by roles
|
||||
role_stats = defaultdict(lambda: {
|
||||
"user_count": 0,
|
||||
"total_queries": 0,
|
||||
"unique_tables": set(),
|
||||
"query_types": Counter(),
|
||||
"avg_queries_per_user": 0,
|
||||
"users": []
|
||||
})
|
||||
|
||||
# Process user access data
|
||||
for user_data in user_access_analysis:
|
||||
user_name = user_data["user_name"]
|
||||
user_stats = user_data["access_stats"]
|
||||
query_types = user_data["query_type_distribution"]
|
||||
|
||||
# Get user roles (default to 'unknown' if not found)
|
||||
roles = user_roles.get(user_name, ["unknown"])
|
||||
|
||||
for role in roles:
|
||||
stats = role_stats[role]
|
||||
stats["user_count"] += 1
|
||||
stats["total_queries"] += user_stats["total_queries"]
|
||||
stats["users"].append(user_name)
|
||||
|
||||
# Aggregate query types
|
||||
for query_type, count in query_types.items():
|
||||
stats["query_types"][query_type] += count
|
||||
|
||||
# Calculate role analysis
|
||||
role_analysis = {}
|
||||
for role, stats in role_stats.items():
|
||||
if stats["user_count"] > 0:
|
||||
avg_queries = stats["total_queries"] / stats["user_count"]
|
||||
|
||||
# Calculate privilege usage (simplified)
|
||||
total_role_queries = sum(stats["query_types"].values())
|
||||
privilege_usage = {}
|
||||
if total_role_queries > 0:
|
||||
privilege_usage = {
|
||||
query_type: round(count / total_role_queries, 3)
|
||||
for query_type, count in stats["query_types"].items()
|
||||
}
|
||||
|
||||
role_analysis[role] = {
|
||||
"user_count": stats["user_count"],
|
||||
"users": stats["users"],
|
||||
"total_queries": stats["total_queries"],
|
||||
"avg_queries_per_user": round(avg_queries, 1),
|
||||
"query_type_distribution": dict(stats["query_types"]),
|
||||
"privilege_usage": privilege_usage,
|
||||
"activity_level": self._classify_role_activity_level(avg_queries)
|
||||
}
|
||||
|
||||
return role_analysis
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze role access patterns: {str(e)}")
|
||||
return {}
|
||||
|
||||
async def _get_user_roles(self, connection) -> Dict[str, List[str]]:
|
||||
"""Get user roles mapping"""
|
||||
try:
|
||||
# Try to get user role information
|
||||
roles_sql = """
|
||||
SELECT
|
||||
User as user_name,
|
||||
COALESCE(Default_role, 'default') as role_name
|
||||
FROM mysql.user
|
||||
"""
|
||||
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(roles_sql, auth_context=auth_context)
|
||||
|
||||
user_roles = defaultdict(list)
|
||||
if result.data:
|
||||
for row in result.data:
|
||||
user_name = row.get("user_name", "")
|
||||
role_name = row.get("role_name", "default")
|
||||
if user_name:
|
||||
user_roles[user_name].append(role_name)
|
||||
|
||||
return dict(user_roles)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get user roles: {str(e)}")
|
||||
return {}
|
||||
|
||||
def _classify_role_activity_level(self, avg_queries: float) -> str:
|
||||
"""Classify role activity level based on average queries"""
|
||||
if avg_queries > 100:
|
||||
return "high"
|
||||
elif avg_queries > 20:
|
||||
return "medium"
|
||||
elif avg_queries > 5:
|
||||
return "low"
|
||||
else:
|
||||
return "minimal"
|
||||
|
||||
async def _detect_security_anomalies(self, audit_data: List[Dict], user_access_analysis: List[Dict]) -> List[Dict]:
|
||||
"""Detect potential security anomalies"""
|
||||
alerts = []
|
||||
|
||||
# 1. Detect unusual access times
|
||||
for user_data in user_access_analysis:
|
||||
user_name = user_data["user_name"]
|
||||
hourly_pattern = user_data["temporal_patterns"]["hourly_distribution"]
|
||||
|
||||
# Check for significant night-time activity
|
||||
night_queries = sum(hourly_pattern[22:24]) + sum(hourly_pattern[0:6])
|
||||
total_queries = sum(hourly_pattern)
|
||||
|
||||
if total_queries > 0 and night_queries / total_queries > 0.3: # >30% night activity
|
||||
alerts.append({
|
||||
"alert_type": "unusual_access_time",
|
||||
"severity": "medium",
|
||||
"user": user_name,
|
||||
"description": f"User {user_name} has {night_queries/total_queries:.1%} of queries during night hours",
|
||||
"night_query_percentage": round(night_queries/total_queries, 3),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 2. Detect users with high failure rates
|
||||
for user_data in user_access_analysis:
|
||||
user_name = user_data["user_name"]
|
||||
success_rate = user_data["access_stats"]["success_rate"]
|
||||
total_queries = user_data["access_stats"]["total_queries"]
|
||||
|
||||
if total_queries > 10 and success_rate < 0.8: # <80% success rate
|
||||
alerts.append({
|
||||
"alert_type": "high_failure_rate",
|
||||
"severity": "medium",
|
||||
"user": user_name,
|
||||
"description": f"User {user_name} has low query success rate ({success_rate:.1%})",
|
||||
"success_rate": success_rate,
|
||||
"total_queries": total_queries,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 3. Detect unusual data volume access
|
||||
data_volumes = [user["access_stats"]["data_volume_read_gb"] for user in user_access_analysis]
|
||||
if data_volumes:
|
||||
avg_volume = sum(data_volumes) / len(data_volumes)
|
||||
std_dev = (sum((x - avg_volume) ** 2 for x in data_volumes) / len(data_volumes)) ** 0.5
|
||||
threshold = avg_volume + 2 * std_dev # 2 standard deviations above mean
|
||||
|
||||
for user_data in user_access_analysis:
|
||||
user_name = user_data["user_name"]
|
||||
volume = user_data["access_stats"]["data_volume_read_gb"]
|
||||
|
||||
if volume > threshold and volume > 1.0: # >1GB and above threshold
|
||||
alerts.append({
|
||||
"alert_type": "unusual_data_volume",
|
||||
"severity": "high" if volume > threshold * 2 else "medium",
|
||||
"user": user_name,
|
||||
"description": f"User {user_name} read {volume:.2f}GB (threshold: {threshold:.2f}GB)",
|
||||
"data_volume_gb": volume,
|
||||
"threshold_gb": round(threshold, 2),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 4. Detect users accessing many different tables
|
||||
for user_data in user_access_analysis:
|
||||
user_name = user_data["user_name"]
|
||||
unique_tables = user_data["access_stats"]["unique_tables_accessed"]
|
||||
total_queries = user_data["access_stats"]["total_queries"]
|
||||
|
||||
# High table diversity might indicate privilege escalation or data mining
|
||||
if unique_tables > 20 and total_queries > 50:
|
||||
alerts.append({
|
||||
"alert_type": "broad_table_access",
|
||||
"severity": "medium",
|
||||
"user": user_name,
|
||||
"description": f"User {user_name} accessed {unique_tables} different tables",
|
||||
"unique_tables_count": unique_tables,
|
||||
"total_queries": total_queries,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
return sorted(alerts, key=lambda x: {"high": 3, "medium": 2, "low": 1}.get(x["severity"], 0), reverse=True)
|
||||
|
||||
async def _generate_access_insights(self, user_access_analysis: List[Dict], role_analysis: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Generate access insights and patterns"""
|
||||
insights = {
|
||||
"user_behavior_patterns": {},
|
||||
"role_effectiveness": {},
|
||||
"security_posture": {}
|
||||
}
|
||||
|
||||
# User behavior patterns
|
||||
if user_access_analysis:
|
||||
total_users = len(user_access_analysis)
|
||||
active_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 10])
|
||||
power_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 100])
|
||||
|
||||
# Access pattern distribution
|
||||
pattern_distribution = Counter(
|
||||
user["access_stats"]["access_pattern"] for user in user_access_analysis
|
||||
)
|
||||
|
||||
insights["user_behavior_patterns"] = {
|
||||
"total_users_analyzed": total_users,
|
||||
"active_users": active_users,
|
||||
"power_users": power_users,
|
||||
"access_pattern_distribution": dict(pattern_distribution),
|
||||
"avg_queries_per_user": round(
|
||||
sum(u["access_stats"]["total_queries"] for u in user_access_analysis) / total_users, 1
|
||||
) if total_users > 0 else 0
|
||||
}
|
||||
|
||||
# Role effectiveness
|
||||
if role_analysis:
|
||||
most_active_role = max(role_analysis.items(), key=lambda x: x[1]["total_queries"])
|
||||
least_active_role = min(role_analysis.items(), key=lambda x: x[1]["total_queries"])
|
||||
|
||||
insights["role_effectiveness"] = {
|
||||
"total_roles": len(role_analysis),
|
||||
"most_active_role": {
|
||||
"role": most_active_role[0],
|
||||
"total_queries": most_active_role[1]["total_queries"],
|
||||
"user_count": most_active_role[1]["user_count"]
|
||||
},
|
||||
"least_active_role": {
|
||||
"role": least_active_role[0],
|
||||
"total_queries": least_active_role[1]["total_queries"],
|
||||
"user_count": least_active_role[1]["user_count"]
|
||||
},
|
||||
"avg_users_per_role": round(
|
||||
sum(role_info["user_count"] for role_info in role_analysis.values()) / len(role_analysis), 1
|
||||
)
|
||||
}
|
||||
|
||||
# Security posture assessment
|
||||
if user_access_analysis:
|
||||
users_with_failures = len([u for u in user_access_analysis if u["access_stats"]["failed_queries"] > 0])
|
||||
users_night_access = len([
|
||||
u for u in user_access_analysis
|
||||
if any(u["temporal_patterns"]["hourly_distribution"][hour] > 0 for hour in list(range(22, 24)) + list(range(0, 6)))
|
||||
])
|
||||
|
||||
insights["security_posture"] = {
|
||||
"users_with_query_failures": users_with_failures,
|
||||
"users_with_night_access": users_night_access,
|
||||
"security_score": self._calculate_security_score(user_access_analysis),
|
||||
"risk_level": self._assess_overall_risk_level(user_access_analysis)
|
||||
}
|
||||
|
||||
return insights
|
||||
|
||||
def _calculate_security_score(self, user_access_analysis: List[Dict]) -> float:
|
||||
"""Calculate overall security score (0-1, higher is better)"""
|
||||
if not user_access_analysis:
|
||||
return 0.0
|
||||
|
||||
total_users = len(user_access_analysis)
|
||||
|
||||
# Factors that contribute to security score
|
||||
users_with_high_success_rate = len([u for u in user_access_analysis if u["access_stats"]["success_rate"] > 0.9])
|
||||
users_with_normal_patterns = len([u for u in user_access_analysis if u["access_stats"]["access_pattern"] == "regular_business_hours"])
|
||||
|
||||
success_rate_score = users_with_high_success_rate / total_users
|
||||
pattern_score = users_with_normal_patterns / total_users
|
||||
|
||||
# Combined score
|
||||
overall_score = (success_rate_score * 0.6 + pattern_score * 0.4)
|
||||
return round(overall_score, 3)
|
||||
|
||||
def _assess_overall_risk_level(self, user_access_analysis: List[Dict]) -> str:
|
||||
"""Assess overall security risk level"""
|
||||
security_score = self._calculate_security_score(user_access_analysis)
|
||||
|
||||
if security_score > 0.8:
|
||||
return "low"
|
||||
elif security_score > 0.6:
|
||||
return "medium"
|
||||
else:
|
||||
return "high"
|
||||
|
||||
def _generate_user_access_summary(self, user_access_analysis: List[Dict]) -> Dict[str, Any]:
|
||||
"""Generate summary statistics for user access"""
|
||||
if not user_access_analysis:
|
||||
return {
|
||||
"total_users": 0,
|
||||
"active_users": 0,
|
||||
"high_activity_users": 0,
|
||||
"dormant_users": 0
|
||||
}
|
||||
|
||||
total_users = len(user_access_analysis)
|
||||
active_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 10])
|
||||
high_activity_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 100])
|
||||
dormant_users = total_users - active_users
|
||||
|
||||
return {
|
||||
"total_users": total_users,
|
||||
"active_users": active_users,
|
||||
"high_activity_users": high_activity_users,
|
||||
"dormant_users": dormant_users,
|
||||
"activity_distribution": {
|
||||
"high": high_activity_users,
|
||||
"medium": active_users - high_activity_users,
|
||||
"low": dormant_users
|
||||
}
|
||||
}
|
||||
|
||||
def _generate_security_recommendations(self, security_alerts: List[Dict], access_insights: Dict[str, Any]) -> List[Dict]:
|
||||
"""Generate security recommendations based on analysis"""
|
||||
recommendations = []
|
||||
|
||||
# Recommendations based on alerts
|
||||
if security_alerts:
|
||||
high_severity_alerts = [alert for alert in security_alerts if alert["severity"] == "high"]
|
||||
if high_severity_alerts:
|
||||
recommendations.append({
|
||||
"type": "urgent_security_review",
|
||||
"priority": "high",
|
||||
"description": f"Found {len(high_severity_alerts)} high-severity security alerts",
|
||||
"action": "Immediate review of flagged users and access patterns required",
|
||||
"affected_users": list(set(alert["user"] for alert in high_severity_alerts if "user" in alert))
|
||||
})
|
||||
|
||||
# Night access recommendations
|
||||
night_access_alerts = [alert for alert in security_alerts if alert["alert_type"] == "unusual_access_time"]
|
||||
if night_access_alerts:
|
||||
recommendations.append({
|
||||
"type": "access_time_policy",
|
||||
"priority": "medium",
|
||||
"description": f"{len(night_access_alerts)} users have significant night-time access",
|
||||
"action": "Review access time policies and consider time-based restrictions",
|
||||
"affected_users": [alert["user"] for alert in night_access_alerts]
|
||||
})
|
||||
|
||||
# Recommendations based on insights
|
||||
security_posture = access_insights.get("security_posture", {})
|
||||
risk_level = security_posture.get("risk_level", "unknown")
|
||||
|
||||
if risk_level == "high":
|
||||
recommendations.append({
|
||||
"type": "overall_security_improvement",
|
||||
"priority": "high",
|
||||
"description": "Overall security posture indicates high risk",
|
||||
"action": "Comprehensive security audit and policy review recommended"
|
||||
})
|
||||
|
||||
# Role-based recommendations
|
||||
role_effectiveness = access_insights.get("role_effectiveness", {})
|
||||
if role_effectiveness and role_effectiveness.get("total_roles", 0) < 3:
|
||||
recommendations.append({
|
||||
"type": "role_management",
|
||||
"priority": "medium",
|
||||
"description": "Limited role diversity detected",
|
||||
"action": "Consider implementing more granular role-based access control"
|
||||
})
|
||||
|
||||
return recommendations
|
||||
301
doris_mcp_server/utils/sql_security_utils.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
SQL Security Utilities Module
|
||||
|
||||
Provides SQL identifier validation, escaping, and safe query building utilities
|
||||
to prevent SQL injection attacks.
|
||||
"""
|
||||
|
||||
import re
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional, Tuple, List, Any
|
||||
|
||||
from .logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Context variable for auth_context (set by HTTP middleware)
|
||||
auth_context_var: ContextVar = ContextVar('mcp_auth_context', default=None)
|
||||
|
||||
|
||||
class SQLSecurityError(Exception):
|
||||
"""Exception raised for SQL security validation failures"""
|
||||
pass
|
||||
|
||||
|
||||
class SQLSecurityUtils:
|
||||
"""
|
||||
SQL Security Utilities for preventing SQL injection attacks.
|
||||
|
||||
Provides:
|
||||
- Identifier validation (database names, table names, column names)
|
||||
- Safe identifier quoting with backticks
|
||||
- Safe table reference building
|
||||
- Auth context retrieval from context variables
|
||||
"""
|
||||
|
||||
# Valid SQL identifier pattern: letters, numbers, underscores
|
||||
# Must start with letter or underscore, not a number
|
||||
# Supports Unicode letters for international database/table names
|
||||
IDENTIFIER_PATTERN = re.compile(r'^[a-zA-Z_\u4e00-\u9fff][a-zA-Z0-9_\u4e00-\u9fff]*$')
|
||||
|
||||
# Maximum identifier length (MySQL/Doris standard)
|
||||
MAX_IDENTIFIER_LENGTH = 64
|
||||
|
||||
# SQL reserved keywords that should be quoted
|
||||
SQL_KEYWORDS = {
|
||||
'SELECT', 'FROM', 'WHERE', 'INSERT', 'UPDATE', 'DELETE', 'DROP',
|
||||
'CREATE', 'ALTER', 'TABLE', 'DATABASE', 'INDEX', 'VIEW', 'AND',
|
||||
'OR', 'NOT', 'NULL', 'TRUE', 'FALSE', 'IN', 'LIKE', 'BETWEEN',
|
||||
'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'ON', 'AS', 'ORDER',
|
||||
'BY', 'GROUP', 'HAVING', 'LIMIT', 'OFFSET', 'UNION', 'ALL',
|
||||
'DISTINCT', 'INTO', 'VALUES', 'SET', 'DEFAULT', 'PRIMARY', 'KEY',
|
||||
'FOREIGN', 'REFERENCES', 'CHECK', 'UNIQUE', 'CONSTRAINT'
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def validate_identifier(cls, name: str, identifier_type: str = "identifier") -> str:
|
||||
"""
|
||||
Validate a SQL identifier (database name, table name, column name, etc.)
|
||||
|
||||
Args:
|
||||
name: The identifier to validate
|
||||
identifier_type: Type description for error messages (e.g., "database name", "table name")
|
||||
|
||||
Returns:
|
||||
The validated identifier (unchanged if valid)
|
||||
|
||||
Raises:
|
||||
SQLSecurityError: If the identifier is invalid
|
||||
"""
|
||||
if not name:
|
||||
raise SQLSecurityError(f"Empty {identifier_type} is not allowed")
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise SQLSecurityError(f"Invalid {identifier_type}: must be a string, got {type(name).__name__}")
|
||||
|
||||
# Strip whitespace
|
||||
name = name.strip()
|
||||
|
||||
if not name:
|
||||
raise SQLSecurityError(f"Empty {identifier_type} is not allowed")
|
||||
|
||||
# Check length
|
||||
if len(name) > cls.MAX_IDENTIFIER_LENGTH:
|
||||
raise SQLSecurityError(
|
||||
f"Invalid {identifier_type}: '{name[:20]}...' exceeds maximum length of {cls.MAX_IDENTIFIER_LENGTH} characters"
|
||||
)
|
||||
|
||||
# Check for dangerous characters that could be SQL injection
|
||||
dangerous_chars = ["'", '"', ';', '--', '/*', '*/', '\\', '\x00']
|
||||
for char in dangerous_chars:
|
||||
if char in name:
|
||||
raise SQLSecurityError(
|
||||
f"Invalid {identifier_type}: '{name}' contains forbidden character '{char}'"
|
||||
)
|
||||
|
||||
# Validate pattern
|
||||
if not cls.IDENTIFIER_PATTERN.match(name):
|
||||
raise SQLSecurityError(
|
||||
f"Invalid {identifier_type}: '{name}' contains invalid characters. "
|
||||
f"Only letters, numbers, and underscores are allowed, and must start with a letter or underscore."
|
||||
)
|
||||
|
||||
logger.debug(f"Validated {identifier_type}: {name}")
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
def quote_identifier(cls, name: str, identifier_type: str = "identifier") -> str:
|
||||
"""
|
||||
Safely quote a SQL identifier using backticks.
|
||||
|
||||
Args:
|
||||
name: The identifier to quote
|
||||
identifier_type: Type description for error messages
|
||||
|
||||
Returns:
|
||||
The quoted identifier (e.g., `table_name`)
|
||||
|
||||
Raises:
|
||||
SQLSecurityError: If the identifier is invalid
|
||||
"""
|
||||
# First validate the identifier
|
||||
validated_name = cls.validate_identifier(name, identifier_type)
|
||||
|
||||
# Escape any backticks within the name (double them)
|
||||
escaped_name = validated_name.replace('`', '``')
|
||||
|
||||
return f"`{escaped_name}`"
|
||||
|
||||
@classmethod
|
||||
def build_table_reference(
|
||||
cls,
|
||||
table_name: str,
|
||||
db_name: Optional[str] = None,
|
||||
catalog_name: Optional[str] = None,
|
||||
quote: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Build a safe, fully-qualified table reference.
|
||||
|
||||
Args:
|
||||
table_name: The table name (required)
|
||||
db_name: The database name (optional)
|
||||
catalog_name: The catalog name (optional)
|
||||
quote: Whether to quote identifiers with backticks (default: True)
|
||||
|
||||
Returns:
|
||||
A safe table reference string (e.g., `catalog`.`db`.`table`)
|
||||
|
||||
Raises:
|
||||
SQLSecurityError: If any identifier is invalid
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if catalog_name:
|
||||
if quote:
|
||||
parts.append(cls.quote_identifier(catalog_name, "catalog name"))
|
||||
else:
|
||||
parts.append(cls.validate_identifier(catalog_name, "catalog name"))
|
||||
|
||||
if db_name:
|
||||
if quote:
|
||||
parts.append(cls.quote_identifier(db_name, "database name"))
|
||||
else:
|
||||
parts.append(cls.validate_identifier(db_name, "database name"))
|
||||
|
||||
if quote:
|
||||
parts.append(cls.quote_identifier(table_name, "table name"))
|
||||
else:
|
||||
parts.append(cls.validate_identifier(table_name, "table name"))
|
||||
|
||||
return '.'.join(parts)
|
||||
|
||||
@classmethod
|
||||
def build_column_reference(
|
||||
cls,
|
||||
column_name: str,
|
||||
table_name: Optional[str] = None,
|
||||
quote: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Build a safe column reference.
|
||||
|
||||
Args:
|
||||
column_name: The column name (required)
|
||||
table_name: The table name (optional, for qualified references)
|
||||
quote: Whether to quote identifiers with backticks (default: True)
|
||||
|
||||
Returns:
|
||||
A safe column reference string (e.g., `table`.`column`)
|
||||
|
||||
Raises:
|
||||
SQLSecurityError: If any identifier is invalid
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if table_name:
|
||||
if quote:
|
||||
parts.append(cls.quote_identifier(table_name, "table name"))
|
||||
else:
|
||||
parts.append(cls.validate_identifier(table_name, "table name"))
|
||||
|
||||
if quote:
|
||||
parts.append(cls.quote_identifier(column_name, "column name"))
|
||||
else:
|
||||
parts.append(cls.validate_identifier(column_name, "column name"))
|
||||
|
||||
return '.'.join(parts)
|
||||
|
||||
@classmethod
|
||||
def validate_and_build_where_condition(
|
||||
cls,
|
||||
column_name: str,
|
||||
operator: str = "=",
|
||||
use_param: bool = True
|
||||
) -> Tuple[str, bool]:
|
||||
"""
|
||||
Build a safe WHERE condition for a column.
|
||||
|
||||
Args:
|
||||
column_name: The column name
|
||||
operator: The comparison operator (=, !=, <, >, <=, >=, LIKE, IN)
|
||||
use_param: Whether to use parameterized placeholder (%s)
|
||||
|
||||
Returns:
|
||||
Tuple of (condition_string, needs_param)
|
||||
e.g., ("`column` = %s", True) or ("`column` = DATABASE()", False)
|
||||
|
||||
Raises:
|
||||
SQLSecurityError: If column name is invalid or operator is not allowed
|
||||
"""
|
||||
# Validate column name
|
||||
quoted_column = cls.quote_identifier(column_name, "column name")
|
||||
|
||||
# Validate operator
|
||||
allowed_operators = {'=', '!=', '<>', '<', '>', '<=', '>=', 'LIKE', 'IN', 'IS'}
|
||||
if operator.upper() not in allowed_operators:
|
||||
raise SQLSecurityError(f"Invalid operator: '{operator}'. Allowed: {allowed_operators}")
|
||||
|
||||
if use_param:
|
||||
return f"{quoted_column} {operator} %s", True
|
||||
else:
|
||||
return f"{quoted_column} {operator}", False
|
||||
|
||||
@staticmethod
|
||||
def get_auth_context():
|
||||
"""
|
||||
Get auth_context from the context variable.
|
||||
|
||||
This retrieves the auth_context that was set by the HTTP middleware
|
||||
during request processing.
|
||||
|
||||
Returns:
|
||||
The auth_context object, or None if not available
|
||||
"""
|
||||
try:
|
||||
auth_context = auth_context_var.get()
|
||||
if auth_context:
|
||||
logger.debug(f"Retrieved auth_context from context variable")
|
||||
return auth_context
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not retrieve auth_context: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def set_auth_context(auth_context):
|
||||
"""
|
||||
Set auth_context in the context variable.
|
||||
|
||||
This is typically called by the HTTP middleware during request processing.
|
||||
|
||||
Args:
|
||||
auth_context: The auth_context object to set
|
||||
"""
|
||||
auth_context_var.set(auth_context)
|
||||
logger.debug("Set auth_context in context variable")
|
||||
|
||||
|
||||
# Convenience functions for direct use
|
||||
validate_identifier = SQLSecurityUtils.validate_identifier
|
||||
quote_identifier = SQLSecurityUtils.quote_identifier
|
||||
build_table_reference = SQLSecurityUtils.build_table_reference
|
||||
build_column_reference = SQLSecurityUtils.build_column_reference
|
||||
get_auth_context = SQLSecurityUtils.get_auth_context
|
||||
set_auth_context = SQLSecurityUtils.set_auth_context
|
||||
|
||||
147
examples/cursor/README.md
Normal file
@@ -0,0 +1,147 @@
|
||||
<!--
|
||||
Licensed to the Apache Software Foundation (ASF) under one
|
||||
or more contributor license agreements. See the NOTICE file
|
||||
distributed with this work for additional information
|
||||
regarding copyright ownership. The ASF licenses this file
|
||||
to you under the Apache License, Version 2.0 (the
|
||||
"License"); you may not use this file except in compliance
|
||||
with the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing,
|
||||
software distributed under the License is distributed on an
|
||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
-->
|
||||
|
||||
|
||||
# Cursor Example: Integrating Doris MCP Server
|
||||
|
||||
This guide provides step-by-step instructions on how to integrate the `doris-mcp-server` with the [Cursor](https://cursor.sh/) IDE. This integration allows you to interact with your Apache Doris database using natural language queries directly within Cursor's AI chat.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
* [Prerequisites](#prerequisites)
|
||||
* [Step 1: Set Up the Project](#step-1-set-up-the-project)
|
||||
* [Step 2: Configure the MCP Server in Cursor](#step-2-configure-the-mcp-server-in-cursor)
|
||||
* [Step 3: Verify the Integration](#step-3-verify-the-integration)
|
||||
* [Step 4: Query Your Database](#step-4-query-your-database)
|
||||
|
||||
* [Example 1: List Tables](#example-1-list-tables)
|
||||
* [Example 2: Analyze Sales Trends](#example-2-analyze-sales-trends)
|
||||
|
||||
---
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Before you begin, ensure you have the following installed and configured:
|
||||
|
||||
* The **Cursor** IDE
|
||||
* **Git** for cloning the repository
|
||||
* Access to an **Apache Doris** cluster (FE host, port, username, and password)
|
||||
* **uv**, a fast Python package installer and runner
|
||||
|
||||
You can install `uv` with one of the following commands:
|
||||
|
||||
```bash
|
||||
# For macOS (recommended)
|
||||
brew install uv
|
||||
|
||||
# For other systems using pipx
|
||||
pipx install uv
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Step 1: Set Up the Project
|
||||
|
||||
First, clone the `doris-mcp-server` repository to your local machine:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/apache/doris-mcp-server.git
|
||||
|
||||
cd doris-mcp-server
|
||||
```
|
||||
|
||||
The necessary dependencies are listed in `requirements.txt` and will be managed automatically by `uv` in the next step.
|
||||
|
||||
---
|
||||
|
||||
### Step 2: Configure the MCP Server in Cursor
|
||||
|
||||
1. Open the cloned `doris-mcp-server` directory in Cursor.
|
||||
2. Click the ⚙️ icon (top-right), then go to **Tools & Integrations**.
|
||||

|
||||
3. Click **Add a custom MCP Server**.
|
||||
4. Paste the following JSON configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"doris-mcp": {
|
||||
"command": "uv",
|
||||
"args": [
|
||||
"run",
|
||||
"--project",
|
||||
"/path/to/your/doris-mcp-server",
|
||||
"doris-mcp-server"
|
||||
],
|
||||
"env": {
|
||||
"DORIS_HOST": "your_doris_fe_host",
|
||||
"DORIS_PORT": "9030",
|
||||
"DORIS_USER": "your_username",
|
||||
"DORIS_PASSWORD": "your_password",
|
||||
"DORIS_DATABASE": "ssb"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> ⚠️ **Important:**
|
||||
>
|
||||
> * Replace `"/path/to/your/doris-mcp-server"` with the **absolute path** to your local project directory.
|
||||
> * Fill in your actual Doris FE host, username, password, and database name.
|
||||
|
||||
---
|
||||
|
||||
### Step 3: Verify the Integration
|
||||
|
||||
Once saved, go back to the **Settings** panel. If everything is configured correctly, you’ll see a green status dot next to `doris-mcp-server`, along with available tools like `exec_query`.
|
||||
|
||||

|
||||
|
||||
---
|
||||
|
||||
### Step 4: Query Your Database
|
||||
|
||||
You can now chat with Cursor Agent to run SQL queries against your Doris database.
|
||||
|
||||
1. Open the chat panel using `Cmd + K` (macOS) or `Ctrl + K` (Windows/Linux), or click the chat icon in the top-right.
|
||||
2. Switch to **Agent Mode**.
|
||||
3. Start asking questions using natural language.
|
||||
|
||||

|
||||
|
||||
---
|
||||
|
||||
#### Example 1: List Tables
|
||||
|
||||
> **Prompt:** What tables are in the `ssb` database?
|
||||
|
||||
The agent will call the `get_db_table_list` tool and return the results.
|
||||
|
||||

|
||||
|
||||
---
|
||||
|
||||
#### Example 2: Analyze Sales Trends
|
||||
|
||||
> **Prompt:** What has been the sales trend over the past ten years in the `ssb` database, and which year had the fastest growth?
|
||||
|
||||
The agent will generate an appropriate SQL query, send it to the MCP server, and interpret the results to give you growth trends and highlights.
|
||||
|
||||

|
||||
156
examples/dify/README.md
Normal file
@@ -0,0 +1,156 @@
|
||||
<!--
|
||||
Licensed to the Apache Software Foundation (ASF) under one
|
||||
or more contributor license agreements. See the NOTICE file
|
||||
distributed with this work for additional information
|
||||
regarding copyright ownership. The ASF licenses this file
|
||||
to you under the Apache License, Version 2.0 (the
|
||||
"License"); you may not use this file except in compliance
|
||||
with the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing,
|
||||
software distributed under the License is distributed on an
|
||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
-->
|
||||
|
||||
|
||||
# Dify Example: Integrating Doris MCP Server
|
||||
|
||||
This document demonstrates how to integrate and use `doris-mcp-server` in Dify to perform Doris SQL calls via MCP.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Prerequisites](#prerequisites)
|
||||
- [Starting the MCP Server](#starting-the-mcp-server)
|
||||
- [Ngrok Tunnel (Optional)](#ngrok-tunnel-optional)
|
||||
- [Installing & Configuring the Plugin in Dify](#installing--configuring-the-plugin-in-dify)
|
||||
- [Creating a Dify App](#creating-a-dify-app)
|
||||
- [Adding MCP Tools](#adding-mcp-tools)
|
||||
- [Example Calls](#example-calls)
|
||||
|
||||
|
||||
-----
|
||||
|
||||
### Prerequisites
|
||||
|
||||
First, install `mcp-doris-server`:
|
||||
|
||||
```bash
|
||||
pip install mcp-doris-server
|
||||
```
|
||||
|
||||
## Starting the MCP Server
|
||||
|
||||
Run the startup script:
|
||||
|
||||
```bash
|
||||
# Full configuration with database connection
|
||||
doris-mcp-server \
|
||||
--transport http \
|
||||
--host 0.0.0.0 \
|
||||
--port 3000 \
|
||||
--db-host 127.0.0.1 \
|
||||
--db-port 9030 \
|
||||
--db-user root \
|
||||
--db-password your_password
|
||||
```
|
||||
|
||||
If successful, you'll see logs similar to this:
|
||||
|
||||

|
||||
|
||||
-----
|
||||
|
||||
## Ngrok Tunnel (Optional)
|
||||
|
||||
If your Dify deployment requires a publicly accessible endpoint, you can use the **ngrok** tool. Ngrok is a third-party service that securely exposes local servers to the internet.
|
||||
|
||||
|
||||
-----
|
||||
|
||||
## Installing & Configuring the Plugin in Dify
|
||||
|
||||
1. In the Dify console, go to **Plugin Marketplace**, search for, and install **MCP‑SSE / StreamableHTTP**:
|
||||

|
||||
|
||||
2. After installation, click **Configure** and set the URL to your public or local address. For example, if you're using `ngrok`, this should be the public URL `ngrok` provides, in the format `https://<your-domain>/mcp`. If Dify can directly access your local server, use `http://localhost:3000/mcp`.
|
||||
|
||||
```json
|
||||
{
|
||||
"doris_mcp_server": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://<your-domain>/mcp"
|
||||
}
|
||||
}
|
||||
```
|
||||

|
||||
|
||||
3. Click **Save**. If configured correctly, you'll see a green **Authorized** indicator:
|
||||
|
||||

|
||||
|
||||
-----
|
||||
|
||||
## Creating a Dify App
|
||||
|
||||
1. In the Dify console, click **New App** → **Blank App**.
|
||||

|
||||
|
||||
2. Select **Agent** as the template and set the **App Name** (e.g., `Doris ChatBI`).
|
||||

|
||||
|
||||
|
||||
3. Import from DSL,[dify_doris_dsl.yml](dify_doris_dsl.yml)
|
||||
|
||||
-----
|
||||
|
||||
## Instructions & Tool Configuration
|
||||
|
||||
### Instruction Block
|
||||
|
||||
Paste the following into the **Instruction** field:
|
||||
|
||||
```
|
||||
<instruction>
|
||||
Use MCP tools to complete tasks as much as possible. Carefully read the annotations, method names, and parameter descriptions of each tool. Please follow these steps:
|
||||
1. Analyze the user's question and match the most appropriate tool.
|
||||
2. Use tool names and parameters exactly as defined; do not invent new ones.
|
||||
3. Pass parameters in the required JSON format.
|
||||
4. When calling tools, use:
|
||||
{"mcp_sse_call_tool": {"tool_name": "<tool_name>", "arguments": "{}"}}
|
||||
5. Output plain text only—no XML tags.
|
||||
<input>
|
||||
User question: user_query
|
||||
</input>
|
||||
<output>
|
||||
Return tool results or a final answer, including analysis.
|
||||
</output>
|
||||
</instruction>
|
||||
```
|
||||
|
||||
### Adding MCP Tools
|
||||
|
||||
In the **Tools** pane, click **Add** twice to add two entries, both named `mcp_sse` (they will inherit the transport and URL from the plugin):
|
||||

|
||||
|
||||
-----
|
||||
|
||||
## Example Calls
|
||||
|
||||
### List Tables in Database
|
||||
|
||||
* **User**: What tables are in the database?
|
||||
|
||||
* **Result**: Dify will call the MCP tool to run `SHOW TABLES` and return the list.
|
||||

|
||||
|
||||
### Sales Trend Over Ten Years
|
||||
|
||||
* **User**: What has been the sales trend over the past ten years in the ssb database, and which year had the fastest growth?
|
||||
|
||||
* **Result**: The tool will execute the SQL, calculate growth rates, and return data.
|
||||

|
||||
127
examples/dify/dify_doris_dsl.yml
Normal file
@@ -0,0 +1,127 @@
|
||||
app:
|
||||
description: ''
|
||||
icon: 🤖
|
||||
icon_background: '#FFEAD5'
|
||||
mode: agent-chat
|
||||
name: doris
|
||||
use_icon_as_answer_icon: false
|
||||
dependencies:
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: langgenius/deepseek:0.0.5@21408d5c48cd9f18d66b08883d0999fe89e6d049c891324c2229dea23b9665d5
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: junjiem/mcp_sse:0.2.1@53cc613667fcf91dd7208dd5f6d2c8df3c7ff0af8b79e8f3c0a430f1b39bda4c
|
||||
kind: app
|
||||
model_config:
|
||||
agent_mode:
|
||||
enabled: true
|
||||
max_iteration: 10
|
||||
prompt: null
|
||||
strategy: function_call
|
||||
tools:
|
||||
- enabled: true
|
||||
isDeleted: false
|
||||
notAuthor: false
|
||||
provider_id: junjiem/mcp_sse/mcp_sse
|
||||
provider_name: junjiem/mcp_sse/mcp_sse
|
||||
provider_type: builtin
|
||||
tool_label: 获取 MCP 工具列表
|
||||
tool_name: mcp_sse_list_tools
|
||||
tool_parameters:
|
||||
prompts_as_tools: 1
|
||||
resources_as_tools: 1
|
||||
servers_config: null
|
||||
- enabled: true
|
||||
isDeleted: false
|
||||
notAuthor: false
|
||||
provider_id: junjiem/mcp_sse/mcp_sse
|
||||
provider_name: junjiem/mcp_sse/mcp_sse
|
||||
provider_type: builtin
|
||||
tool_label: 调用 MCP 工具
|
||||
tool_name: mcp_sse_call_tool
|
||||
tool_parameters:
|
||||
arguments: ''
|
||||
prompts_as_tools: ''
|
||||
resources_as_tools: ''
|
||||
servers_config: ''
|
||||
tool_name: ''
|
||||
annotation_reply:
|
||||
enabled: false
|
||||
chat_prompt_config: {}
|
||||
completion_prompt_config: {}
|
||||
dataset_configs:
|
||||
datasets:
|
||||
datasets: []
|
||||
reranking_enable: true
|
||||
reranking_mode: reranking_model
|
||||
reranking_model:
|
||||
reranking_model_name: ''
|
||||
reranking_provider_name: ''
|
||||
retrieval_model: multiple
|
||||
top_k: 4
|
||||
dataset_query_variable: ''
|
||||
external_data_tools: []
|
||||
file_upload:
|
||||
allowed_file_extensions:
|
||||
- .JPG
|
||||
- .JPEG
|
||||
- .PNG
|
||||
- .GIF
|
||||
- .WEBP
|
||||
- .SVG
|
||||
- .MP4
|
||||
- .MOV
|
||||
- .MPEG
|
||||
- .WEBM
|
||||
allowed_file_types: []
|
||||
allowed_file_upload_methods:
|
||||
- remote_url
|
||||
- local_file
|
||||
enabled: false
|
||||
image:
|
||||
detail: high
|
||||
enabled: false
|
||||
number_limits: 3
|
||||
transfer_methods:
|
||||
- remote_url
|
||||
- local_file
|
||||
number_limits: 3
|
||||
model:
|
||||
completion_params:
|
||||
stop: []
|
||||
mode: chat
|
||||
name: deepseek-chat
|
||||
provider: langgenius/deepseek/deepseek
|
||||
more_like_this:
|
||||
enabled: false
|
||||
opening_statement: ''
|
||||
pre_prompt: "<instruction>\nUse MCP tools to complete tasks as much as possible.\
|
||||
\ Carefully read the annotations, method names, and parameter descriptions of\
|
||||
\ each tool. Please follow these steps:\n1. Analyze the user's question and match\
|
||||
\ the most appropriate tool.\n2. Use tool names and parameters exactly as defined;\
|
||||
\ do not invent new ones.\n3. Pass parameters in the required JSON format.\n4.\
|
||||
\ When calling tools, use:\n {\"mcp_sse_call_tool\": {\"tool_name\": \"<tool_name>\"\
|
||||
, \"arguments\": \"{}\"}}\n5. Output plain text only—no XML tags.\n<input>\nUser\
|
||||
\ question: user_query\n</input>\n<output>\nReturn tool results or a final answer,\
|
||||
\ including analysis.\n</output>\n</instruction>"
|
||||
prompt_type: simple
|
||||
retriever_resource:
|
||||
enabled: true
|
||||
sensitive_word_avoidance:
|
||||
configs: []
|
||||
enabled: false
|
||||
type: ''
|
||||
speech_to_text:
|
||||
enabled: false
|
||||
suggested_questions: []
|
||||
suggested_questions_after_answer:
|
||||
enabled: false
|
||||
text_to_speech:
|
||||
enabled: false
|
||||
language: ''
|
||||
voice: ''
|
||||
user_input_form: []
|
||||
version: 0.3.0
|
||||
BIN
examples/images/cursor_add_mcp.png
Normal file
|
After Width: | Height: | Size: 323 KiB |
BIN
examples/images/cursor_agent.png
Normal file
|
After Width: | Height: | Size: 673 KiB |
BIN
examples/images/cursor_ask1.png
Normal file
|
After Width: | Height: | Size: 118 KiB |
BIN
examples/images/cursor_ask2.png
Normal file
|
After Width: | Height: | Size: 232 KiB |
BIN
examples/images/cursor_doris-mcp.png
Normal file
|
After Width: | Height: | Size: 80 KiB |
BIN
examples/images/dify_add_tools.png
Normal file
|
After Width: | Height: | Size: 17 KiB |
BIN
examples/images/dify_agent_setup.png
Normal file
|
After Width: | Height: | Size: 258 KiB |
BIN
examples/images/dify_authorized.png
Normal file
|
After Width: | Height: | Size: 44 KiB |
BIN
examples/images/dify_config_mcp.png
Normal file
|
After Width: | Height: | Size: 66 KiB |
BIN
examples/images/dify_create_app.png
Normal file
|
After Width: | Height: | Size: 127 KiB |
BIN
examples/images/dify_install_plugin.png
Normal file
|
After Width: | Height: | Size: 317 KiB |
BIN
examples/images/dify_query_tabels.png
Normal file
|
After Width: | Height: | Size: 369 KiB |
BIN
examples/images/dify_sale_trend.png
Normal file
|
After Width: | Height: | Size: 272 KiB |
BIN
examples/images/dify_start_server.png
Normal file
|
After Width: | Height: | Size: 73 KiB |
@@ -20,7 +20,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "doris-mcp-server"
|
||||
version = "0.3.0"
|
||||
version = "0.6.1"
|
||||
description = "Enterprise-grade Model Context Protocol (MCP) server implementation for Apache Doris"
|
||||
authors = [
|
||||
{name = "Yijia Su", email = "freeoneplus@apache.org"}
|
||||
@@ -42,10 +42,14 @@ classifiers = [
|
||||
|
||||
dependencies = [
|
||||
# Core MCP dependencies
|
||||
"mcp>=1.0.0",
|
||||
"mcp>=1.8.0,<2.0.0",
|
||||
# Database drivers
|
||||
"aiomysql>=0.2.0",
|
||||
"PyMySQL>=1.1.0",
|
||||
# ADBC (Arrow Flight SQL) dependencies
|
||||
"adbc-driver-manager>=0.8.0",
|
||||
"adbc-driver-flightsql>=0.8.0",
|
||||
"pyarrow>=14.0.0",
|
||||
# Async and utility libraries
|
||||
"asyncio-mqtt>=0.16.0",
|
||||
"aiofiles>=23.0.0",
|
||||
|
||||
@@ -1,21 +1,5 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# 开发依赖 - 从 pyproject.toml 自动生成
|
||||
# 安装命令: pip install -r requirements-dev.txt
|
||||
# Development dependencies - auto-generated from pyproject.toml
|
||||
# Installation command: pip install -r requirements-dev.txt
|
||||
|
||||
pytest>=7.4.0
|
||||
pytest-asyncio>=0.23.0
|
||||
|
||||
@@ -1,26 +1,13 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# 主要依赖 - 从 pyproject.toml 自动生成
|
||||
# 请不要手动编辑此文件,使用 python generate_requirements.py 重新生成
|
||||
# Main dependencies - auto-generated from pyproject.toml
|
||||
# Do not edit this file manually, use 'python generate_requirements.py' to regenerate
|
||||
|
||||
# === 核心依赖 ===
|
||||
mcp>=1.0.0
|
||||
# === Core Dependencies ===
|
||||
mcp>=1.8.0,<2.0.0
|
||||
aiomysql>=0.2.0
|
||||
PyMySQL>=1.1.0
|
||||
adbc-driver-manager>=0.8.0
|
||||
adbc-driver-flightsql>=0.8.0
|
||||
pyarrow>=14.0.0
|
||||
asyncio-mqtt>=0.16.0
|
||||
aiofiles>=23.0.0
|
||||
aiohttp>=3.9.0
|
||||
@@ -53,8 +40,11 @@ click>=8.1.0
|
||||
typer>=0.9.0
|
||||
requests>=2.31.0
|
||||
tqdm>=4.66.0
|
||||
pytest>=8.4.0
|
||||
pytest-asyncio>=1.0.0
|
||||
pytest-cov>=6.1.1
|
||||
|
||||
# === 开发依赖 ===
|
||||
# === Development Dependencies ===
|
||||
pytest>=7.4.0
|
||||
pytest-asyncio>=0.23.0
|
||||
pytest-cov>=4.1.0
|
||||
|
||||
@@ -64,9 +64,11 @@ else
|
||||
fi
|
||||
|
||||
# Set HTTP-specific environment variables
|
||||
# FIX for Issue #62 Bug 4: Use SERVER_PORT instead of MCP_PORT for consistency with code
|
||||
export MCP_TRANSPORT_TYPE="http"
|
||||
export MCP_HOST="${MCP_HOST:-0.0.0.0}"
|
||||
export MCP_PORT="${MCP_PORT:-3000}"
|
||||
export SERVER_PORT="${SERVER_PORT:-3000}" # Changed from MCP_PORT to SERVER_PORT
|
||||
export WORKERS="${WORKERS:-1}"
|
||||
export ALLOWED_ORIGINS="${ALLOWED_ORIGINS:-*}"
|
||||
export LOG_LEVEL="${LOG_LEVEL:-info}"
|
||||
export MCP_ALLOW_CREDENTIALS="${MCP_ALLOW_CREDENTIALS:-false}"
|
||||
@@ -76,14 +78,15 @@ export MCP_DEBUG_ADAPTER="true"
|
||||
export PYTHONPATH="$(pwd):$PYTHONPATH"
|
||||
|
||||
echo -e "${GREEN}Starting MCP server (Streamable HTTP mode)...${NC}"
|
||||
echo -e "${YELLOW}Service will run on http://${MCP_HOST}:${MCP_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Health Check: http://${MCP_HOST}:${MCP_PORT}/health${NC}"
|
||||
echo -e "${YELLOW}MCP Endpoint: http://${MCP_HOST}:${MCP_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Local access: http://localhost:${MCP_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Service will run on http://${MCP_HOST}:${SERVER_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Health Check: http://${MCP_HOST}:${SERVER_PORT}/health${NC}"
|
||||
echo -e "${YELLOW}MCP Endpoint: http://${MCP_HOST}:${SERVER_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Local access: http://localhost:${SERVER_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Workers: ${WORKERS}${NC}"
|
||||
echo -e "${YELLOW}Use Ctrl+C to stop the service${NC}"
|
||||
|
||||
# Start the server in HTTP mode (Streamable HTTP)
|
||||
python -m doris_mcp_server.main --transport http --host ${MCP_HOST} --port ${MCP_PORT}
|
||||
python -m doris_mcp_server.main --transport http --host ${MCP_HOST} --port ${SERVER_PORT} --workers ${WORKERS}
|
||||
|
||||
# Check exit status
|
||||
if [ $? -ne 0 ]; then
|
||||
@@ -95,4 +98,4 @@ fi
|
||||
echo -e "${YELLOW}Tip: If the page displays abnormally, please clear your browser cache or use incognito mode${NC}"
|
||||
echo -e "${YELLOW}Chrome browser clear cache shortcut: Ctrl+Shift+Del (Windows) or Cmd+Shift+Del (Mac)${NC}"
|
||||
echo -e "${CYAN}For testing HTTP endpoints, you can use:${NC}"
|
||||
echo -e "${CYAN} curl -X POST http://localhost:${MCP_PORT}/mcp -H 'Content-Type: application/json' -d '{\"method\":\"tools/list\"}'${NC}"
|
||||
echo -e "${CYAN} curl -X POST http://localhost:${SERVER_PORT}/mcp -H 'Content-Type: application/json' -d '{\"method\":\"tools/list\"}'${NC}"
|
||||
@@ -47,22 +47,29 @@ def event_loop():
|
||||
|
||||
@pytest.fixture
|
||||
def test_config():
|
||||
"""Provide test configuration"""
|
||||
return {
|
||||
"doris_host": "localhost",
|
||||
"doris_port": 9030,
|
||||
"doris_user": "test_user",
|
||||
"doris_password": "test_password",
|
||||
"doris_database": "test_db",
|
||||
"blocked_keywords": ["DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE"],
|
||||
"sensitive_tables": {
|
||||
"user_info": "confidential",
|
||||
"payment_records": "secret",
|
||||
"employee_data": "confidential",
|
||||
"public_reports": "public"
|
||||
},
|
||||
"max_query_complexity": 100
|
||||
}
|
||||
"""Test configuration fixture"""
|
||||
from doris_mcp_server.utils.config import DorisConfig, DatabaseConfig, SecurityConfig
|
||||
|
||||
config = DorisConfig()
|
||||
|
||||
# Database configuration
|
||||
config.database.host = "localhost"
|
||||
config.database.port = 9030
|
||||
config.database.user = "test_user"
|
||||
config.database.password = "test_password"
|
||||
config.database.database = "test_db"
|
||||
config.database.health_check_interval = 60
|
||||
config.database.max_connections = 20
|
||||
config.database.connection_timeout = 30
|
||||
config.database.max_connection_age = 3600
|
||||
|
||||
# Security configuration
|
||||
config.security.enable_masking = True
|
||||
config.security.auth_type = "token"
|
||||
config.security.token_secret = "test_secret"
|
||||
config.security.token_expiry = 3600
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -34,17 +34,9 @@ class TestEndToEndIntegration:
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Create mock configuration"""
|
||||
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
|
||||
from doris_mcp_server.utils.config import ADBCConfig, DatabaseConfig, SecurityConfig
|
||||
|
||||
config = Mock(spec=DorisConfig)
|
||||
config.doris_host = "localhost"
|
||||
config.doris_port = 9030
|
||||
config.doris_user = "test_user"
|
||||
config.doris_password = "test_password"
|
||||
config.doris_database = "test_db"
|
||||
config.server_host = "localhost"
|
||||
config.server_port = 8000
|
||||
config.enable_security = True
|
||||
|
||||
# Add database config
|
||||
config.database = Mock(spec=DatabaseConfig)
|
||||
@@ -54,7 +46,6 @@ class TestEndToEndIntegration:
|
||||
config.database.password = "test_password"
|
||||
config.database.database = "test_db"
|
||||
config.database.health_check_interval = 60
|
||||
config.database.min_connections = 5
|
||||
config.database.max_connections = 20
|
||||
config.database.connection_timeout = 30
|
||||
config.database.max_connection_age = 3600
|
||||
@@ -65,7 +56,12 @@ class TestEndToEndIntegration:
|
||||
config.security.auth_type = "token"
|
||||
config.security.token_secret = "test_secret"
|
||||
config.security.token_expiry = 3600
|
||||
config.security.blocked_keywords = ["DROP"]
|
||||
|
||||
# Add adbc config
|
||||
config.adbc = Mock(spec=ADBCConfig)
|
||||
config.adbc.enabled = True
|
||||
|
||||
return config
|
||||
|
||||
@pytest.fixture
|
||||
@@ -239,7 +235,7 @@ class TestEndToEndIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_with_security(self, doris_server):
|
||||
"""Test tool execution with security checks"""
|
||||
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
with patch.object(doris_server.tools_manager.connection_manager, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [{"Database": "test_db"}]
|
||||
|
||||
# Test tool execution through tools manager
|
||||
@@ -266,7 +262,7 @@ class TestEndToEndIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_performance_monitoring_integration(self, doris_server):
|
||||
"""Test performance monitoring integration"""
|
||||
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
with patch.object(doris_server.tools_manager.connection_manager, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [
|
||||
{
|
||||
"query_count": 1500,
|
||||
@@ -277,10 +273,7 @@ class TestEndToEndIntegration:
|
||||
]
|
||||
|
||||
# Test performance stats tool
|
||||
result = await doris_server.tools_manager.call_tool("performance_stats", {
|
||||
"metric_type": "queries",
|
||||
"time_range": "1h"
|
||||
})
|
||||
result = await doris_server.tools_manager.call_tool("get_db_list", {})
|
||||
result_data = json.loads(result)
|
||||
|
||||
# Accept either success result or error (due to mock environment)
|
||||
@@ -296,4 +289,4 @@ class TestEndToEndIntegration:
|
||||
# Verify tools are available - use list_tools instead
|
||||
import asyncio
|
||||
tools = asyncio.run(doris_server.tools_manager.list_tools())
|
||||
assert len(tools) > 0
|
||||
assert len(tools) > 0
|
||||
|
||||
367
test/security/test_sql_injection.py
Normal file
@@ -0,0 +1,367 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
SQL Security Test Suite for Apache Doris MCP Server
|
||||
|
||||
Tests for:
|
||||
1. SQL injection prevention via identifier validation
|
||||
2. Multi-statement SQL parsing in security validator
|
||||
3. auth_context enforcement
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
|
||||
class TestSQLSecurityUtils:
|
||||
"""Test cases for sql_security_utils module"""
|
||||
|
||||
def test_validate_identifier_accepts_valid_names(self):
|
||||
"""Test that valid identifiers are accepted"""
|
||||
from doris_mcp_server.utils.sql_security_utils import validate_identifier
|
||||
|
||||
valid_names = [
|
||||
"users",
|
||||
"my_table",
|
||||
"Table123",
|
||||
"_private_table",
|
||||
"CamelCaseTable",
|
||||
"table_with_numbers_123",
|
||||
]
|
||||
|
||||
for name in valid_names:
|
||||
result = validate_identifier(name, "table")
|
||||
assert result == name, f"Valid identifier '{name}' should be accepted"
|
||||
|
||||
def test_validate_identifier_rejects_sql_injection(self):
|
||||
"""Test that SQL injection attempts are rejected"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
injection_attempts = [
|
||||
# Basic SQL injection
|
||||
"'; DROP TABLE users; --",
|
||||
"table' OR '1'='1",
|
||||
"table'; DELETE FROM users; --",
|
||||
|
||||
# Union-based injection
|
||||
"table' UNION SELECT * FROM passwords --",
|
||||
|
||||
# Comment injection
|
||||
"table/**/OR/**/1=1",
|
||||
"table--comment",
|
||||
|
||||
# Special characters
|
||||
"table`; DROP TABLE users;",
|
||||
'table"; DROP TABLE users;',
|
||||
"table\"; DELETE FROM",
|
||||
|
||||
# Backtick escape attempt
|
||||
"analytics`; SELECT * FROM sensitive_table;--",
|
||||
|
||||
# Whitespace injection
|
||||
"table name with spaces",
|
||||
"table\ttab",
|
||||
"table\nnewline",
|
||||
]
|
||||
|
||||
for injection in injection_attempts:
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(injection, "table")
|
||||
|
||||
def test_validate_identifier_rejects_empty(self):
|
||||
"""Test that empty identifiers are rejected"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier("", "table")
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(None, "table")
|
||||
|
||||
def test_validate_identifier_rejects_too_long(self):
|
||||
"""Test that identifiers exceeding max length are rejected"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
# Doris identifier max length is typically 64 characters
|
||||
long_name = "a" * 100
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(long_name, "table")
|
||||
|
||||
def test_quote_identifier_adds_backticks(self):
|
||||
"""Test that quote_identifier properly escapes identifiers"""
|
||||
from doris_mcp_server.utils.sql_security_utils import quote_identifier
|
||||
|
||||
assert quote_identifier("my_table", "table") == "`my_table`"
|
||||
assert quote_identifier("users", "table") == "`users`"
|
||||
assert quote_identifier("Table123", "table") == "`Table123`"
|
||||
|
||||
def test_quote_identifier_validates_first(self):
|
||||
"""Test that quote_identifier validates before quoting"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
quote_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
quote_identifier("'; DROP TABLE users; --", "table")
|
||||
|
||||
|
||||
class TestSQLSecurityValidator:
|
||||
"""Test cases for SQLSecurityValidator multi-statement parsing"""
|
||||
|
||||
@pytest.fixture
|
||||
def dict_config(self):
|
||||
"""Create dictionary configuration"""
|
||||
return {
|
||||
"blocked_keywords": [
|
||||
"DROP", "CREATE", "ALTER", "TRUNCATE",
|
||||
"DELETE", "INSERT", "UPDATE",
|
||||
"GRANT", "REVOKE", "EXEC", "EXECUTE"
|
||||
],
|
||||
"max_query_complexity": 100,
|
||||
"enable_security_check": True
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_context(self):
|
||||
"""Create mock auth context"""
|
||||
from doris_mcp_server.utils.security import AuthContext, SecurityLevel
|
||||
return AuthContext(
|
||||
user_id="test_user",
|
||||
roles=["user"],
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validates_all_statements(self, dict_config, mock_auth_context):
|
||||
"""Test that validator checks ALL SQL statements, not just the first"""
|
||||
from doris_mcp_server.utils.security import SQLSecurityValidator
|
||||
|
||||
validator = SQLSecurityValidator(dict_config)
|
||||
|
||||
# Multi-statement with injection in second statement
|
||||
# This should be BLOCKED
|
||||
malicious_sql = "SELECT 1; DROP TABLE users; SELECT 2"
|
||||
|
||||
result = await validator.validate(malicious_sql, mock_auth_context)
|
||||
|
||||
assert not result.is_valid, "Multi-statement injection should be blocked"
|
||||
# Check for either DROP keyword detection or SQL injection detection
|
||||
error_upper = result.error_message.upper()
|
||||
assert ("DROP" in error_upper or
|
||||
"INJECTION" in error_upper or
|
||||
"BLOCKED" in error_upper), f"Expected DROP/injection/blocked in: {result.error_message}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocks_hidden_dangerous_statement(self, dict_config, mock_auth_context):
|
||||
"""Test that dangerous statements hidden after safe ones are blocked"""
|
||||
from doris_mcp_server.utils.security import SQLSecurityValidator
|
||||
|
||||
validator = SQLSecurityValidator(dict_config)
|
||||
|
||||
# Safe statement followed by dangerous one
|
||||
malicious_sql = """
|
||||
SELECT * FROM users WHERE id = 1;
|
||||
DELETE FROM audit_log;
|
||||
SELECT 1;
|
||||
"""
|
||||
|
||||
result = await validator.validate(malicious_sql, mock_auth_context)
|
||||
|
||||
assert not result.is_valid, "Hidden DELETE statement should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_safe_multi_statement(self, dict_config, mock_auth_context):
|
||||
"""Test that multiple safe SELECT statements are allowed"""
|
||||
from doris_mcp_server.utils.security import SQLSecurityValidator
|
||||
|
||||
validator = SQLSecurityValidator(dict_config)
|
||||
|
||||
safe_sql = """
|
||||
SELECT * FROM users;
|
||||
SELECT COUNT(*) FROM orders;
|
||||
SELECT id, name FROM products;
|
||||
"""
|
||||
|
||||
result = await validator.validate(safe_sql, mock_auth_context)
|
||||
|
||||
assert result.is_valid, f"Multiple safe SELECT statements should be allowed, got: {result.error_message}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_switch_injection_blocked(self, dict_config, mock_auth_context):
|
||||
"""Test that context switch SQL injection is blocked"""
|
||||
from doris_mcp_server.utils.security import SQLSecurityValidator
|
||||
|
||||
validator = SQLSecurityValidator(dict_config)
|
||||
|
||||
# Simulating the exec_query_for_mcp attack vector
|
||||
injected_sql = """
|
||||
USE `analytics`; SELECT * FROM sensitive_table;-- `;
|
||||
SELECT * FROM public_table;
|
||||
"""
|
||||
|
||||
result = await validator.validate(injected_sql, mock_auth_context)
|
||||
|
||||
# The validator should process all statements
|
||||
# Even if USE is allowed, subsequent unauthorized access should be caught
|
||||
# by table access checks (if configured)
|
||||
|
||||
|
||||
class TestExecQueryForMCP:
|
||||
"""Test cases for exec_query_for_mcp function"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_malicious_db_name(self):
|
||||
"""Test that malicious db_name is rejected"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
# The attack vector from security report
|
||||
malicious_db_name = "analytics`; SELECT * FROM sensitive_table;--"
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(malicious_db_name, "database name")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_malicious_catalog_name(self):
|
||||
"""Test that malicious catalog_name is rejected"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
malicious_catalog_name = "internal'; DROP DATABASE production;--"
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(malicious_catalog_name, "catalog name")
|
||||
|
||||
|
||||
class TestDependencyAnalysisTools:
|
||||
"""Test cases for dependency_analysis_tools security fixes"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tables_metadata_rejects_injection(self):
|
||||
"""Test that _get_tables_metadata rejects SQL injection"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
# The attack vector from security report
|
||||
injection_db_name = "test_db' OR '1'='1' --"
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(injection_db_name, "database name")
|
||||
|
||||
|
||||
class TestAuthContextEnforcement:
|
||||
"""Test cases for auth_context enforcement"""
|
||||
|
||||
def test_execute_requires_auth_context_for_security(self):
|
||||
"""Test that security checks require auth_context"""
|
||||
# This test documents the expected behavior:
|
||||
# When auth_context is None, security checks are skipped
|
||||
# When auth_context is provided, security checks are performed
|
||||
|
||||
# The fix ensures all execute() calls pass auth_context
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_auth_context_returns_context(self):
|
||||
"""Test that get_auth_context retrieves context from ContextVar"""
|
||||
from doris_mcp_server.utils.sql_security_utils import get_auth_context
|
||||
|
||||
# When no context is set, should return None
|
||||
result = get_auth_context()
|
||||
# This is expected - context is set by HTTP middleware
|
||||
assert result is None or hasattr(result, 'user_id')
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Integration test scenarios for security fixes"""
|
||||
|
||||
def test_attack_scenario_1_permission_bypass(self):
|
||||
"""
|
||||
Attack Scenario 1: Permission Bypass in Multi-Tenant Environment
|
||||
|
||||
Expected: User can only query their own database (db_name="tenant_a_db")
|
||||
Attack: Inject "tenant_a_db' OR '1'='1' --" to query ALL databases
|
||||
Result: Should be BLOCKED by validate_identifier()
|
||||
"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier("tenant_a_db' OR '1'='1' --", "database name")
|
||||
|
||||
def test_attack_scenario_2_union_injection(self):
|
||||
"""
|
||||
Attack Scenario 2: UNION-based Information Disclosure
|
||||
|
||||
Attack: Inject UNION SELECT to extract sensitive data
|
||||
Result: Should be BLOCKED
|
||||
"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(
|
||||
"test' UNION SELECT password FROM users --",
|
||||
"database name"
|
||||
)
|
||||
|
||||
def test_attack_scenario_3_backtick_escape(self):
|
||||
"""
|
||||
Attack Scenario 3: Backtick Escape Attempt
|
||||
|
||||
Attack: Use backticks to break out of quoted identifier
|
||||
Result: Should be BLOCKED
|
||||
"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(
|
||||
"analytics`; SELECT * FROM sensitive_table;--",
|
||||
"database name"
|
||||
)
|
||||
|
||||
|
||||
# Run tests with: pytest tests/test_sql_security.py -v
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
|
||||
871
test/security/test_sql_injection_api.py
Normal file
@@ -0,0 +1,871 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
SQL Injection API Integration Tests
|
||||
|
||||
This module tests SQL injection prevention through the MCP HTTP API.
|
||||
It sends malicious payloads and verifies they are properly blocked.
|
||||
|
||||
Prerequisites:
|
||||
- MCP server running on localhost:3000
|
||||
- Run with: pytest test/security/test_sql_injection_api.py -v
|
||||
|
||||
Usage:
|
||||
# Start server first
|
||||
bash start_server.sh
|
||||
|
||||
# Run tests
|
||||
pytest test/security/test_sql_injection_api.py -v --no-cov
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import httpx
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# Server configuration
|
||||
MCP_BASE_URL = "http://localhost:3000"
|
||||
MCP_ENDPOINT = f"{MCP_BASE_URL}/mcp"
|
||||
HEALTH_ENDPOINT = f"{MCP_BASE_URL}/health"
|
||||
TIMEOUT = 30.0
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""Simple MCP HTTP client for testing"""
|
||||
|
||||
def __init__(self, base_url: str = MCP_BASE_URL):
|
||||
self.base_url = base_url
|
||||
self.mcp_endpoint = f"{base_url}/mcp"
|
||||
self.session_id: Optional[str] = None
|
||||
self.request_id = 0
|
||||
self.client = httpx.AsyncClient(timeout=TIMEOUT)
|
||||
|
||||
async def close(self):
|
||||
await self.client.aclose()
|
||||
|
||||
def _next_id(self) -> int:
|
||||
self.request_id += 1
|
||||
return self.request_id
|
||||
|
||||
async def initialize(self) -> dict:
|
||||
"""Initialize MCP session"""
|
||||
response = await self.client.post(
|
||||
self.mcp_endpoint,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream"
|
||||
},
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {
|
||||
"name": "sql-injection-test",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
},
|
||||
"id": self._next_id()
|
||||
}
|
||||
)
|
||||
|
||||
# Extract session ID from response header
|
||||
self.session_id = response.headers.get("mcp-session-id")
|
||||
return self._parse_response(response.text)
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
|
||||
"""Call an MCP tool"""
|
||||
if not self.session_id:
|
||||
await self.initialize()
|
||||
|
||||
response = await self.client.post(
|
||||
self.mcp_endpoint,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"mcp-session-id": self.session_id
|
||||
},
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
},
|
||||
"id": self._next_id()
|
||||
}
|
||||
)
|
||||
|
||||
return self._parse_response(response.text)
|
||||
|
||||
def _parse_response(self, text: str) -> dict:
|
||||
"""Parse JSON response"""
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
# Try SSE format
|
||||
lines = text.strip().split("\n")
|
||||
for line in lines:
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
return json.loads(line[6:])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return {"raw": text}
|
||||
|
||||
|
||||
def print_result(test_name: str, payload: dict, result: dict):
|
||||
"""Print test result in a readable format"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"TEST: {test_name}")
|
||||
print(f"{'='*60}")
|
||||
print(f"PAYLOAD: {json.dumps(payload, ensure_ascii=False)}")
|
||||
print(f"{'-'*60}")
|
||||
|
||||
# Extract inner result content
|
||||
if "result" in result and "content" in result.get("result", {}):
|
||||
for item in result["result"]["content"]:
|
||||
if item.get("type") == "text":
|
||||
try:
|
||||
inner = json.loads(item["text"])
|
||||
print("RESPONSE:")
|
||||
print(f" success: {inner.get('success')}")
|
||||
if inner.get('error'):
|
||||
print(f" error: {inner.get('error')}")
|
||||
if inner.get('error_type'):
|
||||
print(f" error_type: {inner.get('error_type')}")
|
||||
if inner.get('risk_level'):
|
||||
print(f" risk_level: {inner.get('risk_level')}")
|
||||
if inner.get('message'):
|
||||
print(f" message: {inner.get('message')}")
|
||||
if inner.get('data') is not None and inner.get('success'):
|
||||
data_str = json.dumps(inner.get('data'), ensure_ascii=False)
|
||||
if len(data_str) > 200:
|
||||
data_str = data_str[:200] + "..."
|
||||
print(f" data: {data_str}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
print(f"RESPONSE (raw): {item.get('text', '')[:500]}")
|
||||
elif "error" in result:
|
||||
print(f"RESPONSE ERROR: {result['error']}")
|
||||
else:
|
||||
print(f"RESPONSE (raw): {json.dumps(result, ensure_ascii=False)[:500]}")
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
class TestSQLInjectionAPI:
|
||||
"""Test SQL injection prevention through MCP API"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.fixture
|
||||
def is_server_running(self):
|
||||
"""Check if MCP server is running"""
|
||||
import httpx
|
||||
try:
|
||||
response = httpx.get(HEALTH_ENDPOINT, timeout=5.0)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_health(self):
|
||||
"""Test that MCP server is running and healthy"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(HEALTH_ENDPOINT)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_with_drop_injection(self, mcp_client):
|
||||
"""Test exec_query rejects DROP TABLE injection"""
|
||||
# Classic SQL injection: append DROP TABLE
|
||||
payload = {"sql": "SELECT * FROM users; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("DROP TABLE Injection", payload, result)
|
||||
|
||||
# Should return error, not execute the DROP
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"DROP TABLE injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_with_union_injection(self, mcp_client):
|
||||
"""Test exec_query blocks UNION-based injection attempts"""
|
||||
# UNION injection to extract data from other tables
|
||||
payload = {"sql": "SELECT id FROM users UNION SELECT password FROM admin_users"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("UNION Injection", payload, result)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_with_delete_injection(self, mcp_client):
|
||||
"""Test exec_query rejects DELETE injection"""
|
||||
payload = {"sql": "SELECT 1; DELETE FROM users WHERE 1=1; SELECT 2"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("DELETE Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"DELETE injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_with_update_injection(self, mcp_client):
|
||||
"""Test exec_query rejects UPDATE injection"""
|
||||
payload = {"sql": "SELECT 1; UPDATE users SET role='admin' WHERE id=1; --"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("UPDATE Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"UPDATE injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_db_name_injection(self, mcp_client):
|
||||
"""Test exec_query rejects SQL injection via db_name parameter"""
|
||||
# Attack vector: inject SQL via db_name parameter
|
||||
payload = {"sql": "SELECT 1", "db_name": "test'; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("db_name Parameter Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"db_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_catalog_name_injection(self, mcp_client):
|
||||
"""Test exec_query rejects SQL injection via catalog_name parameter"""
|
||||
# Attack vector: inject SQL via catalog_name parameter
|
||||
payload = {"sql": "SELECT 1", "catalog_name": "internal`; SELECT * FROM mysql.user; --"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("catalog_name Parameter Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"catalog_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_schema_injection(self, mcp_client):
|
||||
"""Test get_table_schema rejects SQL injection via table_name"""
|
||||
# Attack vector: inject SQL via table_name parameter
|
||||
payload = {"table_name": "users'; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||
print_result("table_name Injection (get_table_schema)", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_schema_db_injection(self, mcp_client):
|
||||
"""Test get_table_schema rejects SQL injection via db_name"""
|
||||
payload = {"table_name": "users", "db_name": "test' OR '1'='1"}
|
||||
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||
print_result("db_name Injection (get_table_schema)", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"db_name injection in get_table_schema should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_dependencies_injection(self, mcp_client):
|
||||
"""Test analyze_dependencies rejects SQL injection"""
|
||||
# This was the original vulnerability reported
|
||||
payload = {"table_name": "users", "db_name": "test_db' OR '1'='1' --"}
|
||||
result = await mcp_client.call_tool("analyze_dependencies", payload)
|
||||
print_result("analyze_dependencies Injection (Original Report)", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"analyze_dependencies db_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stacked_queries_injection(self, mcp_client):
|
||||
"""Test that stacked queries (multiple statements) are blocked"""
|
||||
# Multiple statements injection
|
||||
payload = {"sql": "SELECT * FROM users WHERE id = 1; INSERT INTO audit_log VALUES (NULL, 'hacked', NOW()); SELECT 1;"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Stacked Queries (INSERT) Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"Stacked queries with INSERT should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_comment_based_injection(self, mcp_client):
|
||||
"""Test that comment-based injection is blocked"""
|
||||
# Using comments to bypass filters
|
||||
payload = {"sql": "SELECT * FROM users WHERE id = 1/**/OR/**/1=1"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Comment-based Injection", payload, result)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hex_encoded_injection(self, mcp_client):
|
||||
"""Test that hex-encoded injection attempts are handled"""
|
||||
# Hex-encoded 'DROP' attempt
|
||||
payload = {"sql": "SELECT 0x44524F50205441424C4520757365727320"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Hex Encoded Injection", payload, result)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backtick_escape_injection(self, mcp_client):
|
||||
"""Test backtick escape injection is blocked"""
|
||||
# Attempt to escape backtick quoting
|
||||
payload = {"sql": "SELECT 1", "db_name": "analytics`; SELECT * FROM sensitive_table;--"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Backtick Escape Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"Backtick escape injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_query_succeeds(self, mcp_client):
|
||||
"""Test that valid queries still work"""
|
||||
# Simple valid query should work
|
||||
payload = {"sql": "SELECT 1 AS test_value"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Valid Query (should succeed)", payload, result)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_show_databases(self, mcp_client):
|
||||
"""Test that SHOW DATABASES works"""
|
||||
payload = {"sql": "SHOW DATABASES"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("SHOW DATABASES (should succeed)", payload, result)
|
||||
|
||||
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||
"""Check if result indicates blocked or error"""
|
||||
if not result:
|
||||
return True
|
||||
|
||||
# Check for JSON-RPC error
|
||||
if "error" in result:
|
||||
return True
|
||||
|
||||
# Check for error in result content
|
||||
if "result" in result:
|
||||
result_content = result.get("result", {})
|
||||
if isinstance(result_content, dict):
|
||||
# Check for isError flag
|
||||
if result_content.get("isError"):
|
||||
return True
|
||||
# Check content array for error messages
|
||||
content = result_content.get("content", [])
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
text = item.get("text", "")
|
||||
# Parse the JSON text content
|
||||
try:
|
||||
text_data = json.loads(text)
|
||||
# Check for success: false
|
||||
if text_data.get("success") is False:
|
||||
return True
|
||||
# Check for error field
|
||||
if text_data.get("error"):
|
||||
return True
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
# Check text for security keywords
|
||||
if any(keyword in text.lower() for keyword in [
|
||||
"error", "blocked", "invalid", "security",
|
||||
"injection", "denied", "forbidden", "not allowed",
|
||||
"security_violation", "risk_level"
|
||||
]):
|
||||
return True
|
||||
|
||||
# Check raw text response
|
||||
raw = result.get("raw", "")
|
||||
if isinstance(raw, str) and any(keyword in raw.lower() for keyword in [
|
||||
"error", "blocked", "invalid", "security"
|
||||
]):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class TestIdentifierInjectionAPI:
|
||||
"""Test identifier-based SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_name_with_semicolon(self, mcp_client):
|
||||
"""Test table name containing semicolon is rejected"""
|
||||
payload = {"table_name": "users; DROP TABLE users"}
|
||||
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||
print_result("Table Name with Semicolon", payload, result)
|
||||
|
||||
# Should be blocked by identifier validation
|
||||
assert self._contains_error_indicator(result), \
|
||||
f"Table name with semicolon should be rejected"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_name_with_quotes(self, mcp_client):
|
||||
"""Test table name containing quotes is rejected"""
|
||||
payload = {"table_name": "users' OR '1'='1"}
|
||||
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||
print_result("Table Name with Quotes", payload, result)
|
||||
|
||||
assert self._contains_error_indicator(result), \
|
||||
f"Table name with quotes should be rejected"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_db_name_with_special_chars(self, mcp_client):
|
||||
"""Test database name with special characters is rejected"""
|
||||
special_chars = [
|
||||
"test;db",
|
||||
"test'db",
|
||||
"test\"db",
|
||||
"test`db",
|
||||
"test--db",
|
||||
"test/*db*/",
|
||||
]
|
||||
|
||||
for db_name in special_chars:
|
||||
payload = {"table_name": "users", "db_name": db_name}
|
||||
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||
print_result(f"Special Char in db_name: {db_name}", payload, result)
|
||||
|
||||
assert self._contains_error_indicator(result), \
|
||||
f"db_name '{db_name}' should be rejected"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_identifiers_accepted(self, mcp_client):
|
||||
"""Test that valid identifiers are accepted"""
|
||||
valid_names = [
|
||||
"users",
|
||||
"my_table",
|
||||
"Table123",
|
||||
"_internal_table",
|
||||
]
|
||||
|
||||
for table_name in valid_names:
|
||||
payload = {"table_name": table_name}
|
||||
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||
print_result(f"Valid Identifier: {table_name}", payload, result)
|
||||
|
||||
def _contains_error_indicator(self, result: dict) -> bool:
|
||||
"""Check if result contains error indicators"""
|
||||
if not result:
|
||||
return True
|
||||
|
||||
# Check for JSON-RPC error
|
||||
if "error" in result:
|
||||
return True
|
||||
|
||||
# Check result content
|
||||
result_str = json.dumps(result).lower()
|
||||
error_keywords = [
|
||||
"error", "invalid", "illegal", "blocked",
|
||||
"security", "injection", "denied", "forbidden"
|
||||
]
|
||||
|
||||
return any(keyword in result_str for keyword in error_keywords)
|
||||
|
||||
|
||||
class TestMultiStatementInjectionAPI:
|
||||
"""Test multi-statement SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hidden_drop_after_select(self, mcp_client):
|
||||
"""Test DROP hidden after legitimate SELECT is blocked"""
|
||||
payload = {"sql": "SELECT id, name FROM users WHERE status = 'active'; DROP TABLE audit_log; SELECT 1;"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Hidden DROP after SELECT", payload, result)
|
||||
|
||||
assert self._is_dangerous_blocked(result), \
|
||||
f"Hidden DROP statement should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hidden_truncate_after_select(self, mcp_client):
|
||||
"""Test TRUNCATE hidden after SELECT is blocked"""
|
||||
payload = {"sql": "SELECT 1; TRUNCATE TABLE users"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Hidden TRUNCATE after SELECT", payload, result)
|
||||
|
||||
assert self._is_dangerous_blocked(result), \
|
||||
f"Hidden TRUNCATE should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hidden_grant_after_select(self, mcp_client):
|
||||
"""Test GRANT hidden after SELECT is blocked"""
|
||||
payload = {"sql": "SELECT 1; GRANT ALL ON *.* TO 'hacker'@'%'"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Hidden GRANT after SELECT", payload, result)
|
||||
|
||||
assert self._is_dangerous_blocked(result), \
|
||||
f"Hidden GRANT should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_safe_selects_allowed(self, mcp_client):
|
||||
"""Test that multiple SELECT statements may be allowed"""
|
||||
payload = {"sql": "SELECT 1; SELECT 2; SELECT 3;"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Multiple Safe SELECTs", payload, result)
|
||||
|
||||
def _is_dangerous_blocked(self, result: dict) -> bool:
|
||||
"""Check if dangerous operation was blocked"""
|
||||
if not result:
|
||||
return True
|
||||
|
||||
# Check for error
|
||||
if "error" in result:
|
||||
return True
|
||||
|
||||
# Check result content for blocking indicators
|
||||
result_str = json.dumps(result).lower()
|
||||
block_indicators = [
|
||||
"drop", "truncate", "grant", "revoke",
|
||||
"blocked", "denied", "forbidden", "not allowed",
|
||||
"security", "error"
|
||||
]
|
||||
|
||||
return any(indicator in result_str for indicator in block_indicators)
|
||||
|
||||
|
||||
class TestADBCQueryInjectionAPI:
|
||||
"""Test ADBC query SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_adbc_query_drop_injection(self, mcp_client):
|
||||
"""Test exec_adbc_query rejects DROP TABLE injection"""
|
||||
payload = {"sql": "SELECT * FROM users; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("exec_adbc_query", payload)
|
||||
print_result("ADBC DROP TABLE Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"ADBC DROP TABLE injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_adbc_query_delete_injection(self, mcp_client):
|
||||
"""Test exec_adbc_query rejects DELETE injection"""
|
||||
payload = {"sql": "SELECT 1; DELETE FROM users; --"}
|
||||
result = await mcp_client.call_tool("exec_adbc_query", payload)
|
||||
print_result("ADBC DELETE Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"ADBC DELETE injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_adbc_query_valid(self, mcp_client):
|
||||
"""Test exec_adbc_query allows valid queries"""
|
||||
payload = {"sql": "SELECT 1 AS test"}
|
||||
result = await mcp_client.call_tool("exec_adbc_query", payload)
|
||||
print_result("ADBC Valid Query", payload, result)
|
||||
|
||||
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||
"""Check if result indicates blocked or error"""
|
||||
if not result:
|
||||
return True
|
||||
if "error" in result:
|
||||
return True
|
||||
result_str = json.dumps(result).lower()
|
||||
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||
|
||||
|
||||
class TestMetadataToolsInjectionAPI:
|
||||
"""Test metadata tools SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_table_list_db_injection(self, mcp_client):
|
||||
"""Test get_db_table_list rejects db_name injection"""
|
||||
payload = {"db_name": "test'; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("get_db_table_list", payload)
|
||||
print_result("get_db_table_list db_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"db_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_table_list_catalog_injection(self, mcp_client):
|
||||
"""Test get_db_table_list rejects catalog_name injection"""
|
||||
payload = {"catalog_name": "internal`; SELECT * FROM mysql.user; --"}
|
||||
result = await mcp_client.call_tool("get_db_table_list", payload)
|
||||
print_result("get_db_table_list catalog_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"catalog_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_comment_injection(self, mcp_client):
|
||||
"""Test get_table_comment rejects table_name injection"""
|
||||
payload = {"table_name": "users'; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("get_table_comment", payload)
|
||||
print_result("get_table_comment table_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_column_comments_injection(self, mcp_client):
|
||||
"""Test get_table_column_comments rejects injection"""
|
||||
payload = {"table_name": "users'; DROP TABLE users; --", "db_name": "test"}
|
||||
result = await mcp_client.call_tool("get_table_column_comments", payload)
|
||||
print_result("get_table_column_comments Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_indexes_injection(self, mcp_client):
|
||||
"""Test get_table_indexes rejects table_name injection"""
|
||||
payload = {"table_name": "users; DROP TABLE users", "db_name": "test"}
|
||||
result = await mcp_client.call_tool("get_table_indexes", payload)
|
||||
print_result("get_table_indexes Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||
"""Check if result indicates blocked or error"""
|
||||
if not result:
|
||||
return True
|
||||
if "error" in result:
|
||||
return True
|
||||
result_str = json.dumps(result).lower()
|
||||
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||
|
||||
|
||||
class TestAnalyticsToolsInjectionAPI:
|
||||
"""Test analytics tools SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_columns_table_injection(self, mcp_client):
|
||||
"""Test analyze_columns rejects table_name injection"""
|
||||
payload = {"table_name": "users'; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("analyze_columns", payload)
|
||||
print_result("analyze_columns table_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_columns_db_injection(self, mcp_client):
|
||||
"""Test analyze_columns rejects db_name injection"""
|
||||
payload = {"table_name": "users", "db_name": "test' OR '1'='1"}
|
||||
result = await mcp_client.call_tool("analyze_columns", payload)
|
||||
print_result("analyze_columns db_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"db_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_basic_info_injection(self, mcp_client):
|
||||
"""Test get_table_basic_info rejects injection"""
|
||||
payload = {"table_name": "users; DROP TABLE audit_log"}
|
||||
result = await mcp_client.call_tool("get_table_basic_info", payload)
|
||||
print_result("get_table_basic_info Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_table_storage_injection(self, mcp_client):
|
||||
"""Test analyze_table_storage rejects injection"""
|
||||
payload = {"table_name": "users`; SELECT * FROM sensitive; --"}
|
||||
result = await mcp_client.call_tool("analyze_table_storage", payload)
|
||||
print_result("analyze_table_storage Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sql_explain_injection(self, mcp_client):
|
||||
"""Test get_sql_explain rejects SQL injection"""
|
||||
payload = {"sql": "SELECT 1; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("get_sql_explain", payload)
|
||||
print_result("get_sql_explain SQL Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"SQL injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sql_profile_injection(self, mcp_client):
|
||||
"""Test get_sql_profile rejects SQL injection"""
|
||||
payload = {"sql": "SELECT 1; DELETE FROM audit_log; --"}
|
||||
result = await mcp_client.call_tool("get_sql_profile", payload)
|
||||
print_result("get_sql_profile SQL Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"SQL injection should be blocked"
|
||||
|
||||
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||
"""Check if result indicates blocked or error"""
|
||||
if not result:
|
||||
return True
|
||||
if "error" in result:
|
||||
return True
|
||||
result_str = json.dumps(result).lower()
|
||||
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||
|
||||
|
||||
class TestGovernanceToolsInjectionAPI:
|
||||
"""Test data governance tools SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trace_column_lineage_table_injection(self, mcp_client):
|
||||
"""Test trace_column_lineage rejects table_name injection"""
|
||||
payload = {"table_name": "users'; DROP TABLE users; --", "column_name": "id"}
|
||||
result = await mcp_client.call_tool("trace_column_lineage", payload)
|
||||
print_result("trace_column_lineage table_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trace_column_lineage_column_injection(self, mcp_client):
|
||||
"""Test trace_column_lineage rejects column_name injection"""
|
||||
payload = {"table_name": "users", "column_name": "id; DROP TABLE users"}
|
||||
result = await mcp_client.call_tool("trace_column_lineage", payload)
|
||||
print_result("trace_column_lineage column_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"column_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_monitor_data_freshness_injection(self, mcp_client):
|
||||
"""Test monitor_data_freshness rejects table_name injection"""
|
||||
payload = {"table_name": "users`; SELECT * FROM passwords; --"}
|
||||
result = await mcp_client.call_tool("monitor_data_freshness", payload)
|
||||
print_result("monitor_data_freshness Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_data_access_patterns_injection(self, mcp_client):
|
||||
"""Test analyze_data_access_patterns rejects injection"""
|
||||
payload = {"table_name": "users' UNION SELECT password FROM admin --"}
|
||||
result = await mcp_client.call_tool("analyze_data_access_patterns", payload)
|
||||
print_result("analyze_data_access_patterns Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||
"""Check if result indicates blocked or error"""
|
||||
if not result:
|
||||
return True
|
||||
if "error" in result:
|
||||
return True
|
||||
result_str = json.dumps(result).lower()
|
||||
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||
|
||||
|
||||
class TestPerformanceToolsInjectionAPI:
|
||||
"""Test performance analytics tools SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_slow_queries_db_injection(self, mcp_client):
|
||||
"""Test analyze_slow_queries_topn rejects db_name injection"""
|
||||
payload = {"db_name": "test'; DROP TABLE audit_log; --"}
|
||||
result = await mcp_client.call_tool("analyze_slow_queries_topn", payload)
|
||||
print_result("analyze_slow_queries_topn db_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"db_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_resource_growth_db_injection(self, mcp_client):
|
||||
"""Test analyze_resource_growth_curves rejects db_name injection"""
|
||||
payload = {"db_name": "test`; GRANT ALL ON *.* TO 'hacker'; --"}
|
||||
result = await mcp_client.call_tool("analyze_resource_growth_curves", payload)
|
||||
print_result("analyze_resource_growth_curves db_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"db_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_data_size_injection(self, mcp_client):
|
||||
"""Test get_table_data_size rejects table_name injection"""
|
||||
payload = {"table_name": "users; TRUNCATE TABLE logs"}
|
||||
result = await mcp_client.call_tool("get_table_data_size", payload)
|
||||
print_result("get_table_data_size Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||
"""Check if result indicates blocked or error"""
|
||||
if not result:
|
||||
return True
|
||||
if "error" in result:
|
||||
return True
|
||||
result_str = json.dumps(result).lower()
|
||||
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||
|
||||
|
||||
# Pytest configuration for async tests
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create event loop for async tests"""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short", "-x"])
|
||||
|
||||
@@ -44,17 +44,31 @@
|
||||
}
|
||||
},
|
||||
"expected_tools": [
|
||||
"analyze_columns",
|
||||
"analyze_data_access_patterns",
|
||||
"analyze_data_flow_dependencies",
|
||||
"analyze_resource_growth_curves",
|
||||
"analyze_slow_queries_topn",
|
||||
"analyze_table_storage",
|
||||
"exec_adbc_query",
|
||||
"exec_query",
|
||||
"get_db_list",
|
||||
"get_adbc_connection_info",
|
||||
"get_catalog_list",
|
||||
"get_db_list",
|
||||
"get_db_table_list",
|
||||
"get_table_schema",
|
||||
"get_table_comment",
|
||||
"get_table_column_comments",
|
||||
"get_table_indexes",
|
||||
"column_analysis",
|
||||
"performance_stats",
|
||||
"get_memory_stats",
|
||||
"get_monitoring_metrics",
|
||||
"get_recent_audit_logs",
|
||||
"get_catalog_list"
|
||||
"get_sql_explain",
|
||||
"get_sql_profile",
|
||||
"get_table_basic_info",
|
||||
"get_table_column_comments",
|
||||
"get_table_comment",
|
||||
"get_table_data_size",
|
||||
"get_table_indexes",
|
||||
"get_table_schema",
|
||||
"monitor_data_freshness",
|
||||
"trace_column_lineage"
|
||||
],
|
||||
"expected_resources": [
|
||||
"database",
|
||||
@@ -66,4 +80,4 @@
|
||||
"data_analysis_helper",
|
||||
"schema_explorer"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,8 +185,9 @@ async def test_server_connectivity(transport: Optional[str] = None) -> bool:
|
||||
logger.error(f"Connectivity test failed: {e}")
|
||||
return False
|
||||
|
||||
result = await client.connect_and_run(test_connection)
|
||||
return result
|
||||
await client.connect_and_run(test_connection)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test server connectivity: {e}")
|
||||
return False
|
||||
@@ -211,4 +212,4 @@ if __name__ == "__main__":
|
||||
stdio_ok = await test_server_connectivity("stdio")
|
||||
print(f" Stdio connectivity: {'✓' if stdio_ok else '✗'}")
|
||||
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -72,8 +72,7 @@ class TestToolsClientServer:
|
||||
|
||||
return tools
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert len(result) > 0
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_exec_query_via_client(self, client, test_config):
|
||||
@@ -91,14 +90,13 @@ class TestToolsClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
|
||||
if result["success"]:
|
||||
assert "result" in result, "Successful result should contain 'result' field"
|
||||
assert "data" in result, "Successful result should contain 'data' field"
|
||||
else:
|
||||
assert "error" in result, "Failed result should contain 'error' field"
|
||||
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
# Don't assert success=True as it depends on actual server state
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_get_db_list_via_client(self, client, test_config):
|
||||
@@ -115,8 +113,7 @@ class TestToolsClientServer:
|
||||
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_get_table_schema_via_client(self, client, test_config):
|
||||
@@ -133,27 +130,7 @@ class TestToolsClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_performance_stats_via_client(self, client, test_config):
|
||||
"""Test calling performance_stats tool through client"""
|
||||
if not test_config.is_performance_tests_enabled():
|
||||
pytest.skip("Performance tests are disabled")
|
||||
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.call_tool("performance_stats", {
|
||||
"metric_type": "queries",
|
||||
"time_range": "1h"
|
||||
})
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_error_handling_via_client(self, client, test_config):
|
||||
@@ -168,8 +145,7 @@ class TestToolsClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_with_auth_token_via_client(self, client, test_config):
|
||||
@@ -188,5 +164,4 @@ class TestToolsClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@@ -36,11 +36,6 @@ class TestDorisToolsManager:
|
||||
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
|
||||
|
||||
config = Mock(spec=DorisConfig)
|
||||
config.doris_host = "localhost"
|
||||
config.doris_port = 9030
|
||||
config.doris_user = "test_user"
|
||||
config.doris_password = "test_password"
|
||||
config.doris_database = "test_db"
|
||||
|
||||
# Add database config
|
||||
config.database = Mock(spec=DatabaseConfig)
|
||||
@@ -50,7 +45,6 @@ class TestDorisToolsManager:
|
||||
config.database.password = "test_password"
|
||||
config.database.database = "test_db"
|
||||
config.database.health_check_interval = 60
|
||||
config.database.min_connections = 5
|
||||
config.database.max_connections = 20
|
||||
config.database.connection_timeout = 30
|
||||
config.database.max_connection_age = 3600
|
||||
@@ -235,62 +229,7 @@ class TestDorisToolsManager:
|
||||
elif "result" in result_data:
|
||||
assert len(result_data["result"]) >= 0 # May be empty if no catalogs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_column_analysis_tool(self, tools_manager):
|
||||
"""Test column_analysis tool"""
|
||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
# Mock basic analysis result
|
||||
mock_execute.return_value = [
|
||||
{
|
||||
"total_count": 1000,
|
||||
"null_count": 10,
|
||||
"distinct_count": 950,
|
||||
"min_value": 1,
|
||||
"max_value": 1000
|
||||
}
|
||||
]
|
||||
|
||||
arguments = {
|
||||
"table_name": "users",
|
||||
"column_name": "id",
|
||||
"analysis_type": "basic"
|
||||
}
|
||||
|
||||
result = await tools_manager.call_tool("column_analysis", arguments)
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
# Check if result has analysis field or result field
|
||||
if "analysis" in result_data:
|
||||
assert result_data["analysis"]["total_count"] == 1000
|
||||
elif "result" in result_data:
|
||||
assert "result" in result_data # Just check result exists
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_performance_stats_tool(self, tools_manager):
|
||||
"""Test performance_stats tool"""
|
||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [
|
||||
{
|
||||
"query_count": 1500,
|
||||
"avg_execution_time": 0.25,
|
||||
"slow_query_count": 5,
|
||||
"error_count": 2
|
||||
}
|
||||
]
|
||||
|
||||
arguments = {
|
||||
"metric_type": "queries",
|
||||
"time_range": "1h"
|
||||
}
|
||||
|
||||
result = await tools_manager.call_tool("performance_stats", arguments)
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
# Check if result has stats field or result field
|
||||
if "stats" in result_data:
|
||||
assert result_data["stats"]["query_count"] == 1500
|
||||
elif "result" in result_data:
|
||||
assert "result" in result_data # Just check result exists
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_tool_name(self, tools_manager):
|
||||
@@ -328,4 +267,4 @@ class TestDorisToolsManager:
|
||||
|
||||
# Required fields should be defined
|
||||
if 'required' in tool.inputSchema:
|
||||
assert isinstance(tool.inputSchema['required'], list)
|
||||
assert isinstance(tool.inputSchema['required'], list)
|
||||
|
||||
78
test/utils/test_db.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from doris_mcp_server.utils.db import DorisConnection, DorisSessionCache
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_cache():
|
||||
"""Provides a DorisSessionCache instance with a mock connection manager."""
|
||||
connection_manager = MagicMock()
|
||||
cache = DorisSessionCache(connection_manager=connection_manager)
|
||||
yield cache, connection_manager
|
||||
|
||||
|
||||
class TestDorisSessionCache:
|
||||
|
||||
def test_initialization(self, session_cache):
|
||||
cache, _ = session_cache
|
||||
assert cache.cache_system_session is True
|
||||
assert cache.cache_user_session is False
|
||||
assert not cache.cached
|
||||
|
||||
def test_should_cache(self, session_cache):
|
||||
cache, _ = session_cache
|
||||
assert cache._should_cache("query") is True
|
||||
assert cache._should_cache("system") is True
|
||||
assert cache._should_cache("user-test-session-id") is False
|
||||
|
||||
cache.cache_user_session = True
|
||||
assert cache._should_cache("user-test-session-id") is True
|
||||
|
||||
def test_save_and_get_session(self, session_cache):
|
||||
cache, _ = session_cache
|
||||
mock_connection = MagicMock(spec=DorisConnection)
|
||||
mock_connection.session_id = "query"
|
||||
|
||||
cache.save(mock_connection)
|
||||
retrieved_conn = cache.get("query")
|
||||
assert retrieved_conn is mock_connection
|
||||
|
||||
mock_user_connection = MagicMock(spec=DorisConnection)
|
||||
mock_user_connection.session_id = "user-test-session-id"
|
||||
cache.save(mock_user_connection)
|
||||
assert cache.get("user-test-session-id") is None
|
||||
|
||||
cache.cache_user_session = True
|
||||
cache.save(mock_user_connection)
|
||||
retrieved_user_conn = cache.get("user-test-session-id")
|
||||
assert retrieved_user_conn is mock_user_connection
|
||||
|
||||
def test_remove_session(self, session_cache):
|
||||
cache, _ = session_cache
|
||||
mock_connection = MagicMock(spec=DorisConnection)
|
||||
mock_connection.session_id = "system"
|
||||
|
||||
cache.save(mock_connection)
|
||||
assert cache.get("system") is not None
|
||||
|
||||
cache.remove("system")
|
||||
assert cache.get("system") is None
|
||||
|
||||
def test_clear_cache(self, session_cache):
|
||||
cache, connection_manager = session_cache
|
||||
mock_conn1 = MagicMock(spec=DorisConnection)
|
||||
mock_conn1.session_id = "query"
|
||||
mock_conn2 = MagicMock(spec=DorisConnection)
|
||||
mock_conn2.session_id = "system"
|
||||
|
||||
cache.save(mock_conn1)
|
||||
cache.save(mock_conn2)
|
||||
assert len(cache.cached) == 2
|
||||
|
||||
cache.clear()
|
||||
|
||||
assert not cache.cached
|
||||
connection_manager.release_connection.assert_any_call("query", mock_conn1)
|
||||
connection_manager.release_connection.assert_any_call("system", mock_conn2)
|
||||
assert connection_manager.release_connection.call_count == 2
|
||||
@@ -35,11 +35,6 @@ class TestDorisQueryExecutor:
|
||||
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
|
||||
|
||||
config = Mock(spec=DorisConfig)
|
||||
config.doris_host = "localhost"
|
||||
config.doris_port = 9030
|
||||
config.doris_user = "test_user"
|
||||
config.doris_password = "test_password"
|
||||
config.doris_database = "test_db"
|
||||
|
||||
# Add database config
|
||||
config.database = Mock(spec=DatabaseConfig)
|
||||
@@ -49,11 +44,17 @@ class TestDorisQueryExecutor:
|
||||
config.database.password = "test_password"
|
||||
config.database.database = "test_db"
|
||||
config.database.health_check_interval = 60
|
||||
config.database.min_connections = 5
|
||||
config.database.max_connections = 20
|
||||
config.database.connection_timeout = 30
|
||||
config.database.max_connection_age = 3600
|
||||
|
||||
# Add security config
|
||||
config.security = Mock(spec=SecurityConfig)
|
||||
config.security.enable_masking = True
|
||||
config.security.auth_type = "token"
|
||||
config.security.token_secret = "test_secret"
|
||||
config.security.token_expiry = 3600
|
||||
|
||||
return config
|
||||
|
||||
@pytest.fixture
|
||||
@@ -199,4 +200,74 @@ class TestDorisQueryExecutor:
|
||||
assert "success" in result
|
||||
if result["success"]:
|
||||
assert "data" in result
|
||||
assert "row_count" in result
|
||||
assert "row_count" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_multi_sql_statements(self, query_executor):
|
||||
"""Test execution of multiple SQL statements"""
|
||||
from doris_mcp_server.utils.query_executor import QueryResult
|
||||
|
||||
# Disable security check for this test
|
||||
query_executor.connection_manager.config.security.enable_security_check = False
|
||||
|
||||
with patch.object(query_executor, 'execute_query') as mock_execute:
|
||||
# Mock results for three SQL statements
|
||||
mock_execute.side_effect = [
|
||||
QueryResult(
|
||||
data=[{"id": 1, "name": "张三"}],
|
||||
row_count=1,
|
||||
execution_time=0.1,
|
||||
sql="SELECT id, name FROM users WHERE id = 1",
|
||||
metadata={"columns": ["id", "name"]}
|
||||
),
|
||||
QueryResult(
|
||||
data=[{"id": 2, "name": "李四"}],
|
||||
row_count=1,
|
||||
execution_time=0.12,
|
||||
sql="SELECT id, name FROM users WHERE id = 2",
|
||||
metadata={"columns": ["id", "name"]}
|
||||
),
|
||||
QueryResult(
|
||||
data=[{"count": 100}],
|
||||
row_count=1,
|
||||
execution_time=0.08,
|
||||
sql="SELECT COUNT(*) as count FROM users",
|
||||
metadata={"columns": ["count"]}
|
||||
)
|
||||
]
|
||||
|
||||
# Execute multiple SQL statements separated by semicolons
|
||||
multi_sql = """
|
||||
SELECT id, name FROM users WHERE id = 1;
|
||||
SELECT id, name FROM users WHERE id = 2;
|
||||
SELECT COUNT(*) as count FROM users;
|
||||
"""
|
||||
|
||||
result = await query_executor.execute_sql_for_mcp(multi_sql)
|
||||
|
||||
# Verify the result structure for multiple statements
|
||||
assert result["success"] is True
|
||||
assert result["multiple_results"] is True
|
||||
assert "results" in result
|
||||
assert len(result["results"]) == 3
|
||||
|
||||
# Verify first query result
|
||||
assert result["results"][0]["data"] == [{"id": 1, "name": "张三"}]
|
||||
assert result["results"][0]["row_count"] == 1
|
||||
assert result["results"][0]["metadata"]["columns"] == ["id", "name"]
|
||||
assert result["results"][0]["metadata"]["query"] == "SELECT id, name FROM users WHERE id = 1"
|
||||
|
||||
# Verify second query result
|
||||
assert result["results"][1]["data"] == [{"id": 2, "name": "李四"}]
|
||||
assert result["results"][1]["row_count"] == 1
|
||||
assert result["results"][1]["metadata"]["columns"] == ["id", "name"]
|
||||
assert result["results"][1]["metadata"]["query"] == "SELECT id, name FROM users WHERE id = 2"
|
||||
|
||||
# Verify third query result
|
||||
assert result["results"][2]["data"] == [{"count": 100}]
|
||||
assert result["results"][2]["row_count"] == 1
|
||||
assert result["results"][2]["metadata"]["columns"] == ["count"]
|
||||
assert result["results"][2]["metadata"]["query"] == "SELECT COUNT(*) as count FROM users"
|
||||
|
||||
# Verify execute_query was called three times
|
||||
assert mock_execute.call_count == 3
|
||||
|
||||
@@ -21,8 +21,6 @@ Tests the query execution functionality through actual MCP client-server communi
|
||||
Assumes the server is already running and configured properly
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import pytest
|
||||
import os
|
||||
import sys
|
||||
@@ -66,14 +64,13 @@ class TestQueryExecutorClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
|
||||
if result["success"]:
|
||||
assert "result" in result, "Successful result should contain 'result' field"
|
||||
assert "data" in result, "Successful result should contain 'data' field"
|
||||
else:
|
||||
assert "error" in result, "Failed result should contain 'error' field"
|
||||
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_show_databases_query_via_client(self, client, test_config):
|
||||
@@ -87,8 +84,7 @@ class TestQueryExecutorClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_information_schema_query_via_client(self, client, test_config):
|
||||
@@ -102,8 +98,7 @@ class TestQueryExecutorClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_max_rows_parameter_via_client(self, client, test_config):
|
||||
@@ -118,8 +113,7 @@ class TestQueryExecutorClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_error_handling_via_client(self, client, test_config):
|
||||
@@ -131,8 +125,7 @@ class TestQueryExecutorClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_auth_token_via_client(self, client, test_config):
|
||||
@@ -152,5 +145,4 @@ class TestQueryExecutorClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
64
tokens.json
Normal file
@@ -0,0 +1,64 @@
|
||||
{
|
||||
"version": "1.0",
|
||||
"description": "Doris MCP Server Token configuration file",
|
||||
"created_at": "2025-09-01T00:00:00Z",
|
||||
"tokens": [
|
||||
{
|
||||
"token_id": "admin-token",
|
||||
"token": "doris_admin_token_123456",
|
||||
"description": "Doris admin API access token",
|
||||
"expires_hours": null,
|
||||
"is_active": true,
|
||||
"database_config": {
|
||||
"host": "127.0.0.1",
|
||||
"port": 9030,
|
||||
"user": "root",
|
||||
"password": "",
|
||||
"database": "information_schema",
|
||||
"charset": "UTF8",
|
||||
"fe_http_port": 8030
|
||||
}
|
||||
},
|
||||
{
|
||||
"token_id": "analyst-token",
|
||||
"token": "doris_analyst_token_123456",
|
||||
"description": "Doris analyst API access token",
|
||||
"expires_hours": 8760,
|
||||
"is_active": true,
|
||||
"database_config": {
|
||||
"host": "127.0.0.1",
|
||||
"port": 9030,
|
||||
"user": "root",
|
||||
"password": "",
|
||||
"database": "information_schema",
|
||||
"charset": "UTF8",
|
||||
"fe_http_port": 8030
|
||||
}
|
||||
},
|
||||
{
|
||||
"token_id": "readonly-token",
|
||||
"token": "doris_readonly_token_123456",
|
||||
"description": "Doris readonly API access token",
|
||||
"expires_hours": 4320,
|
||||
"is_active": true,
|
||||
"database_config": {
|
||||
"host": "127.0.0.1",
|
||||
"port": 9030,
|
||||
"user": "root",
|
||||
"password": "",
|
||||
"database": "information_schema",
|
||||
"charset": "UTF8",
|
||||
"fe_http_port": 8030
|
||||
}
|
||||
}
|
||||
],
|
||||
"notes": [
|
||||
"The admin_token, analyst_token, readonly_token is default token,Please change the token before using in production!",
|
||||
"The token_id is the key of the token,Please use the token_id to identify the token",
|
||||
"The token is the value of the token,Please use the token to identify the token",
|
||||
"The description is the description of the token,Please use the description to identify the token",
|
||||
"The expires_hours is the expires hours of the token,Please use the expires_hours to identify the token",
|
||||
"The is_active is the is active of the token,Please use the is_active to identify the token",
|
||||
"The token_id, token, description, expires_hours, is_active is the metadata of the token,Please use the metadata to identify the token"
|
||||
]
|
||||
}
|
||||
112
uv.lock
generated
@@ -1,19 +1,3 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
version = 1
|
||||
revision = 1
|
||||
requires-python = ">=3.12"
|
||||
@@ -22,6 +6,48 @@ resolution-markers = [
|
||||
"python_full_version < '3.13'",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "adbc-driver-flightsql"
|
||||
version = "1.7.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "adbc-driver-manager" },
|
||||
{ name = "importlib-resources" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b8/d4/ebd3eed981c771565677084474cdf465141455b5deb1ca409c616609bfd7/adbc_driver_flightsql-1.7.0.tar.gz", hash = "sha256:5dca460a2c66e45b29208eaf41a7206f252177435fa48b16f19833b12586f7a0", size = 21247 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/36/20/807fca9d904b7e0d3020439828d6410db7fd7fd635824a80cab113d9fad1/adbc_driver_flightsql-1.7.0-py3-none-macosx_10_15_x86_64.whl", hash = "sha256:a5658f9bc3676bd122b26138e9b9ce56b8bf37387efe157b4c66d56f942361c6", size = 7749664 },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/e6/9e50f6497819c911b9cc1962ffde610b60f7d8e951d6bb3fa145dcfb50a7/adbc_driver_flightsql-1.7.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:65e21df86b454d8db422c8ee22db31be217d88c42d9d6dd89119f06813037c91", size = 7302476 },
|
||||
{ url = "https://files.pythonhosted.org/packages/27/82/e51af85e7cc8c87bc8ce4fae8ca7ee1d3cf39c926be0aeab789cedc93f0a/adbc_driver_flightsql-1.7.0-py3-none-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3282fdc7b73c712780cc777975288c88b1e3a555355bbe09df101aa954f8f105", size = 7686056 },
|
||||
{ url = "https://files.pythonhosted.org/packages/8b/c9/591c8ecbaf010ba3f4b360db602050ee5880cd077a573c9e90fcb270ab71/adbc_driver_flightsql-1.7.0-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e0c5737ae6ee3bbfba44dcbc28ba1ff8cf3ab6521888c4b0f10dd6a482482161", size = 7050275 },
|
||||
{ url = "https://files.pythonhosted.org/packages/10/14/f339e9a5d8dbb3e3040215514cea9cca0a58640964aaccc6532f18003a03/adbc_driver_flightsql-1.7.0-py3-none-win_amd64.whl", hash = "sha256:f8b5290b322304b7d944ca823754e6354c1868dbbe94ddf84236f3e0329545da", size = 14312858 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "adbc-driver-manager"
|
||||
version = "1.7.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/bb/bf/2986a2cd3e1af658d2597f7e2308564e5c11e036f9736d5c256f1e00d578/adbc_driver_manager-1.7.0.tar.gz", hash = "sha256:e3edc5d77634b5925adf6eb4fbcd01676b54acb2f5b1d6864b6a97c6a899591a", size = 198128 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/74/3a/72bd9c45d55f1f5f4c549e206de8cfe3313b31f7b95fbcb180da05c81044/adbc_driver_manager-1.7.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:8da1ac4c19bcbf30b3bd54247ec889dfacc9b44147c70b4da79efe2e9ba93600", size = 524210 },
|
||||
{ url = "https://files.pythonhosted.org/packages/33/29/e1a8d8dde713a287f8021f3207127f133ddce578711a4575218bdf78ef27/adbc_driver_manager-1.7.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:408bc23bad1a6823b364e2388f85f96545e82c3b2db97d7828a4b94839d3f29e", size = 505902 },
|
||||
{ url = "https://files.pythonhosted.org/packages/59/00/773ece64a58c0ade797ab4577e7cdc4c71ebf800b86d2d5637e3bfe605e9/adbc_driver_manager-1.7.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf38294320c23e47ed3455348e910031ad8289c3f9167ae35519ac957b7add01", size = 2974883 },
|
||||
{ url = "https://files.pythonhosted.org/packages/7c/ad/1568da6ae9ab70983f1438503d3906c6b1355601230e891d16e272376a04/adbc_driver_manager-1.7.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:689f91b62c18a9f86f892f112786fb157cacc4729b4d81666db4ca778eade2a8", size = 2997781 },
|
||||
{ url = "https://files.pythonhosted.org/packages/19/66/2b6ea5afded25a3fa009873c2bbebcd9283910877cc10b9453d680c00b9a/adbc_driver_manager-1.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:f936cfc8d098898a47ef60396bd7a73926ec3068f2d6d92a2be4e56e4aaf3770", size = 690041 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/3b/91154c83a98f103a3d97c9e2cb838c3842aef84ca4f4b219164b182d9516/adbc_driver_manager-1.7.0-cp313-cp313-macosx_10_15_x86_64.whl", hash = "sha256:ab9ee36683fd54f61b0db0f4a96f70fe1932223e61df9329290370b145abb0a9", size = 522737 },
|
||||
{ url = "https://files.pythonhosted.org/packages/9c/52/4bc80c3388d5e2a3b6e504ba9656dd9eb3d8dbe822d07af38db1b8c96fb1/adbc_driver_manager-1.7.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4ec03d94177f71a8d3a149709f4111e021f9950229b35c0a803aadb1a1855a4b", size = 503896 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e1/f3/46052ca11224f661cef4721e19138bc73e750ba6aea54f22606950491606/adbc_driver_manager-1.7.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:700c79dac08a620018c912ede45a6dc7851819bc569a53073ab652dc0bd0c92f", size = 2972586 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a2/22/44738b41bb5ca30f94b5f4c00c71c20be86d7eb4ddc389d4cf3c7b8b69ef/adbc_driver_manager-1.7.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:98db0f5d0aa1635475f63700a7b6f677390beb59c69c7ba9d388bc8ce3779388", size = 2992001 },
|
||||
{ url = "https://files.pythonhosted.org/packages/1b/2b/5184fe5a529feb019582cc90d0f65e0021d52c34ca20620551532340645a/adbc_driver_manager-1.7.0-cp313-cp313-win_amd64.whl", hash = "sha256:4b7e5e9a163acb21804647cc7894501df51cdcd780ead770557112a26ca01ca6", size = 688789 },
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/e0/b283544e1bb7864bf5a5ac9cd330f111009eff9180ec5000420510cf9342/adbc_driver_manager-1.7.0-cp313-cp313t-macosx_10_15_x86_64.whl", hash = "sha256:ac83717965b83367a8ad6c0536603acdcfa66e0592d783f8940f55fda47d963e", size = 538625 },
|
||||
{ url = "https://files.pythonhosted.org/packages/77/5a/dc244264bd8d0c331a418d2bdda5cb6e26c30493ff075d706aa81d4e3b30/adbc_driver_manager-1.7.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:4c234cf81b00eaf7e7c65dbd0f0ddf7bdae93dfcf41e9d8543f9ecf4b10590f6", size = 523627 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/ff/a499a00367fd092edb20dc6e36c81e3c7a437671c70481cae97f46c8156a/adbc_driver_manager-1.7.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ad8aa4b039cc50722a700b544773388c6b1dea955781a01f79cd35d0a1e6edbf", size = 3037517 },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/6e/9dfdb113294dcb24b4f53924cd4a9c9af3fbe45a9790c1327048df731246/adbc_driver_manager-1.7.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4409ff53578e01842a8f57787ebfbfee790c1da01a6bd57fcb7701ed5d4dd4f7", size = 3016543 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiofiles"
|
||||
version = "24.1.0"
|
||||
@@ -536,9 +562,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "doris-mcp-server"
|
||||
version = "0.3.0"
|
||||
version = "0.6.1"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "adbc-driver-flightsql" },
|
||||
{ name = "adbc-driver-manager" },
|
||||
{ name = "aiofiles" },
|
||||
{ name = "aiohttp" },
|
||||
{ name = "aiomysql" },
|
||||
@@ -555,6 +583,7 @@ dependencies = [
|
||||
{ name = "pandas" },
|
||||
{ name = "passlib", extra = ["bcrypt"] },
|
||||
{ name = "prometheus-client" },
|
||||
{ name = "pyarrow" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "pyjwt" },
|
||||
@@ -625,6 +654,8 @@ dev = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "adbc-driver-flightsql", specifier = ">=0.8.0" },
|
||||
{ name = "adbc-driver-manager", specifier = ">=0.8.0" },
|
||||
{ name = "aiofiles", specifier = ">=23.0.0" },
|
||||
{ name = "aiohttp", specifier = ">=3.9.0" },
|
||||
{ name = "aiomysql", specifier = ">=0.2.0" },
|
||||
@@ -642,7 +673,7 @@ requires-dist = [
|
||||
{ name = "httpx", specifier = ">=0.26.0" },
|
||||
{ name = "isort", marker = "extra == 'dev'", specifier = ">=5.13.0" },
|
||||
{ name = "jaeger-client", marker = "extra == 'monitoring'", specifier = ">=4.8.0" },
|
||||
{ name = "mcp", specifier = ">=1.0.0" },
|
||||
{ name = "mcp", specifier = ">=1.8.0,<2.0.0" },
|
||||
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" },
|
||||
{ name = "myst-parser", marker = "extra == 'dev'", specifier = ">=2.0.0" },
|
||||
{ name = "myst-parser", marker = "extra == 'docs'", specifier = ">=2.0.0" },
|
||||
@@ -656,6 +687,7 @@ requires-dist = [
|
||||
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.6.0" },
|
||||
{ name = "prometheus-client", specifier = ">=0.19.0" },
|
||||
{ name = "prometheus-client", marker = "extra == 'monitoring'", specifier = ">=0.19.0" },
|
||||
{ name = "pyarrow", specifier = ">=14.0.0" },
|
||||
{ name = "pydantic", specifier = ">=2.5.0" },
|
||||
{ name = "pydantic-settings", specifier = ">=2.1.0" },
|
||||
{ name = "pyjwt", specifier = ">=2.8.0" },
|
||||
@@ -948,6 +980,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/20/b0/36bd937216ec521246249be3bf9855081de4c5e06a0c9b4219dbeda50373/importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd", size = 27656 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "importlib-resources"
|
||||
version = "6.5.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/cf/8c/f834fbf984f691b4f7ff60f50b514cc3de5cc08abfc3295564dd89c5e2e7/importlib_resources-6.5.2.tar.gz", hash = "sha256:185f87adef5bcc288449d98fb4fba07cea78bc036455dd44c5fc4a2fe78fed2c", size = 44693 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "2.1.0"
|
||||
@@ -1621,6 +1662,41 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7b/d7/7831438e6c3ebbfa6e01a927127a6cb42ad3ab844247f3c5b96bea25d73d/psutil-6.1.1-cp37-abi3-win_amd64.whl", hash = "sha256:f35cfccb065fff93529d2afb4a2e89e363fe63ca1e4a5da22b603a85833c2649", size = 254444 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyarrow"
|
||||
version = "20.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a2/ee/a7810cb9f3d6e9238e61d312076a9859bf3668fd21c69744de9532383912/pyarrow-20.0.0.tar.gz", hash = "sha256:febc4a913592573c8d5805091a6c2b5064c8bd6e002131f01061797d91c783c1", size = 1125187 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a1/d6/0c10e0d54f6c13eb464ee9b67a68b8c71bcf2f67760ef5b6fbcddd2ab05f/pyarrow-20.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:75a51a5b0eef32727a247707d4755322cb970be7e935172b6a3a9f9ae98404ba", size = 30815067 },
|
||||
{ url = "https://files.pythonhosted.org/packages/7e/e2/04e9874abe4094a06fd8b0cbb0f1312d8dd7d707f144c2ec1e5e8f452ffa/pyarrow-20.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:211d5e84cecc640c7a3ab900f930aaff5cd2702177e0d562d426fb7c4f737781", size = 32297128 },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/fd/c565e5dcc906a3b471a83273039cb75cb79aad4a2d4a12f76cc5ae90a4b8/pyarrow-20.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ba3cf4182828be7a896cbd232aa8dd6a31bd1f9e32776cc3796c012855e1199", size = 41334890 },
|
||||
{ url = "https://files.pythonhosted.org/packages/af/a9/3bdd799e2c9b20c1ea6dc6fa8e83f29480a97711cf806e823f808c2316ac/pyarrow-20.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c3a01f313ffe27ac4126f4c2e5ea0f36a5fc6ab51f8726cf41fee4b256680bd", size = 42421775 },
|
||||
{ url = "https://files.pythonhosted.org/packages/10/f7/da98ccd86354c332f593218101ae56568d5dcedb460e342000bd89c49cc1/pyarrow-20.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:a2791f69ad72addd33510fec7bb14ee06c2a448e06b649e264c094c5b5f7ce28", size = 40687231 },
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/1b/2168d6050e52ff1e6cefc61d600723870bf569cbf41d13db939c8cf97a16/pyarrow-20.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4250e28a22302ce8692d3a0e8ec9d9dde54ec00d237cff4dfa9c1fbf79e472a8", size = 42295639 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/66/2d976c0c7158fd25591c8ca55aee026e6d5745a021915a1835578707feb3/pyarrow-20.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:89e030dc58fc760e4010148e6ff164d2f44441490280ef1e97a542375e41058e", size = 42908549 },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/a9/dfb999c2fc6911201dcbf348247f9cc382a8990f9ab45c12eabfd7243a38/pyarrow-20.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6102b4864d77102dbbb72965618e204e550135a940c2534711d5ffa787df2a5a", size = 44557216 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/8e/9adee63dfa3911be2382fb4d92e4b2e7d82610f9d9f668493bebaa2af50f/pyarrow-20.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:96d6a0a37d9c98be08f5ed6a10831d88d52cac7b13f5287f1e0f625a0de8062b", size = 25660496 },
|
||||
{ url = "https://files.pythonhosted.org/packages/9b/aa/daa413b81446d20d4dad2944110dcf4cf4f4179ef7f685dd5a6d7570dc8e/pyarrow-20.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:a15532e77b94c61efadde86d10957950392999503b3616b2ffcef7621a002893", size = 30798501 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/75/2303d1caa410925de902d32ac215dc80a7ce7dd8dfe95358c165f2adf107/pyarrow-20.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:dd43f58037443af715f34f1322c782ec463a3c8a94a85fdb2d987ceb5658e061", size = 32277895 },
|
||||
{ url = "https://files.pythonhosted.org/packages/92/41/fe18c7c0b38b20811b73d1bdd54b1fccba0dab0e51d2048878042d84afa8/pyarrow-20.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aa0d288143a8585806e3cc7c39566407aab646fb9ece164609dac1cfff45f6ae", size = 41327322 },
|
||||
{ url = "https://files.pythonhosted.org/packages/da/ab/7dbf3d11db67c72dbf36ae63dcbc9f30b866c153b3a22ef728523943eee6/pyarrow-20.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6953f0114f8d6f3d905d98e987d0924dabce59c3cda380bdfaa25a6201563b4", size = 42411441 },
|
||||
{ url = "https://files.pythonhosted.org/packages/90/c3/0c7da7b6dac863af75b64e2f827e4742161128c350bfe7955b426484e226/pyarrow-20.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:991f85b48a8a5e839b2128590ce07611fae48a904cae6cab1f089c5955b57eb5", size = 40677027 },
|
||||
{ url = "https://files.pythonhosted.org/packages/be/27/43a47fa0ff9053ab5203bb3faeec435d43c0d8bfa40179bfd076cdbd4e1c/pyarrow-20.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:97c8dc984ed09cb07d618d57d8d4b67a5100a30c3818c2fb0b04599f0da2de7b", size = 42281473 },
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/0b/d56c63b078876da81bbb9ba695a596eabee9b085555ed12bf6eb3b7cab0e/pyarrow-20.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9b71daf534f4745818f96c214dbc1e6124d7daf059167330b610fc69b6f3d3e3", size = 42893897 },
|
||||
{ url = "https://files.pythonhosted.org/packages/92/ac/7d4bd020ba9145f354012838692d48300c1b8fe5634bfda886abcada67ed/pyarrow-20.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e8b88758f9303fa5a83d6c90e176714b2fd3852e776fc2d7e42a22dd6c2fb368", size = 44543847 },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/07/290f4abf9ca702c5df7b47739c1b2c83588641ddfa2cc75e34a301d42e55/pyarrow-20.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:30b3051b7975801c1e1d387e17c588d8ab05ced9b1e14eec57915f79869b5031", size = 25653219 },
|
||||
{ url = "https://files.pythonhosted.org/packages/95/df/720bb17704b10bd69dde086e1400b8eefb8f58df3f8ac9cff6c425bf57f1/pyarrow-20.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:ca151afa4f9b7bc45bcc791eb9a89e90a9eb2772767d0b1e5389609c7d03db63", size = 30853957 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/72/0d5f875efc31baef742ba55a00a25213a19ea64d7176e0fe001c5d8b6e9a/pyarrow-20.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:4680f01ecd86e0dd63e39eb5cd59ef9ff24a9d166db328679e36c108dc993d4c", size = 32247972 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d5/bc/e48b4fa544d2eea72f7844180eb77f83f2030b84c8dad860f199f94307ed/pyarrow-20.0.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f4c8534e2ff059765647aa69b75d6543f9fef59e2cd4c6d18015192565d2b70", size = 41256434 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c3/01/974043a29874aa2cf4f87fb07fd108828fc7362300265a2a64a94965e35b/pyarrow-20.0.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e1f8a47f4b4ae4c69c4d702cfbdfe4d41e18e5c7ef6f1bb1c50918c1e81c57b", size = 42353648 },
|
||||
{ url = "https://files.pythonhosted.org/packages/68/95/cc0d3634cde9ca69b0e51cbe830d8915ea32dda2157560dda27ff3b3337b/pyarrow-20.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:a1f60dc14658efaa927f8214734f6a01a806d7690be4b3232ba526836d216122", size = 40619853 },
|
||||
{ url = "https://files.pythonhosted.org/packages/29/c2/3ad40e07e96a3e74e7ed7cc8285aadfa84eb848a798c98ec0ad009eb6bcc/pyarrow-20.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:204a846dca751428991346976b914d6d2a82ae5b8316a6ed99789ebf976551e6", size = 42241743 },
|
||||
{ url = "https://files.pythonhosted.org/packages/eb/cb/65fa110b483339add6a9bc7b6373614166b14e20375d4daa73483755f830/pyarrow-20.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:f3b117b922af5e4c6b9a9115825726cac7d8b1421c37c2b5e24fbacc8930612c", size = 42839441 },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/7b/f30b1954589243207d7a0fbc9997401044bf9a033eec78f6cb50da3f304a/pyarrow-20.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:e724a3fd23ae5b9c010e7be857f4405ed5e679db5c93e66204db1a69f733936a", size = 44503279 },
|
||||
{ url = "https://files.pythonhosted.org/packages/37/40/ad395740cd641869a13bcf60851296c89624662575621968dcfafabaa7f6/pyarrow-20.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:82f1ee5133bd8f49d31be1299dc07f585136679666b502540db854968576faf9", size = 25944982 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyasn1"
|
||||
version = "0.6.1"
|
||||
|
||||