Compare commits
57 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
81305ffbf9 | ||
|
|
43143f0b30 | ||
|
|
e58361e04b | ||
|
|
a125a2f5f8 | ||
|
|
2613912df3 | ||
|
|
067f160b3e | ||
|
|
9ba4cc6f45 | ||
|
|
f99399c6c7 | ||
|
|
c3d487ccdd | ||
|
|
c1e3b13851 | ||
|
|
5923cc1c89 | ||
|
|
9b5ac8533d | ||
|
|
cc84d605e5 | ||
|
|
55dbdd5e14 | ||
|
|
affa4a0319 | ||
|
|
ecb5db8137 | ||
|
|
5d15f6f3a4 | ||
|
|
6247d49192 | ||
|
|
fb5e864a24 | ||
|
|
9bb5b17199 | ||
|
|
6d3c128f54 | ||
|
|
651d524814 | ||
|
|
54572d0861 | ||
|
|
d12dfbd014 | ||
|
|
4052b7e938 | ||
|
|
693c48d5ee | ||
|
|
c1ce9a5cc7 | ||
|
|
282a1c0bd9 | ||
|
|
e3b9bf96ab | ||
|
|
667cecbbe0 | ||
|
|
c777905bd3 | ||
|
|
d4ea125e35 | ||
|
|
f135d9b949 | ||
|
|
124dd0da88 | ||
|
|
775b4cb630 | ||
|
|
26e8bc1149 | ||
|
|
8526cb75fe | ||
|
|
97006a756d | ||
|
|
72865654e2 | ||
|
|
050c09f902 | ||
|
|
159399bd38 | ||
|
|
e859fbb778 | ||
|
|
1b9cb29f5f | ||
|
|
c95c0fe03c | ||
|
|
1e2e79d90d | ||
|
|
609816bc4a | ||
|
|
5d46d153e1 | ||
|
|
0a81d5693b | ||
|
|
a4306867f6 | ||
|
|
a22ff3ae9b | ||
|
|
2c5f26889c | ||
|
|
e47534c296 | ||
|
|
0f52591259 | ||
|
|
3b429f37b3 | ||
|
|
f5a4c8abbe | ||
|
|
87563ef6e1 | ||
|
|
b6157c500b |
@@ -24,9 +24,15 @@ github:
|
|||||||
- olap
|
- olap
|
||||||
- lakehouse
|
- lakehouse
|
||||||
- mcp
|
- mcp
|
||||||
|
- ai
|
||||||
enabled_merge_buttons:
|
enabled_merge_buttons:
|
||||||
squash: true
|
squash: true
|
||||||
merge: false
|
merge: false
|
||||||
rebase: false
|
rebase: false
|
||||||
|
features:
|
||||||
|
issues: true
|
||||||
|
projects: true
|
||||||
notifications:
|
notifications:
|
||||||
pullrequests_status: commits@doris.apache.org
|
issues: commits@doris.apache.org
|
||||||
|
commits: commits@doris.apache.org
|
||||||
|
pullrequests: commits@doris.apache.org
|
||||||
|
|||||||
2
.dockerignore
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
**/.venv
|
||||||
|
**/venv
|
||||||
500
.env.example
@@ -14,58 +14,512 @@
|
|||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
# ===================================================================
|
||||||
|
# Doris MCP Server Environment Configuration Example
|
||||||
|
# ===================================================================
|
||||||
|
# Copy this file to .env and modify the configuration values as needed
|
||||||
|
|
||||||
# Doris MCP Server Environment Configuration
|
# ===================================================================
|
||||||
# Copy this file to .env and modify the values as needed
|
# Database Connection Configuration
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
# Database Configuration
|
# Doris FE (Frontend) connection settings
|
||||||
DORIS_HOST=localhost
|
DORIS_HOST=localhost
|
||||||
DORIS_PORT=9030
|
DORIS_PORT=9030
|
||||||
DORIS_USER=root
|
DORIS_USER=root
|
||||||
DORIS_PASSWORD=your_password_here
|
DORIS_PASSWORD=
|
||||||
DORIS_DATABASE=your_database_name
|
DORIS_DATABASE=information_schema
|
||||||
|
|
||||||
# Connection Pool Settings
|
# Doris FE HTTP API port (for Profile and other HTTP APIs)
|
||||||
DORIS_MIN_CONNECTIONS=5
|
DORIS_FE_HTTP_PORT=8030
|
||||||
|
|
||||||
|
# Doris BE (Backend) nodes configuration (optional, for external access)
|
||||||
|
# Format: host1,host2,host3 (if empty, will use "show backends" to get BE nodes)
|
||||||
|
DORIS_BE_HOSTS=
|
||||||
|
DORIS_BE_WEBSERVER_PORT=8040
|
||||||
|
|
||||||
|
# Connection pool configuration
|
||||||
DORIS_MAX_CONNECTIONS=20
|
DORIS_MAX_CONNECTIONS=20
|
||||||
DORIS_CONNECTION_TIMEOUT=30
|
DORIS_CONNECTION_TIMEOUT=30
|
||||||
DORIS_HEALTH_CHECK_INTERVAL=60
|
DORIS_HEALTH_CHECK_INTERVAL=60
|
||||||
DORIS_MAX_CONNECTION_AGE=3600
|
DORIS_MAX_CONNECTION_AGE=3600
|
||||||
|
|
||||||
# Security Settings
|
# Arrow Flight SQL Configuration (Required for ADBC tools)
|
||||||
|
# FE_ARROW_FLIGHT_SQL_PORT=
|
||||||
|
# BE_ARROW_FLIGHT_SQL_PORT=
|
||||||
|
|
||||||
|
# ===================================================================
|
||||||
|
# Security Configuration
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
# Independent Authentication Switches - NEW DESIGN!
|
||||||
|
# Each authentication method can be enabled/disabled independently
|
||||||
|
# Any enabled method that succeeds will allow access
|
||||||
|
# If all methods are disabled, anonymous access is allowed
|
||||||
|
|
||||||
|
# Legacy configuration - kept for backward compatibility
|
||||||
|
# AUTH_TYPE is now deprecated - use individual switches above
|
||||||
AUTH_TYPE=token
|
AUTH_TYPE=token
|
||||||
TOKEN_SECRET=your_256_bit_secret_key_here
|
|
||||||
|
# Token Authentication (Default method - simple and effective)
|
||||||
|
ENABLE_TOKEN_AUTH=false
|
||||||
|
|
||||||
|
# JWT Authentication (For stateless applications)
|
||||||
|
ENABLE_JWT_AUTH=false
|
||||||
|
|
||||||
|
# OAuth 2.0/OIDC Authentication (For enterprise integration)
|
||||||
|
ENABLE_OAUTH_AUTH=false
|
||||||
|
|
||||||
|
# ===================================================================
|
||||||
|
# Token Authentication Configuration (Enable with ENABLE_TOKEN_AUTH=true)
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
# Basic token authentication settings
|
||||||
|
TOKEN_FILE_PATH=tokens.json
|
||||||
|
ENABLE_TOKEN_EXPIRY=true
|
||||||
|
DEFAULT_TOKEN_EXPIRY_HOURS=720
|
||||||
|
TOKEN_HASH_ALGORITHM=sha256
|
||||||
|
|
||||||
|
# ===================================================================
|
||||||
|
# Token Management Security Configuration (NEW in v0.6.0) - CRITICAL SECURITY SETTINGS
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
# HTTP Token Management Endpoints (DISABLED BY DEFAULT FOR SECURITY)
|
||||||
|
# WARNING: These endpoints allow creation, deletion, and management of authentication tokens
|
||||||
|
# Only enable if you need HTTP-based token management and understand the security implications
|
||||||
|
ENABLE_HTTP_TOKEN_MANAGEMENT=true
|
||||||
|
|
||||||
|
# Admin Authentication Token (REQUIRED if HTTP token management is enabled)
|
||||||
|
# This token is required to access HTTP token management endpoints
|
||||||
|
# SECURITY: Generate a secure random token in production - NEVER use default values
|
||||||
|
TOKEN_MANAGEMENT_ADMIN_TOKEN=
|
||||||
|
|
||||||
|
# IP Address Restrictions for Token Management (CRITICAL SECURITY CONTROL)
|
||||||
|
# Only these IP addresses/networks can access token management endpoints
|
||||||
|
# DEFAULT: localhost only (most secure) - add other IPs/networks only if necessary
|
||||||
|
# Format: comma-separated list of IPs and CIDR networks
|
||||||
|
# Examples:
|
||||||
|
# - Localhost only: 127.0.0.1,::1
|
||||||
|
# - Private network: 127.0.0.1,192.168.1.0/24,10.0.0.0/8
|
||||||
|
# - Specific IPs: 127.0.0.1,192.168.1.10,192.168.1.11
|
||||||
|
TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1
|
||||||
|
|
||||||
|
# Require Admin Authentication (ENABLED BY DEFAULT FOR SECURITY)
|
||||||
|
# When true, all token management operations require valid admin token
|
||||||
|
# When false, only IP restrictions apply (NOT RECOMMENDED for production)
|
||||||
|
REQUIRE_ADMIN_AUTH=true
|
||||||
|
|
||||||
|
# ===================================================================
|
||||||
|
# JWT Authentication Configuration (Enable with ENABLE_JWT_AUTH=true)
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
# JWT token settings (when ENABLE_JWT_AUTH=true)
|
||||||
|
JWT_SECRET_KEY=your_jwt_secret_key_here_change_in_production
|
||||||
|
JWT_ALGORITHM=HS256
|
||||||
|
JWT_EXPIRATION_HOURS=24
|
||||||
|
JWT_ISSUER=doris-mcp-server
|
||||||
|
JWT_AUDIENCE=doris-mcp-client
|
||||||
|
|
||||||
|
# JWT token validation settings
|
||||||
|
JWT_VERIFY_SIGNATURE=true
|
||||||
|
JWT_VERIFY_EXPIRATION=true
|
||||||
|
JWT_VERIFY_AUDIENCE=true
|
||||||
|
JWT_VERIFY_ISSUER=true
|
||||||
|
|
||||||
|
# JWT refresh token settings
|
||||||
|
ENABLE_JWT_REFRESH=true
|
||||||
|
JWT_REFRESH_EXPIRATION_DAYS=30
|
||||||
|
JWT_REFRESH_SECRET_KEY=your_jwt_refresh_secret_key_here
|
||||||
|
|
||||||
|
# JWT user claims configuration
|
||||||
|
JWT_USER_ID_CLAIM=user_id
|
||||||
|
JWT_ROLES_CLAIM=roles
|
||||||
|
JWT_PERMISSIONS_CLAIM=permissions
|
||||||
|
JWT_SECURITY_LEVEL_CLAIM=security_level
|
||||||
|
|
||||||
|
# ===================================================================
|
||||||
|
# OAuth 2.0 / OpenID Connect Configuration (Enable with ENABLE_OAUTH_AUTH=true)
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
# OAuth provider settings (when ENABLE_OAUTH_AUTH=true)
|
||||||
|
OAUTH_PROVIDER_TYPE=generic
|
||||||
|
OAUTH_CLIENT_ID=your_oauth_client_id
|
||||||
|
OAUTH_CLIENT_SECRET=your_oauth_client_secret
|
||||||
|
OAUTH_REDIRECT_URI=http://localhost:3000/auth/callback
|
||||||
|
|
||||||
|
# OAuth endpoints (for generic provider)
|
||||||
|
OAUTH_AUTHORIZATION_URL=https://your-provider.com/auth
|
||||||
|
OAUTH_TOKEN_URL=https://your-provider.com/token
|
||||||
|
OAUTH_USERINFO_URL=https://your-provider.com/userinfo
|
||||||
|
OAUTH_JWKS_URL=https://your-provider.com/.well-known/jwks.json
|
||||||
|
|
||||||
|
# OAuth scope and claims
|
||||||
|
OAUTH_SCOPE=openid profile email
|
||||||
|
OAUTH_USER_ID_CLAIM=sub
|
||||||
|
OAUTH_USERNAME_CLAIM=preferred_username
|
||||||
|
OAUTH_EMAIL_CLAIM=email
|
||||||
|
OAUTH_ROLES_CLAIM=roles
|
||||||
|
OAUTH_GROUPS_CLAIM=groups
|
||||||
|
|
||||||
|
# OAuth session settings
|
||||||
|
OAUTH_SESSION_SECRET=your_oauth_session_secret_here
|
||||||
|
OAUTH_SESSION_EXPIRY=3600
|
||||||
|
OAUTH_STATE_EXPIRY=300
|
||||||
|
|
||||||
|
# Popular OAuth providers presets (uncomment and configure as needed)
|
||||||
|
|
||||||
|
# Google OAuth Configuration
|
||||||
|
# OAUTH_PROVIDER_TYPE=google
|
||||||
|
# OAUTH_CLIENT_ID=your_google_client_id.apps.googleusercontent.com
|
||||||
|
# OAUTH_CLIENT_SECRET=your_google_client_secret
|
||||||
|
# OAUTH_AUTHORIZATION_URL=https://accounts.google.com/o/oauth2/auth
|
||||||
|
# OAUTH_TOKEN_URL=https://oauth2.googleapis.com/token
|
||||||
|
# OAUTH_USERINFO_URL=https://www.googleapis.com/oauth2/v1/userinfo
|
||||||
|
# OAUTH_JWKS_URL=https://www.googleapis.com/oauth2/v3/certs
|
||||||
|
# OAUTH_SCOPE=openid profile email
|
||||||
|
|
||||||
|
# Microsoft Azure AD Configuration
|
||||||
|
# OAUTH_PROVIDER_TYPE=azure
|
||||||
|
# OAUTH_CLIENT_ID=your_azure_client_id
|
||||||
|
# OAUTH_CLIENT_SECRET=your_azure_client_secret
|
||||||
|
# OAUTH_TENANT_ID=your_tenant_id
|
||||||
|
# OAUTH_AUTHORIZATION_URL=https://login.microsoftonline.com/{tenant}/oauth2/v2.0/authorize
|
||||||
|
# OAUTH_TOKEN_URL=https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token
|
||||||
|
# OAUTH_USERINFO_URL=https://graph.microsoft.com/v1.0/me
|
||||||
|
# OAUTH_JWKS_URL=https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys
|
||||||
|
# OAUTH_SCOPE=openid profile email
|
||||||
|
|
||||||
|
# GitHub OAuth Configuration
|
||||||
|
# OAUTH_PROVIDER_TYPE=github
|
||||||
|
# OAUTH_CLIENT_ID=your_github_client_id
|
||||||
|
# OAUTH_CLIENT_SECRET=your_github_client_secret
|
||||||
|
# OAUTH_AUTHORIZATION_URL=https://github.com/login/oauth/authorize
|
||||||
|
# OAUTH_TOKEN_URL=https://github.com/login/oauth/access_token
|
||||||
|
# OAUTH_USERINFO_URL=https://api.github.com/user
|
||||||
|
# OAUTH_SCOPE=user:email
|
||||||
|
|
||||||
|
# GitLab OAuth Configuration
|
||||||
|
# OAUTH_PROVIDER_TYPE=gitlab
|
||||||
|
# OAUTH_CLIENT_ID=your_gitlab_client_id
|
||||||
|
# OAUTH_CLIENT_SECRET=your_gitlab_client_secret
|
||||||
|
# OAUTH_AUTHORIZATION_URL=https://gitlab.com/oauth/authorize
|
||||||
|
# OAUTH_TOKEN_URL=https://gitlab.com/oauth/token
|
||||||
|
# OAUTH_USERINFO_URL=https://gitlab.com/api/v4/user
|
||||||
|
# OAUTH_SCOPE=read_user
|
||||||
|
|
||||||
|
# Keycloak OAuth Configuration
|
||||||
|
# OAUTH_PROVIDER_TYPE=keycloak
|
||||||
|
# OAUTH_CLIENT_ID=your_keycloak_client_id
|
||||||
|
# OAUTH_CLIENT_SECRET=your_keycloak_client_secret
|
||||||
|
# OAUTH_REALM=your_realm
|
||||||
|
# OAUTH_SERVER_URL=https://your-keycloak-server.com
|
||||||
|
# OAUTH_AUTHORIZATION_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/auth
|
||||||
|
# OAUTH_TOKEN_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/token
|
||||||
|
# OAUTH_USERINFO_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/userinfo
|
||||||
|
# OAUTH_JWKS_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/certs
|
||||||
|
# OAUTH_SCOPE=openid profile email
|
||||||
|
|
||||||
|
# Legacy token settings (for backward compatibility)
|
||||||
|
TOKEN_SECRET=your_secret_key_here
|
||||||
TOKEN_EXPIRY=3600
|
TOKEN_EXPIRY=3600
|
||||||
|
|
||||||
|
# SQL security check
|
||||||
|
ENABLE_SECURITY_CHECK=true
|
||||||
|
|
||||||
|
# Blocked keywords (comma separated)
|
||||||
|
BLOCKED_KEYWORDS=DROP,CREATE,ALTER,TRUNCATE,DELETE,INSERT,UPDATE,GRANT,REVOKE,EXEC,EXECUTE,SHUTDOWN,KILL
|
||||||
|
|
||||||
|
# Query limits
|
||||||
|
MAX_QUERY_COMPLEXITY=100
|
||||||
MAX_RESULT_ROWS=10000
|
MAX_RESULT_ROWS=10000
|
||||||
|
|
||||||
|
# Data masking
|
||||||
ENABLE_MASKING=true
|
ENABLE_MASKING=true
|
||||||
|
|
||||||
# Performance Settings
|
# ===================================================================
|
||||||
|
# Performance Configuration
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
# Query cache
|
||||||
ENABLE_QUERY_CACHE=true
|
ENABLE_QUERY_CACHE=true
|
||||||
CACHE_TTL=300
|
CACHE_TTL=300
|
||||||
MAX_CACHE_SIZE=1000
|
MAX_CACHE_SIZE=1000
|
||||||
|
|
||||||
|
# Concurrency control
|
||||||
MAX_CONCURRENT_QUERIES=50
|
MAX_CONCURRENT_QUERIES=50
|
||||||
QUERY_TIMEOUT=300
|
QUERY_TIMEOUT=300
|
||||||
|
|
||||||
# Logging Configuration
|
# Response content size limit (characters)
|
||||||
LOG_LEVEL=INFO
|
MAX_RESPONSE_CONTENT_SIZE=4096
|
||||||
LOG_FILE_PATH=./log/doris-mcp-server.log
|
|
||||||
ENABLE_AUDIT=true
|
|
||||||
AUDIT_FILE_PATH=./log/doris-mcp-audit.log
|
|
||||||
|
|
||||||
# Monitoring Settings
|
# ===================================================================
|
||||||
|
# ADBC (Arrow Flight SQL) Configuration
|
||||||
|
# ===================================================================
|
||||||
|
# Enable/disable ADBC tools
|
||||||
|
ADBC_ENABLED=true
|
||||||
|
|
||||||
|
# Default ADBC query parameters
|
||||||
|
ADBC_DEFAULT_MAX_ROWS=100000
|
||||||
|
ADBC_DEFAULT_TIMEOUT=60
|
||||||
|
# Format: "arrow", "pandas", "dict"
|
||||||
|
ADBC_DEFAULT_RETURN_FORMAT=arrow
|
||||||
|
|
||||||
|
# ADBC connection timeout
|
||||||
|
ADBC_CONNECTION_TIMEOUT=300
|
||||||
|
|
||||||
|
# ===================================================================
|
||||||
|
# Logging Configuration
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
# Basic logging configuration
|
||||||
|
LOG_LEVEL=INFO
|
||||||
|
LOG_FILE_PATH=
|
||||||
|
|
||||||
|
# Audit logging
|
||||||
|
ENABLE_AUDIT=true
|
||||||
|
AUDIT_FILE_PATH=
|
||||||
|
|
||||||
|
# Log file rotation configuration
|
||||||
|
LOG_MAX_FILE_SIZE=10485760
|
||||||
|
LOG_BACKUP_COUNT=5
|
||||||
|
|
||||||
|
# ===================================================================
|
||||||
|
# Log Cleanup Configuration - NEW!
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
# Enable automatic log cleanup
|
||||||
|
ENABLE_LOG_CLEANUP=true
|
||||||
|
|
||||||
|
# Maximum age of log files in days (files older than this will be deleted)
|
||||||
|
LOG_MAX_AGE_DAYS=30
|
||||||
|
|
||||||
|
# Cleanup check interval in hours
|
||||||
|
LOG_CLEANUP_INTERVAL_HOURS=24
|
||||||
|
|
||||||
|
# ===================================================================
|
||||||
|
# Monitoring Configuration
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
# Metrics collection
|
||||||
ENABLE_METRICS=true
|
ENABLE_METRICS=true
|
||||||
METRICS_PORT=3001
|
METRICS_PORT=3001
|
||||||
METRICS_PATH=/metrics
|
|
||||||
HEALTH_CHECK_PORT=3002
|
HEALTH_CHECK_PORT=3002
|
||||||
HEALTH_CHECK_PATH=/health
|
|
||||||
|
# Alert configuration
|
||||||
ENABLE_ALERTS=false
|
ENABLE_ALERTS=false
|
||||||
ALERT_WEBHOOK_URL=
|
ALERT_WEBHOOK_URL=
|
||||||
|
|
||||||
# Server Settings
|
# ===================================================================
|
||||||
|
# Server Configuration
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
# Basic server information
|
||||||
SERVER_NAME=doris-mcp-server
|
SERVER_NAME=doris-mcp-server
|
||||||
SERVER_VERSION=0.3.0
|
SERVER_VERSION=0.6.0
|
||||||
SERVER_PORT=3000
|
SERVER_PORT=3000
|
||||||
|
|
||||||
# Development Settings (for development environment only)
|
# Temporary files directory
|
||||||
DEBUG=false
|
TEMP_FILES_DIR=tmp
|
||||||
VERBOSE=false
|
|
||||||
|
# ===================================================================
|
||||||
|
# Configuration Examples for Different Environments
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
# Development Environment Example:
|
||||||
|
# LOG_LEVEL=DEBUG
|
||||||
|
# LOG_MAX_AGE_DAYS=7
|
||||||
|
# LOG_CLEANUP_INTERVAL_HOURS=6
|
||||||
|
# ENABLE_SECURITY_CHECK=false
|
||||||
|
|
||||||
|
# Production Environment Example:
|
||||||
|
# LOG_LEVEL=INFO
|
||||||
|
# LOG_MAX_AGE_DAYS=30
|
||||||
|
# LOG_CLEANUP_INTERVAL_HOURS=24
|
||||||
|
# ENABLE_SECURITY_CHECK=true
|
||||||
|
# ENABLE_LOG_CLEANUP=true
|
||||||
|
|
||||||
|
# Testing Environment Example:
|
||||||
|
# LOG_LEVEL=WARNING
|
||||||
|
# LOG_MAX_AGE_DAYS=3
|
||||||
|
# LOG_CLEANUP_INTERVAL_HOURS=1
|
||||||
|
# MAX_RESULT_ROWS=1000
|
||||||
|
|
||||||
|
# ===================================================================
|
||||||
|
# Advanced Configuration Notes
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
# 1. Log Cleanup Feature:
|
||||||
|
# - ENABLE_LOG_CLEANUP: Controls whether to enable automatic cleanup
|
||||||
|
# - LOG_MAX_AGE_DAYS: File retention days, recommended 30 days for production, 7 days for development
|
||||||
|
# - LOG_CLEANUP_INTERVAL_HOURS: Check frequency, recommended 24 hours
|
||||||
|
|
||||||
|
# 2. Security Best Practices:
|
||||||
|
# - NEW: Enable individual authentication methods using ENABLE_TOKEN_AUTH, ENABLE_JWT_AUTH, ENABLE_OAUTH_AUTH
|
||||||
|
# - When all methods are disabled, ALL requests are allowed with anonymous access
|
||||||
|
# - Authentication methods work independently - any one succeeding allows access
|
||||||
|
# - Token Auth: Change default tokens (DEFAULT_ADMIN_TOKEN, etc.) in production
|
||||||
|
# - JWT Auth: Change JWT_SECRET_KEY and JWT_REFRESH_SECRET_KEY in production
|
||||||
|
# - OAuth Auth: Configure OAuth provider settings and secure client secrets
|
||||||
|
# - Must change TOKEN_SECRET in production environment (legacy compatibility)
|
||||||
|
# - Adjust BLOCKED_KEYWORDS according to business needs
|
||||||
|
# - Enable ENABLE_SECURITY_CHECK and ENABLE_MASKING
|
||||||
|
# - NEW v0.6.0: Token Management Security (CRITICAL):
|
||||||
|
# * ENABLE_HTTP_TOKEN_MANAGEMENT=false by default (SECURE BY DEFAULT)
|
||||||
|
# * Only enable if you need HTTP token management endpoints
|
||||||
|
# * TOKEN_MANAGEMENT_ADMIN_TOKEN: Use secure random token in production
|
||||||
|
# * TOKEN_MANAGEMENT_ALLOWED_IPS: Restrict to localhost (127.0.0.1,::1) only
|
||||||
|
# * REQUIRE_ADMIN_AUTH=true: Always require admin authentication
|
||||||
|
# * Never expose token management endpoints to external networks
|
||||||
|
|
||||||
|
# 3. Performance Tuning:
|
||||||
|
# - Adjust MAX_CONCURRENT_QUERIES based on hardware resources
|
||||||
|
# - Adjust QUERY_TIMEOUT based on query complexity
|
||||||
|
# - Adjust MAX_CACHE_SIZE based on memory size
|
||||||
|
|
||||||
|
# 4. Connection Pool Optimization:
|
||||||
|
# - DORIS_MAX_CONNECTIONS recommended to be 2-4 times the number of CPU cores
|
||||||
|
# - DORIS_CONNECTION_TIMEOUT adjust based on network latency
|
||||||
|
# - DORIS_MAX_CONNECTION_AGE recommended 1 hour to avoid long connection issues
|
||||||
|
|
||||||
|
# 5. ADBC (Arrow Flight SQL) Configuration:
|
||||||
|
# - FE_ARROW_FLIGHT_SQL_PORT and BE_ARROW_FLIGHT_SQL_PORT: Required for ADBC functionality
|
||||||
|
# - ADBC_DEFAULT_MAX_ROWS: Default maximum rows for ADBC queries (recommended: 100000)
|
||||||
|
# - ADBC_DEFAULT_TIMEOUT: Default timeout for ADBC queries in seconds (recommended: 60)
|
||||||
|
# - ADBC_DEFAULT_RETURN_FORMAT: Default return format (arrow/pandas/dict, recommended: arrow)
|
||||||
|
# - ADBC_CONNECTION_TIMEOUT: Connection timeout for ADBC (recommended: 30)
|
||||||
|
# - ADBC_ENABLED: Enable or disable ADBC tools (true/false)
|
||||||
|
# - Prerequisites: Install adbc_driver_manager, adbc_driver_flightsql, pyarrow packages
|
||||||
|
|
||||||
|
# 6. Authentication Configuration Guide - UPDATED DESIGN!
|
||||||
|
#
|
||||||
|
# Independent Authentication Control (NEW):
|
||||||
|
# - ENABLE_TOKEN_AUTH=false (default): Disable token authentication
|
||||||
|
# - ENABLE_JWT_AUTH=false (default): Disable JWT authentication
|
||||||
|
# - ENABLE_OAUTH_AUTH=false (default): Disable OAuth authentication
|
||||||
|
# - When all methods are disabled, no authentication is required (anonymous access)
|
||||||
|
# - When multiple methods are enabled, any one succeeding allows access
|
||||||
|
# - Recommended for development/testing: all false, production: enable needed methods
|
||||||
|
#
|
||||||
|
# Token Authentication (ENABLE_TOKEN_AUTH=true) - Recommended for most use cases:
|
||||||
|
# - Simple and secure token-based authentication
|
||||||
|
# - Configurable default tokens via environment variables
|
||||||
|
# - Support for custom tokens via TOKEN_* environment variables
|
||||||
|
# - Token file configuration via tokens.json
|
||||||
|
# - Built-in token management HTTP endpoints
|
||||||
|
# - No user management complexity - pure API access control
|
||||||
|
#
|
||||||
|
# JWT Authentication (ENABLE_JWT_AUTH=true) - For stateless applications:
|
||||||
|
# - JSON Web Token based authentication
|
||||||
|
# - Configurable token expiration and refresh
|
||||||
|
# - Support for standard JWT claims
|
||||||
|
# - RSA/ECDSA/HS256 algorithm support
|
||||||
|
# - Suitable for microservices and distributed systems
|
||||||
|
#
|
||||||
|
# OAuth 2.0/OIDC (ENABLE_OAUTH_AUTH=true) - For enterprise integration:
|
||||||
|
# - Integration with external identity providers
|
||||||
|
# - Support for popular providers (Google, Microsoft, GitHub, GitLab, Keycloak)
|
||||||
|
# - OpenID Connect compatibility
|
||||||
|
# - Automatic user provisioning from provider
|
||||||
|
# - Secure authorization code flow
|
||||||
|
#
|
||||||
|
# Authentication Method Selection Guide:
|
||||||
|
# - No Auth (all switches false): Development, testing, trusted networks
|
||||||
|
# - Token Auth only: Small teams, simple deployment, direct API access
|
||||||
|
# - JWT Auth only: Stateless apps, microservices, mobile clients
|
||||||
|
# - OAuth Auth only: Enterprise SSO, large teams, external identity providers
|
||||||
|
# - Multiple methods: Flexible access, different client types, migration scenarios
|
||||||
|
|
||||||
|
# 7. Token Management Security Configuration Guide (NEW in v0.6.0) - CRITICAL!
|
||||||
|
#
|
||||||
|
# ⚠️ SECURITY WARNING: Token management endpoints are POWERFUL and DANGEROUS
|
||||||
|
# They allow creation, revocation, and management of authentication tokens.
|
||||||
|
# Improper configuration can lead to complete system compromise.
|
||||||
|
#
|
||||||
|
# 🔒 SECURE BY DEFAULT:
|
||||||
|
# - ENABLE_HTTP_TOKEN_MANAGEMENT=false (disabled by default)
|
||||||
|
# - REQUIRE_ADMIN_AUTH=true (admin auth required by default)
|
||||||
|
# - TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1 (localhost only by default)
|
||||||
|
#
|
||||||
|
# 🛡️ SECURITY LAYERS (Applied in order):
|
||||||
|
# 1. Configuration Check: HTTP token management must be explicitly enabled
|
||||||
|
# 2. IP Restrictions: Only allowed IP addresses/networks can access endpoints
|
||||||
|
# 3. Admin Authentication: Valid admin token required for all operations
|
||||||
|
#
|
||||||
|
# 📋 CONFIGURATION OPTIONS:
|
||||||
|
#
|
||||||
|
# Disable Token Management (RECOMMENDED for most deployments):
|
||||||
|
# ENABLE_HTTP_TOKEN_MANAGEMENT=false
|
||||||
|
# # All token management endpoints will return 403 Forbidden
|
||||||
|
#
|
||||||
|
# Enable with Maximum Security (Production):
|
||||||
|
# ENABLE_HTTP_TOKEN_MANAGEMENT=true
|
||||||
|
# TOKEN_MANAGEMENT_ADMIN_TOKEN=<secure-random-token-256-bit>
|
||||||
|
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1
|
||||||
|
# REQUIRE_ADMIN_AUTH=true
|
||||||
|
#
|
||||||
|
# Enable for Private Network (Advanced):
|
||||||
|
# ENABLE_HTTP_TOKEN_MANAGEMENT=true
|
||||||
|
# TOKEN_MANAGEMENT_ADMIN_TOKEN=<secure-random-token-256-bit>
|
||||||
|
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,192.168.1.0/24,10.0.0.0/8
|
||||||
|
# REQUIRE_ADMIN_AUTH=true
|
||||||
|
#
|
||||||
|
# 🔑 ADMIN TOKEN GENERATION:
|
||||||
|
# # Generate secure admin token (Linux/macOS):
|
||||||
|
# openssl rand -hex 32
|
||||||
|
#
|
||||||
|
# # Generate secure admin token (Python):
|
||||||
|
# python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||||
|
#
|
||||||
|
# 🌐 IP CONFIGURATION EXAMPLES:
|
||||||
|
# # Localhost only (most secure):
|
||||||
|
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1
|
||||||
|
#
|
||||||
|
# # Private network + localhost:
|
||||||
|
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1,192.168.1.0/24,10.0.0.0/8
|
||||||
|
#
|
||||||
|
# # Specific servers only:
|
||||||
|
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,192.168.1.10,192.168.1.11
|
||||||
|
#
|
||||||
|
# # Corporate network (be careful):
|
||||||
|
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,172.16.0.0/12,192.168.0.0/16
|
||||||
|
#
|
||||||
|
# 🚫 NEVER DO THIS (Security Anti-Patterns):
|
||||||
|
# # NEVER allow all IPs:
|
||||||
|
# # TOKEN_MANAGEMENT_ALLOWED_IPS=0.0.0.0/0 # DANGEROUS!
|
||||||
|
#
|
||||||
|
# # NEVER disable admin auth in production:
|
||||||
|
# # REQUIRE_ADMIN_AUTH=false # DANGEROUS!
|
||||||
|
#
|
||||||
|
# # NEVER use weak admin tokens:
|
||||||
|
# # TOKEN_MANAGEMENT_ADMIN_TOKEN=admin # DANGEROUS!
|
||||||
|
# # TOKEN_MANAGEMENT_ADMIN_TOKEN=123456 # DANGEROUS!
|
||||||
|
#
|
||||||
|
# 📊 ENDPOINT SECURITY TESTING:
|
||||||
|
# # Test security (should fail):
|
||||||
|
# curl -X POST http://external-ip:3000/token/create
|
||||||
|
# # Expected: 403 Forbidden (IP not allowed)
|
||||||
|
#
|
||||||
|
# # Test without auth (should fail):
|
||||||
|
# curl -X POST http://127.0.0.1:3000/token/create
|
||||||
|
# # Expected: 401 Unauthorized (missing admin token)
|
||||||
|
#
|
||||||
|
# # Test with valid auth (should succeed if enabled):
|
||||||
|
# curl -H "Authorization: Bearer your-admin-token" http://127.0.0.1:3000/token/stats
|
||||||
|
# # Expected: 200 OK with token statistics
|
||||||
|
#
|
||||||
|
# 🔍 MONITORING & AUDITING:
|
||||||
|
# # All token management access attempts are logged:
|
||||||
|
# tail -f logs/doris_mcp_server_audit.log | grep "token management"
|
||||||
|
#
|
||||||
|
# # Monitor security events:
|
||||||
|
# tail -f logs/doris_mcp_server_info.log | grep -E "(access denied|token management)"
|
||||||
|
#
|
||||||
|
# ✅ SECURITY BEST PRACTICES:
|
||||||
|
# - Keep ENABLE_HTTP_TOKEN_MANAGEMENT=false unless absolutely necessary
|
||||||
|
# - Use file-based token management (tokens.json) instead of HTTP endpoints
|
||||||
|
# - Generate strong admin tokens using cryptographically secure methods
|
||||||
|
# - Restrict access to localhost (127.0.0.1,::1) whenever possible
|
||||||
|
# - Never expose token management endpoints to public internet
|
||||||
|
# - Regularly audit token management access logs
|
||||||
|
# - Use firewall rules as additional protection layer
|
||||||
|
# - Consider VPN access for remote token management needs
|
||||||
23
.gitignore
vendored
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
*.log
|
||||||
|
*.log.*
|
||||||
|
*.bak
|
||||||
|
logs
|
||||||
|
/configs/*.py
|
||||||
|
.vscode/
|
||||||
|
__pycache__/
|
||||||
|
*.log
|
||||||
|
|
||||||
|
.python-version
|
||||||
|
Pipfile.lock
|
||||||
|
poetry.lock
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
.idea/
|
||||||
|
.coverage
|
||||||
|
coverage.xml
|
||||||
10
Dockerfile
@@ -32,6 +32,7 @@ RUN apt-get update && apt-get install -y \
|
|||||||
g++ \
|
g++ \
|
||||||
pkg-config \
|
pkg-config \
|
||||||
default-libmysqlclient-dev \
|
default-libmysqlclient-dev \
|
||||||
|
dos2unix \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy requirements file
|
# Copy requirements file
|
||||||
@@ -43,12 +44,13 @@ RUN pip install --no-cache-dir -r requirements.txt
|
|||||||
# Copy application code
|
# Copy application code
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
|
# Convert line endings for shell scripts and ensure proper execution format
|
||||||
|
RUN find . -name "*.sh" -exec dos2unix {} \; && \
|
||||||
|
find . -name "*.sh" -exec chmod +x {} \;
|
||||||
|
|
||||||
# Create necessary directories
|
# Create necessary directories
|
||||||
RUN mkdir -p /app/logs /app/config /app/data
|
RUN mkdir -p /app/logs /app/config /app/data
|
||||||
|
|
||||||
# Set permissions
|
|
||||||
RUN chmod +x /app/start.sh
|
|
||||||
|
|
||||||
# Create non-root user
|
# Create non-root user
|
||||||
RUN groupadd -r doris && useradd -r -g doris doris
|
RUN groupadd -r doris && useradd -r -g doris doris
|
||||||
RUN chown -R doris:doris /app
|
RUN chown -R doris:doris /app
|
||||||
@@ -62,4 +64,4 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
|||||||
EXPOSE 3000 3001 3002
|
EXPOSE 3000 3001 3002
|
||||||
|
|
||||||
# Start command
|
# Start command
|
||||||
CMD ["/app/start.sh"]
|
CMD ["/app/start_server.sh"]
|
||||||
|
|||||||
@@ -133,9 +133,6 @@ async def database_operations(client):
|
|||||||
|
|
||||||
# Get table schema
|
# Get table schema
|
||||||
schema = await client.get_table_schema("table_name", "db_name")
|
schema = await client.get_table_schema("table_name", "db_name")
|
||||||
|
|
||||||
# Column data analysis
|
|
||||||
analysis = await client.analyze_column("table", "column", "basic")
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## 🧪 Testing
|
## 🧪 Testing
|
||||||
@@ -177,7 +174,6 @@ python test_unified_client.py benchmark
|
|||||||
2. get_table_list: Get table list for specified database
|
2. get_table_list: Get table list for specified database
|
||||||
3. get_table_schema: Get table structure information
|
3. get_table_schema: Get table structure information
|
||||||
4. exec_query: Execute SQL query
|
4. exec_query: Execute SQL query
|
||||||
5. column_analysis: Analyze column data distribution and statistics
|
|
||||||
...
|
...
|
||||||
|
|
||||||
🧪 Testing basic functionality...
|
🧪 Testing basic functionality...
|
||||||
@@ -189,8 +185,6 @@ python test_unified_client.py benchmark
|
|||||||
✅ SSB query successful
|
✅ SSB query successful
|
||||||
4️⃣ Getting table structure...
|
4️⃣ Getting table structure...
|
||||||
✅ Table structure retrieved successfully
|
✅ Table structure retrieved successfully
|
||||||
5️⃣ Column data analysis...
|
|
||||||
✅ Column analysis successful
|
|
||||||
|
|
||||||
✅ HTTP mode testing completed!
|
✅ HTTP mode testing completed!
|
||||||
```
|
```
|
||||||
@@ -255,12 +249,6 @@ async def comprehensive_example():
|
|||||||
# Get table schema
|
# Get table schema
|
||||||
schema_result = await client.get_table_schema("lineorder", "ssb")
|
schema_result = await client.get_table_schema("lineorder", "ssb")
|
||||||
print(f"Table schema: {schema_result}")
|
print(f"Table schema: {schema_result}")
|
||||||
|
|
||||||
# Column analysis
|
|
||||||
analysis_result = await client.analyze_column(
|
|
||||||
"lineorder", "lo_orderkey", "basic"
|
|
||||||
)
|
|
||||||
print(f"Column analysis: {analysis_result}")
|
|
||||||
|
|
||||||
await client.connect_and_run(demo_operations)
|
await client.connect_and_run(demo_operations)
|
||||||
|
|
||||||
|
|||||||
@@ -323,7 +323,7 @@ class DorisUnifiedClient:
|
|||||||
async with streamablehttp_client(
|
async with streamablehttp_client(
|
||||||
self.config.server_url,
|
self.config.server_url,
|
||||||
timeout=timedelta(seconds=self.config.timeout)
|
timeout=timedelta(seconds=self.config.timeout)
|
||||||
) as (read, write):
|
) as (read, write, _):
|
||||||
async with ClientSession(read, write) as session:
|
async with ClientSession(read, write) as session:
|
||||||
self.session = session
|
self.session = session
|
||||||
self._init_sub_clients()
|
self._init_sub_clients()
|
||||||
@@ -422,18 +422,14 @@ class DorisUnifiedClient:
|
|||||||
|
|
||||||
return await self.call_tool(tool_name, kwargs)
|
return await self.call_tool(tool_name, kwargs)
|
||||||
|
|
||||||
async def analyze_column(self, table_name: str, column_name: str, analysis_type: str = "basic", **kwargs) -> dict[str, Any]:
|
async def get_memory_stats(self, tracker_type: str = "overview", include_details: bool = True, **kwargs) -> dict[str, Any]:
|
||||||
"""Analyze column"""
|
"""Get memory statistics"""
|
||||||
tool_name = await self._find_tool_by_pattern(["column_analysis", "analyze_column", "column"])
|
tool_name = await self._find_tool_by_pattern(["memory", "realtime_memory"])
|
||||||
if not tool_name:
|
if not tool_name:
|
||||||
return {"success": False, "error": "Column analysis tool not found"}
|
return {"success": False, "error": "Memory stats tool not found"}
|
||||||
|
|
||||||
arguments = {
|
arguments = {"tracker_type": tracker_type, "include_details": include_details}
|
||||||
"table_name": table_name,
|
arguments.update(kwargs)
|
||||||
"column_name": column_name,
|
|
||||||
"analysis_type": analysis_type,
|
|
||||||
**kwargs
|
|
||||||
}
|
|
||||||
return await self.call_tool(tool_name, arguments)
|
return await self.call_tool(tool_name, arguments)
|
||||||
|
|
||||||
async def call_tool_by_function(self, function_description: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
async def call_tool_by_function(self, function_description: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||||
@@ -467,7 +463,7 @@ async def create_http_client(server_url: str, timeout: int = 60) -> DorisUnified
|
|||||||
# Example usage
|
# Example usage
|
||||||
async def example_stdio():
|
async def example_stdio():
|
||||||
"""stdio mode example"""
|
"""stdio mode example"""
|
||||||
client = await create_stdio_client("python", ["doris_mcp_server/main.py"])
|
client = await create_stdio_client("python", ["-m", "doris_mcp_server.main", "--transport", "stdio"])
|
||||||
|
|
||||||
async def test_client(client: DorisUnifiedClient):
|
async def test_client(client: DorisUnifiedClient):
|
||||||
# Get server capabilities
|
# Get server capabilities
|
||||||
@@ -510,4 +506,4 @@ if __name__ == "__main__":
|
|||||||
asyncio.run(example_stdio())
|
asyncio.run(example_stdio())
|
||||||
|
|
||||||
# Run HTTP example
|
# Run HTTP example
|
||||||
# asyncio.run(example_http())
|
# asyncio.run(example_http())
|
||||||
|
|||||||
56
doris_mcp_server/auth/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
Doris MCP Server Authentication Module
|
||||||
|
Provides JWT-based, Token-based, and OAuth 2.0/OIDC authentication and authorization services
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .jwt_manager import JWTManager
|
||||||
|
from .key_manager import KeyManager
|
||||||
|
from .token_validators import TokenValidator, TokenBlacklist
|
||||||
|
from .auth_middleware import AuthMiddleware
|
||||||
|
from .token_manager import TokenManager, TokenInfo, TokenValidationResult
|
||||||
|
from .token_handlers import TokenHandlers
|
||||||
|
from .oauth_client import OAuthClient, OAuthStateManager
|
||||||
|
from .oauth_provider import OAuthAuthenticationProvider
|
||||||
|
from .oauth_types import (
|
||||||
|
OAuthProvider, OAuthState, OAuthTokens, OAuthUserInfo,
|
||||||
|
OIDCDiscovery, OAuthError, OAuthProviderConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"JWTManager",
|
||||||
|
"KeyManager",
|
||||||
|
"TokenValidator",
|
||||||
|
"TokenBlacklist",
|
||||||
|
"AuthMiddleware",
|
||||||
|
"TokenManager",
|
||||||
|
"TokenInfo",
|
||||||
|
"TokenValidationResult",
|
||||||
|
"TokenHandlers",
|
||||||
|
"OAuthClient",
|
||||||
|
"OAuthStateManager",
|
||||||
|
"OAuthAuthenticationProvider",
|
||||||
|
"OAuthProvider",
|
||||||
|
"OAuthState",
|
||||||
|
"OAuthTokens",
|
||||||
|
"OAuthUserInfo",
|
||||||
|
"OIDCDiscovery",
|
||||||
|
"OAuthError",
|
||||||
|
"OAuthProviderConfig"
|
||||||
|
]
|
||||||
271
doris_mcp_server/auth/auth_middleware.py
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
Authentication Middleware Module
|
||||||
|
Provides middleware for JWT authentication in HTTP and MCP contexts
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Dict, Any, Callable, Awaitable
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from .jwt_manager import JWTManager
|
||||||
|
from ..utils.security import AuthContext, SecurityLevel
|
||||||
|
from ..utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthMiddleware:
|
||||||
|
"""Authentication Middleware
|
||||||
|
|
||||||
|
Provides JWT authentication functionality for HTTP and MCP requests
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, jwt_manager: JWTManager):
|
||||||
|
"""Initialize authentication middleware
|
||||||
|
|
||||||
|
Args:
|
||||||
|
jwt_manager: JWT manager instance
|
||||||
|
"""
|
||||||
|
self.jwt_manager = jwt_manager
|
||||||
|
logger.info("AuthMiddleware initialized")
|
||||||
|
|
||||||
|
def extract_token_from_header(self, authorization: str) -> Optional[str]:
|
||||||
|
"""Extract JWT token from Authorization header
|
||||||
|
|
||||||
|
Args:
|
||||||
|
authorization: Authorization header value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JWT token string, or None if not found
|
||||||
|
"""
|
||||||
|
if not authorization:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Support Bearer format
|
||||||
|
if authorization.startswith('Bearer '):
|
||||||
|
return authorization[7:] # Remove "Bearer " prefix
|
||||||
|
|
||||||
|
# Support direct token format
|
||||||
|
if not authorization.startswith('Basic '):
|
||||||
|
return authorization
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def authenticate_request(self, auth_info: Dict[str, Any]) -> AuthContext:
|
||||||
|
"""Authenticate request and return authentication context
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth_info: Authentication information dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AuthContext authentication context
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Authentication failed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
auth_type = auth_info.get("type", "jwt")
|
||||||
|
|
||||||
|
if auth_type == "jwt" or auth_type == "token":
|
||||||
|
return await self._authenticate_jwt(auth_info)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported authentication type: {auth_type}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Request authentication failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _authenticate_jwt(self, auth_info: Dict[str, Any]) -> AuthContext:
|
||||||
|
"""JWT authentication processing
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth_info: Authentication information containing JWT token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AuthContext authentication context
|
||||||
|
"""
|
||||||
|
# Get token
|
||||||
|
token = auth_info.get("token")
|
||||||
|
if not token:
|
||||||
|
# Try to get from Authorization header
|
||||||
|
authorization = auth_info.get("authorization")
|
||||||
|
token = self.extract_token_from_header(authorization)
|
||||||
|
|
||||||
|
if not token:
|
||||||
|
raise ValueError("Missing JWT token")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Validate token
|
||||||
|
validation_result = await self.jwt_manager.validate_token(token, 'access')
|
||||||
|
payload = validation_result['payload']
|
||||||
|
|
||||||
|
# Build authentication context
|
||||||
|
auth_context = AuthContext(
|
||||||
|
token_id=payload.get('jti', ''),
|
||||||
|
user_id=payload.get('sub'),
|
||||||
|
roles=payload.get('roles', []),
|
||||||
|
permissions=payload.get('permissions', []),
|
||||||
|
security_level=SecurityLevel(payload.get('security_level', 'internal')),
|
||||||
|
session_id=payload.get('jti'), # Use JWT ID as session ID
|
||||||
|
login_time=datetime.fromtimestamp(payload.get('iat', 0)),
|
||||||
|
last_activity=datetime.utcnow(),
|
||||||
|
token=token # Store raw token for token-bound database configuration
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"JWT authentication successful for user: {auth_context.user_id}")
|
||||||
|
return auth_context
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"JWT authentication failed: {e}")
|
||||||
|
raise ValueError(f"JWT authentication failed: {str(e)}")
|
||||||
|
|
||||||
|
async def create_auth_response_headers(self, auth_context: AuthContext) -> Dict[str, str]:
|
||||||
|
"""Create authentication response headers
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth_context: Authentication context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response headers dictionary
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
'X-Auth-User': auth_context.user_id,
|
||||||
|
'X-Auth-Roles': ','.join(auth_context.roles),
|
||||||
|
'X-Auth-Session': auth_context.session_id,
|
||||||
|
'X-Auth-Security-Level': auth_context.security_level.value
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_http_middleware(self, skip_paths: Optional[list] = None):
|
||||||
|
"""Create HTTP middleware function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skip_paths: List of paths to skip authentication
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ASGI middleware function
|
||||||
|
"""
|
||||||
|
skip_paths = skip_paths or ['/health', '/docs', '/openapi.json']
|
||||||
|
|
||||||
|
async def middleware(scope, receive, send):
|
||||||
|
"""HTTP authentication middleware"""
|
||||||
|
if scope['type'] != 'http':
|
||||||
|
# Pass through non-HTTP requests directly
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
path = scope.get('path', '')
|
||||||
|
|
||||||
|
# Check if authentication should be skipped
|
||||||
|
if any(path.startswith(skip) for skip in skip_paths):
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
# Extract authentication information
|
||||||
|
headers = dict(scope.get('headers', []))
|
||||||
|
authorization = headers.get(b'authorization', b'').decode()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Perform authentication
|
||||||
|
auth_info = {
|
||||||
|
'type': 'jwt',
|
||||||
|
'authorization': authorization
|
||||||
|
}
|
||||||
|
auth_context = await self.authenticate_request(auth_info)
|
||||||
|
|
||||||
|
# Add authentication context to scope
|
||||||
|
scope['auth_context'] = auth_context
|
||||||
|
|
||||||
|
# Create response wrapper to add authentication headers
|
||||||
|
async def send_wrapper(message):
|
||||||
|
if message['type'] == 'http.response.start':
|
||||||
|
headers = dict(message.get('headers', []))
|
||||||
|
auth_headers = await self.create_auth_response_headers(auth_context)
|
||||||
|
|
||||||
|
for key, value in auth_headers.items():
|
||||||
|
headers[key.encode()] = value.encode()
|
||||||
|
|
||||||
|
message['headers'] = list(headers.items())
|
||||||
|
|
||||||
|
await send(message)
|
||||||
|
|
||||||
|
return await self.app(scope, receive, send_wrapper)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Authentication failed, return 401 error
|
||||||
|
response_body = f'{{"error": "Authentication failed", "message": "{str(e)}"}}'
|
||||||
|
|
||||||
|
await send({
|
||||||
|
'type': 'http.response.start',
|
||||||
|
'status': 401,
|
||||||
|
'headers': [
|
||||||
|
(b'content-type', b'application/json'),
|
||||||
|
(b'www-authenticate', b'Bearer')
|
||||||
|
]
|
||||||
|
})
|
||||||
|
await send({
|
||||||
|
'type': 'http.response.body',
|
||||||
|
'body': response_body.encode()
|
||||||
|
})
|
||||||
|
|
||||||
|
return middleware
|
||||||
|
|
||||||
|
async def authenticate_mcp_request(self, headers: Dict[str, str]) -> AuthContext:
|
||||||
|
"""Authenticate MCP request
|
||||||
|
|
||||||
|
Args:
|
||||||
|
headers: MCP request headers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AuthContext authentication context
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Extract authentication information from multiple possible header fields
|
||||||
|
authorization = (
|
||||||
|
headers.get('Authorization') or
|
||||||
|
headers.get('authorization') or
|
||||||
|
headers.get('X-Auth-Token') or
|
||||||
|
headers.get('x-auth-token')
|
||||||
|
)
|
||||||
|
|
||||||
|
auth_info = {
|
||||||
|
'type': 'jwt',
|
||||||
|
'authorization': authorization
|
||||||
|
}
|
||||||
|
|
||||||
|
return await self.authenticate_request(auth_info)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"MCP request authentication failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationError(Exception):
|
||||||
|
"""Authentication error exception"""
|
||||||
|
|
||||||
|
def __init__(self, message: str, error_code: str = "AUTH_FAILED"):
|
||||||
|
self.message = message
|
||||||
|
self.error_code = error_code
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizationError(Exception):
|
||||||
|
"""Authorization error exception"""
|
||||||
|
|
||||||
|
def __init__(self, message: str, error_code: str = "ACCESS_DENIED"):
|
||||||
|
self.message = message
|
||||||
|
self.error_code = error_code
|
||||||
|
super().__init__(message)
|
||||||
471
doris_mcp_server/auth/jwt_manager.py
Normal file
@@ -0,0 +1,471 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
JWT Manager Module
|
||||||
|
Provides comprehensive JWT token management including generation, validation, refresh and revocation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
import asyncio
|
||||||
|
from typing import Dict, Any, Optional, Tuple
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
try:
|
||||||
|
import jwt
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("PyJWT is required for JWT functionality. Install with: pip install PyJWT[crypto]")
|
||||||
|
|
||||||
|
from .key_manager import KeyManager
|
||||||
|
from .token_validators import TokenValidator, TokenBlacklist
|
||||||
|
from ..utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class JWTManager:
|
||||||
|
"""JWT Token Manager
|
||||||
|
|
||||||
|
Provides comprehensive JWT token lifecycle management, including:
|
||||||
|
- Token generation and signing
|
||||||
|
- Token validation and parsing
|
||||||
|
- Token refresh mechanism
|
||||||
|
- Token revocation and blacklist
|
||||||
|
- Automatic key rotation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
"""Initialize JWT manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: DorisConfig configuration object (with security attribute)
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
# Access JWT settings through the security configuration
|
||||||
|
if hasattr(config, 'security'):
|
||||||
|
security_config = config.security
|
||||||
|
else:
|
||||||
|
# Fallback if config is passed directly as SecurityConfig
|
||||||
|
security_config = config
|
||||||
|
|
||||||
|
self.algorithm = security_config.jwt_algorithm
|
||||||
|
self.issuer = security_config.jwt_issuer
|
||||||
|
self.audience = security_config.jwt_audience
|
||||||
|
self.access_token_expiry = security_config.jwt_access_token_expiry
|
||||||
|
self.refresh_token_expiry = security_config.jwt_refresh_token_expiry
|
||||||
|
self.enable_refresh = security_config.enable_token_refresh
|
||||||
|
self.enable_revocation = security_config.enable_token_revocation
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
self.key_manager = KeyManager(config)
|
||||||
|
self.token_blacklist = TokenBlacklist()
|
||||||
|
self.validator = TokenValidator(config, self.token_blacklist)
|
||||||
|
|
||||||
|
# Automatic key rotation task
|
||||||
|
self._key_rotation_task = None
|
||||||
|
|
||||||
|
logger.info(f"JWTManager initialized with algorithm: {self.algorithm}")
|
||||||
|
|
||||||
|
async def initialize(self) -> bool:
|
||||||
|
"""Initialize JWT manager"""
|
||||||
|
try:
|
||||||
|
# Initialize key manager
|
||||||
|
if not await self.key_manager.initialize():
|
||||||
|
logger.error("Failed to initialize key manager")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Start token validator
|
||||||
|
await self.validator.start()
|
||||||
|
|
||||||
|
# Start automatic key rotation
|
||||||
|
if self.key_manager.key_rotation_interval > 0:
|
||||||
|
self._key_rotation_task = asyncio.create_task(self._auto_key_rotation())
|
||||||
|
|
||||||
|
logger.info("JWTManager initialization completed")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize JWTManager: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""Shutdown JWT manager"""
|
||||||
|
try:
|
||||||
|
# Stop key rotation task
|
||||||
|
if self._key_rotation_task:
|
||||||
|
self._key_rotation_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._key_rotation_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Stop validator
|
||||||
|
await self.validator.stop()
|
||||||
|
|
||||||
|
logger.info("JWTManager shutdown completed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during JWTManager shutdown: {e}")
|
||||||
|
|
||||||
|
async def generate_tokens(self, user_info: Dict[str, Any],
|
||||||
|
custom_claims: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||||
|
"""Generate access token and refresh token
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_info: User information dictionary, containing user_id, roles, permissions, etc.
|
||||||
|
custom_claims: Custom claims
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing access_token and refresh_token
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
current_time = int(time.time())
|
||||||
|
jti = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Build base payload
|
||||||
|
base_payload = {
|
||||||
|
'iss': self.issuer,
|
||||||
|
'aud': self.audience,
|
||||||
|
'iat': current_time,
|
||||||
|
'jti': jti,
|
||||||
|
'sub': user_info.get('user_id'),
|
||||||
|
'roles': user_info.get('roles', []),
|
||||||
|
'permissions': user_info.get('permissions', []),
|
||||||
|
'security_level': user_info.get('security_level', 'internal')
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add custom claims
|
||||||
|
if custom_claims:
|
||||||
|
base_payload.update(custom_claims)
|
||||||
|
|
||||||
|
# Generate access token
|
||||||
|
access_payload = base_payload.copy()
|
||||||
|
access_payload.update({
|
||||||
|
'exp': current_time + self.access_token_expiry,
|
||||||
|
'token_type': 'access'
|
||||||
|
})
|
||||||
|
|
||||||
|
access_token = await self._sign_token(access_payload)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
'access_token': access_token,
|
||||||
|
'token_type': 'Bearer',
|
||||||
|
'expires_in': self.access_token_expiry,
|
||||||
|
'user_id': user_info.get('user_id'),
|
||||||
|
'issued_at': current_time
|
||||||
|
}
|
||||||
|
|
||||||
|
# Generate refresh token (if enabled)
|
||||||
|
if self.enable_refresh:
|
||||||
|
refresh_jti = str(uuid.uuid4())
|
||||||
|
refresh_payload = {
|
||||||
|
'iss': self.issuer,
|
||||||
|
'aud': self.audience,
|
||||||
|
'iat': current_time,
|
||||||
|
'exp': current_time + self.refresh_token_expiry,
|
||||||
|
'jti': refresh_jti,
|
||||||
|
'sub': user_info.get('user_id'),
|
||||||
|
'token_type': 'refresh',
|
||||||
|
'access_jti': jti # Associated access token ID
|
||||||
|
}
|
||||||
|
|
||||||
|
refresh_token = await self._sign_token(refresh_payload)
|
||||||
|
result.update({
|
||||||
|
'refresh_token': refresh_token,
|
||||||
|
'refresh_expires_in': self.refresh_token_expiry
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"Generated tokens for user: {user_info.get('user_id')}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to generate tokens: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _sign_token(self, payload: Dict[str, Any]) -> str:
|
||||||
|
"""Sign JWT token
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: JWT payload
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Signed JWT token
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
signing_key = self.key_manager.get_private_key()
|
||||||
|
|
||||||
|
if self.algorithm == "HS256":
|
||||||
|
# Symmetric key signing
|
||||||
|
token = jwt.encode(payload, signing_key, algorithm=self.algorithm)
|
||||||
|
else:
|
||||||
|
# Asymmetric key signing
|
||||||
|
token = jwt.encode(payload, signing_key, algorithm=self.algorithm)
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to sign token: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def validate_token(self, token: str, token_type: str = 'access') -> Dict[str, Any]:
|
||||||
|
"""Validate JWT token
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: JWT token string
|
||||||
|
token_type: Token type ('access' or 'refresh')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validation result and user information
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Token validation failed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Decode token
|
||||||
|
verification_key = self.key_manager.get_public_key()
|
||||||
|
|
||||||
|
# Get security configuration
|
||||||
|
if hasattr(self.config, 'security'):
|
||||||
|
security_config = self.config.security
|
||||||
|
else:
|
||||||
|
security_config = self.config
|
||||||
|
|
||||||
|
# JWT decoding options
|
||||||
|
options = {
|
||||||
|
'verify_signature': security_config.jwt_verify_signature,
|
||||||
|
'verify_exp': security_config.jwt_require_exp,
|
||||||
|
'verify_iat': security_config.jwt_require_iat,
|
||||||
|
'verify_nbf': security_config.jwt_require_nbf,
|
||||||
|
'verify_aud': security_config.jwt_verify_audience,
|
||||||
|
'verify_iss': security_config.jwt_verify_issuer,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Decode JWT
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
verification_key,
|
||||||
|
algorithms=[self.algorithm],
|
||||||
|
audience=self.audience if security_config.jwt_verify_audience else None,
|
||||||
|
issuer=self.issuer if security_config.jwt_verify_issuer else None,
|
||||||
|
leeway=security_config.jwt_leeway,
|
||||||
|
options=options
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check token type
|
||||||
|
if payload.get('token_type') != token_type:
|
||||||
|
raise ValueError(f"Invalid token type: expected {token_type}")
|
||||||
|
|
||||||
|
# Use validator for additional checks
|
||||||
|
validation_result = await self.validator.validate_claims(payload)
|
||||||
|
|
||||||
|
logger.info(f"Token validation successful for user: {payload.get('sub')}")
|
||||||
|
return validation_result
|
||||||
|
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
raise ValueError("Token has expired")
|
||||||
|
except jwt.InvalidTokenError as e:
|
||||||
|
raise ValueError(f"Invalid token: {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Token validation failed: {e}")
|
||||||
|
raise ValueError(f"Token validation failed: {str(e)}")
|
||||||
|
|
||||||
|
async def refresh_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||||
|
"""Refresh access token
|
||||||
|
|
||||||
|
Args:
|
||||||
|
refresh_token: Refresh token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New token pair
|
||||||
|
"""
|
||||||
|
if not self.enable_refresh:
|
||||||
|
raise ValueError("Token refresh is disabled")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Validate refresh token
|
||||||
|
refresh_result = await self.validate_token(refresh_token, 'refresh')
|
||||||
|
refresh_payload = refresh_result['payload']
|
||||||
|
|
||||||
|
# Revoke associated access token (if revocation is enabled)
|
||||||
|
if self.enable_revocation:
|
||||||
|
access_jti = refresh_payload.get('access_jti')
|
||||||
|
if access_jti:
|
||||||
|
# Should revoke old access token here, but since we don't know its expiration time,
|
||||||
|
# in practice might need to store more information or use different strategy
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Build new user information
|
||||||
|
user_info = {
|
||||||
|
'user_id': refresh_payload.get('sub'),
|
||||||
|
'roles': refresh_payload.get('roles', []),
|
||||||
|
'permissions': refresh_payload.get('permissions', []),
|
||||||
|
'security_level': refresh_payload.get('security_level', 'internal')
|
||||||
|
}
|
||||||
|
|
||||||
|
# Generate new token pair
|
||||||
|
new_tokens = await self.generate_tokens(user_info)
|
||||||
|
|
||||||
|
logger.info(f"Token refreshed for user: {user_info['user_id']}")
|
||||||
|
return new_tokens
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Token refresh failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def revoke_token(self, token: str) -> bool:
|
||||||
|
"""Revoke token
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: Token to revoke
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether revocation was successful
|
||||||
|
"""
|
||||||
|
if not self.enable_revocation:
|
||||||
|
logger.warning("Token revocation is disabled")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Decode token to get JTI and expiration time
|
||||||
|
verification_key = self.key_manager.get_public_key()
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
verification_key,
|
||||||
|
algorithms=[self.algorithm],
|
||||||
|
options={'verify_exp': False} # Allow decoding expired tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
jti = payload.get('jti')
|
||||||
|
exp = payload.get('exp')
|
||||||
|
|
||||||
|
if not jti or not exp:
|
||||||
|
logger.error("Token missing required claims for revocation")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Add to blacklist
|
||||||
|
await self.validator.revoke_token(jti, exp)
|
||||||
|
|
||||||
|
logger.info(f"Token {jti} revoked successfully")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Token revocation failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def decode_token_unsafe(self, token: str) -> Dict[str, Any]:
|
||||||
|
"""Decode token without verifying signature (for debugging only)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: JWT token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token payload
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, options={'verify_signature': False})
|
||||||
|
return payload
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to decode token: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_token_info(self, token: str) -> Dict[str, Any]:
|
||||||
|
"""Get token information (without verifying signature)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: JWT token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token information
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
payload = await self.decode_token_unsafe(token)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'jti': payload.get('jti'),
|
||||||
|
'sub': payload.get('sub'),
|
||||||
|
'iss': payload.get('iss'),
|
||||||
|
'aud': payload.get('aud'),
|
||||||
|
'iat': payload.get('iat'),
|
||||||
|
'exp': payload.get('exp'),
|
||||||
|
'token_type': payload.get('token_type'),
|
||||||
|
'roles': payload.get('roles'),
|
||||||
|
'permissions': payload.get('permissions'),
|
||||||
|
'security_level': payload.get('security_level'),
|
||||||
|
'is_expired': payload.get('exp', 0) < time.time() if payload.get('exp') else None
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get token info: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _auto_key_rotation(self):
|
||||||
|
"""Automatic key rotation task"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Check if key rotation is needed
|
||||||
|
if await self.key_manager.is_key_expired():
|
||||||
|
logger.info("Key rotation needed, rotating keys...")
|
||||||
|
await self.key_manager.rotate_keys()
|
||||||
|
|
||||||
|
# Wait until next check
|
||||||
|
await asyncio.sleep(3600) # Check every hour
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in auto key rotation: {e}")
|
||||||
|
# Wait longer before retry after error
|
||||||
|
await asyncio.sleep(3600)
|
||||||
|
|
||||||
|
async def get_public_key_info(self) -> Dict[str, Any]:
|
||||||
|
"""Get public key information (for client verification)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Public key information
|
||||||
|
"""
|
||||||
|
key_info = await self.key_manager.get_key_info()
|
||||||
|
public_key_pem = await self.key_manager.export_public_key_pem()
|
||||||
|
|
||||||
|
return {
|
||||||
|
'algorithm': self.algorithm,
|
||||||
|
'public_key_pem': public_key_pem,
|
||||||
|
'key_info': key_info
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_manager_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get manager statistics
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Statistics information
|
||||||
|
"""
|
||||||
|
key_info = await self.key_manager.get_key_info()
|
||||||
|
validation_stats = await self.validator.get_validation_stats()
|
||||||
|
|
||||||
|
return {
|
||||||
|
'jwt_config': {
|
||||||
|
'algorithm': self.algorithm,
|
||||||
|
'issuer': self.issuer,
|
||||||
|
'audience': self.audience,
|
||||||
|
'access_token_expiry': self.access_token_expiry,
|
||||||
|
'refresh_token_expiry': self.refresh_token_expiry,
|
||||||
|
'enable_refresh': self.enable_refresh,
|
||||||
|
'enable_revocation': self.enable_revocation
|
||||||
|
},
|
||||||
|
'key_manager': key_info,
|
||||||
|
'validator': validation_stats
|
||||||
|
}
|
||||||
343
doris_mcp_server/auth/key_manager.py
Normal file
@@ -0,0 +1,343 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
JWT Key Management Module
|
||||||
|
Provides secure key generation, loading, rotation and management for JWT tokens
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import secrets
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa, ec
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
|
||||||
|
from ..utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class KeyManager:
|
||||||
|
"""JWT Key Manager
|
||||||
|
|
||||||
|
Responsible for generating, loading, rotating and securely storing JWT signing keys
|
||||||
|
Supports RSA and EC algorithms, provides automatic key rotation functionality
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
"""Initialize key manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: DorisConfig configuration object (with security attribute)
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
# Access JWT settings through the security configuration
|
||||||
|
if hasattr(config, 'security'):
|
||||||
|
security_config = config.security
|
||||||
|
else:
|
||||||
|
# Fallback if config is passed directly as SecurityConfig
|
||||||
|
security_config = config
|
||||||
|
|
||||||
|
self.algorithm = security_config.jwt_algorithm
|
||||||
|
self.key_rotation_interval = security_config.key_rotation_interval
|
||||||
|
self.private_key_path = security_config.jwt_private_key_path
|
||||||
|
self.public_key_path = security_config.jwt_public_key_path
|
||||||
|
self.secret_key = security_config.jwt_secret_key
|
||||||
|
|
||||||
|
# Key storage
|
||||||
|
self._private_key = None
|
||||||
|
self._public_key = None
|
||||||
|
self._secret_key = None
|
||||||
|
self._key_generated_at = None
|
||||||
|
|
||||||
|
logger.info(f"KeyManager initialized with algorithm: {self.algorithm}")
|
||||||
|
|
||||||
|
async def initialize(self) -> bool:
|
||||||
|
"""Initialize key manager, load or generate keys"""
|
||||||
|
try:
|
||||||
|
if self.algorithm == "HS256":
|
||||||
|
await self._initialize_symmetric_key()
|
||||||
|
else:
|
||||||
|
await self._initialize_asymmetric_keys()
|
||||||
|
|
||||||
|
logger.info("KeyManager initialization completed")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize KeyManager: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _initialize_symmetric_key(self):
|
||||||
|
"""Initialize symmetric key (HS256)"""
|
||||||
|
if self.secret_key:
|
||||||
|
# Use configured key
|
||||||
|
self._secret_key = self.secret_key.encode()
|
||||||
|
logger.info("Loaded symmetric key from configuration")
|
||||||
|
else:
|
||||||
|
# Generate new key
|
||||||
|
self._secret_key = await self.generate_symmetric_key()
|
||||||
|
logger.info("Generated new symmetric key")
|
||||||
|
|
||||||
|
self._key_generated_at = datetime.utcnow()
|
||||||
|
|
||||||
|
async def _initialize_asymmetric_keys(self):
|
||||||
|
"""Initialize asymmetric key pair (RS256/ES256)"""
|
||||||
|
# Try to load keys from files
|
||||||
|
if await self._load_keys_from_files():
|
||||||
|
logger.info("Loaded asymmetric keys from files")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Try to load from environment variables
|
||||||
|
if await self._load_keys_from_env():
|
||||||
|
logger.info("Loaded asymmetric keys from environment")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Generate new key pair
|
||||||
|
await self.generate_key_pair()
|
||||||
|
logger.info("Generated new asymmetric key pair")
|
||||||
|
|
||||||
|
async def _load_keys_from_files(self) -> bool:
|
||||||
|
"""Load keys from files"""
|
||||||
|
try:
|
||||||
|
if not self.private_key_path or not self.public_key_path:
|
||||||
|
return False
|
||||||
|
|
||||||
|
private_path = Path(self.private_key_path)
|
||||||
|
public_path = Path(self.public_key_path)
|
||||||
|
|
||||||
|
if not (private_path.exists() and public_path.exists()):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Read private key
|
||||||
|
with open(private_path, 'rb') as f:
|
||||||
|
private_key_data = f.read()
|
||||||
|
self._private_key = serialization.load_pem_private_key(
|
||||||
|
private_key_data, password=None, backend=default_backend()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read public key
|
||||||
|
with open(public_path, 'rb') as f:
|
||||||
|
public_key_data = f.read()
|
||||||
|
self._public_key = serialization.load_pem_public_key(
|
||||||
|
public_key_data, backend=default_backend()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get key generation time (using file modification time)
|
||||||
|
self._key_generated_at = datetime.fromtimestamp(private_path.stat().st_mtime)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load keys from files: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _load_keys_from_env(self) -> bool:
|
||||||
|
"""Load keys from environment variables"""
|
||||||
|
try:
|
||||||
|
private_key_env = os.getenv('JWT_PRIVATE_KEY')
|
||||||
|
public_key_env = os.getenv('JWT_PUBLIC_KEY')
|
||||||
|
|
||||||
|
if not (private_key_env and public_key_env):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Parse private key
|
||||||
|
self._private_key = serialization.load_pem_private_key(
|
||||||
|
private_key_env.encode(), password=None, backend=default_backend()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse public key
|
||||||
|
self._public_key = serialization.load_pem_public_key(
|
||||||
|
public_key_env.encode(), backend=default_backend()
|
||||||
|
)
|
||||||
|
|
||||||
|
self._key_generated_at = datetime.utcnow()
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load keys from environment: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def generate_symmetric_key(self, length: int = 32) -> bytes:
|
||||||
|
"""Generate symmetric key
|
||||||
|
|
||||||
|
Args:
|
||||||
|
length: Key length (bytes), default 32 bytes (256 bits)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated key
|
||||||
|
"""
|
||||||
|
return secrets.token_bytes(length)
|
||||||
|
|
||||||
|
async def generate_key_pair(self) -> Tuple[bytes, bytes]:
|
||||||
|
"""Generate asymmetric key pair
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(private key PEM, public key PEM) tuple
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self.algorithm == "RS256":
|
||||||
|
private_key = rsa.generate_private_key(
|
||||||
|
public_exponent=65537,
|
||||||
|
key_size=2048,
|
||||||
|
backend=default_backend()
|
||||||
|
)
|
||||||
|
elif self.algorithm == "ES256":
|
||||||
|
private_key = ec.generate_private_key(
|
||||||
|
ec.SECP256R1(), backend=default_backend()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported algorithm for key generation: {self.algorithm}")
|
||||||
|
|
||||||
|
# Get public key
|
||||||
|
public_key = private_key.public_key()
|
||||||
|
|
||||||
|
# Serialize private key
|
||||||
|
private_pem = private_key.private_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PrivateFormat.PKCS8,
|
||||||
|
encryption_algorithm=serialization.NoEncryption()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Serialize public key
|
||||||
|
public_pem = public_key.public_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store keys
|
||||||
|
self._private_key = private_key
|
||||||
|
self._public_key = public_key
|
||||||
|
self._key_generated_at = datetime.utcnow()
|
||||||
|
|
||||||
|
# If file paths are configured, save to files
|
||||||
|
if self.private_key_path and self.public_key_path:
|
||||||
|
await self._save_keys_to_files(private_pem, public_pem)
|
||||||
|
|
||||||
|
logger.info(f"Generated new {self.algorithm} key pair")
|
||||||
|
return private_pem, public_pem
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to generate key pair: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _save_keys_to_files(self, private_pem: bytes, public_pem: bytes):
|
||||||
|
"""Save keys to files"""
|
||||||
|
try:
|
||||||
|
# Ensure directories exist
|
||||||
|
private_path = Path(self.private_key_path)
|
||||||
|
public_path = Path(self.public_key_path)
|
||||||
|
|
||||||
|
private_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
public_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save private key (set secure permissions)
|
||||||
|
with open(private_path, 'wb') as f:
|
||||||
|
f.write(private_pem)
|
||||||
|
os.chmod(private_path, 0o600) # Only owner can read/write
|
||||||
|
|
||||||
|
# Save public key
|
||||||
|
with open(public_path, 'wb') as f:
|
||||||
|
f.write(public_pem)
|
||||||
|
os.chmod(public_path, 0o644) # Owner read/write, others read only
|
||||||
|
|
||||||
|
logger.info(f"Saved keys to files: {private_path}, {public_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save keys to files: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_private_key(self):
|
||||||
|
"""Get private key for signing"""
|
||||||
|
if self.algorithm == "HS256":
|
||||||
|
return self._secret_key
|
||||||
|
else:
|
||||||
|
return self._private_key
|
||||||
|
|
||||||
|
def get_public_key(self):
|
||||||
|
"""Get public key for verification"""
|
||||||
|
if self.algorithm == "HS256":
|
||||||
|
return self._secret_key
|
||||||
|
else:
|
||||||
|
return self._public_key
|
||||||
|
|
||||||
|
def get_algorithm(self) -> str:
|
||||||
|
"""Get signing algorithm"""
|
||||||
|
return self.algorithm
|
||||||
|
|
||||||
|
async def is_key_expired(self) -> bool:
|
||||||
|
"""Check if key is expired"""
|
||||||
|
if not self._key_generated_at:
|
||||||
|
return True
|
||||||
|
|
||||||
|
expiry_time = self._key_generated_at + timedelta(seconds=self.key_rotation_interval)
|
||||||
|
return datetime.utcnow() > expiry_time
|
||||||
|
|
||||||
|
async def rotate_keys(self) -> bool:
|
||||||
|
"""Rotate keys"""
|
||||||
|
try:
|
||||||
|
logger.info("Starting key rotation")
|
||||||
|
|
||||||
|
if self.algorithm == "HS256":
|
||||||
|
# Generate new symmetric key
|
||||||
|
self._secret_key = await self.generate_symmetric_key()
|
||||||
|
self._key_generated_at = datetime.utcnow()
|
||||||
|
else:
|
||||||
|
# Generate new asymmetric key pair
|
||||||
|
await self.generate_key_pair()
|
||||||
|
|
||||||
|
logger.info("Key rotation completed successfully")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Key rotation failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_key_info(self) -> dict:
|
||||||
|
"""Get key information"""
|
||||||
|
return {
|
||||||
|
"algorithm": self.algorithm,
|
||||||
|
"key_generated_at": self._key_generated_at.isoformat() if self._key_generated_at else None,
|
||||||
|
"key_expires_at": (
|
||||||
|
self._key_generated_at + timedelta(seconds=self.key_rotation_interval)
|
||||||
|
).isoformat() if self._key_generated_at else None,
|
||||||
|
"is_expired": await self.is_key_expired(),
|
||||||
|
"has_private_key": self._private_key is not None or self._secret_key is not None,
|
||||||
|
"has_public_key": self._public_key is not None or self._secret_key is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
async def export_public_key_pem(self) -> Optional[str]:
|
||||||
|
"""Export public key in PEM format"""
|
||||||
|
if self.algorithm == "HS256":
|
||||||
|
return None # Symmetric key not exported
|
||||||
|
|
||||||
|
if not self._public_key:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
public_pem = self._public_key.public_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
||||||
|
)
|
||||||
|
return public_pem.decode()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to export public key: {e}")
|
||||||
|
return None
|
||||||
536
doris_mcp_server/auth/oauth_client.py
Normal file
@@ -0,0 +1,536 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
OAuth 2.0/OIDC Client Manager
|
||||||
|
Provides OAuth authentication client implementation with PKCE and OIDC support
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import secrets
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, Optional, Any, Tuple
|
||||||
|
from urllib.parse import urlencode, parse_qs, urlparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
try:
|
||||||
|
import aiohttp
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("aiohttp is required for OAuth functionality. Install with: pip install aiohttp")
|
||||||
|
|
||||||
|
from .oauth_types import (
|
||||||
|
OAuthProvider, OAuthState, OAuthTokens, OAuthUserInfo,
|
||||||
|
OIDCDiscovery, OAuthError, OAuthProviderConfig, OAUTH_PROVIDERS
|
||||||
|
)
|
||||||
|
from ..utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthStateManager:
|
||||||
|
"""Manages OAuth state parameters for CSRF protection"""
|
||||||
|
|
||||||
|
def __init__(self, state_expiry: int = 600):
|
||||||
|
"""Initialize state manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_expiry: State expiry time in seconds
|
||||||
|
"""
|
||||||
|
self.state_expiry = state_expiry
|
||||||
|
self._states: Dict[str, OAuthState] = {}
|
||||||
|
self._cleanup_task = None
|
||||||
|
|
||||||
|
logger.info("OAuthStateManager initialized")
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Start periodic cleanup task"""
|
||||||
|
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||||
|
logger.info("OAuth state manager started")
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""Stop periodic cleanup task"""
|
||||||
|
if self._cleanup_task:
|
||||||
|
self._cleanup_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._cleanup_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
logger.info("OAuth state manager stopped")
|
||||||
|
|
||||||
|
def create_state(self, redirect_uri: str, pkce_enabled: bool = True,
|
||||||
|
nonce_enabled: bool = True) -> OAuthState:
|
||||||
|
"""Create new OAuth state
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redirect_uri: OAuth redirect URI
|
||||||
|
pkce_enabled: Whether to enable PKCE
|
||||||
|
nonce_enabled: Whether to enable nonce (for OIDC)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OAuth state object
|
||||||
|
"""
|
||||||
|
state = secrets.token_urlsafe(32)
|
||||||
|
nonce = secrets.token_urlsafe(32) if nonce_enabled else None
|
||||||
|
|
||||||
|
pkce_verifier = None
|
||||||
|
pkce_challenge = None
|
||||||
|
if pkce_enabled:
|
||||||
|
pkce_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=')
|
||||||
|
challenge_bytes = hashlib.sha256(pkce_verifier.encode()).digest()
|
||||||
|
pkce_challenge = base64.urlsafe_b64encode(challenge_bytes).decode('utf-8').rstrip('=')
|
||||||
|
|
||||||
|
oauth_state = OAuthState(
|
||||||
|
state=state,
|
||||||
|
nonce=nonce,
|
||||||
|
pkce_verifier=pkce_verifier,
|
||||||
|
pkce_challenge=pkce_challenge,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
expires_at=datetime.utcnow() + timedelta(seconds=self.state_expiry)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._states[state] = oauth_state
|
||||||
|
logger.debug(f"Created OAuth state: {state}")
|
||||||
|
return oauth_state
|
||||||
|
|
||||||
|
def get_state(self, state: str) -> Optional[OAuthState]:
|
||||||
|
"""Get OAuth state by state parameter
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: State parameter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OAuth state object or None if not found/expired
|
||||||
|
"""
|
||||||
|
oauth_state = self._states.get(state)
|
||||||
|
if oauth_state and oauth_state.expires_at > datetime.utcnow():
|
||||||
|
return oauth_state
|
||||||
|
elif oauth_state:
|
||||||
|
# Remove expired state
|
||||||
|
del self._states[state]
|
||||||
|
logger.debug(f"Removed expired OAuth state: {state}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def consume_state(self, state: str) -> Optional[OAuthState]:
|
||||||
|
"""Get and remove OAuth state
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: State parameter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OAuth state object or None if not found/expired
|
||||||
|
"""
|
||||||
|
oauth_state = self.get_state(state)
|
||||||
|
if oauth_state:
|
||||||
|
del self._states[state]
|
||||||
|
logger.debug(f"Consumed OAuth state: {state}")
|
||||||
|
return oauth_state
|
||||||
|
|
||||||
|
async def _periodic_cleanup(self):
|
||||||
|
"""Periodic cleanup of expired states"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(300) # Clean up every 5 minutes
|
||||||
|
current_time = datetime.utcnow()
|
||||||
|
expired_states = [
|
||||||
|
state for state, oauth_state in self._states.items()
|
||||||
|
if oauth_state.expires_at <= current_time
|
||||||
|
]
|
||||||
|
|
||||||
|
for state in expired_states:
|
||||||
|
del self._states[state]
|
||||||
|
|
||||||
|
if expired_states:
|
||||||
|
logger.info(f"Cleaned up {len(expired_states)} expired OAuth states")
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during OAuth state cleanup: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthClient:
|
||||||
|
"""OAuth 2.0/OIDC Client implementation"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
"""Initialize OAuth client
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: DorisConfig with OAuth configuration
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Access OAuth settings through security configuration
|
||||||
|
if hasattr(config, 'security'):
|
||||||
|
security_config = config.security
|
||||||
|
else:
|
||||||
|
security_config = config
|
||||||
|
|
||||||
|
self.enabled = security_config.oauth_enabled
|
||||||
|
if not self.enabled:
|
||||||
|
logger.info("OAuth client disabled by configuration")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build provider configuration
|
||||||
|
self.provider_config = self._build_provider_config(security_config)
|
||||||
|
self.state_manager = OAuthStateManager(security_config.oauth_state_expiry)
|
||||||
|
|
||||||
|
# HTTP client session
|
||||||
|
self._session: Optional[aiohttp.ClientSession] = None
|
||||||
|
|
||||||
|
# Discovery cache
|
||||||
|
self._discovery_cache: Optional[OIDCDiscovery] = None
|
||||||
|
self._discovery_cache_time: Optional[datetime] = None
|
||||||
|
|
||||||
|
logger.info(f"OAuthClient initialized for provider: {self.provider_config.provider.value}")
|
||||||
|
|
||||||
|
def _build_provider_config(self, security_config) -> OAuthProviderConfig:
|
||||||
|
"""Build OAuth provider configuration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
security_config: Security configuration object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OAuth provider configuration
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
provider = OAuthProvider(security_config.oauth_provider)
|
||||||
|
except ValueError:
|
||||||
|
provider = OAuthProvider.CUSTOM
|
||||||
|
|
||||||
|
# Get default configuration for known providers
|
||||||
|
defaults = OAUTH_PROVIDERS.get(provider, {})
|
||||||
|
|
||||||
|
return OAuthProviderConfig(
|
||||||
|
provider=provider,
|
||||||
|
client_id=security_config.oauth_client_id,
|
||||||
|
client_secret=security_config.oauth_client_secret,
|
||||||
|
redirect_uri=security_config.oauth_redirect_uri,
|
||||||
|
scopes=security_config.oauth_scopes or defaults.get("scopes", ["openid", "email", "profile"]),
|
||||||
|
|
||||||
|
# Endpoints (use configured or defaults)
|
||||||
|
authorization_endpoint=security_config.oauth_authorization_endpoint or defaults.get("authorization_endpoint", ""),
|
||||||
|
token_endpoint=security_config.oauth_token_endpoint or defaults.get("token_endpoint", ""),
|
||||||
|
userinfo_endpoint=security_config.oauth_userinfo_endpoint or defaults.get("userinfo_endpoint"),
|
||||||
|
jwks_uri=security_config.oauth_jwks_uri or defaults.get("jwks_uri"),
|
||||||
|
|
||||||
|
# Discovery
|
||||||
|
discovery_url=security_config.oidc_discovery_url or defaults.get("discovery_url"),
|
||||||
|
|
||||||
|
# Settings
|
||||||
|
pkce_enabled=security_config.oauth_pkce_enabled,
|
||||||
|
nonce_enabled=security_config.oauth_nonce_enabled,
|
||||||
|
|
||||||
|
# User mapping
|
||||||
|
user_id_claim=security_config.oauth_user_id_claim or defaults.get("user_id_claim", "sub"),
|
||||||
|
email_claim=security_config.oauth_email_claim or defaults.get("email_claim", "email"),
|
||||||
|
name_claim=security_config.oauth_name_claim or defaults.get("name_claim", "name"),
|
||||||
|
roles_claim=security_config.oauth_roles_claim,
|
||||||
|
default_roles=security_config.oauth_default_roles
|
||||||
|
)
|
||||||
|
|
||||||
|
async def initialize(self) -> bool:
|
||||||
|
"""Initialize OAuth client
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if initialization successful
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create HTTP session
|
||||||
|
self._session = aiohttp.ClientSession()
|
||||||
|
|
||||||
|
# Start state manager
|
||||||
|
await self.state_manager.start()
|
||||||
|
|
||||||
|
# Perform OIDC discovery if configured
|
||||||
|
if self.provider_config.discovery_url:
|
||||||
|
await self._discover_oidc_endpoints()
|
||||||
|
|
||||||
|
logger.info("OAuth client initialization completed")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize OAuth client: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""Shutdown OAuth client"""
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Stop state manager
|
||||||
|
await self.state_manager.stop()
|
||||||
|
|
||||||
|
# Close HTTP session
|
||||||
|
if self._session:
|
||||||
|
await self._session.close()
|
||||||
|
|
||||||
|
logger.info("OAuth client shutdown completed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during OAuth client shutdown: {e}")
|
||||||
|
|
||||||
|
async def _discover_oidc_endpoints(self):
|
||||||
|
"""Discover OIDC endpoints using discovery URL"""
|
||||||
|
try:
|
||||||
|
# Check cache first
|
||||||
|
if (self._discovery_cache and self._discovery_cache_time and
|
||||||
|
datetime.utcnow() - self._discovery_cache_time < timedelta(hours=1)):
|
||||||
|
return self._discovery_cache
|
||||||
|
|
||||||
|
logger.info(f"Discovering OIDC endpoints: {self.provider_config.discovery_url}")
|
||||||
|
|
||||||
|
async with self._session.get(self.provider_config.discovery_url) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
data = await response.json()
|
||||||
|
|
||||||
|
discovery = OIDCDiscovery(
|
||||||
|
issuer=data["issuer"],
|
||||||
|
authorization_endpoint=data["authorization_endpoint"],
|
||||||
|
token_endpoint=data["token_endpoint"],
|
||||||
|
userinfo_endpoint=data.get("userinfo_endpoint"),
|
||||||
|
jwks_uri=data.get("jwks_uri"),
|
||||||
|
scopes_supported=data.get("scopes_supported"),
|
||||||
|
response_types_supported=data.get("response_types_supported"),
|
||||||
|
subject_types_supported=data.get("subject_types_supported"),
|
||||||
|
id_token_signing_alg_values_supported=data.get("id_token_signing_alg_values_supported")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update provider configuration with discovered endpoints
|
||||||
|
if not self.provider_config.authorization_endpoint:
|
||||||
|
self.provider_config.authorization_endpoint = discovery.authorization_endpoint
|
||||||
|
if not self.provider_config.token_endpoint:
|
||||||
|
self.provider_config.token_endpoint = discovery.token_endpoint
|
||||||
|
if not self.provider_config.userinfo_endpoint:
|
||||||
|
self.provider_config.userinfo_endpoint = discovery.userinfo_endpoint
|
||||||
|
if not self.provider_config.jwks_uri:
|
||||||
|
self.provider_config.jwks_uri = discovery.jwks_uri
|
||||||
|
|
||||||
|
# Cache discovery result
|
||||||
|
self._discovery_cache = discovery
|
||||||
|
self._discovery_cache_time = datetime.utcnow()
|
||||||
|
|
||||||
|
logger.info("OIDC endpoint discovery completed successfully")
|
||||||
|
return discovery
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OIDC endpoint discovery failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def build_authorization_url(self) -> Tuple[str, OAuthState]:
|
||||||
|
"""Build OAuth authorization URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (authorization_url, oauth_state)
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
raise ValueError("OAuth client is not enabled")
|
||||||
|
|
||||||
|
# Create state for CSRF protection
|
||||||
|
oauth_state = self.state_manager.create_state(
|
||||||
|
redirect_uri=self.provider_config.redirect_uri,
|
||||||
|
pkce_enabled=self.provider_config.pkce_enabled,
|
||||||
|
nonce_enabled=self.provider_config.nonce_enabled
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build authorization parameters
|
||||||
|
params = {
|
||||||
|
'response_type': 'code',
|
||||||
|
'client_id': self.provider_config.client_id,
|
||||||
|
'redirect_uri': self.provider_config.redirect_uri,
|
||||||
|
'scope': ' '.join(self.provider_config.scopes),
|
||||||
|
'state': oauth_state.state
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add PKCE challenge
|
||||||
|
if oauth_state.pkce_challenge:
|
||||||
|
params['code_challenge'] = oauth_state.pkce_challenge
|
||||||
|
params['code_challenge_method'] = 'S256'
|
||||||
|
|
||||||
|
# Add nonce for OIDC
|
||||||
|
if oauth_state.nonce:
|
||||||
|
params['nonce'] = oauth_state.nonce
|
||||||
|
|
||||||
|
# Build URL
|
||||||
|
authorization_url = f"{self.provider_config.authorization_endpoint}?{urlencode(params)}"
|
||||||
|
|
||||||
|
logger.info(f"Built OAuth authorization URL for state: {oauth_state.state}")
|
||||||
|
return authorization_url, oauth_state
|
||||||
|
|
||||||
|
async def exchange_code_for_tokens(self, code: str, state: str) -> Tuple[OAuthTokens, OAuthState]:
|
||||||
|
"""Exchange authorization code for tokens
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Authorization code
|
||||||
|
state: State parameter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (OAuth tokens, OAuth state)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If state is invalid or exchange fails
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
raise ValueError("OAuth client is not enabled")
|
||||||
|
|
||||||
|
# Validate and consume state
|
||||||
|
oauth_state = self.state_manager.consume_state(state)
|
||||||
|
if not oauth_state:
|
||||||
|
raise ValueError("Invalid or expired state parameter")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare token request
|
||||||
|
data = {
|
||||||
|
'grant_type': 'authorization_code',
|
||||||
|
'client_id': self.provider_config.client_id,
|
||||||
|
'client_secret': self.provider_config.client_secret,
|
||||||
|
'code': code,
|
||||||
|
'redirect_uri': oauth_state.redirect_uri
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add PKCE verifier
|
||||||
|
if oauth_state.pkce_verifier:
|
||||||
|
data['code_verifier'] = oauth_state.pkce_verifier
|
||||||
|
|
||||||
|
# Make token request
|
||||||
|
async with self._session.post(
|
||||||
|
self.provider_config.token_endpoint,
|
||||||
|
data=data,
|
||||||
|
headers={'Content-Type': 'application/x-www-form-urlencoded'}
|
||||||
|
) as response:
|
||||||
|
response_data = await response.json()
|
||||||
|
|
||||||
|
if response.status != 200:
|
||||||
|
error_msg = response_data.get('error_description', response_data.get('error', 'Token exchange failed'))
|
||||||
|
raise ValueError(f"Token exchange failed: {error_msg}")
|
||||||
|
|
||||||
|
tokens = OAuthTokens(
|
||||||
|
access_token=response_data['access_token'],
|
||||||
|
token_type=response_data.get('token_type', 'Bearer'),
|
||||||
|
expires_in=response_data.get('expires_in'),
|
||||||
|
refresh_token=response_data.get('refresh_token'),
|
||||||
|
scope=response_data.get('scope'),
|
||||||
|
id_token=response_data.get('id_token')
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Successfully exchanged authorization code for tokens")
|
||||||
|
return tokens, oauth_state
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Token exchange failed: {e}")
|
||||||
|
raise ValueError(f"Token exchange failed: {str(e)}")
|
||||||
|
|
||||||
|
async def get_user_info(self, tokens: OAuthTokens) -> OAuthUserInfo:
|
||||||
|
"""Get user information from OAuth provider
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: OAuth tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OAuth user information
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
raise ValueError("OAuth client is not enabled")
|
||||||
|
|
||||||
|
if not self.provider_config.userinfo_endpoint:
|
||||||
|
raise ValueError("Userinfo endpoint not configured")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Make userinfo request
|
||||||
|
headers = {'Authorization': f'{tokens.token_type} {tokens.access_token}'}
|
||||||
|
|
||||||
|
async with self._session.get(
|
||||||
|
self.provider_config.userinfo_endpoint,
|
||||||
|
headers=headers
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
user_data = await response.json()
|
||||||
|
|
||||||
|
# Extract user information using configured claims
|
||||||
|
user_info = OAuthUserInfo(
|
||||||
|
sub=str(user_data.get(self.provider_config.user_id_claim, '')),
|
||||||
|
email=user_data.get(self.provider_config.email_claim),
|
||||||
|
name=user_data.get(self.provider_config.name_claim),
|
||||||
|
given_name=user_data.get('given_name'),
|
||||||
|
family_name=user_data.get('family_name'),
|
||||||
|
picture=user_data.get('picture'),
|
||||||
|
locale=user_data.get('locale'),
|
||||||
|
email_verified=user_data.get('email_verified'),
|
||||||
|
roles=user_data.get(self.provider_config.roles_claim, self.provider_config.default_roles.copy()),
|
||||||
|
raw_claims=user_data
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Retrieved user info for user: {user_info.sub}")
|
||||||
|
return user_info
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get user info: {e}")
|
||||||
|
raise ValueError(f"Failed to get user info: {str(e)}")
|
||||||
|
|
||||||
|
async def refresh_tokens(self, refresh_token: str) -> OAuthTokens:
|
||||||
|
"""Refresh OAuth tokens
|
||||||
|
|
||||||
|
Args:
|
||||||
|
refresh_token: Refresh token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New OAuth tokens
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
raise ValueError("OAuth client is not enabled")
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = {
|
||||||
|
'grant_type': 'refresh_token',
|
||||||
|
'client_id': self.provider_config.client_id,
|
||||||
|
'client_secret': self.provider_config.client_secret,
|
||||||
|
'refresh_token': refresh_token
|
||||||
|
}
|
||||||
|
|
||||||
|
async with self._session.post(
|
||||||
|
self.provider_config.token_endpoint,
|
||||||
|
data=data,
|
||||||
|
headers={'Content-Type': 'application/x-www-form-urlencoded'}
|
||||||
|
) as response:
|
||||||
|
response_data = await response.json()
|
||||||
|
|
||||||
|
if response.status != 200:
|
||||||
|
error_msg = response_data.get('error_description', response_data.get('error', 'Token refresh failed'))
|
||||||
|
raise ValueError(f"Token refresh failed: {error_msg}")
|
||||||
|
|
||||||
|
tokens = OAuthTokens(
|
||||||
|
access_token=response_data['access_token'],
|
||||||
|
token_type=response_data.get('token_type', 'Bearer'),
|
||||||
|
expires_in=response_data.get('expires_in'),
|
||||||
|
refresh_token=response_data.get('refresh_token', refresh_token), # Keep old if not provided
|
||||||
|
scope=response_data.get('scope'),
|
||||||
|
id_token=response_data.get('id_token')
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Successfully refreshed OAuth tokens")
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Token refresh failed: {e}")
|
||||||
|
raise ValueError(f"Token refresh failed: {str(e)}")
|
||||||
312
doris_mcp_server/auth/oauth_handlers.py
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
OAuth HTTP Handlers
|
||||||
|
Provides HTTP endpoints for OAuth authentication flow
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
import json
|
||||||
|
|
||||||
|
from starlette.responses import JSONResponse, RedirectResponse, HTMLResponse
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
from ..utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthHandlers:
|
||||||
|
"""OAuth HTTP request handlers"""
|
||||||
|
|
||||||
|
def __init__(self, security_manager):
|
||||||
|
"""Initialize OAuth handlers
|
||||||
|
|
||||||
|
Args:
|
||||||
|
security_manager: DorisSecurityManager instance
|
||||||
|
"""
|
||||||
|
self.security_manager = security_manager
|
||||||
|
logger.info("OAuth handlers initialized")
|
||||||
|
|
||||||
|
async def handle_login(self, request: Request) -> JSONResponse:
|
||||||
|
"""Handle OAuth login initiation
|
||||||
|
|
||||||
|
Returns JSON with authorization URL and state
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check if OAuth is enabled
|
||||||
|
oauth_info = self.security_manager.get_oauth_provider_info()
|
||||||
|
if not oauth_info.get("enabled"):
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "OAuth authentication is not enabled"},
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get authorization URL
|
||||||
|
authorization_url, state = self.security_manager.get_oauth_authorization_url()
|
||||||
|
|
||||||
|
return JSONResponse({
|
||||||
|
"authorization_url": authorization_url,
|
||||||
|
"state": state,
|
||||||
|
"provider": oauth_info.get("provider"),
|
||||||
|
"message": "Navigate to authorization_url to complete OAuth login"
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OAuth login initiation failed: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": f"OAuth login failed: {str(e)}"},
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handle_callback(self, request: Request) -> JSONResponse:
|
||||||
|
"""Handle OAuth callback
|
||||||
|
|
||||||
|
Processes the OAuth callback and returns authentication result
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get query parameters
|
||||||
|
query_params = dict(request.query_params)
|
||||||
|
|
||||||
|
# Check for error in callback
|
||||||
|
if "error" in query_params:
|
||||||
|
error_description = query_params.get("error_description", "Unknown error")
|
||||||
|
logger.warning(f"OAuth callback error: {query_params['error']} - {error_description}")
|
||||||
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"error": query_params["error"],
|
||||||
|
"error_description": error_description,
|
||||||
|
"error_uri": query_params.get("error_uri")
|
||||||
|
},
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract required parameters
|
||||||
|
code = query_params.get("code")
|
||||||
|
state = query_params.get("state")
|
||||||
|
|
||||||
|
if not code or not state:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "Missing required parameters: code and state"},
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle OAuth callback
|
||||||
|
auth_context = await self.security_manager.handle_oauth_callback(code, state)
|
||||||
|
|
||||||
|
# Return successful authentication response
|
||||||
|
return JSONResponse({
|
||||||
|
"success": True,
|
||||||
|
"user_id": auth_context.user_id,
|
||||||
|
"roles": auth_context.roles,
|
||||||
|
"permissions": auth_context.permissions,
|
||||||
|
"security_level": auth_context.security_level.value,
|
||||||
|
"session_id": auth_context.session_id,
|
||||||
|
"message": "OAuth authentication successful"
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OAuth callback handling failed: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": f"OAuth callback failed: {str(e)}"},
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handle_provider_info(self, request: Request) -> JSONResponse:
|
||||||
|
"""Handle OAuth provider information request
|
||||||
|
|
||||||
|
Returns information about the configured OAuth provider
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
provider_info = self.security_manager.get_oauth_provider_info()
|
||||||
|
return JSONResponse(provider_info)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get OAuth provider info: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": f"Failed to get provider info: {str(e)}"},
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handle_demo_page(self, request: Request) -> HTMLResponse:
|
||||||
|
"""Handle OAuth demo page
|
||||||
|
|
||||||
|
Returns a simple HTML page for testing OAuth flow
|
||||||
|
"""
|
||||||
|
oauth_info = self.security_manager.get_oauth_provider_info()
|
||||||
|
if not oauth_info.get("enabled"):
|
||||||
|
return HTMLResponse("""
|
||||||
|
<html>
|
||||||
|
<head><title>OAuth Demo</title></head>
|
||||||
|
<body>
|
||||||
|
<h1>OAuth Demo</h1>
|
||||||
|
<p style="color: red;">OAuth authentication is not enabled.</p>
|
||||||
|
<p>Please configure OAuth settings in your security configuration.</p>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
""")
|
||||||
|
|
||||||
|
html_content = f"""
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Doris MCP Server - OAuth Demo</title>
|
||||||
|
<style>
|
||||||
|
body {{
|
||||||
|
font-family: Arial, sans-serif;
|
||||||
|
max-width: 800px;
|
||||||
|
margin: 0 auto;
|
||||||
|
padding: 20px;
|
||||||
|
}}
|
||||||
|
.info {{
|
||||||
|
background-color: #f0f8ff;
|
||||||
|
padding: 15px;
|
||||||
|
border-left: 4px solid #0066cc;
|
||||||
|
margin: 20px 0;
|
||||||
|
}}
|
||||||
|
.error {{
|
||||||
|
background-color: #ffe6e6;
|
||||||
|
padding: 15px;
|
||||||
|
border-left: 4px solid #cc0000;
|
||||||
|
margin: 20px 0;
|
||||||
|
}}
|
||||||
|
.success {{
|
||||||
|
background-color: #e6ffe6;
|
||||||
|
padding: 15px;
|
||||||
|
border-left: 4px solid #00cc00;
|
||||||
|
margin: 20px 0;
|
||||||
|
}}
|
||||||
|
button {{
|
||||||
|
background-color: #0066cc;
|
||||||
|
color: white;
|
||||||
|
padding: 10px 20px;
|
||||||
|
border: none;
|
||||||
|
border-radius: 4px;
|
||||||
|
cursor: pointer;
|
||||||
|
font-size: 16px;
|
||||||
|
}}
|
||||||
|
button:hover {{
|
||||||
|
background-color: #0052a3;
|
||||||
|
}}
|
||||||
|
pre {{
|
||||||
|
background-color: #f5f5f5;
|
||||||
|
padding: 10px;
|
||||||
|
border-radius: 4px;
|
||||||
|
overflow-x: auto;
|
||||||
|
}}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>Doris MCP Server - OAuth Demo</h1>
|
||||||
|
|
||||||
|
<div class="info">
|
||||||
|
<h3>OAuth Configuration</h3>
|
||||||
|
<p><strong>Provider:</strong> {oauth_info.get('provider', 'N/A')}</p>
|
||||||
|
<p><strong>Client ID:</strong> {oauth_info.get('client_id', 'N/A')}</p>
|
||||||
|
<p><strong>Scopes:</strong> {', '.join(oauth_info.get('scopes', []))}</p>
|
||||||
|
<p><strong>PKCE Enabled:</strong> {oauth_info.get('pkce_enabled', False)}</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<h3>OAuth Authentication Test</h3>
|
||||||
|
<p>Click the button below to start OAuth authentication flow:</p>
|
||||||
|
<button onclick="startOAuthFlow()">Start OAuth Login</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div id="result" style="margin-top: 20px;"></div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<h3>API Endpoints</h3>
|
||||||
|
<ul>
|
||||||
|
<li><code>GET /auth/login</code> - Initiate OAuth login</li>
|
||||||
|
<li><code>GET /auth/callback</code> - OAuth callback handler</li>
|
||||||
|
<li><code>GET /auth/provider</code> - Provider information</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
async function startOAuthFlow() {{
|
||||||
|
const resultDiv = document.getElementById('result');
|
||||||
|
resultDiv.innerHTML = '<div class="info">Initiating OAuth flow...</div>';
|
||||||
|
|
||||||
|
try {{
|
||||||
|
const response = await fetch('/auth/login');
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (response.ok) {{
|
||||||
|
resultDiv.innerHTML = `
|
||||||
|
<div class="success">
|
||||||
|
<h4>OAuth URL Generated Successfully</h4>
|
||||||
|
<p><strong>State:</strong> ${{data.state}}</p>
|
||||||
|
<p><strong>Provider:</strong> ${{data.provider}}</p>
|
||||||
|
<p><a href="${{data.authorization_url}}" target="_blank">Click here to authenticate</a></p>
|
||||||
|
<p><em>Note: After authentication, you will be redirected to the callback URL.</em></p>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
|
||||||
|
// Automatically redirect to OAuth provider
|
||||||
|
// window.open(data.authorization_url, '_blank');
|
||||||
|
}} else {{
|
||||||
|
resultDiv.innerHTML = `
|
||||||
|
<div class="error">
|
||||||
|
<h4>Error</h4>
|
||||||
|
<p>${{data.error}}</p>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
}}
|
||||||
|
}} catch (error) {{
|
||||||
|
resultDiv.innerHTML = `
|
||||||
|
<div class="error">
|
||||||
|
<h4>Network Error</h4>
|
||||||
|
<p>${{error.message}}</p>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
|
||||||
|
// Handle OAuth callback result if present in URL
|
||||||
|
window.addEventListener('load', function() {{
|
||||||
|
const urlParams = new URLSearchParams(window.location.search);
|
||||||
|
if (urlParams.has('code') && urlParams.has('state')) {{
|
||||||
|
const resultDiv = document.getElementById('result');
|
||||||
|
resultDiv.innerHTML = `
|
||||||
|
<div class="success">
|
||||||
|
<h4>OAuth Callback Received</h4>
|
||||||
|
<p>Code: ${{urlParams.get('code')}}</p>
|
||||||
|
<p>State: ${{urlParams.get('state')}}</p>
|
||||||
|
<p>The authentication was successful!</p>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
}} else if (urlParams.has('error')) {{
|
||||||
|
const resultDiv = document.getElementById('result');
|
||||||
|
resultDiv.innerHTML = `
|
||||||
|
<div class="error">
|
||||||
|
<h4>OAuth Error</h4>
|
||||||
|
<p>Error: ${{urlParams.get('error')}}</p>
|
||||||
|
<p>Description: ${{urlParams.get('error_description') || 'No description'}}</p>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
}}
|
||||||
|
}});
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
return HTMLResponse(html_content)
|
||||||
289
doris_mcp_server/auth/oauth_provider.py
Normal file
@@ -0,0 +1,289 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
OAuth Authentication Provider
|
||||||
|
Integrates OAuth 2.0/OIDC authentication with the existing authentication framework
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, Optional, Tuple
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from .oauth_client import OAuthClient
|
||||||
|
from .oauth_types import OAuthTokens, OAuthUserInfo, OAuthState
|
||||||
|
from ..utils.security import AuthContext, SecurityLevel
|
||||||
|
from ..utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAuthenticationProvider:
|
||||||
|
"""OAuth authentication provider for Doris MCP Server"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
"""Initialize OAuth authentication provider
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: DorisConfig with OAuth configuration
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self.oauth_client = OAuthClient(config)
|
||||||
|
self.enabled = self.oauth_client.enabled
|
||||||
|
|
||||||
|
logger.info(f"OAuthAuthenticationProvider initialized (enabled: {self.enabled})")
|
||||||
|
|
||||||
|
async def initialize(self) -> bool:
|
||||||
|
"""Initialize OAuth authentication provider
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if initialization successful
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
return True
|
||||||
|
|
||||||
|
success = await self.oauth_client.initialize()
|
||||||
|
if success:
|
||||||
|
logger.info("OAuth authentication provider initialized successfully")
|
||||||
|
else:
|
||||||
|
logger.error("Failed to initialize OAuth authentication provider")
|
||||||
|
return success
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""Shutdown OAuth authentication provider"""
|
||||||
|
if self.enabled:
|
||||||
|
await self.oauth_client.shutdown()
|
||||||
|
logger.info("OAuth authentication provider shutdown completed")
|
||||||
|
|
||||||
|
def get_authorization_url(self) -> Tuple[str, str]:
|
||||||
|
"""Get OAuth authorization URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (authorization_url, state)
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
raise ValueError("OAuth authentication is not enabled")
|
||||||
|
|
||||||
|
authorization_url, oauth_state = self.oauth_client.build_authorization_url()
|
||||||
|
return authorization_url, oauth_state.state
|
||||||
|
|
||||||
|
async def handle_callback(self, code: str, state: str) -> AuthContext:
|
||||||
|
"""Handle OAuth callback and create authentication context
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Authorization code from OAuth provider
|
||||||
|
state: State parameter for CSRF protection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AuthContext for the authenticated user
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If authentication fails
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
raise ValueError("OAuth authentication is not enabled")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Exchange code for tokens
|
||||||
|
tokens, oauth_state = await self.oauth_client.exchange_code_for_tokens(code, state)
|
||||||
|
|
||||||
|
# Get user information
|
||||||
|
user_info = await self.oauth_client.get_user_info(tokens)
|
||||||
|
|
||||||
|
# Create authentication context
|
||||||
|
auth_context = await self._create_auth_context(user_info, tokens)
|
||||||
|
|
||||||
|
logger.info(f"OAuth authentication successful for user: {auth_context.user_id}")
|
||||||
|
return auth_context
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OAuth callback handling failed: {e}")
|
||||||
|
raise ValueError(f"OAuth authentication failed: {str(e)}")
|
||||||
|
|
||||||
|
async def authenticate_with_token(self, access_token: str) -> AuthContext:
|
||||||
|
"""Authenticate using OAuth access token
|
||||||
|
|
||||||
|
Args:
|
||||||
|
access_token: OAuth access token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AuthContext for the authenticated user
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
raise ValueError("OAuth authentication is not enabled")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create token object
|
||||||
|
tokens = OAuthTokens(access_token=access_token)
|
||||||
|
|
||||||
|
# Get user information
|
||||||
|
user_info = await self.oauth_client.get_user_info(tokens)
|
||||||
|
|
||||||
|
# Create authentication context
|
||||||
|
auth_context = await self._create_auth_context(user_info, tokens)
|
||||||
|
|
||||||
|
logger.info(f"OAuth token authentication successful for user: {auth_context.user_id}")
|
||||||
|
return auth_context
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OAuth token authentication failed: {e}")
|
||||||
|
raise ValueError(f"OAuth token authentication failed: {str(e)}")
|
||||||
|
|
||||||
|
async def refresh_authentication(self, refresh_token: str) -> Tuple[AuthContext, str]:
|
||||||
|
"""Refresh OAuth authentication
|
||||||
|
|
||||||
|
Args:
|
||||||
|
refresh_token: OAuth refresh token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (AuthContext, new_access_token)
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
raise ValueError("OAuth authentication is not enabled")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Refresh tokens
|
||||||
|
tokens = await self.oauth_client.refresh_tokens(refresh_token)
|
||||||
|
|
||||||
|
# Get updated user information
|
||||||
|
user_info = await self.oauth_client.get_user_info(tokens)
|
||||||
|
|
||||||
|
# Create authentication context
|
||||||
|
auth_context = await self._create_auth_context(user_info, tokens)
|
||||||
|
|
||||||
|
logger.info(f"OAuth refresh successful for user: {auth_context.user_id}")
|
||||||
|
return auth_context, tokens.access_token
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OAuth refresh failed: {e}")
|
||||||
|
raise ValueError(f"OAuth refresh failed: {str(e)}")
|
||||||
|
|
||||||
|
async def _create_auth_context(self, user_info: OAuthUserInfo, tokens: OAuthTokens) -> AuthContext:
|
||||||
|
"""Create authentication context from OAuth user info
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_info: OAuth user information
|
||||||
|
tokens: OAuth tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AuthContext for the user
|
||||||
|
"""
|
||||||
|
# Determine security level based on roles or email domain
|
||||||
|
security_level = await self._determine_security_level(user_info)
|
||||||
|
|
||||||
|
# Map OAuth roles to application permissions
|
||||||
|
permissions = await self._map_permissions(user_info.roles)
|
||||||
|
|
||||||
|
# Generate session ID
|
||||||
|
session_id = f"oauth_{user_info.sub}_{datetime.utcnow().timestamp()}"
|
||||||
|
|
||||||
|
return AuthContext(
|
||||||
|
token_id=f"oauth_{user_info.sub}",
|
||||||
|
user_id=user_info.sub,
|
||||||
|
roles=user_info.roles,
|
||||||
|
permissions=permissions,
|
||||||
|
security_level=security_level,
|
||||||
|
session_id=session_id,
|
||||||
|
login_time=datetime.utcnow(),
|
||||||
|
last_activity=datetime.utcnow(),
|
||||||
|
token="" # OAuth doesn't have raw token, use empty string
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _determine_security_level(self, user_info: OAuthUserInfo) -> SecurityLevel:
|
||||||
|
"""Determine security level for OAuth user
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_info: OAuth user information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SecurityLevel for the user
|
||||||
|
"""
|
||||||
|
# Check if user has admin roles
|
||||||
|
admin_roles = {"admin", "administrator", "data_admin", "super_admin"}
|
||||||
|
if any(role.lower() in admin_roles for role in user_info.roles):
|
||||||
|
return SecurityLevel.SECRET
|
||||||
|
|
||||||
|
# Check email domain for internal users
|
||||||
|
if user_info.email:
|
||||||
|
# You can configure trusted domains for internal access
|
||||||
|
trusted_domains = ["yourcompany.com", "internal.org"] # Configure as needed
|
||||||
|
email_domain = user_info.email.split("@")[-1].lower()
|
||||||
|
if email_domain in trusted_domains:
|
||||||
|
return SecurityLevel.CONFIDENTIAL
|
||||||
|
|
||||||
|
# Check for special roles
|
||||||
|
elevated_roles = {"data_analyst", "developer", "manager"}
|
||||||
|
if any(role.lower() in elevated_roles for role in user_info.roles):
|
||||||
|
return SecurityLevel.CONFIDENTIAL
|
||||||
|
|
||||||
|
# Default to internal level for OAuth users
|
||||||
|
return SecurityLevel.INTERNAL
|
||||||
|
|
||||||
|
async def _map_permissions(self, roles: list[str]) -> list[str]:
|
||||||
|
"""Map OAuth roles to application permissions
|
||||||
|
|
||||||
|
Args:
|
||||||
|
roles: OAuth user roles
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of application permissions
|
||||||
|
"""
|
||||||
|
permissions = set()
|
||||||
|
|
||||||
|
# Role to permission mapping
|
||||||
|
role_permissions = {
|
||||||
|
"admin": ["admin", "read_data", "write_data", "manage_users"],
|
||||||
|
"administrator": ["admin", "read_data", "write_data", "manage_users"],
|
||||||
|
"data_admin": ["admin", "read_data", "write_data"],
|
||||||
|
"super_admin": ["admin", "read_data", "write_data", "manage_users", "system_admin"],
|
||||||
|
"data_analyst": ["read_data", "query_database"],
|
||||||
|
"developer": ["read_data", "query_database", "debug"],
|
||||||
|
"viewer": ["read_data"],
|
||||||
|
"user": ["read_data"],
|
||||||
|
"oauth_user": ["read_data"] # Default OAuth user permission
|
||||||
|
}
|
||||||
|
|
||||||
|
# Map roles to permissions
|
||||||
|
for role in roles:
|
||||||
|
role_lower = role.lower()
|
||||||
|
if role_lower in role_permissions:
|
||||||
|
permissions.update(role_permissions[role_lower])
|
||||||
|
|
||||||
|
# Ensure OAuth users have at least basic permissions
|
||||||
|
if not permissions:
|
||||||
|
permissions.add("read_data")
|
||||||
|
|
||||||
|
return list(permissions)
|
||||||
|
|
||||||
|
def get_provider_info(self) -> Dict[str, Any]:
|
||||||
|
"""Get OAuth provider information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Provider information dictionary
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
return {"enabled": False}
|
||||||
|
|
||||||
|
config = self.oauth_client.provider_config
|
||||||
|
return {
|
||||||
|
"enabled": True,
|
||||||
|
"provider": config.provider.value,
|
||||||
|
"client_id": config.client_id,
|
||||||
|
"scopes": config.scopes,
|
||||||
|
"redirect_uri": config.redirect_uri,
|
||||||
|
"pkce_enabled": config.pkce_enabled,
|
||||||
|
"nonce_enabled": config.nonce_enabled
|
||||||
|
}
|
||||||
196
doris_mcp_server/auth/oauth_types.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
OAuth 2.0/OIDC Type Definitions
|
||||||
|
Provides data types and models for OAuth authentication flow
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthProvider(Enum):
|
||||||
|
"""OAuth provider enumeration"""
|
||||||
|
GOOGLE = "google"
|
||||||
|
MICROSOFT = "microsoft"
|
||||||
|
GITHUB = "github"
|
||||||
|
CUSTOM = "custom"
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthGrantType(Enum):
|
||||||
|
"""OAuth grant type enumeration"""
|
||||||
|
AUTHORIZATION_CODE = "authorization_code"
|
||||||
|
REFRESH_TOKEN = "refresh_token"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OAuthState:
|
||||||
|
"""OAuth state parameter for CSRF protection"""
|
||||||
|
state: str
|
||||||
|
nonce: Optional[str] = None
|
||||||
|
pkce_verifier: Optional[str] = None
|
||||||
|
pkce_challenge: Optional[str] = None
|
||||||
|
redirect_uri: str = ""
|
||||||
|
created_at: datetime = None
|
||||||
|
expires_at: datetime = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.created_at is None:
|
||||||
|
self.created_at = datetime.utcnow()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OAuthTokens:
|
||||||
|
"""OAuth token response"""
|
||||||
|
access_token: str
|
||||||
|
token_type: str = "Bearer"
|
||||||
|
expires_in: Optional[int] = None
|
||||||
|
refresh_token: Optional[str] = None
|
||||||
|
scope: Optional[str] = None
|
||||||
|
id_token: Optional[str] = None # OIDC ID token
|
||||||
|
created_at: datetime = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.created_at is None:
|
||||||
|
self.created_at = datetime.utcnow()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OAuthUserInfo:
|
||||||
|
"""OAuth/OIDC user information"""
|
||||||
|
sub: str # Subject identifier
|
||||||
|
email: Optional[str] = None
|
||||||
|
email_verified: Optional[bool] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
given_name: Optional[str] = None
|
||||||
|
family_name: Optional[str] = None
|
||||||
|
picture: Optional[str] = None
|
||||||
|
locale: Optional[str] = None
|
||||||
|
roles: List[str] = None
|
||||||
|
raw_claims: Dict[str, Any] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.roles is None:
|
||||||
|
self.roles = []
|
||||||
|
if self.raw_claims is None:
|
||||||
|
self.raw_claims = {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OIDCDiscovery:
|
||||||
|
"""OIDC Discovery document"""
|
||||||
|
issuer: str
|
||||||
|
authorization_endpoint: str
|
||||||
|
token_endpoint: str
|
||||||
|
userinfo_endpoint: Optional[str] = None
|
||||||
|
jwks_uri: Optional[str] = None
|
||||||
|
scopes_supported: List[str] = None
|
||||||
|
response_types_supported: List[str] = None
|
||||||
|
subject_types_supported: List[str] = None
|
||||||
|
id_token_signing_alg_values_supported: List[str] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.scopes_supported is None:
|
||||||
|
self.scopes_supported = ["openid"]
|
||||||
|
if self.response_types_supported is None:
|
||||||
|
self.response_types_supported = ["code"]
|
||||||
|
if self.subject_types_supported is None:
|
||||||
|
self.subject_types_supported = ["public"]
|
||||||
|
if self.id_token_signing_alg_values_supported is None:
|
||||||
|
self.id_token_signing_alg_values_supported = ["RS256"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OAuthError:
|
||||||
|
"""OAuth error response"""
|
||||||
|
error: str
|
||||||
|
error_description: Optional[str] = None
|
||||||
|
error_uri: Optional[str] = None
|
||||||
|
state: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OAuthProviderConfig:
|
||||||
|
"""OAuth provider configuration"""
|
||||||
|
provider: OAuthProvider
|
||||||
|
client_id: str
|
||||||
|
client_secret: str
|
||||||
|
redirect_uri: str
|
||||||
|
scopes: List[str]
|
||||||
|
|
||||||
|
# Endpoints
|
||||||
|
authorization_endpoint: str
|
||||||
|
token_endpoint: str
|
||||||
|
userinfo_endpoint: Optional[str] = None
|
||||||
|
jwks_uri: Optional[str] = None
|
||||||
|
|
||||||
|
# Discovery
|
||||||
|
discovery_url: Optional[str] = None
|
||||||
|
|
||||||
|
# Settings
|
||||||
|
pkce_enabled: bool = True
|
||||||
|
nonce_enabled: bool = True
|
||||||
|
|
||||||
|
# User mapping
|
||||||
|
user_id_claim: str = "sub"
|
||||||
|
email_claim: str = "email"
|
||||||
|
name_claim: str = "name"
|
||||||
|
roles_claim: str = "roles"
|
||||||
|
default_roles: List[str] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.default_roles is None:
|
||||||
|
self.default_roles = ["oauth_user"]
|
||||||
|
|
||||||
|
|
||||||
|
# Pre-defined provider configurations
|
||||||
|
OAUTH_PROVIDERS = {
|
||||||
|
OAuthProvider.GOOGLE: {
|
||||||
|
"authorization_endpoint": "https://accounts.google.com/o/oauth2/auth",
|
||||||
|
"token_endpoint": "https://oauth2.googleapis.com/token",
|
||||||
|
"userinfo_endpoint": "https://openidconnect.googleapis.com/v1/userinfo",
|
||||||
|
"jwks_uri": "https://www.googleapis.com/oauth2/v3/certs",
|
||||||
|
"discovery_url": "https://accounts.google.com/.well-known/openid_configuration",
|
||||||
|
"scopes": ["openid", "email", "profile"],
|
||||||
|
"user_id_claim": "sub",
|
||||||
|
"email_claim": "email",
|
||||||
|
"name_claim": "name"
|
||||||
|
},
|
||||||
|
OAuthProvider.MICROSOFT: {
|
||||||
|
"authorization_endpoint": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
|
||||||
|
"token_endpoint": "https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||||
|
"userinfo_endpoint": "https://graph.microsoft.com/v1.0/me",
|
||||||
|
"jwks_uri": "https://login.microsoftonline.com/common/discovery/v2.0/keys",
|
||||||
|
"discovery_url": "https://login.microsoftonline.com/common/v2.0/.well-known/openid_configuration",
|
||||||
|
"scopes": ["openid", "profile", "email", "User.Read"],
|
||||||
|
"user_id_claim": "sub",
|
||||||
|
"email_claim": "email",
|
||||||
|
"name_claim": "name"
|
||||||
|
},
|
||||||
|
OAuthProvider.GITHUB: {
|
||||||
|
"authorization_endpoint": "https://github.com/login/oauth/authorize",
|
||||||
|
"token_endpoint": "https://github.com/login/oauth/access_token",
|
||||||
|
"userinfo_endpoint": "https://api.github.com/user",
|
||||||
|
"scopes": ["user:email", "read:user"],
|
||||||
|
"user_id_claim": "id",
|
||||||
|
"email_claim": "email",
|
||||||
|
"name_claim": "name"
|
||||||
|
}
|
||||||
|
}
|
||||||
677
doris_mcp_server/auth/token_handlers.py
Normal file
@@ -0,0 +1,677 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
Token Authentication HTTP Handlers
|
||||||
|
|
||||||
|
Provides HTTP endpoints for token management including creation, revocation,
|
||||||
|
listing, and statistics. Used for administrative token management in HTTP mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse, HTMLResponse
|
||||||
|
|
||||||
|
from ..utils.logger import get_logger
|
||||||
|
from ..utils.security import SecurityLevel
|
||||||
|
from ..utils.config import DatabaseConfig
|
||||||
|
from .token_security_middleware import TokenSecurityMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
class TokenHandlers:
|
||||||
|
"""Token Authentication HTTP Handlers"""
|
||||||
|
|
||||||
|
def __init__(self, security_manager, config=None):
|
||||||
|
self.security_manager = security_manager
|
||||||
|
self.logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Initialize security middleware if config is provided
|
||||||
|
if config:
|
||||||
|
self.security_middleware = TokenSecurityMiddleware(config)
|
||||||
|
else:
|
||||||
|
self.security_middleware = None
|
||||||
|
self.logger.warning("Token handlers initialized without security middleware - access control disabled")
|
||||||
|
|
||||||
|
async def handle_create_token(self, request: Request) -> JSONResponse:
|
||||||
|
"""Handle token creation request"""
|
||||||
|
# Apply security checks
|
||||||
|
if self.security_middleware:
|
||||||
|
security_response = await self.security_middleware.check_token_management_access(request)
|
||||||
|
if security_response:
|
||||||
|
return security_response
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if token manager is available
|
||||||
|
if not self.security_manager.auth_provider.token_manager:
|
||||||
|
return JSONResponse({
|
||||||
|
"error": "Token authentication is not enabled"
|
||||||
|
}, status_code=503)
|
||||||
|
|
||||||
|
# Parse request data
|
||||||
|
if request.method == "GET":
|
||||||
|
# GET request with query parameters
|
||||||
|
query_params = dict(request.query_params)
|
||||||
|
token_id = query_params.get("token_id")
|
||||||
|
expires_hours_str = query_params.get("expires_hours")
|
||||||
|
description = query_params.get("description", "")
|
||||||
|
custom_token = query_params.get("custom_token")
|
||||||
|
# Database configuration from query params
|
||||||
|
db_config = None
|
||||||
|
if query_params.get("db_host"):
|
||||||
|
db_config = DatabaseConfig(
|
||||||
|
host=query_params.get("db_host", "localhost"),
|
||||||
|
port=int(query_params.get("db_port", "9030")),
|
||||||
|
user=query_params.get("db_user", "root"),
|
||||||
|
password=query_params.get("db_password", ""),
|
||||||
|
database=query_params.get("db_database", "information_schema"),
|
||||||
|
fe_http_port=int(query_params.get("db_fe_http_port", "8030"))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# POST request with JSON body
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except:
|
||||||
|
return JSONResponse({
|
||||||
|
"error": "Invalid JSON body"
|
||||||
|
}, status_code=400)
|
||||||
|
|
||||||
|
token_id = body.get("token_id")
|
||||||
|
expires_hours_str = body.get("expires_hours")
|
||||||
|
description = body.get("description", "")
|
||||||
|
custom_token = body.get("custom_token")
|
||||||
|
# Database configuration from JSON body
|
||||||
|
db_config = None
|
||||||
|
if body.get("database_config"):
|
||||||
|
db_data = body["database_config"]
|
||||||
|
try:
|
||||||
|
db_config = DatabaseConfig(
|
||||||
|
host=db_data.get("host", "localhost"),
|
||||||
|
port=int(db_data.get("port", 9030)),
|
||||||
|
user=db_data.get("user", "root"),
|
||||||
|
password=db_data.get("password", ""),
|
||||||
|
database=db_data.get("database", "information_schema"),
|
||||||
|
fe_http_port=int(db_data.get("fe_http_port", 8030))
|
||||||
|
)
|
||||||
|
except (ValueError, TypeError) as e:
|
||||||
|
return JSONResponse({
|
||||||
|
"error": f"Invalid database configuration: {str(e)}"
|
||||||
|
}, status_code=400)
|
||||||
|
|
||||||
|
# Validate required fields
|
||||||
|
if not token_id:
|
||||||
|
return JSONResponse({
|
||||||
|
"error": "token_id is required"
|
||||||
|
}, status_code=400)
|
||||||
|
|
||||||
|
# Parse expires_hours
|
||||||
|
expires_hours = None
|
||||||
|
if expires_hours_str:
|
||||||
|
try:
|
||||||
|
expires_hours = int(expires_hours_str)
|
||||||
|
except ValueError:
|
||||||
|
return JSONResponse({
|
||||||
|
"error": "expires_hours must be an integer"
|
||||||
|
}, status_code=400)
|
||||||
|
|
||||||
|
# Create token using the actual API
|
||||||
|
try:
|
||||||
|
token = await self.security_manager.create_token(
|
||||||
|
token_id=token_id,
|
||||||
|
expires_hours=expires_hours,
|
||||||
|
description=description,
|
||||||
|
custom_token=custom_token,
|
||||||
|
database_config=db_config
|
||||||
|
)
|
||||||
|
|
||||||
|
return JSONResponse({
|
||||||
|
"success": True,
|
||||||
|
"token_id": token_id,
|
||||||
|
"token": token,
|
||||||
|
"expires_hours": expires_hours,
|
||||||
|
"description": description,
|
||||||
|
"message": "Token created successfully"
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Token creation failed: {e}")
|
||||||
|
return JSONResponse({
|
||||||
|
"error": f"Token creation failed: {str(e)}"
|
||||||
|
}, status_code=400)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in handle_create_token: {e}")
|
||||||
|
return JSONResponse({
|
||||||
|
"error": f"Internal server error: {str(e)}"
|
||||||
|
}, status_code=500)
|
||||||
|
|
||||||
|
async def handle_revoke_token(self, request: Request) -> JSONResponse:
|
||||||
|
"""Handle token revocation request"""
|
||||||
|
# Apply security checks
|
||||||
|
if self.security_middleware:
|
||||||
|
security_response = await self.security_middleware.check_token_management_access(request)
|
||||||
|
if security_response:
|
||||||
|
return security_response
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if token manager is available
|
||||||
|
if not self.security_manager.auth_provider.token_manager:
|
||||||
|
return JSONResponse({
|
||||||
|
"error": "Token authentication is not enabled"
|
||||||
|
}, status_code=503)
|
||||||
|
|
||||||
|
# Get token_id from query parameters or path
|
||||||
|
token_id = request.query_params.get("token_id")
|
||||||
|
if not token_id and request.method == "DELETE":
|
||||||
|
# Try to get from path: /token/revoke/{token_id}
|
||||||
|
path_parts = str(request.url.path).split("/")
|
||||||
|
if len(path_parts) >= 4:
|
||||||
|
token_id = path_parts[-1]
|
||||||
|
|
||||||
|
if not token_id:
|
||||||
|
return JSONResponse({
|
||||||
|
"error": "token_id is required"
|
||||||
|
}, status_code=400)
|
||||||
|
|
||||||
|
# Revoke token
|
||||||
|
success = await self.security_manager.revoke_token(token_id)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
return JSONResponse({
|
||||||
|
"success": True,
|
||||||
|
"token_id": token_id,
|
||||||
|
"message": "Token revoked successfully"
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
return JSONResponse({
|
||||||
|
"success": False,
|
||||||
|
"token_id": token_id,
|
||||||
|
"message": "Token not found or already revoked"
|
||||||
|
}, status_code=404)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in handle_revoke_token: {e}")
|
||||||
|
return JSONResponse({
|
||||||
|
"error": f"Internal server error: {str(e)}"
|
||||||
|
}, status_code=500)
|
||||||
|
|
||||||
|
async def handle_list_tokens(self, request: Request) -> JSONResponse:
|
||||||
|
"""Handle token listing request"""
|
||||||
|
# Apply security checks
|
||||||
|
if self.security_middleware:
|
||||||
|
security_response = await self.security_middleware.check_token_management_access(request)
|
||||||
|
if security_response:
|
||||||
|
return security_response
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if token manager is available
|
||||||
|
if not self.security_manager.auth_provider.token_manager:
|
||||||
|
return JSONResponse({
|
||||||
|
"error": "Token authentication is not enabled"
|
||||||
|
}, status_code=503)
|
||||||
|
|
||||||
|
# Get tokens list
|
||||||
|
tokens = await self.security_manager.list_tokens()
|
||||||
|
|
||||||
|
return JSONResponse({
|
||||||
|
"success": True,
|
||||||
|
"count": len(tokens),
|
||||||
|
"tokens": tokens
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in handle_list_tokens: {e}")
|
||||||
|
return JSONResponse({
|
||||||
|
"error": f"Internal server error: {str(e)}"
|
||||||
|
}, status_code=500)
|
||||||
|
|
||||||
|
async def handle_token_stats(self, request: Request) -> JSONResponse:
|
||||||
|
"""Handle token statistics request"""
|
||||||
|
# Apply security checks
|
||||||
|
if self.security_middleware:
|
||||||
|
security_response = await self.security_middleware.check_token_management_access(request)
|
||||||
|
if security_response:
|
||||||
|
return security_response
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if token manager is available
|
||||||
|
if not self.security_manager.auth_provider.token_manager:
|
||||||
|
return JSONResponse({
|
||||||
|
"error": "Token authentication is not enabled"
|
||||||
|
}, status_code=503)
|
||||||
|
|
||||||
|
# Get token statistics
|
||||||
|
stats = self.security_manager.get_token_stats()
|
||||||
|
|
||||||
|
return JSONResponse({
|
||||||
|
"success": True,
|
||||||
|
"stats": stats
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in handle_token_stats: {e}")
|
||||||
|
return JSONResponse({
|
||||||
|
"error": f"Internal server error: {str(e)}"
|
||||||
|
}, status_code=500)
|
||||||
|
|
||||||
|
async def handle_cleanup_tokens(self, request: Request) -> JSONResponse:
|
||||||
|
"""Handle expired tokens cleanup request"""
|
||||||
|
# Apply security checks
|
||||||
|
if self.security_middleware:
|
||||||
|
security_response = await self.security_middleware.check_token_management_access(request)
|
||||||
|
if security_response:
|
||||||
|
return security_response
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if token manager is available
|
||||||
|
if not self.security_manager.auth_provider.token_manager:
|
||||||
|
return JSONResponse({
|
||||||
|
"error": "Token authentication is not enabled"
|
||||||
|
}, status_code=503)
|
||||||
|
|
||||||
|
# Cleanup expired tokens
|
||||||
|
cleaned_count = await self.security_manager.cleanup_expired_tokens()
|
||||||
|
|
||||||
|
return JSONResponse({
|
||||||
|
"success": True,
|
||||||
|
"cleaned_count": cleaned_count,
|
||||||
|
"message": f"Cleaned up {cleaned_count} expired tokens"
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in handle_cleanup_tokens: {e}")
|
||||||
|
return JSONResponse({
|
||||||
|
"error": f"Internal server error: {str(e)}"
|
||||||
|
}, status_code=500)
|
||||||
|
|
||||||
|
async def handle_management_page(self, request: Request) -> HTMLResponse:
|
||||||
|
"""Handle token management demo page"""
|
||||||
|
# Apply security checks
|
||||||
|
if self.security_middleware:
|
||||||
|
security_response = await self.security_middleware.check_token_management_access(request)
|
||||||
|
if security_response:
|
||||||
|
# Convert JSON response to HTML for demo page
|
||||||
|
error_data = security_response.body.decode('utf-8') if hasattr(security_response, 'body') else '{"error": "Access denied"}'
|
||||||
|
try:
|
||||||
|
error_info = json.loads(error_data)
|
||||||
|
except:
|
||||||
|
error_info = {"error": "Access denied"}
|
||||||
|
|
||||||
|
error_html = f"""
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Access Denied - Token Management</title>
|
||||||
|
<style>
|
||||||
|
body {{ font-family: Arial, sans-serif; margin: 50px; background: #f5f5f5; }}
|
||||||
|
.container {{ max-width: 600px; margin: 0 auto; background: white; padding: 30px; border-radius: 8px; }}
|
||||||
|
.error {{ color: #dc3545; background: #f8d7da; border: 1px solid #f5c6cb; padding: 15px; border-radius: 5px; }}
|
||||||
|
.security-info {{ background: #d1ecf1; border: 1px solid #bee5eb; padding: 15px; border-radius: 5px; margin-top: 20px; }}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<h1>🔐 Token Management - Access Denied</h1>
|
||||||
|
<div class="error">
|
||||||
|
<h3>Access Denied</h3>
|
||||||
|
<p><strong>Error:</strong> {error_info.get('error', 'Access denied')}</p>
|
||||||
|
<p><strong>Message:</strong> {error_info.get('message', 'Token management access is restricted')}</p>
|
||||||
|
{'<p><strong>Your IP:</strong> ' + str(error_info.get('client_ip', 'Unknown')) + '</p>' if 'client_ip' in error_info else ''}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="security-info">
|
||||||
|
<h3>🛡️ Security Information</h3>
|
||||||
|
<p>Token management endpoints are protected by the following security measures:</p>
|
||||||
|
<ul>
|
||||||
|
<li><strong>IP Restrictions:</strong> Only localhost/127.0.0.1 access allowed</li>
|
||||||
|
<li><strong>Admin Authentication:</strong> Valid admin token required</li>
|
||||||
|
<li><strong>Configuration Control:</strong> Must be explicitly enabled</li>
|
||||||
|
</ul>
|
||||||
|
<p>If you need access, please:</p>
|
||||||
|
<ol>
|
||||||
|
<li>Access from the server host (127.0.0.1)</li>
|
||||||
|
<li>Ensure HTTP token management is enabled in configuration</li>
|
||||||
|
<li>Provide valid admin authentication</li>
|
||||||
|
</ol>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
return HTMLResponse(error_html, status_code=security_response.status_code)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if token manager is available
|
||||||
|
if not self.security_manager.auth_provider.token_manager:
|
||||||
|
html_content = """
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Token Management - Not Available</title>
|
||||||
|
<style>
|
||||||
|
body { font-family: Arial, sans-serif; margin: 50px; }
|
||||||
|
.error { color: red; font-size: 18px; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>Token Management</h1>
|
||||||
|
<div class="error">Token authentication is not enabled on this server.</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
return HTMLResponse(html_content)
|
||||||
|
|
||||||
|
# Get current stats for demo
|
||||||
|
stats = self.security_manager.get_token_stats()
|
||||||
|
|
||||||
|
html_content = f"""
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Doris MCP Server - Token Management</title>
|
||||||
|
<style>
|
||||||
|
body {{ font-family: Arial, sans-serif; margin: 50px; background: #f5f5f5; }}
|
||||||
|
.container {{ max-width: 1200px; margin: 0 auto; background: white; padding: 30px; border-radius: 8px; }}
|
||||||
|
h1 {{ color: #333; }}
|
||||||
|
.section {{ margin: 30px 0; padding: 20px; border: 1px solid #ddd; border-radius: 5px; }}
|
||||||
|
.stats {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 15px; }}
|
||||||
|
.stat-item {{ padding: 15px; background: #f8f9fa; border-radius: 5px; text-align: center; }}
|
||||||
|
.stat-value {{ font-size: 24px; font-weight: bold; color: #007bff; }}
|
||||||
|
.form-group {{ margin: 15px 0; }}
|
||||||
|
.form-group label {{ display: block; margin-bottom: 5px; font-weight: bold; }}
|
||||||
|
.form-group input, .form-group textarea {{ width: 100%; padding: 8px; border: 1px solid #ddd; border-radius: 4px; }}
|
||||||
|
button {{ padding: 10px 20px; margin: 5px; border: none; border-radius: 4px; cursor: pointer; }}
|
||||||
|
.btn-primary {{ background: #007bff; color: white; }}
|
||||||
|
.btn-danger {{ background: #dc3545; color: white; }}
|
||||||
|
.btn-success {{ background: #28a745; color: white; }}
|
||||||
|
.response {{ margin: 15px 0; padding: 15px; border-radius: 5px; }}
|
||||||
|
.response.success {{ background: #d4edda; border: 1px solid #c3e6cb; }}
|
||||||
|
.response.error {{ background: #f8d7da; border: 1px solid #f5c6cb; }}
|
||||||
|
.token-list {{ margin: 15px 0; }}
|
||||||
|
.token-item {{ padding: 10px; margin: 5px 0; background: #f8f9fa; border-radius: 4px; }}
|
||||||
|
pre {{ background: #f8f9fa; padding: 10px; border-radius: 4px; overflow-x: auto; }}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<h1>🔐 Doris MCP Server - Token Management</h1>
|
||||||
|
|
||||||
|
<div class="section">
|
||||||
|
<h2>📊 Token Statistics</h2>
|
||||||
|
<div class="stats">
|
||||||
|
<div class="stat-item">
|
||||||
|
<div class="stat-value">{stats.get('total_tokens', 0)}</div>
|
||||||
|
<div>Total Tokens</div>
|
||||||
|
</div>
|
||||||
|
<div class="stat-item">
|
||||||
|
<div class="stat-value">{stats.get('active_tokens', 0)}</div>
|
||||||
|
<div>Active Tokens</div>
|
||||||
|
</div>
|
||||||
|
<div class="stat-item">
|
||||||
|
<div class="stat-value">{stats.get('expired_tokens', 0)}</div>
|
||||||
|
<div>Expired Tokens</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<p><strong>Token Expiry:</strong> {'Enabled' if stats.get('expiry_enabled') else 'Disabled'}</p>
|
||||||
|
<p><strong>Default Expiry:</strong> {stats.get('default_expiry_hours', 0)} hours</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="section">
|
||||||
|
<h2>➕ Create New Token</h2>
|
||||||
|
<form id="createTokenForm">
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="token_id">Token ID (required):</label>
|
||||||
|
<input type="text" id="token_id" name="token_id" placeholder="e.g., my-app-token" required>
|
||||||
|
</div>
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="expires_hours">Expires Hours (optional):</label>
|
||||||
|
<input type="number" id="expires_hours" name="expires_hours" placeholder="e.g., 720 (30 days), leave empty for default">
|
||||||
|
</div>
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="description">Description (optional):</label>
|
||||||
|
<textarea id="description" name="description" placeholder="Token description"></textarea>
|
||||||
|
</div>
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="custom_token">Custom Token (optional):</label>
|
||||||
|
<input type="text" id="custom_token" name="custom_token" placeholder="Leave empty to auto-generate">
|
||||||
|
<small style="color: #666; display: block; margin-top: 5px;">If not provided, a secure token will be generated automatically</small>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="section" style="margin: 20px 0; padding: 15px; background: #f8f9fa; border-radius: 5px;">
|
||||||
|
<h3>🗄️ Database Configuration (Optional)</h3>
|
||||||
|
<p style="color: #666; font-size: 14px; margin-bottom: 15px;">Configure database connection for this token. Leave empty to use system defaults.</p>
|
||||||
|
|
||||||
|
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 10px;">
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="db_host">Host:</label>
|
||||||
|
<input type="text" id="db_host" name="db_host" placeholder="localhost">
|
||||||
|
</div>
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="db_port">Port:</label>
|
||||||
|
<input type="number" id="db_port" name="db_port" placeholder="9030">
|
||||||
|
</div>
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="db_user">User:</label>
|
||||||
|
<input type="text" id="db_user" name="db_user" placeholder="root">
|
||||||
|
</div>
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="db_password">Password:</label>
|
||||||
|
<input type="password" id="db_password" name="db_password" placeholder="(optional)">
|
||||||
|
</div>
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="db_database">Database:</label>
|
||||||
|
<input type="text" id="db_database" name="db_database" placeholder="information_schema">
|
||||||
|
</div>
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="db_fe_http_port">FE HTTP Port:</label>
|
||||||
|
<input type="number" id="db_fe_http_port" name="db_fe_http_port" placeholder="8030">
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<button type="submit" class="btn-primary">Create Token</button>
|
||||||
|
</form>
|
||||||
|
<div id="createTokenResponse"></div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="section">
|
||||||
|
<h2>📋 Token Management</h2>
|
||||||
|
<button id="listTokensBtn" class="btn-success">Refresh Token List</button>
|
||||||
|
<button id="cleanupTokensBtn" class="btn-primary">Cleanup Expired Tokens</button>
|
||||||
|
<div id="tokenListResponse"></div>
|
||||||
|
|
||||||
|
<h3>Revoke Token</h3>
|
||||||
|
<div class="form-group">
|
||||||
|
<input type="text" id="revokeTokenId" placeholder="Enter token ID to revoke">
|
||||||
|
<button id="revokeTokenBtn" class="btn-danger">Revoke Token</button>
|
||||||
|
</div>
|
||||||
|
<div id="revokeTokenResponse"></div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="section">
|
||||||
|
<h2>🔧 API Endpoints</h2>
|
||||||
|
<p>Use these endpoints for programmatic token management:</p>
|
||||||
|
<ul>
|
||||||
|
<li><strong>POST /token/create</strong> - Create new token</li>
|
||||||
|
<li><strong>DELETE /token/revoke?token_id=...</strong> - Revoke token</li>
|
||||||
|
<li><strong>GET /token/list</strong> - List all tokens</li>
|
||||||
|
<li><strong>GET /token/stats</strong> - Get token statistics</li>
|
||||||
|
<li><strong>POST /token/cleanup</strong> - Cleanup expired tokens</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
// Get admin token from URL parameters
|
||||||
|
const urlParams = new URLSearchParams(window.location.search);
|
||||||
|
const adminToken = urlParams.get('admin_token');
|
||||||
|
|
||||||
|
// Create request headers with admin token
|
||||||
|
function getAuthHeaders() {{
|
||||||
|
if (adminToken) {{
|
||||||
|
return {{
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Authorization': `Bearer ${{adminToken}}`
|
||||||
|
}};
|
||||||
|
}} else {{
|
||||||
|
return {{'Content-Type': 'application/json'}};
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
|
||||||
|
// Create URL with admin token parameter
|
||||||
|
function getAuthURL(baseUrl) {{
|
||||||
|
if (adminToken) {{
|
||||||
|
const separator = baseUrl.includes('?') ? '&' : '?';
|
||||||
|
return `${{baseUrl}}${{separator}}admin_token=${{encodeURIComponent(adminToken)}}`;
|
||||||
|
}}
|
||||||
|
return baseUrl;
|
||||||
|
}}
|
||||||
|
|
||||||
|
function showResponse(elementId, data, isSuccess = true) {{
|
||||||
|
const element = document.getElementById(elementId);
|
||||||
|
element.innerHTML = '<pre>' + JSON.stringify(data, null, 2) + '</pre>';
|
||||||
|
element.className = 'response ' + (isSuccess ? 'success' : 'error');
|
||||||
|
}}
|
||||||
|
|
||||||
|
// Create token form - updated to match actual API
|
||||||
|
document.getElementById('createTokenForm').addEventListener('submit', async (e) => {{
|
||||||
|
e.preventDefault();
|
||||||
|
const formData = new FormData(e.target);
|
||||||
|
const data = Object.fromEntries(formData.entries());
|
||||||
|
|
||||||
|
// Remove empty fields for optional parameters
|
||||||
|
if (!data.expires_hours) delete data.expires_hours;
|
||||||
|
if (!data.description) delete data.description;
|
||||||
|
if (!data.custom_token) delete data.custom_token;
|
||||||
|
|
||||||
|
// Handle database configuration
|
||||||
|
if (data.db_host) {{
|
||||||
|
data.database_config = {{
|
||||||
|
host: data.db_host,
|
||||||
|
port: data.db_port ? parseInt(data.db_port) : 9030,
|
||||||
|
user: data.db_user || 'root',
|
||||||
|
password: data.db_password || '',
|
||||||
|
database: data.db_database || 'information_schema',
|
||||||
|
fe_http_port: data.db_fe_http_port ? parseInt(data.db_fe_http_port) : 8030
|
||||||
|
}};
|
||||||
|
}}
|
||||||
|
|
||||||
|
// Remove individual database fields from data
|
||||||
|
delete data.db_host;
|
||||||
|
delete data.db_port;
|
||||||
|
delete data.db_user;
|
||||||
|
delete data.db_password;
|
||||||
|
delete data.db_database;
|
||||||
|
delete data.db_fe_http_port;
|
||||||
|
|
||||||
|
try {{
|
||||||
|
const response = await fetch(getAuthURL('/token/create'), {{
|
||||||
|
method: 'POST',
|
||||||
|
headers: getAuthHeaders(),
|
||||||
|
body: JSON.stringify(data)
|
||||||
|
}});
|
||||||
|
const result = await response.json();
|
||||||
|
showResponse('createTokenResponse', result, response.ok);
|
||||||
|
|
||||||
|
// Refresh token list if creation was successful
|
||||||
|
if (response.ok) {{
|
||||||
|
document.getElementById('listTokensBtn').click();
|
||||||
|
}}
|
||||||
|
}} catch (error) {{
|
||||||
|
showResponse('createTokenResponse', {{error: error.message}}, false);
|
||||||
|
}}
|
||||||
|
}});
|
||||||
|
|
||||||
|
// List tokens
|
||||||
|
document.getElementById('listTokensBtn').addEventListener('click', async () => {{
|
||||||
|
try {{
|
||||||
|
const response = await fetch(getAuthURL('/token/list'), {{
|
||||||
|
headers: getAuthHeaders()
|
||||||
|
}});
|
||||||
|
const result = await response.json();
|
||||||
|
showResponse('tokenListResponse', result, response.ok);
|
||||||
|
}} catch (error) {{
|
||||||
|
showResponse('tokenListResponse', {{error: error.message}}, false);
|
||||||
|
}}
|
||||||
|
}});
|
||||||
|
|
||||||
|
// Cleanup tokens
|
||||||
|
document.getElementById('cleanupTokensBtn').addEventListener('click', async () => {{
|
||||||
|
try {{
|
||||||
|
const response = await fetch(getAuthURL('/token/cleanup'), {{
|
||||||
|
method: 'POST',
|
||||||
|
headers: getAuthHeaders()
|
||||||
|
}});
|
||||||
|
const result = await response.json();
|
||||||
|
showResponse('tokenListResponse', result, response.ok);
|
||||||
|
}} catch (error) {{
|
||||||
|
showResponse('tokenListResponse', {{error: error.message}}, false);
|
||||||
|
}}
|
||||||
|
}});
|
||||||
|
|
||||||
|
// Revoke token
|
||||||
|
document.getElementById('revokeTokenBtn').addEventListener('click', async () => {{
|
||||||
|
const tokenId = document.getElementById('revokeTokenId').value;
|
||||||
|
if (!tokenId) {{
|
||||||
|
showResponse('revokeTokenResponse', {{error: 'Token ID is required'}}, false);
|
||||||
|
return;
|
||||||
|
}}
|
||||||
|
|
||||||
|
try {{
|
||||||
|
const response = await fetch(getAuthURL(`/token/revoke?token_id=${{encodeURIComponent(tokenId)}}`), {{
|
||||||
|
method: 'DELETE',
|
||||||
|
headers: getAuthHeaders()
|
||||||
|
}});
|
||||||
|
const result = await response.json();
|
||||||
|
showResponse('revokeTokenResponse', result, response.ok);
|
||||||
|
|
||||||
|
// Refresh token list if revocation was successful
|
||||||
|
if (response.ok) {{
|
||||||
|
document.getElementById('listTokensBtn').click();
|
||||||
|
document.getElementById('revokeTokenId').value = '';
|
||||||
|
}}
|
||||||
|
}} catch (error) {{
|
||||||
|
showResponse('revokeTokenResponse', {{error: error.message}}, false);
|
||||||
|
}}
|
||||||
|
}});
|
||||||
|
|
||||||
|
// Load token list on page load
|
||||||
|
document.getElementById('listTokensBtn').click();
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
return HTMLResponse(html_content)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in handle_demo_page: {e}")
|
||||||
|
error_html = f"""
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Token Management Error</title>
|
||||||
|
<style>body {{ font-family: Arial, sans-serif; margin: 50px; }}</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>Token Management Error</h1>
|
||||||
|
<p>Error loading token management page: {str(e)}</p>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
return HTMLResponse(error_html, status_code=500)
|
||||||
827
doris_mcp_server/auth/token_manager.py
Normal file
@@ -0,0 +1,827 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
Token Authentication Management Module
|
||||||
|
|
||||||
|
Provides enterprise-grade token authentication system with configurable tokens,
|
||||||
|
expiration management, role-based access control and secure token storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from ..utils.logger import get_logger
|
||||||
|
from ..utils.security import SecurityLevel
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatabaseConfig:
|
||||||
|
"""Database connection configuration for token binding"""
|
||||||
|
|
||||||
|
host: str
|
||||||
|
port: int = 9030
|
||||||
|
user: str = ""
|
||||||
|
password: str = ""
|
||||||
|
database: str = "information_schema"
|
||||||
|
charset: str = "UTF8"
|
||||||
|
fe_http_port: int = 8030
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenInfo:
|
||||||
|
"""Token information structure with optional database binding"""
|
||||||
|
|
||||||
|
token_id: str # Unique token identifier for audit and management
|
||||||
|
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
expires_at: Optional[datetime] = None
|
||||||
|
last_used: Optional[datetime] = None
|
||||||
|
description: str = "" # Optional description for token purpose
|
||||||
|
is_active: bool = True
|
||||||
|
database_config: Optional[DatabaseConfig] = None # Optional database binding
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenValidationResult:
|
||||||
|
"""Token validation result"""
|
||||||
|
|
||||||
|
is_valid: bool
|
||||||
|
token_info: Optional[TokenInfo] = None
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TokenManager:
|
||||||
|
"""Enterprise Token Authentication Manager
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Configurable token storage (file-based or environment variables)
|
||||||
|
- Token expiration management
|
||||||
|
- Secure token hashing
|
||||||
|
- Role-based access control
|
||||||
|
- Token lifecycle management
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = config
|
||||||
|
self.logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Token storage
|
||||||
|
self._tokens: Dict[str, TokenInfo] = {} # token_hash -> TokenInfo
|
||||||
|
self._token_ids: Dict[str, str] = {} # token_id -> token_hash
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
self.token_file_path = getattr(config.security, 'token_file_path', 'tokens.json')
|
||||||
|
self.enable_token_expiry = getattr(config.security, 'enable_token_expiry', True)
|
||||||
|
self.default_token_expiry_hours = getattr(config.security, 'default_token_expiry_hours', 24 * 30) # 30 days
|
||||||
|
self.token_hash_algorithm = getattr(config.security, 'token_hash_algorithm', 'sha256')
|
||||||
|
|
||||||
|
# Hot reload configuration
|
||||||
|
self.enable_hot_reload = True
|
||||||
|
self.hot_reload_interval = 10 # Check every 10 seconds
|
||||||
|
self._file_last_modified = 0
|
||||||
|
self._hot_reload_task = None
|
||||||
|
|
||||||
|
# Initialize with default tokens if none exist
|
||||||
|
self._initialize_default_tokens()
|
||||||
|
|
||||||
|
# Load tokens from configuration
|
||||||
|
self._load_tokens()
|
||||||
|
|
||||||
|
# Start hot reload monitoring
|
||||||
|
if self.enable_hot_reload:
|
||||||
|
self._start_hot_reload()
|
||||||
|
|
||||||
|
self.logger.info(f"TokenManager initialized with {len(self._tokens)} tokens, hot reload: {self.enable_hot_reload}")
|
||||||
|
|
||||||
|
def _initialize_default_tokens(self):
|
||||||
|
"""Initialize default tokens for basic authentication (configurable via environment)"""
|
||||||
|
# Default token configurations (can be overridden by environment variables)
|
||||||
|
default_tokens = [
|
||||||
|
{
|
||||||
|
'token_id': 'admin-token',
|
||||||
|
'token': os.getenv('DEFAULT_ADMIN_TOKEN', 'doris_admin_token_123456'),
|
||||||
|
'description': os.getenv('DEFAULT_ADMIN_DESCRIPTION', 'Default admin API access token'),
|
||||||
|
'expires_hours': None # Never expires
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'token_id': 'analyst-token',
|
||||||
|
'token': os.getenv('DEFAULT_ANALYST_TOKEN', 'doris_analyst_token_123456'),
|
||||||
|
'description': os.getenv('DEFAULT_ANALYST_DESCRIPTION', 'Default data analysis API access token'),
|
||||||
|
'expires_hours': None # Never expires
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'token_id': 'readonly-token',
|
||||||
|
'token': os.getenv('DEFAULT_READONLY_TOKEN', 'doris_readonly_token_123456'),
|
||||||
|
'description': os.getenv('DEFAULT_READONLY_DESCRIPTION', 'Default read-only API access token'),
|
||||||
|
'expires_hours': None # Never expires
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Only add default tokens if no custom tokens are defined via environment variables
|
||||||
|
# Check if any TOKEN_* environment variables exist (excluding system and legacy configs)
|
||||||
|
excluded_prefixes = ('DEFAULT_', 'TOKEN_FILE_PATH', 'TOKEN_HASH_')
|
||||||
|
excluded_vars = {'TOKEN_SECRET', 'TOKEN_EXPIRY'}
|
||||||
|
|
||||||
|
custom_tokens_exist = any(
|
||||||
|
key.startswith('TOKEN_') and
|
||||||
|
not key.startswith(excluded_prefixes) and
|
||||||
|
not key.endswith(('_EXPIRES_HOURS', '_DESCRIPTION')) and
|
||||||
|
key not in excluded_vars
|
||||||
|
for key in os.environ.keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Also check if token file exists and has content
|
||||||
|
token_file_exists = False
|
||||||
|
if os.path.exists(self.token_file_path):
|
||||||
|
try:
|
||||||
|
with open(self.token_file_path, 'r') as f:
|
||||||
|
content = f.read().strip()
|
||||||
|
if content and content != '{}':
|
||||||
|
token_file_exists = True
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Add default tokens only if no custom configuration exists
|
||||||
|
if not custom_tokens_exist and not token_file_exists:
|
||||||
|
for token_config in default_tokens:
|
||||||
|
self._add_token_from_config(token_config)
|
||||||
|
|
||||||
|
self.logger.info(f"Initialized {len(default_tokens)} default tokens (no custom config found)")
|
||||||
|
else:
|
||||||
|
self.logger.info("Skipped default tokens initialization (custom tokens detected)")
|
||||||
|
|
||||||
|
def _add_token_from_config(self, token_config: Dict[str, Any]):
|
||||||
|
"""Add token from configuration with optional database binding"""
|
||||||
|
try:
|
||||||
|
# Calculate expiration time
|
||||||
|
expires_at = None
|
||||||
|
if self.enable_token_expiry:
|
||||||
|
expires_hours = token_config.get('expires_hours', self.default_token_expiry_hours)
|
||||||
|
if expires_hours is not None:
|
||||||
|
expires_at = datetime.utcnow() + timedelta(hours=expires_hours)
|
||||||
|
|
||||||
|
# Parse database configuration if provided
|
||||||
|
database_config = None
|
||||||
|
if 'database_config' in token_config:
|
||||||
|
db_config = token_config['database_config']
|
||||||
|
database_config = DatabaseConfig(
|
||||||
|
host=db_config.get('host', 'localhost'),
|
||||||
|
port=db_config.get('port', 9030),
|
||||||
|
user=db_config.get('user', 'root'),
|
||||||
|
password=db_config.get('password', ''),
|
||||||
|
database=db_config.get('database', 'information_schema'),
|
||||||
|
charset=db_config.get('charset', 'UTF8'),
|
||||||
|
fe_http_port=db_config.get('fe_http_port', 8030)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create token info
|
||||||
|
token_info = TokenInfo(
|
||||||
|
token_id=token_config['token_id'],
|
||||||
|
expires_at=expires_at,
|
||||||
|
description=token_config.get('description', ''),
|
||||||
|
is_active=token_config.get('is_active', True),
|
||||||
|
database_config=database_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hash the token
|
||||||
|
raw_token = token_config['token']
|
||||||
|
token_hash = self._hash_token(raw_token)
|
||||||
|
|
||||||
|
# Store token
|
||||||
|
self._tokens[token_hash] = token_info
|
||||||
|
self._token_ids[token_info.token_id] = token_hash
|
||||||
|
|
||||||
|
db_info = f" with DB binding ({database_config.host})" if database_config else ""
|
||||||
|
self.logger.debug(f"Added token '{token_info.token_id}'{db_info}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to add token from config: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _load_tokens(self):
|
||||||
|
"""Load tokens from configuration sources"""
|
||||||
|
# 1. Load from environment variables
|
||||||
|
self._load_tokens_from_env()
|
||||||
|
|
||||||
|
# 2. Load from token file if exists
|
||||||
|
if os.path.exists(self.token_file_path):
|
||||||
|
self._load_tokens_from_file()
|
||||||
|
|
||||||
|
self.logger.info(f"Token loading completed, total tokens: {len(self._tokens)}")
|
||||||
|
|
||||||
|
def _load_tokens_from_env(self):
|
||||||
|
"""Load tokens from environment variables
|
||||||
|
|
||||||
|
Simplified format:
|
||||||
|
TOKEN_<ID>=<token>
|
||||||
|
TOKEN_<ID>_EXPIRES_HOURS=<hours>
|
||||||
|
TOKEN_<ID>_DESCRIPTION=<description>
|
||||||
|
"""
|
||||||
|
token_prefixes = set()
|
||||||
|
|
||||||
|
# Find all TOKEN_ environment variables (exclude legacy and system variables)
|
||||||
|
excluded_token_vars = {
|
||||||
|
'TOKEN_SECRET', # Legacy token secret
|
||||||
|
'TOKEN_EXPIRY', # Legacy token expiry
|
||||||
|
'TOKEN_FILE_PATH', # System config
|
||||||
|
'TOKEN_HASH_ALGORITHM' # System config
|
||||||
|
}
|
||||||
|
|
||||||
|
for key in os.environ:
|
||||||
|
if (key.startswith('TOKEN_') and
|
||||||
|
not key.endswith(('_EXPIRES_HOURS', '_DESCRIPTION')) and
|
||||||
|
key not in excluded_token_vars):
|
||||||
|
token_id = key[6:] # Remove 'TOKEN_' prefix
|
||||||
|
token_prefixes.add(token_id)
|
||||||
|
|
||||||
|
# Load each token
|
||||||
|
for token_id in token_prefixes:
|
||||||
|
try:
|
||||||
|
token = os.environ.get(f'TOKEN_{token_id}')
|
||||||
|
if not token:
|
||||||
|
continue
|
||||||
|
|
||||||
|
expires_hours_str = os.environ.get(f'TOKEN_{token_id}_EXPIRES_HOURS', str(self.default_token_expiry_hours))
|
||||||
|
description = os.environ.get(f'TOKEN_{token_id}_DESCRIPTION', f'Environment token {token_id}')
|
||||||
|
|
||||||
|
expires_hours = None
|
||||||
|
try:
|
||||||
|
if expires_hours_str and expires_hours_str.lower() != 'none':
|
||||||
|
expires_hours = int(expires_hours_str)
|
||||||
|
except ValueError:
|
||||||
|
expires_hours = self.default_token_expiry_hours
|
||||||
|
|
||||||
|
# Add token
|
||||||
|
token_config = {
|
||||||
|
'token_id': token_id.lower(),
|
||||||
|
'token': token,
|
||||||
|
'expires_hours': expires_hours,
|
||||||
|
'description': description
|
||||||
|
}
|
||||||
|
|
||||||
|
self._add_token_from_config(token_config)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to load token {token_id} from environment: {e}")
|
||||||
|
|
||||||
|
def _load_tokens_from_file(self):
|
||||||
|
"""Load tokens from JSON file"""
|
||||||
|
try:
|
||||||
|
with open(self.token_file_path, 'r', encoding='utf-8') as f:
|
||||||
|
tokens_data = json.load(f)
|
||||||
|
|
||||||
|
if isinstance(tokens_data, dict) and 'tokens' in tokens_data:
|
||||||
|
tokens_list = tokens_data['tokens']
|
||||||
|
elif isinstance(tokens_data, list):
|
||||||
|
tokens_list = tokens_data
|
||||||
|
else:
|
||||||
|
self.logger.error(f"Invalid token file format: {self.token_file_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
for token_config in tokens_list:
|
||||||
|
self._add_token_from_config(token_config)
|
||||||
|
|
||||||
|
self.logger.info(f"Loaded {len(tokens_list)} tokens from file: {self.token_file_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to load tokens from file {self.token_file_path}: {e}")
|
||||||
|
|
||||||
|
def _hash_token(self, token: str) -> str:
|
||||||
|
"""Hash token for secure storage"""
|
||||||
|
if self.token_hash_algorithm == 'sha256':
|
||||||
|
return hashlib.sha256(token.encode('utf-8')).hexdigest()
|
||||||
|
elif self.token_hash_algorithm == 'sha512':
|
||||||
|
return hashlib.sha512(token.encode('utf-8')).hexdigest()
|
||||||
|
else:
|
||||||
|
# Fallback to sha256
|
||||||
|
return hashlib.sha256(token.encode('utf-8')).hexdigest()
|
||||||
|
|
||||||
|
async def validate_token(self, token: str) -> TokenValidationResult:
|
||||||
|
"""Validate token and return user information"""
|
||||||
|
try:
|
||||||
|
# Hash the provided token
|
||||||
|
token_hash = self._hash_token(token)
|
||||||
|
|
||||||
|
# Find token info
|
||||||
|
token_info = self._tokens.get(token_hash)
|
||||||
|
if not token_info:
|
||||||
|
return TokenValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message="Invalid token"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if token is active
|
||||||
|
if not token_info.is_active:
|
||||||
|
return TokenValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message="Token is inactive"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check expiration
|
||||||
|
if token_info.expires_at and datetime.utcnow() > token_info.expires_at:
|
||||||
|
return TokenValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message="Token has expired"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update last used time
|
||||||
|
token_info.last_used = datetime.utcnow()
|
||||||
|
|
||||||
|
return TokenValidationResult(
|
||||||
|
is_valid=True,
|
||||||
|
token_info=token_info
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Token validation error: {e}")
|
||||||
|
return TokenValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
error_message=f"Token validation failed: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_token(self, length: int = 32) -> str:
|
||||||
|
"""Generate a cryptographically secure random token"""
|
||||||
|
return secrets.token_urlsafe(length)
|
||||||
|
|
||||||
|
async def create_token(
|
||||||
|
self,
|
||||||
|
token_id: str,
|
||||||
|
expires_hours: Optional[int] = None,
|
||||||
|
description: str = "",
|
||||||
|
custom_token: Optional[str] = None,
|
||||||
|
database_config: Optional[DatabaseConfig] = None
|
||||||
|
) -> str:
|
||||||
|
"""Create a new token"""
|
||||||
|
try:
|
||||||
|
# Check if token_id already exists
|
||||||
|
if token_id in self._token_ids:
|
||||||
|
raise ValueError(f"Token ID '{token_id}' already exists")
|
||||||
|
|
||||||
|
# Generate or use provided token
|
||||||
|
if custom_token:
|
||||||
|
raw_token = custom_token
|
||||||
|
else:
|
||||||
|
raw_token = self.generate_token()
|
||||||
|
|
||||||
|
# Calculate expiration
|
||||||
|
expires_at = None
|
||||||
|
if expires_hours is not None:
|
||||||
|
expires_at = datetime.utcnow() + timedelta(hours=expires_hours)
|
||||||
|
elif self.enable_token_expiry:
|
||||||
|
expires_at = datetime.utcnow() + timedelta(hours=self.default_token_expiry_hours)
|
||||||
|
|
||||||
|
# Create token info
|
||||||
|
token_info = TokenInfo(
|
||||||
|
token_id=token_id,
|
||||||
|
expires_at=expires_at,
|
||||||
|
description=description,
|
||||||
|
database_config=database_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hash and store token
|
||||||
|
token_hash = self._hash_token(raw_token)
|
||||||
|
self._tokens[token_hash] = token_info
|
||||||
|
self._token_ids[token_id] = token_hash
|
||||||
|
|
||||||
|
self.logger.info(f"Created new token '{token_id}'")
|
||||||
|
|
||||||
|
# Save token to file
|
||||||
|
self._save_token_to_file(token_id, raw_token, token_info)
|
||||||
|
|
||||||
|
return raw_token
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to create token: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def revoke_token(self, token_id: str) -> bool:
|
||||||
|
"""Revoke a token by token ID"""
|
||||||
|
try:
|
||||||
|
if token_id not in self._token_ids:
|
||||||
|
self.logger.warning(f"Token ID '{token_id}' not found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Get token hash and remove from storage
|
||||||
|
token_hash = self._token_ids[token_id]
|
||||||
|
if token_hash in self._tokens:
|
||||||
|
del self._tokens[token_hash]
|
||||||
|
del self._token_ids[token_id]
|
||||||
|
|
||||||
|
self.logger.info(f"Revoked token '{token_id}'")
|
||||||
|
|
||||||
|
# Save updated tokens to file
|
||||||
|
self._remove_token_from_file(token_id)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to revoke token '{token_id}': {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _save_tokens_to_file(self):
|
||||||
|
"""Save current tokens to JSON file"""
|
||||||
|
try:
|
||||||
|
# Convert current tokens to file format
|
||||||
|
tokens_list = []
|
||||||
|
|
||||||
|
for token_hash, token_info in self._tokens.items():
|
||||||
|
# Find the raw token for this token_info
|
||||||
|
raw_token = None
|
||||||
|
for tid, thash in self._token_ids.items():
|
||||||
|
if thash == token_hash and tid == token_info.token_id:
|
||||||
|
# We can't recover the original token from hash,
|
||||||
|
# so we'll create a placeholder for existing tokens
|
||||||
|
raw_token = f"<existing_token_hash_{token_hash[:8]}>"
|
||||||
|
break
|
||||||
|
|
||||||
|
if raw_token is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
token_config = {
|
||||||
|
"token_id": token_info.token_id,
|
||||||
|
"token": raw_token,
|
||||||
|
"description": token_info.description,
|
||||||
|
"expires_hours": None,
|
||||||
|
"is_active": token_info.is_active
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add expiration info
|
||||||
|
if token_info.expires_at:
|
||||||
|
# Calculate remaining hours from now
|
||||||
|
remaining = token_info.expires_at - datetime.utcnow()
|
||||||
|
if remaining.total_seconds() > 0:
|
||||||
|
token_config["expires_hours"] = int(remaining.total_seconds() / 3600)
|
||||||
|
else:
|
||||||
|
token_config["expires_hours"] = 0
|
||||||
|
|
||||||
|
# Add database config if present
|
||||||
|
if token_info.database_config:
|
||||||
|
token_config["database_config"] = {
|
||||||
|
"host": token_info.database_config.host,
|
||||||
|
"port": token_info.database_config.port,
|
||||||
|
"user": token_info.database_config.user,
|
||||||
|
"password": token_info.database_config.password,
|
||||||
|
"database": token_info.database_config.database,
|
||||||
|
"charset": token_info.database_config.charset,
|
||||||
|
"fe_http_port": token_info.database_config.fe_http_port
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens_list.append(token_config)
|
||||||
|
|
||||||
|
# Create file content
|
||||||
|
file_content = {
|
||||||
|
"version": "1.0",
|
||||||
|
"description": "Doris MCP Server Token configuration file",
|
||||||
|
"created_at": datetime.utcnow().isoformat() + "Z",
|
||||||
|
"tokens": tokens_list,
|
||||||
|
"notes": [
|
||||||
|
"This file is automatically updated when tokens are created or revoked",
|
||||||
|
"Please backup this file before making manual changes",
|
||||||
|
"Tokens with hash placeholders were loaded from previous configuration"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save to file
|
||||||
|
with open(self.token_file_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(file_content, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
self.logger.info(f"Saved {len(tokens_list)} tokens to file: {self.token_file_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to save tokens to file {self.token_file_path}: {e}")
|
||||||
|
|
||||||
|
def _save_token_to_file(self, token_id: str, raw_token: str, token_info: TokenInfo):
|
||||||
|
"""Save a single new token to file (for newly created tokens only)"""
|
||||||
|
try:
|
||||||
|
# Load existing file
|
||||||
|
existing_data = {"tokens": []}
|
||||||
|
if os.path.exists(self.token_file_path):
|
||||||
|
try:
|
||||||
|
with open(self.token_file_path, 'r', encoding='utf-8') as f:
|
||||||
|
existing_data = json.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Could not load existing token file: {e}")
|
||||||
|
|
||||||
|
# Ensure tokens list exists
|
||||||
|
if 'tokens' not in existing_data or not isinstance(existing_data['tokens'], list):
|
||||||
|
existing_data['tokens'] = []
|
||||||
|
|
||||||
|
# Check if token already exists in file
|
||||||
|
token_exists = False
|
||||||
|
for i, token_config in enumerate(existing_data['tokens']):
|
||||||
|
if token_config.get('token_id') == token_id:
|
||||||
|
# Update existing token
|
||||||
|
existing_data['tokens'][i] = self._token_info_to_config(token_id, raw_token, token_info)
|
||||||
|
token_exists = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# Add new token if it doesn't exist
|
||||||
|
if not token_exists:
|
||||||
|
new_token_config = self._token_info_to_config(token_id, raw_token, token_info)
|
||||||
|
existing_data['tokens'].append(new_token_config)
|
||||||
|
|
||||||
|
# Update metadata
|
||||||
|
existing_data.update({
|
||||||
|
"version": "1.0",
|
||||||
|
"description": "Doris MCP Server Token configuration file",
|
||||||
|
"updated_at": datetime.utcnow().isoformat() + "Z"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Save to file
|
||||||
|
with open(self.token_file_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
self.logger.info(f"Saved token '{token_id}' to file: {self.token_file_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to save token '{token_id}' to file: {e}")
|
||||||
|
|
||||||
|
def _token_info_to_config(self, token_id: str, raw_token: str, token_info: TokenInfo) -> dict:
|
||||||
|
"""Convert TokenInfo to file configuration format"""
|
||||||
|
token_config = {
|
||||||
|
"token_id": token_id,
|
||||||
|
"token": raw_token,
|
||||||
|
"description": token_info.description,
|
||||||
|
"expires_hours": None,
|
||||||
|
"is_active": token_info.is_active
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add expiration info
|
||||||
|
if token_info.expires_at:
|
||||||
|
# Calculate remaining hours from creation time
|
||||||
|
remaining = token_info.expires_at - token_info.created_at
|
||||||
|
token_config["expires_hours"] = int(remaining.total_seconds() / 3600) if remaining.total_seconds() > 0 else None
|
||||||
|
|
||||||
|
# Add database config if present
|
||||||
|
if token_info.database_config:
|
||||||
|
token_config["database_config"] = {
|
||||||
|
"host": token_info.database_config.host,
|
||||||
|
"port": token_info.database_config.port,
|
||||||
|
"user": token_info.database_config.user,
|
||||||
|
"password": token_info.database_config.password,
|
||||||
|
"database": token_info.database_config.database,
|
||||||
|
"charset": token_info.database_config.charset,
|
||||||
|
"fe_http_port": token_info.database_config.fe_http_port
|
||||||
|
}
|
||||||
|
|
||||||
|
return token_config
|
||||||
|
|
||||||
|
def _remove_token_from_file(self, token_id: str):
|
||||||
|
"""Remove a token from the JSON file"""
|
||||||
|
try:
|
||||||
|
if not os.path.exists(self.token_file_path):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load existing file
|
||||||
|
with open(self.token_file_path, 'r', encoding='utf-8') as f:
|
||||||
|
existing_data = json.load(f)
|
||||||
|
|
||||||
|
if 'tokens' not in existing_data or not isinstance(existing_data['tokens'], list):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Remove the token
|
||||||
|
original_count = len(existing_data['tokens'])
|
||||||
|
existing_data['tokens'] = [
|
||||||
|
token for token in existing_data['tokens']
|
||||||
|
if token.get('token_id') != token_id
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(existing_data['tokens']) < original_count:
|
||||||
|
# Update metadata
|
||||||
|
existing_data.update({
|
||||||
|
"version": "1.0",
|
||||||
|
"description": "Doris MCP Server Token configuration file",
|
||||||
|
"updated_at": datetime.utcnow().isoformat() + "Z"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Save to file
|
||||||
|
with open(self.token_file_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
self.logger.info(f"Removed token '{token_id}' from file: {self.token_file_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to remove token '{token_id}' from file: {e}")
|
||||||
|
|
||||||
|
async def list_tokens(self) -> List[Dict[str, Any]]:
|
||||||
|
"""List all tokens (without sensitive data)"""
|
||||||
|
tokens = []
|
||||||
|
|
||||||
|
for token_hash, token_info in self._tokens.items():
|
||||||
|
token_data = {
|
||||||
|
'token_id': token_info.token_id,
|
||||||
|
'created_at': token_info.created_at.isoformat(),
|
||||||
|
'expires_at': token_info.expires_at.isoformat() if token_info.expires_at else None,
|
||||||
|
'last_used': token_info.last_used.isoformat() if token_info.last_used else None,
|
||||||
|
'is_active': token_info.is_active,
|
||||||
|
'description': token_info.description,
|
||||||
|
'is_expired': token_info.expires_at and datetime.utcnow() > token_info.expires_at if token_info.expires_at else False
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add database binding info (without sensitive data)
|
||||||
|
if token_info.database_config:
|
||||||
|
token_data['database_binding'] = {
|
||||||
|
'host': token_info.database_config.host,
|
||||||
|
'port': token_info.database_config.port,
|
||||||
|
'user': token_info.database_config.user,
|
||||||
|
'database': token_info.database_config.database,
|
||||||
|
'has_password': bool(token_info.database_config.password)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
token_data['database_binding'] = None
|
||||||
|
|
||||||
|
tokens.append(token_data)
|
||||||
|
|
||||||
|
# Sort by creation time
|
||||||
|
tokens.sort(key=lambda x: x['created_at'], reverse=True)
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
async def cleanup_expired_tokens(self) -> int:
|
||||||
|
"""Remove expired tokens and return count"""
|
||||||
|
if not self.enable_token_expiry:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
now = datetime.utcnow()
|
||||||
|
expired_tokens = []
|
||||||
|
|
||||||
|
# Find expired tokens
|
||||||
|
for token_hash, token_info in self._tokens.items():
|
||||||
|
if token_info.expires_at and now > token_info.expires_at:
|
||||||
|
expired_tokens.append((token_hash, token_info.token_id))
|
||||||
|
|
||||||
|
# Remove expired tokens
|
||||||
|
for token_hash, token_id in expired_tokens:
|
||||||
|
del self._tokens[token_hash]
|
||||||
|
if token_id in self._token_ids:
|
||||||
|
del self._token_ids[token_id]
|
||||||
|
|
||||||
|
if expired_tokens:
|
||||||
|
self.logger.info(f"Cleaned up {len(expired_tokens)} expired tokens")
|
||||||
|
|
||||||
|
return len(expired_tokens)
|
||||||
|
|
||||||
|
async def save_tokens_to_file(self, file_path: Optional[str] = None) -> bool:
|
||||||
|
"""Save current tokens to JSON file"""
|
||||||
|
try:
|
||||||
|
file_path = file_path or self.token_file_path
|
||||||
|
tokens_list = await self.list_tokens()
|
||||||
|
|
||||||
|
tokens_data = {
|
||||||
|
'version': '1.0',
|
||||||
|
'created_at': datetime.utcnow().isoformat(),
|
||||||
|
'tokens': tokens_list
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(file_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(tokens_data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
self.logger.info(f"Saved {len(tokens_list)} tokens to file: {file_path}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to save tokens to file: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_database_config_by_token(self, token: str) -> Optional[DatabaseConfig]:
|
||||||
|
"""Get database configuration bound to a token
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The raw token string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DatabaseConfig if token exists and has database binding, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
token_hash = self._hash_token(token)
|
||||||
|
token_info = self._tokens.get(token_hash)
|
||||||
|
|
||||||
|
if not token_info or not token_info.is_active:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check expiration
|
||||||
|
if token_info.expires_at and datetime.utcnow() > token_info.expires_at:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return token_info.database_config
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to get database config for token: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_token_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get token statistics"""
|
||||||
|
now = datetime.utcnow()
|
||||||
|
total_tokens = len(self._tokens)
|
||||||
|
active_tokens = sum(1 for info in self._tokens.values() if info.is_active)
|
||||||
|
expired_tokens = sum(1 for info in self._tokens.values()
|
||||||
|
if info.expires_at and now > info.expires_at)
|
||||||
|
tokens_with_db = sum(1 for info in self._tokens.values()
|
||||||
|
if info.database_config is not None)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total_tokens': total_tokens,
|
||||||
|
'active_tokens': active_tokens,
|
||||||
|
'expired_tokens': expired_tokens,
|
||||||
|
'tokens_with_database_binding': tokens_with_db,
|
||||||
|
'expiry_enabled': self.enable_token_expiry,
|
||||||
|
'default_expiry_hours': self.default_token_expiry_hours,
|
||||||
|
'hot_reload_enabled': self.enable_hot_reload,
|
||||||
|
'last_file_check': datetime.fromtimestamp(self._file_last_modified).isoformat() if self._file_last_modified else None
|
||||||
|
}
|
||||||
|
|
||||||
|
def _start_hot_reload(self):
|
||||||
|
"""Start hot reload monitoring task"""
|
||||||
|
if self._hot_reload_task:
|
||||||
|
return # Already running
|
||||||
|
|
||||||
|
# Update initial file modification time
|
||||||
|
self._update_file_modified_time()
|
||||||
|
|
||||||
|
# Start monitoring task
|
||||||
|
self._hot_reload_task = asyncio.create_task(self._hot_reload_monitor())
|
||||||
|
self.logger.info(f"Started hot reload monitoring for {self.token_file_path}")
|
||||||
|
|
||||||
|
def stop_hot_reload(self):
|
||||||
|
"""Stop hot reload monitoring"""
|
||||||
|
if self._hot_reload_task:
|
||||||
|
self._hot_reload_task.cancel()
|
||||||
|
self._hot_reload_task = None
|
||||||
|
self.logger.info("Stopped hot reload monitoring")
|
||||||
|
|
||||||
|
def _update_file_modified_time(self):
|
||||||
|
"""Update the last modified time of tokens file"""
|
||||||
|
try:
|
||||||
|
if os.path.exists(self.token_file_path):
|
||||||
|
self._file_last_modified = os.path.getmtime(self.token_file_path)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.debug(f"Failed to get file modification time: {e}")
|
||||||
|
|
||||||
|
async def _hot_reload_monitor(self):
|
||||||
|
"""Background task to monitor tokens.json file changes"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(self.hot_reload_interval)
|
||||||
|
|
||||||
|
if not os.path.exists(self.token_file_path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if file was modified
|
||||||
|
current_mtime = os.path.getmtime(self.token_file_path)
|
||||||
|
if current_mtime > self._file_last_modified:
|
||||||
|
self.logger.info(f"Detected changes in {self.token_file_path}, reloading tokens...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Backup current tokens
|
||||||
|
old_tokens = self._tokens.copy()
|
||||||
|
old_token_ids = self._token_ids.copy()
|
||||||
|
|
||||||
|
# Clear and reload
|
||||||
|
self._tokens.clear()
|
||||||
|
self._token_ids.clear()
|
||||||
|
|
||||||
|
# Reinitialize default tokens
|
||||||
|
self._initialize_default_tokens()
|
||||||
|
|
||||||
|
# Load from file
|
||||||
|
self._load_tokens_from_file()
|
||||||
|
|
||||||
|
# Update modification time
|
||||||
|
self._file_last_modified = current_mtime
|
||||||
|
|
||||||
|
self.logger.info(f"Hot reload completed, {len(self._tokens)} tokens loaded")
|
||||||
|
|
||||||
|
except Exception as reload_error:
|
||||||
|
# Restore backup on failure
|
||||||
|
self.logger.error(f"Hot reload failed, restoring previous tokens: {reload_error}")
|
||||||
|
self._tokens = old_tokens
|
||||||
|
self._token_ids = old_token_ids
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
self.logger.info("Hot reload monitor stopped")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in hot reload monitor: {e}")
|
||||||
227
doris_mcp_server/auth/token_security_middleware.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
Token Management Security Middleware
|
||||||
|
|
||||||
|
Provides comprehensive security controls for token management endpoints including
|
||||||
|
IP restrictions, admin authentication, and configuration-based access control.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import ipaddress
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
|
from ..utils.logger import get_logger
|
||||||
|
from ..utils.config import DorisConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TokenSecurityMiddleware:
|
||||||
|
"""Security middleware for token management endpoints"""
|
||||||
|
|
||||||
|
def __init__(self, config: DorisConfig):
|
||||||
|
self.config = config
|
||||||
|
self.logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Initialize admin token hash if provided
|
||||||
|
self._admin_token_hash = None
|
||||||
|
if config.security.token_management_admin_token:
|
||||||
|
self._admin_token_hash = self._hash_token(config.security.token_management_admin_token)
|
||||||
|
|
||||||
|
# Normalize allowed IPs
|
||||||
|
self._allowed_networks = self._parse_allowed_networks(
|
||||||
|
config.security.token_management_allowed_ips
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info(f"Token management security initialized: "
|
||||||
|
f"HTTP endpoints {'enabled' if config.security.enable_http_token_management else 'disabled'}, "
|
||||||
|
f"Admin auth {'required' if config.security.require_admin_auth else 'optional'}, "
|
||||||
|
f"Allowed networks: {len(self._allowed_networks)}")
|
||||||
|
|
||||||
|
def _hash_token(self, token: str) -> str:
|
||||||
|
"""Hash token using SHA-256"""
|
||||||
|
return hashlib.sha256(token.encode('utf-8')).hexdigest()
|
||||||
|
|
||||||
|
def _parse_allowed_networks(self, allowed_ips: List[str]) -> List[ipaddress.IPv4Network | ipaddress.IPv6Network]:
|
||||||
|
"""Parse allowed IP addresses and networks"""
|
||||||
|
networks = []
|
||||||
|
for ip_str in allowed_ips:
|
||||||
|
try:
|
||||||
|
# Try to parse as network (CIDR notation)
|
||||||
|
if '/' in ip_str:
|
||||||
|
networks.append(ipaddress.ip_network(ip_str, strict=False))
|
||||||
|
else:
|
||||||
|
# Parse as single IP and convert to /32 network
|
||||||
|
ip = ipaddress.ip_address(ip_str)
|
||||||
|
if isinstance(ip, ipaddress.IPv4Address):
|
||||||
|
networks.append(ipaddress.IPv4Network(f"{ip}/32"))
|
||||||
|
else:
|
||||||
|
networks.append(ipaddress.IPv6Network(f"{ip}/128"))
|
||||||
|
except ValueError as e:
|
||||||
|
self.logger.warning(f"Invalid IP/network '{ip_str}': {e}")
|
||||||
|
|
||||||
|
return networks
|
||||||
|
|
||||||
|
def _get_client_ip(self, request: Request) -> str:
|
||||||
|
"""Extract client IP from request, considering proxies"""
|
||||||
|
# Check X-Forwarded-For header first (for proxy setups)
|
||||||
|
forwarded_for = request.headers.get('X-Forwarded-For')
|
||||||
|
if forwarded_for:
|
||||||
|
# Take the first IP (original client)
|
||||||
|
client_ip = forwarded_for.split(',')[0].strip()
|
||||||
|
elif request.headers.get('X-Real-IP'):
|
||||||
|
client_ip = request.headers.get('X-Real-IP')
|
||||||
|
else:
|
||||||
|
# Direct connection
|
||||||
|
client_ip = request.client.host if request.client else "unknown"
|
||||||
|
|
||||||
|
return client_ip
|
||||||
|
|
||||||
|
def _is_ip_allowed(self, client_ip: str) -> bool:
|
||||||
|
"""Check if client IP is in allowed networks"""
|
||||||
|
try:
|
||||||
|
client_addr = ipaddress.ip_address(client_ip)
|
||||||
|
|
||||||
|
for network in self._allowed_networks:
|
||||||
|
if client_addr in network:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
self.logger.warning(f"Invalid client IP format: {client_ip}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _extract_admin_token(self, request: Request) -> Optional[str]:
|
||||||
|
"""Extract admin token from request headers"""
|
||||||
|
# Try Authorization header first
|
||||||
|
auth_header = request.headers.get('Authorization', '')
|
||||||
|
if auth_header.startswith('Bearer '):
|
||||||
|
return auth_header[7:]
|
||||||
|
elif auth_header.startswith('Token '):
|
||||||
|
return auth_header[6:]
|
||||||
|
|
||||||
|
# Try X-Admin-Token header
|
||||||
|
admin_token = request.headers.get('X-Admin-Token', '')
|
||||||
|
if admin_token:
|
||||||
|
return admin_token
|
||||||
|
|
||||||
|
# Try query parameter as fallback (not recommended for production)
|
||||||
|
admin_token = request.query_params.get('admin_token', '')
|
||||||
|
if admin_token:
|
||||||
|
self.logger.warning("Admin token passed via query parameter - this is insecure for production")
|
||||||
|
return admin_token
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _verify_admin_token(self, provided_token: str) -> bool:
|
||||||
|
"""Verify provided admin token against configured token"""
|
||||||
|
if not self._admin_token_hash:
|
||||||
|
self.logger.warning("No admin token configured for token management")
|
||||||
|
return False
|
||||||
|
|
||||||
|
provided_hash = self._hash_token(provided_token)
|
||||||
|
|
||||||
|
# Use constant-time comparison to prevent timing attacks
|
||||||
|
return hmac.compare_digest(self._admin_token_hash, provided_hash)
|
||||||
|
|
||||||
|
async def check_token_management_access(self, request: Request) -> Optional[JSONResponse]:
|
||||||
|
"""
|
||||||
|
Check if request is authorized for token management operations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None if access is granted
|
||||||
|
JSONResponse with error if access is denied
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Check if HTTP token management is enabled
|
||||||
|
if not self.config.security.enable_http_token_management:
|
||||||
|
self.logger.warning(f"Token management endpoint access denied - HTTP management disabled: {request.url.path}")
|
||||||
|
return JSONResponse({
|
||||||
|
"error": "Token management endpoints are disabled for security",
|
||||||
|
"message": "HTTP token management is disabled. Use file-based token management instead.",
|
||||||
|
"suggestion": "Edit tokens.json file directly or enable HTTP management with proper security configuration"
|
||||||
|
}, status_code=403)
|
||||||
|
|
||||||
|
# Extract client IP
|
||||||
|
client_ip = self._get_client_ip(request)
|
||||||
|
|
||||||
|
# Check IP restrictions
|
||||||
|
if not self._is_ip_allowed(client_ip):
|
||||||
|
self.logger.warning(f"Token management access denied for IP {client_ip}: not in allowed list")
|
||||||
|
return JSONResponse({
|
||||||
|
"error": "Access denied - IP not allowed",
|
||||||
|
"client_ip": client_ip,
|
||||||
|
"message": "Token management is restricted to specific IP addresses",
|
||||||
|
"allowed_networks": [str(net) for net in self._allowed_networks]
|
||||||
|
}, status_code=403)
|
||||||
|
|
||||||
|
# Check admin authentication if required
|
||||||
|
if self.config.security.require_admin_auth:
|
||||||
|
admin_token = self._extract_admin_token(request)
|
||||||
|
|
||||||
|
if not admin_token:
|
||||||
|
self.logger.warning(f"Token management access denied for IP {client_ip}: missing admin token")
|
||||||
|
return JSONResponse({
|
||||||
|
"error": "Admin authentication required",
|
||||||
|
"message": "Token management requires admin authentication",
|
||||||
|
"hint": "Provide admin token in Authorization header: 'Bearer <admin_token>' or 'X-Admin-Token: <admin_token>'"
|
||||||
|
}, status_code=401)
|
||||||
|
|
||||||
|
if not self._verify_admin_token(admin_token):
|
||||||
|
self.logger.warning(f"Token management access denied for IP {client_ip}: invalid admin token")
|
||||||
|
return JSONResponse({
|
||||||
|
"error": "Invalid admin token",
|
||||||
|
"message": "The provided admin token is invalid"
|
||||||
|
}, status_code=401)
|
||||||
|
|
||||||
|
# Log successful access
|
||||||
|
self.logger.info(f"Token management access granted for IP {client_ip} to {request.url.path}")
|
||||||
|
|
||||||
|
# Access granted
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_security_info(self) -> Dict[str, Any]:
|
||||||
|
"""Get current security configuration info (for demo/status pages)"""
|
||||||
|
return {
|
||||||
|
"http_token_management_enabled": self.config.security.enable_http_token_management,
|
||||||
|
"admin_auth_required": self.config.security.require_admin_auth,
|
||||||
|
"admin_token_configured": bool(self._admin_token_hash),
|
||||||
|
"allowed_networks_count": len(self._allowed_networks),
|
||||||
|
"allowed_networks": [str(net) for net in self._allowed_networks],
|
||||||
|
"security_features": [
|
||||||
|
"IP address restrictions",
|
||||||
|
"Admin token authentication" if self.config.security.require_admin_auth else "Optional admin authentication",
|
||||||
|
"Secure token hashing",
|
||||||
|
"Request logging and auditing"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
def generate_admin_token(self) -> str:
|
||||||
|
"""Generate a secure admin token"""
|
||||||
|
return secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience function for middleware creation
|
||||||
|
def create_token_security_middleware(config: DorisConfig) -> TokenSecurityMiddleware:
|
||||||
|
"""Create token security middleware with configuration"""
|
||||||
|
return TokenSecurityMiddleware(config)
|
||||||
365
doris_mcp_server/auth/token_validators.py
Normal file
@@ -0,0 +1,365 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
JWT Token Validation Module
|
||||||
|
Provides token validation, blacklist management and security features
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from typing import Dict, Set, Optional, Any
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from ..utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenBlacklist:
|
||||||
|
"""JWT Token Blacklist Manager
|
||||||
|
|
||||||
|
Manages revoked tokens to prevent revoked tokens from being used again
|
||||||
|
Supports both in-memory and persistent storage
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cleanup_interval: int = 3600):
|
||||||
|
"""Initialize token blacklist
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cleanup_interval: Interval for cleaning up expired tokens (seconds)
|
||||||
|
"""
|
||||||
|
self.cleanup_interval = cleanup_interval
|
||||||
|
# Storage format: {token_jti: expiry_timestamp}
|
||||||
|
self._blacklisted_tokens: Dict[str, float] = {}
|
||||||
|
self._cleanup_task = None
|
||||||
|
|
||||||
|
logger.info("TokenBlacklist initialized")
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Start blacklist manager"""
|
||||||
|
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||||
|
logger.info("TokenBlacklist started with periodic cleanup")
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""Stop blacklist manager"""
|
||||||
|
if self._cleanup_task:
|
||||||
|
self._cleanup_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._cleanup_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
logger.info("TokenBlacklist stopped")
|
||||||
|
|
||||||
|
async def add_token(self, jti: str, exp: float):
|
||||||
|
"""Add token to blacklist
|
||||||
|
|
||||||
|
Args:
|
||||||
|
jti: JWT ID (unique identifier)
|
||||||
|
exp: Token expiration timestamp
|
||||||
|
"""
|
||||||
|
self._blacklisted_tokens[jti] = exp
|
||||||
|
logger.info(f"Token {jti} added to blacklist")
|
||||||
|
|
||||||
|
async def is_blacklisted(self, jti: str) -> bool:
|
||||||
|
"""Check if token is blacklisted
|
||||||
|
|
||||||
|
Args:
|
||||||
|
jti: JWT ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if blacklisted, False otherwise
|
||||||
|
"""
|
||||||
|
return jti in self._blacklisted_tokens
|
||||||
|
|
||||||
|
async def remove_token(self, jti: str) -> bool:
|
||||||
|
"""Remove token from blacklist
|
||||||
|
|
||||||
|
Args:
|
||||||
|
jti: JWT ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if removed, False if not found
|
||||||
|
"""
|
||||||
|
if jti in self._blacklisted_tokens:
|
||||||
|
del self._blacklisted_tokens[jti]
|
||||||
|
logger.info(f"Token {jti} removed from blacklist")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def cleanup_expired(self) -> int:
|
||||||
|
"""Clean up expired blacklisted tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tokens cleaned up
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
expired_tokens = [
|
||||||
|
jti for jti, exp in self._blacklisted_tokens.items()
|
||||||
|
if exp <= current_time
|
||||||
|
]
|
||||||
|
|
||||||
|
for jti in expired_tokens:
|
||||||
|
del self._blacklisted_tokens[jti]
|
||||||
|
|
||||||
|
if expired_tokens:
|
||||||
|
logger.info(f"Cleaned up {len(expired_tokens)} expired tokens from blacklist")
|
||||||
|
|
||||||
|
return len(expired_tokens)
|
||||||
|
|
||||||
|
async def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get blacklist statistics"""
|
||||||
|
current_time = time.time()
|
||||||
|
active_tokens = sum(1 for exp in self._blacklisted_tokens.values() if exp > current_time)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_blacklisted": len(self._blacklisted_tokens),
|
||||||
|
"active_blacklisted": active_tokens,
|
||||||
|
"expired_blacklisted": len(self._blacklisted_tokens) - active_tokens,
|
||||||
|
"cleanup_interval": self.cleanup_interval
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _periodic_cleanup(self):
|
||||||
|
"""Periodically clean up expired tokens"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(self.cleanup_interval)
|
||||||
|
await self.cleanup_expired()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during periodic cleanup: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""Token usage rate limiter"""
|
||||||
|
|
||||||
|
def __init__(self, max_requests: int = 100, time_window: int = 3600):
|
||||||
|
"""Initialize rate limiter
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_requests: Maximum requests within time window
|
||||||
|
time_window: Time window in seconds
|
||||||
|
"""
|
||||||
|
self.max_requests = max_requests
|
||||||
|
self.time_window = time_window
|
||||||
|
# Storage format: {user_id: [timestamp1, timestamp2, ...]}
|
||||||
|
self._request_history: Dict[str, list] = defaultdict(list)
|
||||||
|
|
||||||
|
logger.info(f"RateLimiter initialized: {max_requests} requests per {time_window} seconds")
|
||||||
|
|
||||||
|
async def is_allowed(self, user_id: str) -> bool:
|
||||||
|
"""Check if user is allowed to make request
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if allowed, False otherwise
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
user_requests = self._request_history[user_id]
|
||||||
|
|
||||||
|
# Clean up expired request records
|
||||||
|
cutoff_time = current_time - self.time_window
|
||||||
|
user_requests[:] = [t for t in user_requests if t > cutoff_time]
|
||||||
|
|
||||||
|
# Check if limit exceeded
|
||||||
|
if len(user_requests) >= self.max_requests:
|
||||||
|
logger.warning(f"Rate limit exceeded for user {user_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Record current request
|
||||||
|
user_requests.append(current_time)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def get_usage(self, user_id: str) -> Dict[str, Any]:
|
||||||
|
"""Get user usage information
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Usage statistics
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
user_requests = self._request_history[user_id]
|
||||||
|
|
||||||
|
# Clean up expired records
|
||||||
|
cutoff_time = current_time - self.time_window
|
||||||
|
active_requests = [t for t in user_requests if t > cutoff_time]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"user_id": user_id,
|
||||||
|
"requests_in_window": len(active_requests),
|
||||||
|
"max_requests": self.max_requests,
|
||||||
|
"time_window": self.time_window,
|
||||||
|
"remaining_requests": max(0, self.max_requests - len(active_requests))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TokenValidator:
|
||||||
|
"""JWT Token Validator
|
||||||
|
|
||||||
|
Provides comprehensive JWT token validation functionality, including signature verification,
|
||||||
|
claim validation, blacklist checking and rate limiting
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, blacklist: Optional[TokenBlacklist] = None):
|
||||||
|
"""Initialize token validator
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: DorisConfig configuration object (with security attribute)
|
||||||
|
blacklist: Token blacklist manager
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self.blacklist = blacklist or TokenBlacklist()
|
||||||
|
self.rate_limiter = RateLimiter()
|
||||||
|
|
||||||
|
# Access JWT settings through the security configuration
|
||||||
|
if hasattr(config, 'security'):
|
||||||
|
security_config = config.security
|
||||||
|
else:
|
||||||
|
# Fallback if config is passed directly as SecurityConfig
|
||||||
|
security_config = config
|
||||||
|
|
||||||
|
# Validation options
|
||||||
|
self.verify_signature = security_config.jwt_verify_signature
|
||||||
|
self.verify_audience = security_config.jwt_verify_audience
|
||||||
|
self.verify_issuer = security_config.jwt_verify_issuer
|
||||||
|
self.require_exp = security_config.jwt_require_exp
|
||||||
|
self.require_iat = security_config.jwt_require_iat
|
||||||
|
self.require_nbf = security_config.jwt_require_nbf
|
||||||
|
self.leeway = security_config.jwt_leeway
|
||||||
|
|
||||||
|
# Expected values
|
||||||
|
self.expected_audience = security_config.jwt_audience
|
||||||
|
self.expected_issuer = security_config.jwt_issuer
|
||||||
|
|
||||||
|
logger.info("TokenValidator initialized")
|
||||||
|
|
||||||
|
async def validate_claims(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Validate JWT claims
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: JWT payload
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validation result
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Validation failed
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# Validate issuer
|
||||||
|
if self.verify_issuer:
|
||||||
|
if payload.get('iss') != self.expected_issuer:
|
||||||
|
raise ValueError(f"Invalid issuer: expected {self.expected_issuer}")
|
||||||
|
|
||||||
|
# Validate audience
|
||||||
|
if self.verify_audience:
|
||||||
|
aud = payload.get('aud')
|
||||||
|
if isinstance(aud, list):
|
||||||
|
if self.expected_audience not in aud:
|
||||||
|
raise ValueError(f"Invalid audience: {self.expected_audience} not in {aud}")
|
||||||
|
elif aud != self.expected_audience:
|
||||||
|
raise ValueError(f"Invalid audience: expected {self.expected_audience}")
|
||||||
|
|
||||||
|
# Validate expiration time
|
||||||
|
if self.require_exp or 'exp' in payload:
|
||||||
|
exp = payload.get('exp')
|
||||||
|
if not exp:
|
||||||
|
raise ValueError("Missing 'exp' claim")
|
||||||
|
if current_time > exp + self.leeway:
|
||||||
|
raise ValueError("Token has expired")
|
||||||
|
|
||||||
|
# Validate not before time
|
||||||
|
if self.require_nbf or 'nbf' in payload:
|
||||||
|
nbf = payload.get('nbf')
|
||||||
|
if not nbf:
|
||||||
|
raise ValueError("Missing 'nbf' claim")
|
||||||
|
if current_time < nbf - self.leeway:
|
||||||
|
raise ValueError("Token not yet valid")
|
||||||
|
|
||||||
|
# Validate issued at time
|
||||||
|
if self.require_iat or 'iat' in payload:
|
||||||
|
iat = payload.get('iat')
|
||||||
|
if not iat:
|
||||||
|
raise ValueError("Missing 'iat' claim")
|
||||||
|
# Allow some clock skew, but cannot be future time
|
||||||
|
if iat > current_time + self.leeway:
|
||||||
|
raise ValueError("Token issued in the future")
|
||||||
|
|
||||||
|
# Check blacklist
|
||||||
|
jti = payload.get('jti')
|
||||||
|
if jti and await self.blacklist.is_blacklisted(jti):
|
||||||
|
raise ValueError("Token has been revoked")
|
||||||
|
|
||||||
|
# Rate limit check
|
||||||
|
user_id = payload.get('sub')
|
||||||
|
if user_id:
|
||||||
|
if not await self.rate_limiter.is_allowed(user_id):
|
||||||
|
raise ValueError("Rate limit exceeded")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"valid": True,
|
||||||
|
"user_id": user_id,
|
||||||
|
"payload": payload
|
||||||
|
}
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Start validator"""
|
||||||
|
await self.blacklist.start()
|
||||||
|
logger.info("TokenValidator started")
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""Stop validator"""
|
||||||
|
await self.blacklist.stop()
|
||||||
|
logger.info("TokenValidator stopped")
|
||||||
|
|
||||||
|
async def revoke_token(self, jti: str, exp: float):
|
||||||
|
"""Revoke token
|
||||||
|
|
||||||
|
Args:
|
||||||
|
jti: JWT ID
|
||||||
|
exp: Token expiration time
|
||||||
|
"""
|
||||||
|
await self.blacklist.add_token(jti, exp)
|
||||||
|
logger.info(f"Token {jti} has been revoked")
|
||||||
|
|
||||||
|
async def get_validation_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get validation statistics"""
|
||||||
|
blacklist_stats = await self.blacklist.get_stats()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"blacklist": blacklist_stats,
|
||||||
|
"validation_config": {
|
||||||
|
"verify_signature": self.verify_signature,
|
||||||
|
"verify_audience": self.verify_audience,
|
||||||
|
"verify_issuer": self.verify_issuer,
|
||||||
|
"require_exp": self.require_exp,
|
||||||
|
"require_iat": self.require_iat,
|
||||||
|
"require_nbf": self.require_nbf,
|
||||||
|
"leeway": self.leeway
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_user_rate_limit_info(self, user_id: str) -> Dict[str, Any]:
|
||||||
|
"""Get user rate limit information"""
|
||||||
|
return await self.rate_limiter.get_usage(user_id)
|
||||||
@@ -28,15 +28,183 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from mcp.server import Server
|
# MCP version compatibility handling
|
||||||
from mcp.server.models import InitializationOptions
|
MCP_VERSION = 'unknown'
|
||||||
|
Server = None
|
||||||
|
InitializationOptions = None
|
||||||
|
Prompt = None
|
||||||
|
Resource = None
|
||||||
|
TextContent = None
|
||||||
|
Tool = None
|
||||||
|
|
||||||
from mcp.types import (
|
def _import_mcp_with_compatibility():
|
||||||
Prompt,
|
"""Import MCP components with multi-version compatibility"""
|
||||||
Resource,
|
global MCP_VERSION, Server, InitializationOptions, Prompt, Resource, TextContent, Tool
|
||||||
TextContent,
|
|
||||||
Tool,
|
try:
|
||||||
)
|
# Strategy 1: Try direct server-only imports to avoid client-side issues
|
||||||
|
from mcp.server import Server as _Server
|
||||||
|
from mcp.server.models import InitializationOptions as _InitOptions
|
||||||
|
from mcp.types import (
|
||||||
|
Prompt as _Prompt,
|
||||||
|
Resource as _Resource,
|
||||||
|
TextContent as _TextContent,
|
||||||
|
Tool as _Tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign to globals
|
||||||
|
Server = _Server
|
||||||
|
InitializationOptions = _InitOptions
|
||||||
|
Prompt = _Prompt
|
||||||
|
Resource = _Resource
|
||||||
|
TextContent = _TextContent
|
||||||
|
Tool = _Tool
|
||||||
|
|
||||||
|
# Try to get version safely
|
||||||
|
try:
|
||||||
|
import mcp
|
||||||
|
MCP_VERSION = getattr(mcp, '__version__', None)
|
||||||
|
if not MCP_VERSION:
|
||||||
|
# Fallback: try to get version from package metadata
|
||||||
|
try:
|
||||||
|
import importlib.metadata
|
||||||
|
MCP_VERSION = importlib.metadata.version('mcp')
|
||||||
|
except Exception:
|
||||||
|
# Second fallback: try pkg_resources
|
||||||
|
try:
|
||||||
|
import pkg_resources
|
||||||
|
MCP_VERSION = pkg_resources.get_distribution('mcp').version
|
||||||
|
except Exception:
|
||||||
|
MCP_VERSION = 'detected-but-version-unknown'
|
||||||
|
except Exception:
|
||||||
|
# Version detection failed, but imports worked
|
||||||
|
try:
|
||||||
|
import importlib.metadata
|
||||||
|
MCP_VERSION = importlib.metadata.version('mcp')
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
import pkg_resources
|
||||||
|
MCP_VERSION = pkg_resources.get_distribution('mcp').version
|
||||||
|
except Exception:
|
||||||
|
MCP_VERSION = 'imported-successfully'
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.info(f"MCP components imported successfully, version: {MCP_VERSION}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as import_error:
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Strategy 2: Handle RequestContext compatibility issues in 1.9.x versions
|
||||||
|
error_str = str(import_error).lower()
|
||||||
|
if 'requestcontext' in error_str and 'too few arguments' in error_str:
|
||||||
|
logger.warning(f"Detected MCP RequestContext compatibility issue: {import_error}")
|
||||||
|
logger.info("Attempting comprehensive workaround for MCP 1.9.x RequestContext issue...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Comprehensive monkey patch approach
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
|
||||||
|
# Create and install mock modules before any MCP imports
|
||||||
|
if 'mcp.shared.context' not in sys.modules:
|
||||||
|
mock_context_module = types.ModuleType('mcp.shared.context')
|
||||||
|
|
||||||
|
class FlexibleRequestContext:
|
||||||
|
"""Flexible RequestContext that accepts variable arguments"""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.args = args
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def __class_getitem__(cls, params):
|
||||||
|
# Accept any number of parameters and return cls
|
||||||
|
return cls
|
||||||
|
|
||||||
|
# Add other methods that might be called
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return lambda *args, **kwargs: None
|
||||||
|
|
||||||
|
mock_context_module.RequestContext = FlexibleRequestContext
|
||||||
|
sys.modules['mcp.shared.context'] = mock_context_module
|
||||||
|
|
||||||
|
# Also patch the typing system to be more permissive
|
||||||
|
original_check_generic = None
|
||||||
|
try:
|
||||||
|
import typing
|
||||||
|
if hasattr(typing, '_check_generic'):
|
||||||
|
original_check_generic = typing._check_generic
|
||||||
|
def permissive_check_generic(cls, params, elen):
|
||||||
|
# Don't enforce strict parameter count checking
|
||||||
|
return
|
||||||
|
typing._check_generic = permissive_check_generic
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Clear any cached imports that might have failed
|
||||||
|
modules_to_clear = [k for k in sys.modules.keys() if k.startswith('mcp.')]
|
||||||
|
for module in modules_to_clear:
|
||||||
|
if module in sys.modules:
|
||||||
|
del sys.modules[module]
|
||||||
|
|
||||||
|
# Now try importing again with the patches in place
|
||||||
|
from mcp.server import Server as _Server
|
||||||
|
from mcp.server.models import InitializationOptions as _InitOptions
|
||||||
|
from mcp.types import (
|
||||||
|
Prompt as _Prompt,
|
||||||
|
Resource as _Resource,
|
||||||
|
TextContent as _TextContent,
|
||||||
|
Tool as _Tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign to globals
|
||||||
|
Server = _Server
|
||||||
|
InitializationOptions = _InitOptions
|
||||||
|
Prompt = _Prompt
|
||||||
|
Resource = _Resource
|
||||||
|
TextContent = _TextContent
|
||||||
|
Tool = _Tool
|
||||||
|
|
||||||
|
# Try to detect actual version even in compatibility mode
|
||||||
|
try:
|
||||||
|
import importlib.metadata
|
||||||
|
actual_version = importlib.metadata.version('mcp')
|
||||||
|
MCP_VERSION = f'compatibility-mode-{actual_version}'
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
import pkg_resources
|
||||||
|
actual_version = pkg_resources.get_distribution('mcp').version
|
||||||
|
MCP_VERSION = f'compatibility-mode-{actual_version}'
|
||||||
|
except Exception:
|
||||||
|
MCP_VERSION = 'compatibility-mode-1.9.x'
|
||||||
|
|
||||||
|
logger.info("MCP 1.9.x compatibility workaround successful!")
|
||||||
|
|
||||||
|
# Restore original typing function if we patched it
|
||||||
|
if original_check_generic:
|
||||||
|
typing._check_generic = original_check_generic
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as workaround_error:
|
||||||
|
logger.error(f"MCP compatibility workaround failed: {workaround_error}")
|
||||||
|
|
||||||
|
# Restore original typing function if we patched it
|
||||||
|
if original_check_generic:
|
||||||
|
try:
|
||||||
|
import typing
|
||||||
|
typing._check_generic = original_check_generic
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.error(f"Failed to import MCP components: {import_error}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Perform MCP import with compatibility handling
|
||||||
|
if not _import_mcp_with_compatibility():
|
||||||
|
raise ImportError(
|
||||||
|
"Failed to import MCP components. Please ensure MCP is properly installed. "
|
||||||
|
"Supported versions: 1.8.x, 1.9.x"
|
||||||
|
)
|
||||||
|
|
||||||
from .tools.tools_manager import DorisToolsManager
|
from .tools.tools_manager import DorisToolsManager
|
||||||
from .tools.prompts_manager import DorisPromptsManager
|
from .tools.prompts_manager import DorisPromptsManager
|
||||||
@@ -44,11 +212,16 @@ from .tools.resources_manager import DorisResourcesManager
|
|||||||
from .utils.config import DorisConfig
|
from .utils.config import DorisConfig
|
||||||
from .utils.db import DorisConnectionManager
|
from .utils.db import DorisConnectionManager
|
||||||
from .utils.security import DorisSecurityManager
|
from .utils.security import DorisSecurityManager
|
||||||
|
import os
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging - will be properly initialized later
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Create a default config instance for getting default values
|
||||||
|
_default_config = DorisConfig()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DorisServer:
|
class DorisServer:
|
||||||
"""Apache Doris MCP Server main class"""
|
"""Apache Doris MCP Server main class"""
|
||||||
@@ -57,20 +230,101 @@ class DorisServer:
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.server = Server("doris-mcp-server")
|
self.server = Server("doris-mcp-server")
|
||||||
|
|
||||||
# Initialize security manager
|
# Initialize security manager (without connection_manager initially)
|
||||||
self.security_manager = DorisSecurityManager(config)
|
self.security_manager = DorisSecurityManager(config)
|
||||||
|
|
||||||
# Initialize connection manager, pass in security manager
|
# Initialize connection manager, pass in security manager and token manager for token-bound DB config
|
||||||
self.connection_manager = DorisConnectionManager(config, self.security_manager)
|
token_manager = self.security_manager.auth_provider.token_manager if hasattr(self.security_manager, 'auth_provider') and hasattr(self.security_manager.auth_provider, 'token_manager') else None
|
||||||
|
self.connection_manager = DorisConnectionManager(config, self.security_manager, token_manager)
|
||||||
|
|
||||||
|
# Set connection manager reference in security manager for database validation
|
||||||
|
self.security_manager.connection_manager = self.connection_manager
|
||||||
|
|
||||||
# Initialize independent managers
|
# Initialize independent managers
|
||||||
self.resources_manager = DorisResourcesManager(self.connection_manager)
|
self.resources_manager = DorisResourcesManager(self.connection_manager)
|
||||||
self.tools_manager = DorisToolsManager(self.connection_manager)
|
self.tools_manager = DorisToolsManager(self.connection_manager)
|
||||||
self.prompts_manager = DorisPromptsManager(self.connection_manager)
|
self.prompts_manager = DorisPromptsManager(self.connection_manager)
|
||||||
|
|
||||||
self.logger = logging.getLogger(f"{__name__}.DorisServer")
|
# Import here to avoid circular imports
|
||||||
|
from .utils.logger import get_logger
|
||||||
|
self.logger = get_logger(f"{__name__}.DorisServer")
|
||||||
self._setup_handlers()
|
self._setup_handlers()
|
||||||
|
|
||||||
|
async def _extract_auth_info_from_scope(self, scope, headers):
|
||||||
|
"""Extract authentication information from ASGI scope and headers"""
|
||||||
|
auth_info = {}
|
||||||
|
|
||||||
|
# Extract client IP
|
||||||
|
client = scope.get("client")
|
||||||
|
if client:
|
||||||
|
auth_info["client_ip"] = client[0]
|
||||||
|
else:
|
||||||
|
auth_info["client_ip"] = "unknown"
|
||||||
|
|
||||||
|
# Extract token from Authorization header
|
||||||
|
authorization = headers.get(b'authorization', b'').decode('utf-8')
|
||||||
|
if authorization:
|
||||||
|
if authorization.startswith('Bearer '):
|
||||||
|
auth_info["token"] = authorization[7:]
|
||||||
|
auth_info["authorization"] = authorization
|
||||||
|
elif authorization.startswith('Token '):
|
||||||
|
auth_info["token"] = authorization[6:]
|
||||||
|
auth_info["authorization"] = authorization
|
||||||
|
|
||||||
|
# Extract token from query parameters (for compatibility)
|
||||||
|
query_string = scope.get("query_string", b"").decode('utf-8')
|
||||||
|
if query_string and "token=" in query_string:
|
||||||
|
import urllib.parse
|
||||||
|
query_params = urllib.parse.parse_qs(query_string)
|
||||||
|
if "token" in query_params:
|
||||||
|
auth_info["token"] = query_params["token"][0]
|
||||||
|
|
||||||
|
# If no token found, this will be handled by the authentication system
|
||||||
|
# (either return anonymous context if auth disabled, or raise error if auth enabled)
|
||||||
|
|
||||||
|
return auth_info
|
||||||
|
|
||||||
|
def _get_mcp_capabilities(self):
|
||||||
|
"""Get MCP capabilities with version compatibility"""
|
||||||
|
try:
|
||||||
|
# For MCP 1.9.x and newer
|
||||||
|
from mcp.server.lowlevel.server import NotificationOptions
|
||||||
|
|
||||||
|
return self.server.get_capabilities(
|
||||||
|
notification_options=NotificationOptions(
|
||||||
|
prompts_changed=True,
|
||||||
|
resources_changed=True,
|
||||||
|
tools_changed=True
|
||||||
|
),
|
||||||
|
experimental_capabilities={}
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
try:
|
||||||
|
# For MCP 1.8.x
|
||||||
|
from mcp.server.lowlevel.server import NotificationOptions
|
||||||
|
|
||||||
|
return self.server.get_capabilities(
|
||||||
|
notification_options=NotificationOptions(
|
||||||
|
prompts_changed=True,
|
||||||
|
resources_changed=True,
|
||||||
|
tools_changed=True
|
||||||
|
),
|
||||||
|
experimental_capabilities={}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Could not get capabilities with NotificationOptions: {e}")
|
||||||
|
# Fallback for older versions
|
||||||
|
try:
|
||||||
|
return self.server.get_capabilities()
|
||||||
|
except Exception as fallback_e:
|
||||||
|
self.logger.error(f"Failed to get capabilities: {fallback_e}")
|
||||||
|
# Return minimal capabilities
|
||||||
|
return {
|
||||||
|
"resources": {},
|
||||||
|
"tools": {},
|
||||||
|
"prompts": {}
|
||||||
|
}
|
||||||
|
|
||||||
def _setup_handlers(self):
|
def _setup_handlers(self):
|
||||||
"""Setup MCP protocol handlers"""
|
"""Setup MCP protocol handlers"""
|
||||||
|
|
||||||
@@ -174,12 +428,24 @@ class DorisServer:
|
|||||||
self.logger.info("Starting Doris MCP Server (stdio mode)")
|
self.logger.info("Starting Doris MCP Server (stdio mode)")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Ensure connection manager is initialized
|
# Initialize security manager first (includes JWT setup if enabled)
|
||||||
await self.connection_manager.initialize()
|
await self.security_manager.initialize()
|
||||||
self.logger.info("Connection manager initialization completed")
|
self.logger.info("Security manager initialization completed")
|
||||||
|
|
||||||
|
# For stdio mode, we must establish a working database connection
|
||||||
|
# Use the dedicated stdio mode initialization method
|
||||||
|
await self.connection_manager.initialize_for_stdio_mode()
|
||||||
|
|
||||||
# Start stdio server - using simpler approach
|
# Start stdio server - using compatible import approach
|
||||||
from mcp.server.stdio import stdio_server
|
try:
|
||||||
|
from mcp.server.stdio import stdio_server
|
||||||
|
except ImportError:
|
||||||
|
# Fallback for different MCP versions
|
||||||
|
try:
|
||||||
|
from mcp.server import stdio_server
|
||||||
|
except ImportError as stdio_import_error:
|
||||||
|
self.logger.error(f"Failed to import stdio_server: {stdio_import_error}")
|
||||||
|
raise RuntimeError("stdio_server module not available in this MCP version")
|
||||||
|
|
||||||
self.logger.info("Creating stdio_server transport...")
|
self.logger.info("Creating stdio_server transport...")
|
||||||
|
|
||||||
@@ -189,22 +455,12 @@ class DorisServer:
|
|||||||
read_stream, write_stream = streams
|
read_stream, write_stream = streams
|
||||||
self.logger.info("stdio_server streams created successfully")
|
self.logger.info("stdio_server streams created successfully")
|
||||||
|
|
||||||
# Create initialization options
|
# Create initialization options with version compatibility
|
||||||
# MCP 1.8.0 requires parameters for get_capabilities
|
capabilities = self._get_mcp_capabilities()
|
||||||
from mcp.server.lowlevel.server import NotificationOptions
|
|
||||||
|
|
||||||
capabilities = self.server.get_capabilities(
|
|
||||||
notification_options=NotificationOptions(
|
|
||||||
prompts_changed=True,
|
|
||||||
resources_changed=True,
|
|
||||||
tools_changed=True
|
|
||||||
),
|
|
||||||
experimental_capabilities={}
|
|
||||||
)
|
|
||||||
|
|
||||||
init_options = InitializationOptions(
|
init_options = InitializationOptions(
|
||||||
server_name="doris-mcp-server",
|
server_name="doris-mcp-server",
|
||||||
server_version="1.0.0",
|
server_version=os.getenv("SERVER_VERSION", _default_config.server_version),
|
||||||
capabilities=capabilities,
|
capabilities=capabilities,
|
||||||
)
|
)
|
||||||
self.logger.info("Initialization options created successfully")
|
self.logger.info("Initialization options created successfully")
|
||||||
@@ -237,13 +493,21 @@ class DorisServer:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def start_http(self, host: str = "localhost", port: int = 3000):
|
async def start_http(self, host: str = os.getenv("SERVER_HOST", _default_config.database.host), port: int = os.getenv("SERVER_PORT", _default_config.server_port), workers: int = 1):
|
||||||
"""Start Streamable HTTP transport mode"""
|
"""Start Streamable HTTP transport mode with workers support"""
|
||||||
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}")
|
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}, workers: {workers}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Ensure connection manager is initialized
|
# Initialize security manager first (includes JWT setup if enabled)
|
||||||
await self.connection_manager.initialize()
|
await self.security_manager.initialize()
|
||||||
|
self.logger.info("Security manager initialization completed")
|
||||||
|
|
||||||
|
# For HTTP mode, try to initialize global connection pool with graceful degradation
|
||||||
|
global_pool_created = await self.connection_manager.initialize_for_http_mode()
|
||||||
|
if global_pool_created:
|
||||||
|
self.logger.info("Global database connection pool available for HTTP mode")
|
||||||
|
else:
|
||||||
|
self.logger.info("HTTP mode running without global database pool, will use token-bound configurations")
|
||||||
|
|
||||||
# Use Starlette and StreamableHTTPSessionManager according to official example
|
# Use Starlette and StreamableHTTPSessionManager according to official example
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@@ -251,9 +515,9 @@ class DorisServer:
|
|||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
from starlette.routing import Mount, Route
|
from starlette.routing import Route
|
||||||
from starlette.responses import JSONResponse, Response
|
from starlette.responses import JSONResponse, Response
|
||||||
from starlette.types import Receive, Scope, Send
|
from starlette.types import Scope
|
||||||
|
|
||||||
# Create session manager
|
# Create session manager
|
||||||
session_manager = StreamableHTTPSessionManager(
|
session_manager = StreamableHTTPSessionManager(
|
||||||
@@ -268,6 +532,44 @@ class DorisServer:
|
|||||||
async def health_check(request):
|
async def health_check(request):
|
||||||
return JSONResponse({"status": "healthy", "service": "doris-mcp-server"})
|
return JSONResponse({"status": "healthy", "service": "doris-mcp-server"})
|
||||||
|
|
||||||
|
# OAuth endpoints
|
||||||
|
from .auth.oauth_handlers import OAuthHandlers
|
||||||
|
oauth_handlers = OAuthHandlers(self.security_manager)
|
||||||
|
|
||||||
|
async def oauth_login(request):
|
||||||
|
return await oauth_handlers.handle_login(request)
|
||||||
|
|
||||||
|
async def oauth_callback(request):
|
||||||
|
return await oauth_handlers.handle_callback(request)
|
||||||
|
|
||||||
|
async def oauth_provider_info(request):
|
||||||
|
return await oauth_handlers.handle_provider_info(request)
|
||||||
|
|
||||||
|
async def oauth_demo(request):
|
||||||
|
return await oauth_handlers.handle_demo_page(request)
|
||||||
|
|
||||||
|
# Token management endpoints
|
||||||
|
from .auth.token_handlers import TokenHandlers
|
||||||
|
token_handlers = TokenHandlers(self.security_manager, self.config)
|
||||||
|
|
||||||
|
async def token_create(request):
|
||||||
|
return await token_handlers.handle_create_token(request)
|
||||||
|
|
||||||
|
async def token_revoke(request):
|
||||||
|
return await token_handlers.handle_revoke_token(request)
|
||||||
|
|
||||||
|
async def token_list(request):
|
||||||
|
return await token_handlers.handle_list_tokens(request)
|
||||||
|
|
||||||
|
async def token_stats(request):
|
||||||
|
return await token_handlers.handle_token_stats(request)
|
||||||
|
|
||||||
|
async def token_cleanup(request):
|
||||||
|
return await token_handlers.handle_cleanup_tokens(request)
|
||||||
|
|
||||||
|
async def token_management(request):
|
||||||
|
return await token_handlers.handle_management_page(request)
|
||||||
|
|
||||||
# Lifecycle manager - simplified since we manage session_manager externally
|
# Lifecycle manager - simplified since we manage session_manager externally
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def lifespan(app: Starlette) -> AsyncIterator[None]:
|
async def lifespan(app: Starlette) -> AsyncIterator[None]:
|
||||||
@@ -283,6 +585,18 @@ class DorisServer:
|
|||||||
debug=True,
|
debug=True,
|
||||||
routes=[
|
routes=[
|
||||||
Route("/health", health_check, methods=["GET"]),
|
Route("/health", health_check, methods=["GET"]),
|
||||||
|
# OAuth endpoints
|
||||||
|
Route("/auth/login", oauth_login, methods=["GET"]),
|
||||||
|
Route("/auth/callback", oauth_callback, methods=["GET"]),
|
||||||
|
Route("/auth/provider", oauth_provider_info, methods=["GET"]),
|
||||||
|
Route("/auth/demo", oauth_demo, methods=["GET"]),
|
||||||
|
# Token management endpoints
|
||||||
|
Route("/token/create", token_create, methods=["GET", "POST"]),
|
||||||
|
Route("/token/revoke", token_revoke, methods=["GET", "DELETE"]),
|
||||||
|
Route("/token/list", token_list, methods=["GET"]),
|
||||||
|
Route("/token/stats", token_stats, methods=["GET"]),
|
||||||
|
Route("/token/cleanup", token_cleanup, methods=["GET", "POST"]),
|
||||||
|
Route("/token/management", token_management, methods=["GET"]),
|
||||||
],
|
],
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
@@ -300,8 +614,10 @@ class DorisServer:
|
|||||||
self.logger.info(f"Received request for path: {path}")
|
self.logger.info(f"Received request for path: {path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Handle health check
|
# Handle health check, auth, and token management endpoints
|
||||||
if path.startswith("/health"):
|
if (path.startswith("/health") or
|
||||||
|
path.startswith("/auth/") or
|
||||||
|
path.startswith("/token/")):
|
||||||
await starlette_app(scope, receive, send)
|
await starlette_app(scope, receive, send)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -314,6 +630,39 @@ class DorisServer:
|
|||||||
self.logger.info(f"MCP Request - Method: {method}")
|
self.logger.info(f"MCP Request - Method: {method}")
|
||||||
self.logger.info(f"MCP Request - Headers: {headers}")
|
self.logger.info(f"MCP Request - Headers: {headers}")
|
||||||
|
|
||||||
|
# Authentication check for MCP requests
|
||||||
|
try:
|
||||||
|
# Extract authentication information
|
||||||
|
auth_info = await self._extract_auth_info_from_scope(scope, headers)
|
||||||
|
|
||||||
|
# Authenticate the request
|
||||||
|
auth_context = await self.security_manager.authenticate_request(auth_info)
|
||||||
|
self.logger.info(f"MCP request authenticated: token_id={auth_context.token_id}, client_ip={auth_context.client_ip}")
|
||||||
|
|
||||||
|
# Store auth context in scope for potential use by tools/resources
|
||||||
|
scope["auth_context"] = auth_context
|
||||||
|
|
||||||
|
# FIX for Issue #62 Bug 1: Set auth_context in context variable
|
||||||
|
# This allows tools to access token information for token-bound database configuration
|
||||||
|
# CRITICAL: Use the global ContextVar from security.py to ensure same instance is used everywhere
|
||||||
|
try:
|
||||||
|
from .utils.security import mcp_auth_context_var
|
||||||
|
mcp_auth_context_var.set(auth_context)
|
||||||
|
self.logger.debug(f"Set auth_context in context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
|
||||||
|
except Exception as ctx_error:
|
||||||
|
self.logger.warning(f"Failed to set auth_context in context variable: {ctx_error}")
|
||||||
|
|
||||||
|
except Exception as auth_error:
|
||||||
|
self.logger.error(f"MCP authentication failed: {auth_error}")
|
||||||
|
# Return 401 Unauthorized
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
response = JSONResponse(
|
||||||
|
{"error": "Authentication required", "message": str(auth_error)},
|
||||||
|
status_code=401
|
||||||
|
)
|
||||||
|
await response(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
# Handle Dify compatibility for GET requests
|
# Handle Dify compatibility for GET requests
|
||||||
if method == "GET":
|
if method == "GET":
|
||||||
accept_header = headers.get(b'accept', b'').decode('utf-8')
|
accept_header = headers.get(b'accept', b'').decode('utf-8')
|
||||||
@@ -356,19 +705,35 @@ class DorisServer:
|
|||||||
self.logger.warning(f"Unsupported scope type: {scope['type']}")
|
self.logger.warning(f"Unsupported scope type: {scope['type']}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Start uvicorn server with session manager lifecycle
|
# Choose startup method based on worker count
|
||||||
config = uvicorn.Config(
|
if workers > 1:
|
||||||
app=mcp_app,
|
self.logger.info(f"Using multi-process mode with {workers} workers")
|
||||||
host=host,
|
self.logger.info("Note: Multi-worker mode provides full MCP functionality with independent worker processes")
|
||||||
port=port,
|
|
||||||
log_level="info"
|
# Use the dedicated multiworker app module with full MCP support
|
||||||
)
|
uvicorn.run(
|
||||||
server = uvicorn.Server(config)
|
"doris_mcp_server.multiworker_app:app",
|
||||||
|
host=host,
|
||||||
# Run session manager and server together
|
port=port,
|
||||||
async with session_manager.run():
|
workers=workers,
|
||||||
self.logger.info("Session manager started, now starting HTTP server")
|
log_level="info"
|
||||||
await server.serve()
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.logger.info("Using single-process mode")
|
||||||
|
# Single worker mode, use original logic with session manager lifecycle
|
||||||
|
config = uvicorn.Config(
|
||||||
|
app=mcp_app,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
log_level="info"
|
||||||
|
)
|
||||||
|
server = uvicorn.Server(config)
|
||||||
|
|
||||||
|
# Run session manager and server together
|
||||||
|
async with session_manager.run():
|
||||||
|
self.logger.info("Session manager started, now starting HTTP server")
|
||||||
|
await server.serve()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Streamable HTTP server startup failed: {e}")
|
self.logger.error(f"Streamable HTTP server startup failed: {e}")
|
||||||
@@ -383,10 +748,16 @@ class DorisServer:
|
|||||||
self.logger.error(f" Exception {i+1}: {type(exc).__name__}: {exc}")
|
self.logger.error(f" Exception {i+1}: {type(exc).__name__}: {exc}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
"""Shutdown server"""
|
"""Shutdown server"""
|
||||||
self.logger.info("Shutting down Doris MCP Server")
|
self.logger.info("Shutting down Doris MCP Server")
|
||||||
try:
|
try:
|
||||||
|
# Shutdown security manager first (includes JWT cleanup)
|
||||||
|
await self.security_manager.shutdown()
|
||||||
|
self.logger.info("Security manager shutdown completed")
|
||||||
|
|
||||||
await self.connection_manager.close()
|
await self.connection_manager.close()
|
||||||
self.logger.info("Doris MCP Server has been shut down")
|
self.logger.info("Doris MCP Server has been shut down")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -406,6 +777,11 @@ Transport Modes:
|
|||||||
Examples:
|
Examples:
|
||||||
python -m doris_mcp_server --transport stdio
|
python -m doris_mcp_server --transport stdio
|
||||||
python -m doris_mcp_server --transport http --host 0.0.0.0 --port 3000
|
python -m doris_mcp_server --transport http --host 0.0.0.0 --port 3000
|
||||||
|
python -m doris_mcp_server --transport stdio --doris-host localhost --doris-port 9030
|
||||||
|
python -m doris_mcp_server --transport http --doris-user admin --doris-database test_db
|
||||||
|
|
||||||
|
# Backward compatibility: --db-* parameters are also supported
|
||||||
|
python -m doris_mcp_server --transport stdio --db-host localhost --db-port 9030
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -413,91 +789,151 @@ Examples:
|
|||||||
"--transport",
|
"--transport",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["stdio", "http"],
|
choices=["stdio", "http"],
|
||||||
default="stdio",
|
default=os.getenv("TRANSPORT", _default_config.transport),
|
||||||
help="Transport protocol type: stdio (local), http (Streamable HTTP)",
|
help=f"Transport protocol type: stdio (local), http (Streamable HTTP) (default: {_default_config.transport})",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--host",
|
"--host",
|
||||||
type=str,
|
type=str,
|
||||||
default="localhost",
|
default=os.getenv("SERVER_HOST", _default_config.server_host),
|
||||||
help="Host address for HTTP mode (default: localhost)",
|
help=f"Host address for HTTP mode (default: {_default_config.server_host})",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--port", type=int, default=3000, help="Port number for HTTP mode (default: 3000)"
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=3000,
|
||||||
|
help="Port number for HTTP mode (default: 3000)"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--db-host",
|
"--workers",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of worker processes for HTTP mode (default: 1, use 0 for auto-detect CPU cores)"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--doris-host", "--db-host",
|
||||||
type=str,
|
type=str,
|
||||||
default="localhost",
|
default=os.getenv("DORIS_HOST", _default_config.database.host),
|
||||||
help="Doris database host address (default: localhost)",
|
help=f"Doris database host address (default: {_default_config.database.host})",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--db-port", type=int, default=9030, help="Doris database port number (default: 9030)"
|
"--doris-port", "--db-port", type=int, default=9030, help="Doris database port number (default: 9030)"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--db-user", type=str, default="root", help="Doris database username (default: root)"
|
"--doris-user", "--db-user", type=str, default=os.getenv("DORIS_USER", _default_config.database.user), help=f"Doris database username (default: {_default_config.database.user})"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--db-password", type=str, default="", help="Doris database password")
|
parser.add_argument("--doris-password", "--db-password", type=str, default=os.getenv("DORIS_PASSWORD", ""), help="Doris database password")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--db-database",
|
"--doris-database", "--db-database",
|
||||||
type=str,
|
type=str,
|
||||||
default="information_schema",
|
default=os.getenv("DORIS_DATABASE", _default_config.database.database),
|
||||||
help="Doris database name (default: information_schema)",
|
help=f"Doris database name (default: {_default_config.database.database})",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--log-level",
|
"--log-level",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||||
default="INFO",
|
default=os.getenv("LOG_LEVEL", _default_config.logging.level),
|
||||||
help="Log level (default: INFO)",
|
help=f"Log level (default: {_default_config.logging.level})",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
def update_configuration(config: DorisConfig):
|
||||||
"""Main function"""
|
"""Update doris configuration object"""
|
||||||
|
# For some arguments, if not specified, environment variables or default configurations will be used as default values
|
||||||
parser = create_arg_parser()
|
parser = create_arg_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Set log level
|
# Update config values
|
||||||
logging.getLogger().setLevel(getattr(logging, args.log_level))
|
|
||||||
|
|
||||||
# Create configuration - priority: command line arguments > .env file > default values
|
|
||||||
config = DorisConfig.from_env() # First load from .env file and environment variables
|
|
||||||
|
|
||||||
# Command line arguments override configuration (if provided)
|
# Command line arguments override configuration (if provided)
|
||||||
if args.db_host != "localhost": # If not default value, use command line argument
|
# basic
|
||||||
config.database.host = args.db_host
|
if args.transport != _default_config.transport:
|
||||||
if args.db_port != 9030:
|
config.transport = args.transport
|
||||||
config.database.port = args.db_port
|
if args.host != _default_config.server_host:
|
||||||
if args.db_user != "root":
|
config.server_host = args.host
|
||||||
config.database.user = args.db_user
|
if args.port != _default_config.server_port:
|
||||||
if args.db_password: # Use password if provided
|
config.server_port = args.port
|
||||||
config.database.password = args.db_password
|
server_name = os.getenv("SERVER_NAME")
|
||||||
if args.db_database != "information_schema":
|
if server_name:
|
||||||
config.database.database = args.db_database
|
config.server_name = server_name
|
||||||
if args.log_level != "INFO":
|
server_version = os.getenv("SERVER_VERSION")
|
||||||
|
if server_version:
|
||||||
|
config.server_version = server_version
|
||||||
|
|
||||||
|
# database
|
||||||
|
if args.doris_host != _default_config.database.host: # If not default value, use command line argument
|
||||||
|
config.database.host = args.doris_host
|
||||||
|
if args.doris_port != _default_config.database.port:
|
||||||
|
config.database.port = args.doris_port
|
||||||
|
if args.doris_user != _default_config.database.user:
|
||||||
|
config.database.user = args.doris_user
|
||||||
|
if args.doris_password: # Use password if provided
|
||||||
|
config.database.password = args.doris_password
|
||||||
|
if args.doris_database != _default_config.database.database:
|
||||||
|
config.database.database = args.doris_database
|
||||||
|
|
||||||
|
# logging
|
||||||
|
if args.log_level != _default_config.logging.level:
|
||||||
config.logging.level = args.log_level
|
config.logging.level = args.log_level
|
||||||
|
|
||||||
|
# workers (add to config for HTTP mode)
|
||||||
|
if hasattr(args, 'workers'):
|
||||||
|
config.workers = args.workers
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main function"""
|
||||||
|
# Create configuration - priority: command line arguments > env variables > .env file > default values
|
||||||
|
# First load from .env file and environment variables
|
||||||
|
config = DorisConfig.from_env()
|
||||||
|
|
||||||
|
# Then parse the command line arguments, and update the config object.
|
||||||
|
update_configuration(config)
|
||||||
|
|
||||||
|
# Initialize enhanced logging system
|
||||||
|
from .utils.config import ConfigManager
|
||||||
|
config_manager = ConfigManager(config)
|
||||||
|
config_manager.setup_logging()
|
||||||
|
|
||||||
|
# Get logger with proper configuration
|
||||||
|
from .utils.logger import get_logger, log_system_info
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Log system information for debugging
|
||||||
|
log_system_info()
|
||||||
|
|
||||||
|
logger.info("Starting Doris MCP Server...")
|
||||||
|
logger.info(f"Transport: {config.transport}")
|
||||||
|
logger.info(f"Log Level: {config.logging.level}")
|
||||||
|
|
||||||
# Create server instance
|
# Create server instance
|
||||||
server = DorisServer(config)
|
server = DorisServer(config)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if args.transport == "stdio":
|
if config.transport == "stdio":
|
||||||
await server.start_stdio()
|
await server.start_stdio()
|
||||||
elif args.transport == "http":
|
elif config.transport == "http":
|
||||||
await server.start_http(args.host, args.port)
|
# Get workers configuration with auto-detection support
|
||||||
|
workers = getattr(config, 'workers', 1)
|
||||||
|
if workers == 0:
|
||||||
|
import multiprocessing
|
||||||
|
workers = multiprocessing.cpu_count()
|
||||||
|
logger.info(f"Auto-detected {workers} CPU cores for worker processes")
|
||||||
|
|
||||||
|
await server.start_http(config.server_host, config.server_port, workers)
|
||||||
else:
|
else:
|
||||||
logger.error(f"Unsupported transport protocol: {args.transport}")
|
logger.error(f"Unsupported transport protocol: {config.transport}")
|
||||||
await server.shutdown()
|
await server.shutdown()
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
@@ -517,6 +953,10 @@ async def main():
|
|||||||
await server.shutdown()
|
await server.shutdown()
|
||||||
except Exception as shutdown_error:
|
except Exception as shutdown_error:
|
||||||
logger.error(f"Error occurred while shutting down server: {shutdown_error}")
|
logger.error(f"Error occurred while shutting down server: {shutdown_error}")
|
||||||
|
|
||||||
|
# Shutdown logging system
|
||||||
|
from .utils.logger import shutdown_logging
|
||||||
|
shutdown_logging()
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|||||||
627
doris_mcp_server/multiworker_app.py
Normal file
@@ -0,0 +1,627 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
Multi-worker application module for doris-mcp-server
|
||||||
|
|
||||||
|
This module provides full MCP functionality with multi-worker support.
|
||||||
|
Each worker process creates its own MCP server and session manager using the same
|
||||||
|
robust architecture as the single-worker mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
# Import MCP components with compatibility handling
|
||||||
|
# Use the same import strategy as main.py for consistency
|
||||||
|
MCP_VERSION = 'unknown'
|
||||||
|
Server = None
|
||||||
|
InitializationOptions = None
|
||||||
|
Prompt = None
|
||||||
|
Resource = None
|
||||||
|
TextContent = None
|
||||||
|
Tool = None
|
||||||
|
|
||||||
|
def _import_mcp_with_compatibility():
|
||||||
|
"""Import MCP components with multi-version compatibility"""
|
||||||
|
global MCP_VERSION, Server, InitializationOptions, Prompt, Resource, TextContent, Tool
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Strategy 1: Try direct server-only imports to avoid client-side issues
|
||||||
|
from mcp.server import Server as _Server
|
||||||
|
from mcp.server.models import InitializationOptions as _InitOptions
|
||||||
|
from mcp.types import (
|
||||||
|
Prompt as _Prompt,
|
||||||
|
Resource as _Resource,
|
||||||
|
TextContent as _TextContent,
|
||||||
|
Tool as _Tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign to globals
|
||||||
|
Server = _Server
|
||||||
|
InitializationOptions = _InitOptions
|
||||||
|
Prompt = _Prompt
|
||||||
|
Resource = _Resource
|
||||||
|
TextContent = _TextContent
|
||||||
|
Tool = _Tool
|
||||||
|
|
||||||
|
# Try to get version safely
|
||||||
|
try:
|
||||||
|
import mcp
|
||||||
|
MCP_VERSION = getattr(mcp, '__version__', None)
|
||||||
|
if not MCP_VERSION:
|
||||||
|
# Fallback: try to get version from package metadata
|
||||||
|
try:
|
||||||
|
import importlib.metadata
|
||||||
|
MCP_VERSION = importlib.metadata.version('mcp')
|
||||||
|
except Exception:
|
||||||
|
# Second fallback: try pkg_resources
|
||||||
|
try:
|
||||||
|
import pkg_resources
|
||||||
|
MCP_VERSION = pkg_resources.get_distribution('mcp').version
|
||||||
|
except Exception:
|
||||||
|
MCP_VERSION = 'detected-but-version-unknown'
|
||||||
|
except Exception:
|
||||||
|
# Version detection failed, but imports worked
|
||||||
|
try:
|
||||||
|
import importlib.metadata
|
||||||
|
MCP_VERSION = importlib.metadata.version('mcp')
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
import pkg_resources
|
||||||
|
MCP_VERSION = pkg_resources.get_distribution('mcp').version
|
||||||
|
except Exception:
|
||||||
|
MCP_VERSION = 'imported-successfully'
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.info(f"MCP components imported successfully in multiworker, version: {MCP_VERSION}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as import_error:
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Strategy 2: Handle RequestContext compatibility issues in 1.9.x versions
|
||||||
|
error_str = str(import_error).lower()
|
||||||
|
if 'requestcontext' in error_str and 'too few arguments' in error_str:
|
||||||
|
logger.warning(f"Detected MCP RequestContext compatibility issue: {import_error}")
|
||||||
|
logger.info("Attempting comprehensive workaround for MCP 1.9.x RequestContext issue...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Comprehensive monkey patch approach
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
|
||||||
|
# Create and install mock modules before any MCP imports
|
||||||
|
if 'mcp.shared.context' not in sys.modules:
|
||||||
|
mock_context_module = types.ModuleType('mcp.shared.context')
|
||||||
|
|
||||||
|
class FlexibleRequestContext:
|
||||||
|
"""Flexible RequestContext that accepts variable arguments"""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.args = args
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def __class_getitem__(cls, params):
|
||||||
|
# Accept any number of parameters and return cls
|
||||||
|
return cls
|
||||||
|
|
||||||
|
# Add other methods that might be called
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return lambda *args, **kwargs: None
|
||||||
|
|
||||||
|
mock_context_module.RequestContext = FlexibleRequestContext
|
||||||
|
sys.modules['mcp.shared.context'] = mock_context_module
|
||||||
|
|
||||||
|
# Also patch the typing system to be more permissive
|
||||||
|
original_check_generic = None
|
||||||
|
try:
|
||||||
|
import typing
|
||||||
|
if hasattr(typing, '_check_generic'):
|
||||||
|
original_check_generic = typing._check_generic
|
||||||
|
def permissive_check_generic(cls, params, elen):
|
||||||
|
# Don't enforce strict parameter count checking
|
||||||
|
return
|
||||||
|
typing._check_generic = permissive_check_generic
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Clear any cached imports that might have failed
|
||||||
|
modules_to_clear = [k for k in sys.modules.keys() if k.startswith('mcp.')]
|
||||||
|
for module in modules_to_clear:
|
||||||
|
if module in sys.modules:
|
||||||
|
del sys.modules[module]
|
||||||
|
|
||||||
|
# Now try importing again with the patches in place
|
||||||
|
from mcp.server import Server as _Server
|
||||||
|
from mcp.server.models import InitializationOptions as _InitOptions
|
||||||
|
from mcp.types import (
|
||||||
|
Prompt as _Prompt,
|
||||||
|
Resource as _Resource,
|
||||||
|
TextContent as _TextContent,
|
||||||
|
Tool as _Tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign to globals
|
||||||
|
Server = _Server
|
||||||
|
InitializationOptions = _InitOptions
|
||||||
|
Prompt = _Prompt
|
||||||
|
Resource = _Resource
|
||||||
|
TextContent = _TextContent
|
||||||
|
Tool = _Tool
|
||||||
|
|
||||||
|
# Try to detect actual version even in compatibility mode
|
||||||
|
try:
|
||||||
|
import importlib.metadata
|
||||||
|
actual_version = importlib.metadata.version('mcp')
|
||||||
|
MCP_VERSION = f'compatibility-mode-{actual_version}'
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
import pkg_resources
|
||||||
|
actual_version = pkg_resources.get_distribution('mcp').version
|
||||||
|
MCP_VERSION = f'compatibility-mode-{actual_version}'
|
||||||
|
except Exception:
|
||||||
|
MCP_VERSION = 'compatibility-mode-1.9.x'
|
||||||
|
|
||||||
|
logger.info("MCP 1.9.x compatibility workaround successful in multiworker!")
|
||||||
|
|
||||||
|
# Restore original typing function if we patched it
|
||||||
|
if original_check_generic:
|
||||||
|
typing._check_generic = original_check_generic
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as workaround_error:
|
||||||
|
logger.error(f"MCP compatibility workaround failed in multiworker: {workaround_error}")
|
||||||
|
|
||||||
|
# Restore original typing function if we patched it
|
||||||
|
if original_check_generic:
|
||||||
|
try:
|
||||||
|
import typing
|
||||||
|
typing._check_generic = original_check_generic
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.error(f"Failed to import MCP components in multiworker: {import_error}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Perform MCP import with compatibility handling
|
||||||
|
if not _import_mcp_with_compatibility():
|
||||||
|
raise ImportError(
|
||||||
|
"Failed to import MCP components in multiworker. Please ensure MCP is properly installed. "
|
||||||
|
"Supported versions: 1.8.x, 1.9.x"
|
||||||
|
)
|
||||||
|
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.routing import Route
|
||||||
|
from starlette.responses import JSONResponse, Response
|
||||||
|
|
||||||
|
# Import Doris MCP components
|
||||||
|
from .tools.tools_manager import DorisToolsManager
|
||||||
|
from .tools.prompts_manager import DorisPromptsManager
|
||||||
|
from .tools.resources_manager import DorisResourcesManager
|
||||||
|
from .utils.config import DorisConfig
|
||||||
|
from .utils.db import DorisConnectionManager
|
||||||
|
from .utils.security import DorisSecurityManager
|
||||||
|
|
||||||
|
# Global variables for worker-specific instances
|
||||||
|
_worker_server = None
|
||||||
|
_worker_session_manager = None
|
||||||
|
_worker_connection_manager = None
|
||||||
|
_worker_security_manager = None
|
||||||
|
_worker_session_manager_context = None
|
||||||
|
_worker_initialized = False
|
||||||
|
|
||||||
|
def get_mcp_capabilities():
|
||||||
|
"""Get MCP capabilities for worker - use the same logic as main.py"""
|
||||||
|
try:
|
||||||
|
# For MCP 1.9.x and newer
|
||||||
|
from mcp.server.lowlevel.server import NotificationOptions
|
||||||
|
|
||||||
|
capabilities = {
|
||||||
|
"resources": {},
|
||||||
|
"tools": {},
|
||||||
|
"prompts": {},
|
||||||
|
"notification_options": {
|
||||||
|
"prompts_changed": True,
|
||||||
|
"resources_changed": True,
|
||||||
|
"tools_changed": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return capabilities
|
||||||
|
except Exception as e:
|
||||||
|
# Import logger properly
|
||||||
|
from .utils.logger import get_logger
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
logger.warning(f"Failed to get full capabilities in multiworker: {e}")
|
||||||
|
return {
|
||||||
|
"resources": {},
|
||||||
|
"tools": {},
|
||||||
|
"prompts": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
async def initialize_worker():
|
||||||
|
"""Initialize MCP server and managers for this worker process"""
|
||||||
|
global _worker_server, _worker_session_manager, _worker_connection_manager, _worker_security_manager, _worker_session_manager_context, _worker_initialized, _oauth_handlers, _token_handlers
|
||||||
|
|
||||||
|
if _worker_initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Import logger properly
|
||||||
|
from .utils.logger import get_logger
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
logger.info(f"Initializing MCP worker process {os.getpid()}")
|
||||||
|
|
||||||
|
# Create configuration
|
||||||
|
config = DorisConfig.from_env()
|
||||||
|
|
||||||
|
# Initialize enhanced logging system
|
||||||
|
from .utils.config import ConfigManager
|
||||||
|
config_manager = ConfigManager(config)
|
||||||
|
config_manager.setup_logging()
|
||||||
|
|
||||||
|
# Create security manager
|
||||||
|
_worker_security_manager = DorisSecurityManager(config)
|
||||||
|
|
||||||
|
# Initialize security manager first (includes JWT setup if enabled)
|
||||||
|
await _worker_security_manager.initialize()
|
||||||
|
logger.info(f"Worker {os.getpid()} security manager initialization completed")
|
||||||
|
|
||||||
|
# Create connection manager with token manager for token-bound DB config
|
||||||
|
token_manager = _worker_security_manager.auth_provider.token_manager if hasattr(_worker_security_manager, 'auth_provider') and hasattr(_worker_security_manager.auth_provider, 'token_manager') else None
|
||||||
|
_worker_connection_manager = DorisConnectionManager(config, _worker_security_manager, token_manager)
|
||||||
|
|
||||||
|
# Set connection manager reference in security manager for database validation
|
||||||
|
_worker_security_manager.connection_manager = _worker_connection_manager
|
||||||
|
|
||||||
|
await _worker_connection_manager.initialize()
|
||||||
|
|
||||||
|
# Create MCP server
|
||||||
|
_worker_server = Server("doris-mcp-server")
|
||||||
|
|
||||||
|
# Create managers
|
||||||
|
resources_manager = DorisResourcesManager(_worker_connection_manager)
|
||||||
|
tools_manager = DorisToolsManager(_worker_connection_manager)
|
||||||
|
prompts_manager = DorisPromptsManager(_worker_connection_manager)
|
||||||
|
|
||||||
|
# Setup MCP handlers
|
||||||
|
@_worker_server.list_resources()
|
||||||
|
async def handle_list_resources() -> list[Resource]:
|
||||||
|
"""Handle resource list request"""
|
||||||
|
try:
|
||||||
|
logger.info("Handling resource list request in worker")
|
||||||
|
resources = await resources_manager.list_resources()
|
||||||
|
logger.info(f"Returning {len(resources)} resources from worker")
|
||||||
|
return resources
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to handle resource list request in worker: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
@_worker_server.read_resource()
|
||||||
|
async def handle_read_resource(uri: str) -> str:
|
||||||
|
"""Handle resource read request"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Handling resource read request in worker: {uri}")
|
||||||
|
content = await resources_manager.read_resource(uri)
|
||||||
|
return content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to handle resource read request in worker: {e}")
|
||||||
|
return json.dumps(
|
||||||
|
{"error": f"Failed to read resource: {str(e)}", "uri": uri},
|
||||||
|
ensure_ascii=False,
|
||||||
|
indent=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
@_worker_server.list_tools()
|
||||||
|
async def handle_list_tools() -> list[Tool]:
|
||||||
|
"""Handle tool list request"""
|
||||||
|
try:
|
||||||
|
logger.info("Handling tool list request in worker")
|
||||||
|
tools = await tools_manager.list_tools()
|
||||||
|
logger.info(f"Returning {len(tools)} tools from worker")
|
||||||
|
return tools
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to handle tool list request in worker: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
@_worker_server.call_tool()
|
||||||
|
async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
|
||||||
|
"""Handle tool call request"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Handling tool call request in worker: {name}")
|
||||||
|
result = await tools_manager.call_tool(name, arguments)
|
||||||
|
return [TextContent(type="text", text=result)]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to handle tool call request in worker: {e}")
|
||||||
|
error_result = json.dumps(
|
||||||
|
{
|
||||||
|
"error": f"Tool call failed: {str(e)}",
|
||||||
|
"tool_name": name,
|
||||||
|
"arguments": arguments,
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
indent=2,
|
||||||
|
)
|
||||||
|
return [TextContent(type="text", text=error_result)]
|
||||||
|
|
||||||
|
@_worker_server.list_prompts()
|
||||||
|
async def handle_list_prompts() -> list[Prompt]:
|
||||||
|
"""Handle prompt list request"""
|
||||||
|
try:
|
||||||
|
logger.info("Handling prompt list request in worker")
|
||||||
|
prompts = await prompts_manager.list_prompts()
|
||||||
|
logger.info(f"Returning {len(prompts)} prompts from worker")
|
||||||
|
return prompts
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to handle prompt list request in worker: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
@_worker_server.get_prompt()
|
||||||
|
async def handle_get_prompt(name: str, arguments: dict[str, Any]) -> str:
|
||||||
|
"""Handle prompt get request"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Handling prompt get request in worker: {name}")
|
||||||
|
result = await prompts_manager.get_prompt(name, arguments)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to handle prompt get request in worker: {e}")
|
||||||
|
error_result = json.dumps(
|
||||||
|
{
|
||||||
|
"error": f"Failed to get prompt: {str(e)}",
|
||||||
|
"prompt_name": name,
|
||||||
|
"arguments": arguments,
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
indent=2,
|
||||||
|
)
|
||||||
|
return error_result
|
||||||
|
|
||||||
|
# Create session manager for this worker
|
||||||
|
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||||
|
|
||||||
|
_worker_session_manager = StreamableHTTPSessionManager(
|
||||||
|
app=_worker_server,
|
||||||
|
json_response=True,
|
||||||
|
stateless=True # Use stateless mode for multi-worker compatibility
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start the session manager context
|
||||||
|
_worker_session_manager_context = _worker_session_manager.run()
|
||||||
|
await _worker_session_manager_context.__aenter__()
|
||||||
|
|
||||||
|
# Initialize OAuth and Token handlers
|
||||||
|
from .auth.oauth_handlers import OAuthHandlers
|
||||||
|
from .auth.token_handlers import TokenHandlers
|
||||||
|
_oauth_handlers = OAuthHandlers(_worker_security_manager)
|
||||||
|
_token_handlers = TokenHandlers(_worker_security_manager, config)
|
||||||
|
|
||||||
|
_worker_initialized = True
|
||||||
|
logger.info(f"Worker {os.getpid()} MCP initialization completed successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
from .utils.logger import get_logger
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
logger.error(f"Failed to initialize worker {os.getpid()}: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error("Complete error stack:")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def health_check(request):
|
||||||
|
"""Health check endpoint that shows worker PID"""
|
||||||
|
return JSONResponse({
|
||||||
|
"status": "healthy",
|
||||||
|
"service": "doris-mcp-server",
|
||||||
|
"worker_pid": os.getpid(),
|
||||||
|
"worker_mode": "multi-process-full-mcp",
|
||||||
|
"mcp_initialized": _worker_initialized,
|
||||||
|
"mcp_version": MCP_VERSION
|
||||||
|
})
|
||||||
|
|
||||||
|
# OAuth and Token handlers (initialize after worker setup)
|
||||||
|
_oauth_handlers = None
|
||||||
|
_token_handlers = None
|
||||||
|
|
||||||
|
async def oauth_login(request):
|
||||||
|
"""OAuth login endpoint"""
|
||||||
|
if not _oauth_handlers:
|
||||||
|
return JSONResponse({"error": "OAuth not initialized"}, status_code=503)
|
||||||
|
return await _oauth_handlers.handle_login(request)
|
||||||
|
|
||||||
|
async def oauth_callback(request):
|
||||||
|
"""OAuth callback endpoint"""
|
||||||
|
if not _oauth_handlers:
|
||||||
|
return JSONResponse({"error": "OAuth not initialized"}, status_code=503)
|
||||||
|
return await _oauth_handlers.handle_callback(request)
|
||||||
|
|
||||||
|
async def oauth_provider_info(request):
|
||||||
|
"""OAuth provider info endpoint"""
|
||||||
|
if not _oauth_handlers:
|
||||||
|
return JSONResponse({"error": "OAuth not initialized"}, status_code=503)
|
||||||
|
return await _oauth_handlers.handle_provider_info(request)
|
||||||
|
|
||||||
|
async def oauth_demo(request):
|
||||||
|
"""OAuth demo page endpoint"""
|
||||||
|
if not _oauth_handlers:
|
||||||
|
from starlette.responses import HTMLResponse
|
||||||
|
return HTMLResponse("<h1>OAuth not initialized</h1>")
|
||||||
|
return await _oauth_handlers.handle_demo_page(request)
|
||||||
|
|
||||||
|
# Token management endpoints
|
||||||
|
async def token_create(request):
|
||||||
|
"""Token creation endpoint"""
|
||||||
|
if not _token_handlers:
|
||||||
|
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
|
||||||
|
return await _token_handlers.handle_create_token(request)
|
||||||
|
|
||||||
|
async def token_revoke(request):
|
||||||
|
"""Token revocation endpoint"""
|
||||||
|
if not _token_handlers:
|
||||||
|
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
|
||||||
|
return await _token_handlers.handle_revoke_token(request)
|
||||||
|
|
||||||
|
async def token_list(request):
|
||||||
|
"""Token listing endpoint"""
|
||||||
|
if not _token_handlers:
|
||||||
|
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
|
||||||
|
return await _token_handlers.handle_list_tokens(request)
|
||||||
|
|
||||||
|
async def token_stats(request):
|
||||||
|
"""Token statistics endpoint"""
|
||||||
|
if not _token_handlers:
|
||||||
|
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
|
||||||
|
return await _token_handlers.handle_token_stats(request)
|
||||||
|
|
||||||
|
async def token_cleanup(request):
|
||||||
|
"""Token cleanup endpoint"""
|
||||||
|
if not _token_handlers:
|
||||||
|
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
|
||||||
|
return await _token_handlers.handle_cleanup_tokens(request)
|
||||||
|
|
||||||
|
async def token_management(request):
|
||||||
|
"""Token management page endpoint"""
|
||||||
|
if not _token_handlers:
|
||||||
|
from starlette.responses import HTMLResponse
|
||||||
|
return HTMLResponse("<h1>Token handlers not initialized</h1>")
|
||||||
|
return await _token_handlers.handle_management_page(request)
|
||||||
|
|
||||||
|
async def root_info(request):
|
||||||
|
"""Root endpoint"""
|
||||||
|
return JSONResponse({
|
||||||
|
"service": "doris-mcp-server",
|
||||||
|
"mode": "multi-worker-full-mcp",
|
||||||
|
"worker_pid": os.getpid(),
|
||||||
|
"mcp_initialized": _worker_initialized,
|
||||||
|
"mcp_version": MCP_VERSION,
|
||||||
|
"endpoints": {
|
||||||
|
"health": "/health",
|
||||||
|
"mcp": "/mcp"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app):
|
||||||
|
"""Application lifespan manager"""
|
||||||
|
# Startup
|
||||||
|
try:
|
||||||
|
await initialize_worker()
|
||||||
|
# Import logger properly
|
||||||
|
from .utils.logger import get_logger
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
logger.info(f"Worker {os.getpid()} startup completed")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Shutdown
|
||||||
|
from .utils.logger import get_logger
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Close session manager context
|
||||||
|
if _worker_session_manager_context:
|
||||||
|
try:
|
||||||
|
await _worker_session_manager_context.__aexit__(None, None, None)
|
||||||
|
logger.info(f"Worker {os.getpid()} session manager context closed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing worker session manager context: {e}")
|
||||||
|
|
||||||
|
if _worker_connection_manager:
|
||||||
|
try:
|
||||||
|
await _worker_connection_manager.close()
|
||||||
|
logger.info(f"Worker {os.getpid()} connection manager closed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing worker connection manager: {e}")
|
||||||
|
|
||||||
|
if _worker_security_manager:
|
||||||
|
try:
|
||||||
|
await _worker_security_manager.shutdown()
|
||||||
|
logger.info(f"Worker {os.getpid()} security manager shutdown completed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error shutting down worker security manager: {e}")
|
||||||
|
|
||||||
|
# Shutdown logging system
|
||||||
|
try:
|
||||||
|
from .utils.logger import shutdown_logging
|
||||||
|
shutdown_logging()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error shutting down logging system: {e}")
|
||||||
|
|
||||||
|
async def mcp_asgi_app(scope, receive, send):
|
||||||
|
"""ASGI app that handles MCP requests"""
|
||||||
|
if not _worker_initialized:
|
||||||
|
# Send error response if worker not initialized
|
||||||
|
await send({
|
||||||
|
'type': 'http.response.start',
|
||||||
|
'status': 503,
|
||||||
|
'headers': [(b'content-type', b'application/json')]
|
||||||
|
})
|
||||||
|
await send({
|
||||||
|
'type': 'http.response.body',
|
||||||
|
'body': b'{"error": "Worker not initialized"}'
|
||||||
|
})
|
||||||
|
return
|
||||||
|
|
||||||
|
# Import logger properly
|
||||||
|
from .utils.logger import get_logger
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Get request path for logging
|
||||||
|
path = scope.get('path', '')
|
||||||
|
method = scope.get('method', 'UNKNOWN')
|
||||||
|
logger.debug(f"Worker {os.getpid()} handling MCP request: {method} {path}")
|
||||||
|
|
||||||
|
# Handle the request directly without nested run context
|
||||||
|
await _worker_session_manager.handle_request(scope, receive, send)
|
||||||
|
|
||||||
|
# Create Starlette app with basic routes
|
||||||
|
basic_app = Starlette(
|
||||||
|
debug=True,
|
||||||
|
routes=[
|
||||||
|
Route("/", root_info, methods=["GET"]),
|
||||||
|
Route("/health", health_check, methods=["GET"]),
|
||||||
|
# OAuth endpoints
|
||||||
|
Route("/auth/login", oauth_login, methods=["GET"]),
|
||||||
|
Route("/auth/callback", oauth_callback, methods=["GET"]),
|
||||||
|
Route("/auth/provider", oauth_provider_info, methods=["GET"]),
|
||||||
|
Route("/auth/demo", oauth_demo, methods=["GET"]),
|
||||||
|
# Token management endpoints
|
||||||
|
Route("/token/create", token_create, methods=["GET", "POST"]),
|
||||||
|
Route("/token/revoke", token_revoke, methods=["GET", "DELETE"]),
|
||||||
|
Route("/token/list", token_list, methods=["GET"]),
|
||||||
|
Route("/token/stats", token_stats, methods=["GET"]),
|
||||||
|
Route("/token/cleanup", token_cleanup, methods=["GET", "POST"]),
|
||||||
|
Route("/token/management", token_management, methods=["GET"]),
|
||||||
|
],
|
||||||
|
lifespan=lifespan
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create main ASGI app that routes between basic app and MCP
|
||||||
|
async def app(scope, receive, send):
|
||||||
|
"""Main ASGI app that routes requests"""
|
||||||
|
path = scope.get('path', '/')
|
||||||
|
|
||||||
|
if path == "/mcp" or path.startswith('/mcp/'):
|
||||||
|
# Handle MCP requests with session manager
|
||||||
|
await mcp_asgi_app(scope, receive, send)
|
||||||
|
else:
|
||||||
|
# Handle other requests with basic Starlette app (includes auth endpoints)
|
||||||
|
await basic_app(scope, receive, send)
|
||||||
@@ -31,6 +31,7 @@ from mcp.types import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from ..utils.db import DorisConnectionManager
|
from ..utils.db import DorisConnectionManager
|
||||||
|
from ..utils.sql_security_utils import get_auth_context
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplate:
|
class PromptTemplate:
|
||||||
@@ -422,7 +423,8 @@ Please generate accurate and efficient SQL queries based on the above requiremen
|
|||||||
AND table_type = 'BASE TABLE'
|
AND table_type = 'BASE TABLE'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
db_result = await connection.execute(db_info_sql)
|
auth_context = get_auth_context()
|
||||||
|
db_result = await connection.execute(db_info_sql, auth_context=auth_context)
|
||||||
db_info = db_result.data[0] if db_result.data else {}
|
db_info = db_result.data[0] if db_result.data else {}
|
||||||
|
|
||||||
# Get main table list
|
# Get main table list
|
||||||
@@ -438,7 +440,7 @@ Please generate accurate and efficient SQL queries based on the above requiremen
|
|||||||
LIMIT 10
|
LIMIT 10
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tables_result = await connection.execute(tables_sql)
|
tables_result = await connection.execute(tables_sql, auth_context=auth_context)
|
||||||
|
|
||||||
context = f"""Current database statistics:
|
context = f"""Current database statistics:
|
||||||
- Total number of tables: {db_info.get("table_count", 0)}
|
- Total number of tables: {db_info.get("table_count", 0)}
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from typing import Any
|
|||||||
from mcp.types import Resource
|
from mcp.types import Resource
|
||||||
|
|
||||||
from ..utils.db import DorisConnectionManager
|
from ..utils.db import DorisConnectionManager
|
||||||
|
from ..utils.sql_security_utils import get_auth_context
|
||||||
|
|
||||||
|
|
||||||
class TableMetadata:
|
class TableMetadata:
|
||||||
@@ -169,7 +170,8 @@ class DorisResourcesManager:
|
|||||||
ORDER BY table_name
|
ORDER BY table_name
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(tables_query)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(tables_query, auth_context=auth_context)
|
||||||
tables = []
|
tables = []
|
||||||
|
|
||||||
for row in result.data:
|
for row in result.data:
|
||||||
@@ -204,7 +206,8 @@ class DorisResourcesManager:
|
|||||||
ORDER BY ordinal_position
|
ORDER BY ordinal_position
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(columns_query, (table_name,))
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(columns_query, params=(table_name,), auth_context=auth_context)
|
||||||
return [dict(row) for row in result.data]
|
return [dict(row) for row in result.data]
|
||||||
|
|
||||||
async def _get_view_metadata(self) -> list[ViewMetadata]:
|
async def _get_view_metadata(self) -> list[ViewMetadata]:
|
||||||
@@ -226,7 +229,8 @@ class DorisResourcesManager:
|
|||||||
ORDER BY table_name
|
ORDER BY table_name
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(views_query)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(views_query, auth_context=auth_context)
|
||||||
views = []
|
views = []
|
||||||
|
|
||||||
for row in result.data:
|
for row in result.data:
|
||||||
@@ -257,7 +261,8 @@ class DorisResourcesManager:
|
|||||||
AND table_name = %s
|
AND table_name = %s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
table_result = await connection.execute(table_info_query, (table_name,))
|
auth_context = get_auth_context()
|
||||||
|
table_result = await connection.execute(table_info_query, params=(table_name,), auth_context=auth_context)
|
||||||
if not table_result.data:
|
if not table_result.data:
|
||||||
raise ValueError(f"Table {table_name} does not exist")
|
raise ValueError(f"Table {table_name} does not exist")
|
||||||
|
|
||||||
@@ -295,7 +300,8 @@ class DorisResourcesManager:
|
|||||||
ORDER BY index_name, seq_in_index
|
ORDER BY index_name, seq_in_index
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(indexes_query, (table_name,))
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(indexes_query, params=(table_name,), auth_context=auth_context)
|
||||||
return [dict(row) for row in result.data]
|
return [dict(row) for row in result.data]
|
||||||
|
|
||||||
async def _get_view_definition(self, view_name: str) -> str:
|
async def _get_view_definition(self, view_name: str) -> str:
|
||||||
@@ -312,7 +318,8 @@ class DorisResourcesManager:
|
|||||||
AND table_name = %s
|
AND table_name = %s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(view_query, (view_name,))
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(view_query, params=(view_name,), auth_context=auth_context)
|
||||||
if not result.data:
|
if not result.data:
|
||||||
raise ValueError(f"View {view_name} does not exist")
|
raise ValueError(f"View {view_name} does not exist")
|
||||||
|
|
||||||
@@ -340,7 +347,8 @@ class DorisResourcesManager:
|
|||||||
AND table_type = 'BASE TABLE'
|
AND table_type = 'BASE TABLE'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
table_result = await connection.execute(table_stats_query)
|
auth_context = get_auth_context()
|
||||||
|
table_result = await connection.execute(table_stats_query, auth_context=auth_context)
|
||||||
table_stats = table_result.data[0] if table_result.data else {}
|
table_stats = table_result.data[0] if table_result.data else {}
|
||||||
|
|
||||||
# Get view statistics
|
# Get view statistics
|
||||||
@@ -350,7 +358,7 @@ class DorisResourcesManager:
|
|||||||
WHERE table_schema = DATABASE()
|
WHERE table_schema = DATABASE()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
view_result = await connection.execute(view_stats_query)
|
view_result = await connection.execute(view_stats_query, auth_context=auth_context)
|
||||||
view_stats = view_result.data[0] if view_result.data else {}
|
view_stats = view_result.data[0] if view_result.data else {}
|
||||||
|
|
||||||
stats_info = {
|
stats_info = {
|
||||||
|
|||||||
542
doris_mcp_server/utils/adbc_query_tools.py
Normal file
@@ -0,0 +1,542 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Apache Doris ADBC Query Tools
|
||||||
|
High-performance data querying using Apache Arrow Flight SQL protocol
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from ..utils.logger import get_logger
|
||||||
|
from ..utils.db import DorisConnectionManager
|
||||||
|
from ..utils.sql_security_utils import get_auth_context
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_numpy_types(obj):
|
||||||
|
"""Convert numpy types to native Python types for JSON serialization"""
|
||||||
|
try:
|
||||||
|
# Import numpy only when needed
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
if isinstance(obj, np.integer):
|
||||||
|
return int(obj)
|
||||||
|
elif isinstance(obj, np.floating):
|
||||||
|
return float(obj)
|
||||||
|
elif isinstance(obj, np.bool_):
|
||||||
|
return bool(obj)
|
||||||
|
elif isinstance(obj, np.ndarray):
|
||||||
|
return obj.tolist()
|
||||||
|
elif isinstance(obj, (pd.Timestamp, pd.NaT.__class__)):
|
||||||
|
return str(obj)
|
||||||
|
elif pd.isna(obj):
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
except ImportError:
|
||||||
|
# If numpy/pandas not available, return as-is
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_dataframe_to_json_serializable(df):
|
||||||
|
"""Convert DataFrame to JSON serializable format"""
|
||||||
|
try:
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Convert DataFrame to records
|
||||||
|
records = df.to_dict('records')
|
||||||
|
|
||||||
|
# Convert each record's values
|
||||||
|
converted_records = []
|
||||||
|
for record in records:
|
||||||
|
converted_record = {}
|
||||||
|
for key, value in record.items():
|
||||||
|
converted_record[key] = _convert_numpy_types(value)
|
||||||
|
converted_records.append(converted_record)
|
||||||
|
|
||||||
|
return converted_records
|
||||||
|
except ImportError:
|
||||||
|
# Fallback to basic dict conversion
|
||||||
|
return df.to_dict('records')
|
||||||
|
|
||||||
|
|
||||||
|
class DorisADBCQueryTools:
|
||||||
|
"""ADBC Query Tools for high-performance data transfer using Arrow Flight SQL"""
|
||||||
|
|
||||||
|
def __init__(self, connection_manager: DorisConnectionManager):
|
||||||
|
self.connection_manager = connection_manager
|
||||||
|
self.adbc_client = None
|
||||||
|
self.flight_sql_module = None
|
||||||
|
self.adbc_manager_module = None
|
||||||
|
|
||||||
|
async def exec_adbc_query(
|
||||||
|
self,
|
||||||
|
sql: str,
|
||||||
|
max_rows: int | None = None,
|
||||||
|
timeout: int | None = None,
|
||||||
|
return_format: str | None = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Execute SQL query using ADBC (Arrow Flight SQL) protocol
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sql: SQL statement to execute
|
||||||
|
max_rows: Maximum number of rows to return (uses config default if None)
|
||||||
|
timeout: Query timeout in seconds (uses config default if None)
|
||||||
|
return_format: Format for returned data ("arrow", "pandas", "dict", uses config default if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Query results in specified format with metadata
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Use configuration defaults if parameters not specified
|
||||||
|
adbc_config = self.connection_manager.config.adbc
|
||||||
|
max_rows = max_rows if max_rows is not None else adbc_config.default_max_rows
|
||||||
|
timeout = timeout if timeout is not None else adbc_config.default_timeout
|
||||||
|
return_format = return_format if return_format is not None else adbc_config.default_return_format
|
||||||
|
|
||||||
|
# Step 1: Check environment variables and port availability
|
||||||
|
port_check_result = await self._check_arrow_flight_ports()
|
||||||
|
if not port_check_result["success"]:
|
||||||
|
return port_check_result
|
||||||
|
|
||||||
|
# Step 2: Import required ADBC modules
|
||||||
|
import_result = await self._import_adbc_modules()
|
||||||
|
if not import_result["success"]:
|
||||||
|
return import_result
|
||||||
|
|
||||||
|
# Step 3: Create ADBC connection
|
||||||
|
connection_result = await self._create_adbc_connection()
|
||||||
|
if not connection_result["success"]:
|
||||||
|
return connection_result
|
||||||
|
|
||||||
|
# Step 4: Execute query using ADBC
|
||||||
|
query_result = await self._execute_query_with_adbc(
|
||||||
|
sql, max_rows, timeout, return_format
|
||||||
|
)
|
||||||
|
|
||||||
|
execution_time = time.time() - start_time
|
||||||
|
|
||||||
|
if query_result["success"]:
|
||||||
|
query_result["execution_time"] = round(execution_time, 3)
|
||||||
|
query_result["protocol"] = "ADBC_Arrow_Flight_SQL"
|
||||||
|
query_result["timestamp"] = datetime.now().isoformat()
|
||||||
|
|
||||||
|
return query_result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"ADBC query execution failed: {str(e)}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"ADBC query execution failed: {str(e)}",
|
||||||
|
"error_type": "execution_error",
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _check_arrow_flight_ports(self) -> Dict[str, Any]:
|
||||||
|
"""Check Arrow Flight SQL port configuration and availability"""
|
||||||
|
try:
|
||||||
|
# Check environment variables
|
||||||
|
fe_port = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
|
||||||
|
be_port = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
|
||||||
|
|
||||||
|
if not fe_port:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "Missing environment variable FE_ARROW_FLIGHT_SQL_PORT, please configure Arrow Flight SQL FE port in .env file",
|
||||||
|
"error_type": "missing_fe_port_config"
|
||||||
|
}
|
||||||
|
|
||||||
|
if not be_port:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "Missing environment variable BE_ARROW_FLIGHT_SQL_PORT, please configure Arrow Flight SQL BE port in .env file",
|
||||||
|
"error_type": "missing_be_port_config"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Convert to integer and validate
|
||||||
|
try:
|
||||||
|
fe_port = int(fe_port)
|
||||||
|
be_port = int(be_port)
|
||||||
|
except ValueError:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "Invalid Arrow Flight SQL port configuration, please ensure FE_ARROW_FLIGHT_SQL_PORT and BE_ARROW_FLIGHT_SQL_PORT are valid numbers",
|
||||||
|
"error_type": "invalid_port_format"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get host address
|
||||||
|
db_config = self.connection_manager.config.database
|
||||||
|
fe_host = db_config.host
|
||||||
|
|
||||||
|
# Check FE Arrow Flight SQL port availability
|
||||||
|
fe_available = self._check_port_connectivity(fe_host, fe_port)
|
||||||
|
if not fe_available:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Cannot connect to FE Arrow Flight SQL port {fe_host}:{fe_port}, please check if service is running",
|
||||||
|
"error_type": "fe_port_unavailable",
|
||||||
|
"fe_host": fe_host,
|
||||||
|
"fe_port": fe_port
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get BE host list
|
||||||
|
be_hosts = await self._get_be_hosts()
|
||||||
|
if not be_hosts:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "Cannot get BE node information, please check cluster status",
|
||||||
|
"error_type": "no_be_hosts"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check at least one BE Arrow Flight SQL port availability
|
||||||
|
be_available_count = 0
|
||||||
|
be_check_results = []
|
||||||
|
|
||||||
|
for be_host in be_hosts[:3]: # Check first 3 BE nodes
|
||||||
|
be_available = self._check_port_connectivity(be_host, be_port)
|
||||||
|
be_check_results.append({
|
||||||
|
"host": be_host,
|
||||||
|
"port": be_port,
|
||||||
|
"available": be_available
|
||||||
|
})
|
||||||
|
if be_available:
|
||||||
|
be_available_count += 1
|
||||||
|
|
||||||
|
if be_available_count == 0:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Cannot connect to any BE Arrow Flight SQL port (port: {be_port}), please check if BE services are running",
|
||||||
|
"error_type": "no_be_ports_available",
|
||||||
|
"be_check_results": be_check_results
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"fe_host": fe_host,
|
||||||
|
"fe_port": fe_port,
|
||||||
|
"be_port": be_port,
|
||||||
|
"be_hosts": be_hosts,
|
||||||
|
"be_available_count": be_available_count,
|
||||||
|
"be_check_results": be_check_results
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Arrow Flight port check failed: {str(e)}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Arrow Flight port check failed: {str(e)}",
|
||||||
|
"error_type": "port_check_error"
|
||||||
|
}
|
||||||
|
|
||||||
|
def _check_port_connectivity(self, host: str, port: int, timeout: int | None = None) -> bool:
|
||||||
|
"""Check port connectivity"""
|
||||||
|
try:
|
||||||
|
# Use config timeout if not specified
|
||||||
|
if timeout is None:
|
||||||
|
timeout = self.connection_manager.config.adbc.connection_timeout
|
||||||
|
|
||||||
|
with socket.create_connection((host, port), timeout=timeout):
|
||||||
|
return True
|
||||||
|
except (socket.timeout, socket.error, OSError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _get_be_hosts(self) -> List[str]:
|
||||||
|
"""Get BE host list"""
|
||||||
|
try:
|
||||||
|
db_config = self.connection_manager.config.database
|
||||||
|
|
||||||
|
# Use configured BE hosts first
|
||||||
|
if db_config.be_hosts:
|
||||||
|
logger.info(f"Using configured BE hosts: {db_config.be_hosts}")
|
||||||
|
return db_config.be_hosts
|
||||||
|
|
||||||
|
# Get BE nodes via SHOW BACKENDS
|
||||||
|
logger.info("No BE hosts configured, getting BE node information via SHOW BACKENDS")
|
||||||
|
connection = await self.connection_manager.get_connection("query")
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute("SHOW BACKENDS", auth_context=auth_context)
|
||||||
|
|
||||||
|
be_hosts = []
|
||||||
|
for row in result.data:
|
||||||
|
host = row.get("Host")
|
||||||
|
alive = row.get("Alive", "").lower()
|
||||||
|
if host and alive == "true":
|
||||||
|
be_hosts.append(host)
|
||||||
|
|
||||||
|
logger.info(f"Got {len(be_hosts)} active BE nodes from SHOW BACKENDS")
|
||||||
|
return be_hosts
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get BE hosts: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _import_adbc_modules(self) -> Dict[str, Any]:
|
||||||
|
"""Import ADBC related modules"""
|
||||||
|
try:
|
||||||
|
# Import ADBC Driver Manager
|
||||||
|
try:
|
||||||
|
import adbc_driver_manager
|
||||||
|
self.adbc_manager_module = adbc_driver_manager
|
||||||
|
except ImportError:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "Missing adbc_driver_manager module, please install: pip install adbc_driver_manager",
|
||||||
|
"error_type": "missing_adbc_manager"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Import ADBC Flight SQL Driver
|
||||||
|
try:
|
||||||
|
import adbc_driver_flightsql.dbapi as flight_sql
|
||||||
|
self.flight_sql_module = flight_sql
|
||||||
|
except ImportError:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "Missing adbc_driver_flightsql module, please install: pip install adbc_driver_flightsql",
|
||||||
|
"error_type": "missing_flight_sql_driver"
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"adbc_manager_version": getattr(adbc_driver_manager, '__version__', 'unknown'),
|
||||||
|
"flight_sql_version": getattr(flight_sql, '__version__', 'unknown')
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"ADBC module import failed: {str(e)}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"ADBC module import failed: {str(e)}",
|
||||||
|
"error_type": "import_error"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _create_adbc_connection(self) -> Dict[str, Any]:
|
||||||
|
"""Create ADBC connection"""
|
||||||
|
try:
|
||||||
|
db_config = self.connection_manager.config.database
|
||||||
|
fe_port = int(os.getenv("FE_ARROW_FLIGHT_SQL_PORT"))
|
||||||
|
|
||||||
|
# Build connection URI
|
||||||
|
uri = f"grpc://{db_config.host}:{fe_port}"
|
||||||
|
|
||||||
|
# Create database connection parameters
|
||||||
|
db_kwargs = {
|
||||||
|
self.adbc_manager_module.DatabaseOptions.USERNAME.value: db_config.user,
|
||||||
|
self.adbc_manager_module.DatabaseOptions.PASSWORD.value: db_config.password,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create connection
|
||||||
|
self.adbc_client = self.flight_sql_module.connect(
|
||||||
|
uri=uri,
|
||||||
|
db_kwargs=db_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"uri": uri,
|
||||||
|
"connection_established": True
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create ADBC connection: {str(e)}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Failed to create ADBC connection: {str(e)}",
|
||||||
|
"error_type": "connection_error"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _execute_query_with_adbc(
|
||||||
|
self,
|
||||||
|
sql: str,
|
||||||
|
max_rows: int,
|
||||||
|
timeout: int,
|
||||||
|
return_format: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Execute query using ADBC"""
|
||||||
|
try:
|
||||||
|
if not self.adbc_client:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "ADBC connection not established",
|
||||||
|
"error_type": "no_connection"
|
||||||
|
}
|
||||||
|
|
||||||
|
# SECURITY FIX: Perform SQL security validation before executing
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
if self.connection_manager.security_manager:
|
||||||
|
# Always perform security validation, even without auth_context
|
||||||
|
# Use a default context for basic SQL security checks
|
||||||
|
validation_result = await self.connection_manager.security_manager.validate_sql_security(sql, auth_context)
|
||||||
|
if not validation_result.is_valid:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"SQL security validation failed: {validation_result.error_message}",
|
||||||
|
"error_type": "security_violation",
|
||||||
|
"risk_level": validation_result.risk_level
|
||||||
|
}
|
||||||
|
|
||||||
|
cursor = self.adbc_client.cursor()
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Execute query
|
||||||
|
cursor.execute(sql)
|
||||||
|
|
||||||
|
# Get results based on return format
|
||||||
|
if return_format == "arrow":
|
||||||
|
# Return Arrow format
|
||||||
|
arrow_data = cursor.fetchallarrow()
|
||||||
|
|
||||||
|
# Limit rows
|
||||||
|
if len(arrow_data) > max_rows:
|
||||||
|
arrow_data = arrow_data.slice(0, max_rows)
|
||||||
|
|
||||||
|
# Convert Arrow data to serializable format
|
||||||
|
preview_df = arrow_data.to_pandas().head(10) if len(arrow_data) > 0 else None
|
||||||
|
result_data = {
|
||||||
|
"format": "arrow",
|
||||||
|
"num_rows": len(arrow_data),
|
||||||
|
"num_columns": len(arrow_data.schema),
|
||||||
|
"column_names": arrow_data.schema.names,
|
||||||
|
"column_types": [str(field.type) for field in arrow_data.schema],
|
||||||
|
"data_preview": _convert_dataframe_to_json_serializable(preview_df) if preview_df is not None else [],
|
||||||
|
"total_bytes": arrow_data.nbytes if hasattr(arrow_data, 'nbytes') else 0
|
||||||
|
}
|
||||||
|
|
||||||
|
elif return_format == "pandas":
|
||||||
|
# Return Pandas DataFrame
|
||||||
|
df = cursor.fetch_df()
|
||||||
|
|
||||||
|
# Limit rows
|
||||||
|
if len(df) > max_rows:
|
||||||
|
df = df.head(max_rows)
|
||||||
|
|
||||||
|
result_data = {
|
||||||
|
"format": "pandas",
|
||||||
|
"num_rows": len(df),
|
||||||
|
"num_columns": len(df.columns),
|
||||||
|
"column_names": df.columns.tolist(),
|
||||||
|
"column_types": df.dtypes.astype(str).tolist(),
|
||||||
|
"data": _convert_dataframe_to_json_serializable(df),
|
||||||
|
"memory_usage": int(df.memory_usage(deep=True).sum())
|
||||||
|
}
|
||||||
|
|
||||||
|
else: # return_format == "dict"
|
||||||
|
# Return dictionary format
|
||||||
|
arrow_data = cursor.fetchallarrow()
|
||||||
|
df = arrow_data.to_pandas()
|
||||||
|
|
||||||
|
# Limit rows
|
||||||
|
if len(df) > max_rows:
|
||||||
|
df = df.head(max_rows)
|
||||||
|
|
||||||
|
result_data = {
|
||||||
|
"format": "dict",
|
||||||
|
"num_rows": len(df),
|
||||||
|
"num_columns": len(df.columns),
|
||||||
|
"column_names": df.columns.tolist(),
|
||||||
|
"column_types": df.dtypes.astype(str).tolist(),
|
||||||
|
"data": _convert_dataframe_to_json_serializable(df)
|
||||||
|
}
|
||||||
|
|
||||||
|
execution_time = time.time() - start_time
|
||||||
|
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"result": result_data,
|
||||||
|
"execution_time": round(execution_time, 3),
|
||||||
|
"sql": sql,
|
||||||
|
"max_rows_applied": len(result_data.get("data", [])) >= max_rows
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"ADBC query execution failed: {str(e)}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"ADBC query execution failed: {str(e)}",
|
||||||
|
"error_type": "query_execution_error",
|
||||||
|
"sql": sql
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_adbc_connection_info(self) -> Dict[str, Any]:
|
||||||
|
"""Get ADBC connection information and status"""
|
||||||
|
try:
|
||||||
|
# Check port status
|
||||||
|
port_status = await self._check_arrow_flight_ports()
|
||||||
|
|
||||||
|
# Check module status
|
||||||
|
module_status = await self._import_adbc_modules()
|
||||||
|
|
||||||
|
# Get configuration information
|
||||||
|
db_config = self.connection_manager.config.database
|
||||||
|
fe_port = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
|
||||||
|
be_port = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
|
||||||
|
|
||||||
|
connection_info = {
|
||||||
|
"adbc_available": module_status["success"],
|
||||||
|
"ports_available": port_status["success"],
|
||||||
|
"configuration": {
|
||||||
|
"fe_host": db_config.host,
|
||||||
|
"fe_arrow_flight_port": fe_port,
|
||||||
|
"be_arrow_flight_port": be_port,
|
||||||
|
"user": db_config.user
|
||||||
|
},
|
||||||
|
"port_status": port_status,
|
||||||
|
"module_status": module_status,
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
if port_status["success"] and module_status["success"]:
|
||||||
|
connection_info["status"] = "ready"
|
||||||
|
connection_info["message"] = "ADBC Arrow Flight SQL connection ready"
|
||||||
|
else:
|
||||||
|
connection_info["status"] = "not_ready"
|
||||||
|
errors = []
|
||||||
|
if not port_status["success"]:
|
||||||
|
errors.append(port_status["error"])
|
||||||
|
if not module_status["success"]:
|
||||||
|
errors.append(module_status["error"])
|
||||||
|
connection_info["message"] = "; ".join(errors)
|
||||||
|
|
||||||
|
return connection_info
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get ADBC connection information: {str(e)}")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error": f"Failed to get ADBC connection information: {str(e)}",
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Cleanup resources"""
|
||||||
|
try:
|
||||||
|
if self.adbc_client:
|
||||||
|
self.adbc_client.close()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
@@ -32,6 +32,8 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
load_dotenv = None
|
load_dotenv = None
|
||||||
|
|
||||||
|
from .logger import get_logger
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DatabaseConfig:
|
class DatabaseConfig:
|
||||||
@@ -41,38 +43,105 @@ class DatabaseConfig:
|
|||||||
port: int = 9030
|
port: int = 9030
|
||||||
user: str = "root"
|
user: str = "root"
|
||||||
password: str = ""
|
password: str = ""
|
||||||
database: str = "test"
|
database: str = "information_schema"
|
||||||
charset: str = "utf8mb4"
|
charset: str = "UTF8"
|
||||||
|
|
||||||
|
# FE HTTP API port for profile and other HTTP APIs
|
||||||
|
fe_http_port: int = 8030
|
||||||
|
|
||||||
|
# BE nodes configuration for external access
|
||||||
|
# If be_hosts is empty, will use "show backends" to get BE nodes
|
||||||
|
be_hosts: list[str] = field(default_factory=list)
|
||||||
|
be_webserver_port: int = 8040
|
||||||
|
|
||||||
|
# Arrow Flight SQL Configuration (Required for ADBC tools)
|
||||||
|
fe_arrow_flight_sql_port: int | None = None
|
||||||
|
be_arrow_flight_sql_port: int | None = None
|
||||||
|
|
||||||
# Connection pool configuration
|
# Connection pool configuration
|
||||||
min_connections: int = 5
|
# Note: min_connections is fixed at 0 to avoid at_eof connection issues
|
||||||
|
# This prevents pre-creation of connections which can cause state problems
|
||||||
|
_min_connections: int = field(default=0, init=False) # Internal use only, always 0
|
||||||
max_connections: int = 20
|
max_connections: int = 20
|
||||||
connection_timeout: int = 30
|
connection_timeout: int = 30
|
||||||
health_check_interval: int = 60
|
health_check_interval: int = 60
|
||||||
max_connection_age: int = 3600
|
max_connection_age: int = 3600
|
||||||
|
|
||||||
|
@property
|
||||||
|
def min_connections(self) -> int:
|
||||||
|
"""Minimum connections is always 0 to prevent at_eof issues"""
|
||||||
|
return self._min_connections
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SecurityConfig:
|
class SecurityConfig:
|
||||||
"""Security configuration"""
|
"""Security configuration"""
|
||||||
|
|
||||||
# Authentication configuration
|
# Independent authentication switches - any one enabled allows that method
|
||||||
auth_type: str = "token" # token, basic, oauth
|
enable_token_auth: bool = False # Enable token-based authentication (default: disabled)
|
||||||
token_secret: str = "default_secret"
|
enable_jwt_auth: bool = False # Enable JWT authentication (default: disabled)
|
||||||
|
enable_oauth_auth: bool = False # Enable OAuth 2.0/OIDC authentication (default: disabled)
|
||||||
|
|
||||||
|
# Legacy configuration (kept for backward compatibility)
|
||||||
|
auth_type: str = "token" # jwt, token, basic, oauth (deprecated: use individual switches)
|
||||||
|
token_secret: str = "default_secret" # Legacy token secret for backward compatibility
|
||||||
token_expiry: int = 3600
|
token_expiry: int = 3600
|
||||||
|
|
||||||
|
# Enhanced Token Authentication Configuration
|
||||||
|
token_file_path: str = "tokens.json" # Path to token configuration file
|
||||||
|
enable_token_expiry: bool = True # Enable token expiration
|
||||||
|
default_token_expiry_hours: int = 24 * 30 # Default expiry: 30 days
|
||||||
|
token_hash_algorithm: str = "sha256" # Token hashing algorithm: sha256, sha512
|
||||||
|
|
||||||
|
# Token Management Security (New in v0.6.0)
|
||||||
|
enable_http_token_management: bool = False # Enable HTTP token management endpoints (default: disabled for security)
|
||||||
|
token_management_admin_token: str = "" # Admin token for token management endpoints (required if HTTP management enabled)
|
||||||
|
token_management_allowed_ips: list[str] = field(default_factory=lambda: ["127.0.0.1", "::1", "localhost"]) # Allowed IPs for token management
|
||||||
|
require_admin_auth: bool = True # Require admin authentication for token management (default: true)
|
||||||
|
|
||||||
|
# JWT Configuration
|
||||||
|
jwt_algorithm: str = "RS256" # RS256, ES256, HS256
|
||||||
|
jwt_issuer: str = "doris-mcp-server"
|
||||||
|
jwt_audience: str = "doris-mcp-client"
|
||||||
|
jwt_private_key_path: str = ""
|
||||||
|
jwt_public_key_path: str = ""
|
||||||
|
jwt_secret_key: str = "" # Only used for HS256 algorithm
|
||||||
|
jwt_access_token_expiry: int = 3600 # 1 hour
|
||||||
|
jwt_refresh_token_expiry: int = 86400 # 24 hours
|
||||||
|
enable_token_refresh: bool = True
|
||||||
|
enable_token_revocation: bool = True
|
||||||
|
key_rotation_interval: int = 30 * 24 * 3600 # 30 days in seconds
|
||||||
|
|
||||||
|
# JWT Security Features
|
||||||
|
jwt_require_iat: bool = True # Require "issued at" claim
|
||||||
|
jwt_require_exp: bool = True # Require "expires at" claim
|
||||||
|
jwt_require_nbf: bool = False # Require "not before" claim
|
||||||
|
jwt_leeway: int = 10 # Clock skew tolerance in seconds
|
||||||
|
jwt_verify_signature: bool = True # Verify JWT signature
|
||||||
|
jwt_verify_audience: bool = True # Verify audience claim
|
||||||
|
jwt_verify_issuer: bool = True # Verify issuer claim
|
||||||
|
|
||||||
# SQL security configuration
|
# SQL security configuration
|
||||||
|
enable_security_check: bool = True # Main switch: whether to enable SQL security check
|
||||||
blocked_keywords: list[str] = field(
|
blocked_keywords: list[str] = field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
|
# DDL Operations (Data Definition Language)
|
||||||
"DROP",
|
"DROP",
|
||||||
"DELETE",
|
"CREATE",
|
||||||
"TRUNCATE",
|
|
||||||
"ALTER",
|
"ALTER",
|
||||||
"CREATE",
|
"TRUNCATE",
|
||||||
|
# DML Operations (Data Manipulation Language)
|
||||||
|
"DELETE",
|
||||||
"INSERT",
|
"INSERT",
|
||||||
"UPDATE",
|
"UPDATE",
|
||||||
|
# DCL Operations (Data Control Language)
|
||||||
"GRANT",
|
"GRANT",
|
||||||
"REVOKE",
|
"REVOKE",
|
||||||
|
# System Operations
|
||||||
|
"EXEC",
|
||||||
|
"EXECUTE",
|
||||||
|
"SHUTDOWN",
|
||||||
|
"KILL",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
max_query_complexity: int = 100
|
max_query_complexity: int = 100
|
||||||
@@ -85,6 +154,45 @@ class SecurityConfig:
|
|||||||
enable_masking: bool = True
|
enable_masking: bool = True
|
||||||
masking_rules: list[dict[str, Any]] = field(default_factory=list)
|
masking_rules: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
# OAuth 2.0/OIDC Configuration
|
||||||
|
oauth_enabled: bool = False
|
||||||
|
oauth_provider: str = "" # 'google', 'microsoft', 'github', 'custom'
|
||||||
|
oauth_client_id: str = ""
|
||||||
|
oauth_client_secret: str = ""
|
||||||
|
oauth_redirect_uri: str = "http://localhost:3000/auth/callback"
|
||||||
|
|
||||||
|
# OIDC Discovery
|
||||||
|
oidc_discovery_url: str = "" # e.g., https://accounts.google.com/.well-known/openid_configuration
|
||||||
|
oauth_authorization_endpoint: str = ""
|
||||||
|
oauth_token_endpoint: str = ""
|
||||||
|
oauth_userinfo_endpoint: str = ""
|
||||||
|
oauth_jwks_uri: str = ""
|
||||||
|
|
||||||
|
# OAuth Scopes and Settings
|
||||||
|
oauth_scopes: list[str] = field(default_factory=list)
|
||||||
|
oauth_state_expiry: int = 600 # State parameter expiry in seconds (10 minutes)
|
||||||
|
oauth_pkce_enabled: bool = True # Enable PKCE for better security
|
||||||
|
oauth_nonce_enabled: bool = True # Enable nonce for OIDC
|
||||||
|
|
||||||
|
# User Mapping Configuration
|
||||||
|
oauth_user_id_claim: str = "sub" # JWT claim for user ID
|
||||||
|
oauth_email_claim: str = "email"
|
||||||
|
oauth_name_claim: str = "name"
|
||||||
|
oauth_roles_claim: str = "roles" # Custom claim for roles
|
||||||
|
oauth_default_roles: list[str] = field(default_factory=lambda: ["oauth_user"])
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Initialize default OAuth scopes based on provider"""
|
||||||
|
if not self.oauth_scopes and self.oauth_provider:
|
||||||
|
if self.oauth_provider == "google":
|
||||||
|
self.oauth_scopes = ["openid", "email", "profile"]
|
||||||
|
elif self.oauth_provider == "microsoft":
|
||||||
|
self.oauth_scopes = ["openid", "profile", "email", "User.Read"]
|
||||||
|
elif self.oauth_provider == "github":
|
||||||
|
self.oauth_scopes = ["user:email", "read:user"]
|
||||||
|
else:
|
||||||
|
self.oauth_scopes = ["openid", "email", "profile"]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PerformanceConfig:
|
class PerformanceConfig:
|
||||||
@@ -102,6 +210,52 @@ class PerformanceConfig:
|
|||||||
# Connection pool optimization configuration
|
# Connection pool optimization configuration
|
||||||
connection_pool_size: int = 20
|
connection_pool_size: int = 20
|
||||||
idle_timeout: int = 1800
|
idle_timeout: int = 1800
|
||||||
|
|
||||||
|
# Response content size limit (characters)
|
||||||
|
max_response_content_size: int = 4096
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataQualityConfig:
|
||||||
|
"""Data quality analysis configuration"""
|
||||||
|
|
||||||
|
# Column analysis configuration
|
||||||
|
max_columns_per_batch: int = 20 # Maximum columns to analyze in a single batch
|
||||||
|
default_sample_size: int = 100000 # Default sample size for analysis
|
||||||
|
|
||||||
|
# Sampling strategy configuration
|
||||||
|
small_table_threshold: int = 100000 # Tables smaller than this use full table analysis
|
||||||
|
medium_table_threshold: int = 1000000 # Tables smaller than this use simple LIMIT sampling
|
||||||
|
# Tables larger than medium_table_threshold use systematic sampling
|
||||||
|
|
||||||
|
# Performance optimization
|
||||||
|
enable_batch_analysis: bool = True # Enable batch analysis for multiple columns
|
||||||
|
batch_timeout: int = 300 # Timeout for batch analysis in seconds
|
||||||
|
|
||||||
|
# Accuracy vs Performance trade-off
|
||||||
|
enable_fast_mode: bool = False # Use approximate algorithms for faster results
|
||||||
|
fast_mode_sample_size: int = 10000 # Sample size for fast mode
|
||||||
|
|
||||||
|
# Statistical analysis configuration
|
||||||
|
enable_distribution_analysis: bool = True # Enable distribution analysis
|
||||||
|
histogram_bins: int = 20 # Number of bins for histogram analysis
|
||||||
|
percentile_levels: list[float] = field(default_factory=lambda: [0.25, 0.5, 0.75, 0.95, 0.99]) # Percentile levels to calculate
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ADBCConfig:
|
||||||
|
"""ADBC (Arrow Flight SQL) configuration"""
|
||||||
|
|
||||||
|
# Default query parameters
|
||||||
|
default_max_rows: int = 100000
|
||||||
|
default_timeout: int = 60
|
||||||
|
default_return_format: str = "arrow" # "arrow", "pandas", "dict"
|
||||||
|
|
||||||
|
# Connection timeout for ADBC
|
||||||
|
connection_timeout: int = 30
|
||||||
|
|
||||||
|
# Whether to enable ADBC tools
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -117,6 +271,11 @@ class LoggingConfig:
|
|||||||
# Audit log configuration
|
# Audit log configuration
|
||||||
enable_audit: bool = True
|
enable_audit: bool = True
|
||||||
audit_file_path: str | None = None
|
audit_file_path: str | None = None
|
||||||
|
|
||||||
|
# Log cleanup configuration
|
||||||
|
enable_cleanup: bool = True
|
||||||
|
max_age_days: int = 30
|
||||||
|
cleanup_interval_hours: int = 24
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -125,11 +284,11 @@ class MonitoringConfig:
|
|||||||
|
|
||||||
# Metrics collection configuration
|
# Metrics collection configuration
|
||||||
enable_metrics: bool = True
|
enable_metrics: bool = True
|
||||||
metrics_port: int = 8081
|
metrics_port: int = 3001
|
||||||
metrics_path: str = "/metrics"
|
metrics_path: str = "/metrics"
|
||||||
|
|
||||||
# Health check configuration
|
# Health check configuration
|
||||||
health_check_port: int = 8082
|
health_check_port: int = 3002
|
||||||
health_check_path: str = "/health"
|
health_check_path: str = "/health"
|
||||||
|
|
||||||
# Alert configuration
|
# Alert configuration
|
||||||
@@ -143,15 +302,22 @@ class DorisConfig:
|
|||||||
|
|
||||||
# Basic configuration
|
# Basic configuration
|
||||||
server_name: str = "doris-mcp-server"
|
server_name: str = "doris-mcp-server"
|
||||||
server_version: str = "1.0.0"
|
server_version: str = "0.4.1"
|
||||||
server_port: int = 8080
|
server_host: str = "localhost"
|
||||||
|
server_port: int = 3000
|
||||||
|
transport: str = "stdio"
|
||||||
|
|
||||||
|
# Temporary files configuration
|
||||||
|
temp_files_dir: str = "tmp" # Temporary files directory for Explain and Profile outputs
|
||||||
|
|
||||||
# Sub-configuration modules
|
# Sub-configuration modules
|
||||||
database: DatabaseConfig = field(default_factory=DatabaseConfig)
|
database: DatabaseConfig = field(default_factory=DatabaseConfig)
|
||||||
security: SecurityConfig = field(default_factory=SecurityConfig)
|
security: SecurityConfig = field(default_factory=SecurityConfig)
|
||||||
performance: PerformanceConfig = field(default_factory=PerformanceConfig)
|
performance: PerformanceConfig = field(default_factory=PerformanceConfig)
|
||||||
|
data_quality: DataQualityConfig = field(default_factory=DataQualityConfig)
|
||||||
logging: LoggingConfig = field(default_factory=LoggingConfig)
|
logging: LoggingConfig = field(default_factory=LoggingConfig)
|
||||||
monitoring: MonitoringConfig = field(default_factory=MonitoringConfig)
|
monitoring: MonitoringConfig = field(default_factory=MonitoringConfig)
|
||||||
|
adbc: ADBCConfig = field(default_factory=ADBCConfig)
|
||||||
|
|
||||||
# Custom configuration
|
# Custom configuration
|
||||||
custom_config: dict[str, Any] = field(default_factory=dict)
|
custom_config: dict[str, Any] = field(default_factory=dict)
|
||||||
@@ -180,6 +346,9 @@ class DorisConfig:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_env(cls, env_file: str | None = None) -> "DorisConfig":
|
def from_env(cls, env_file: str | None = None) -> "DorisConfig":
|
||||||
"""Load configuration from environment variables
|
"""Load configuration from environment variables
|
||||||
|
|
||||||
|
The kv pairs in the. env file will be loaded as environment variables,
|
||||||
|
but the existing environment variables will not be overridden.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env_file: .env file path, if None, search in the following order:
|
env_file: .env file path, if None, search in the following order:
|
||||||
@@ -199,7 +368,7 @@ class DorisConfig:
|
|||||||
env_files = [".env", ".env.local", ".env.production", ".env.development"]
|
env_files = [".env", ".env.local", ".env.production", ".env.development"]
|
||||||
for env_path in env_files:
|
for env_path in env_files:
|
||||||
if Path(env_path).exists():
|
if Path(env_path).exists():
|
||||||
load_dotenv(env_path)
|
load_dotenv(env_path, override=False)
|
||||||
logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_path}")
|
logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_path}")
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@@ -209,17 +378,45 @@ class DorisConfig:
|
|||||||
|
|
||||||
config = cls()
|
config = cls()
|
||||||
|
|
||||||
# Database configuration
|
# Database configuration - handle empty strings properly
|
||||||
config.database.host = os.getenv("DORIS_HOST", config.database.host)
|
doris_host = os.getenv("DORIS_HOST", "").strip()
|
||||||
config.database.port = int(os.getenv("DORIS_PORT", str(config.database.port)))
|
config.database.host = doris_host if doris_host else config.database.host
|
||||||
config.database.user = os.getenv("DORIS_USER", config.database.user)
|
|
||||||
config.database.password = os.getenv("DORIS_PASSWORD", config.database.password)
|
doris_port = os.getenv("DORIS_PORT", "").strip()
|
||||||
config.database.database = os.getenv("DORIS_DATABASE", config.database.database)
|
if doris_port and doris_port.isdigit():
|
||||||
|
config.database.port = int(doris_port)
|
||||||
|
|
||||||
|
doris_user = os.getenv("DORIS_USER", "").strip()
|
||||||
|
config.database.user = doris_user if doris_user else config.database.user
|
||||||
|
|
||||||
|
doris_password = os.getenv("DORIS_PASSWORD", "")
|
||||||
|
config.database.password = doris_password if doris_password else config.database.password
|
||||||
|
|
||||||
|
doris_database = os.getenv("DORIS_DATABASE", "").strip()
|
||||||
|
config.database.database = doris_database if doris_database else config.database.database
|
||||||
|
|
||||||
|
doris_fe_http_port = os.getenv("DORIS_FE_HTTP_PORT", "").strip()
|
||||||
|
if doris_fe_http_port and doris_fe_http_port.isdigit():
|
||||||
|
config.database.fe_http_port = int(doris_fe_http_port)
|
||||||
|
|
||||||
|
# BE nodes configuration
|
||||||
|
be_hosts_env = os.getenv("DORIS_BE_HOSTS", "")
|
||||||
|
if be_hosts_env:
|
||||||
|
config.database.be_hosts = [host.strip() for host in be_hosts_env.split(",") if host.strip()]
|
||||||
|
be_webserver_port = os.getenv("DORIS_BE_WEBSERVER_PORT", "").strip()
|
||||||
|
if be_webserver_port and be_webserver_port.isdigit():
|
||||||
|
config.database.be_webserver_port = int(be_webserver_port)
|
||||||
|
|
||||||
|
# Arrow Flight SQL Configuration
|
||||||
|
fe_arrow_port_env = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
|
||||||
|
if fe_arrow_port_env:
|
||||||
|
config.database.fe_arrow_flight_sql_port = int(fe_arrow_port_env)
|
||||||
|
|
||||||
|
be_arrow_port_env = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
|
||||||
|
if be_arrow_port_env:
|
||||||
|
config.database.be_arrow_flight_sql_port = int(be_arrow_port_env)
|
||||||
|
|
||||||
# Connection pool configuration
|
# Connection pool configuration
|
||||||
config.database.min_connections = int(
|
|
||||||
os.getenv("DORIS_MIN_CONNECTIONS", str(config.database.min_connections))
|
|
||||||
)
|
|
||||||
config.database.max_connections = int(
|
config.database.max_connections = int(
|
||||||
os.getenv("DORIS_MAX_CONNECTIONS", str(config.database.max_connections))
|
os.getenv("DORIS_MAX_CONNECTIONS", str(config.database.max_connections))
|
||||||
)
|
)
|
||||||
@@ -234,6 +431,10 @@ class DorisConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Security configuration
|
# Security configuration
|
||||||
|
# Independent authentication switches
|
||||||
|
config.security.enable_token_auth = os.getenv("ENABLE_TOKEN_AUTH", str(config.security.enable_token_auth)).lower() == "true"
|
||||||
|
config.security.enable_jwt_auth = os.getenv("ENABLE_JWT_AUTH", str(config.security.enable_jwt_auth)).lower() == "true"
|
||||||
|
config.security.enable_oauth_auth = os.getenv("ENABLE_OAUTH_AUTH", str(config.security.enable_oauth_auth)).lower() == "true"
|
||||||
config.security.auth_type = os.getenv("AUTH_TYPE", config.security.auth_type)
|
config.security.auth_type = os.getenv("AUTH_TYPE", config.security.auth_type)
|
||||||
config.security.token_secret = os.getenv("TOKEN_SECRET", config.security.token_secret)
|
config.security.token_secret = os.getenv("TOKEN_SECRET", config.security.token_secret)
|
||||||
config.security.token_expiry = int(
|
config.security.token_expiry = int(
|
||||||
@@ -245,9 +446,50 @@ class DorisConfig:
|
|||||||
config.security.max_query_complexity = int(
|
config.security.max_query_complexity = int(
|
||||||
os.getenv("MAX_QUERY_COMPLEXITY", str(config.security.max_query_complexity))
|
os.getenv("MAX_QUERY_COMPLEXITY", str(config.security.max_query_complexity))
|
||||||
)
|
)
|
||||||
|
config.security.enable_security_check = (
|
||||||
|
os.getenv("ENABLE_SECURITY_CHECK", str(config.security.enable_security_check).lower()).lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle blocked keywords environment variable configuration
|
||||||
|
# Format: BLOCKED_KEYWORDS="DROP,DELETE,TRUNCATE,ALTER,CREATE,INSERT,UPDATE,GRANT,REVOKE"
|
||||||
|
blocked_keywords_env = os.getenv("BLOCKED_KEYWORDS", "")
|
||||||
|
if blocked_keywords_env:
|
||||||
|
# If environment variable is provided, use keywords list from environment variable
|
||||||
|
config.security.blocked_keywords = [
|
||||||
|
keyword.strip().upper()
|
||||||
|
for keyword in blocked_keywords_env.split(",")
|
||||||
|
if keyword.strip()
|
||||||
|
]
|
||||||
|
# If environment variable is empty, keep default configuration unchanged
|
||||||
|
|
||||||
config.security.enable_masking = (
|
config.security.enable_masking = (
|
||||||
os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true"
|
os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Enhanced Token Authentication configuration
|
||||||
|
config.security.token_file_path = os.getenv("TOKEN_FILE_PATH", config.security.token_file_path)
|
||||||
|
config.security.enable_token_expiry = (
|
||||||
|
os.getenv("ENABLE_TOKEN_EXPIRY", str(config.security.enable_token_expiry).lower()).lower() == "true"
|
||||||
|
)
|
||||||
|
config.security.default_token_expiry_hours = int(
|
||||||
|
os.getenv("DEFAULT_TOKEN_EXPIRY_HOURS", str(config.security.default_token_expiry_hours))
|
||||||
|
)
|
||||||
|
config.security.token_hash_algorithm = os.getenv("TOKEN_HASH_ALGORITHM", config.security.token_hash_algorithm)
|
||||||
|
|
||||||
|
# Token Management Security Configuration (New in v0.6.0)
|
||||||
|
config.security.enable_http_token_management = (
|
||||||
|
os.getenv("ENABLE_HTTP_TOKEN_MANAGEMENT", str(config.security.enable_http_token_management).lower()).lower() == "true"
|
||||||
|
)
|
||||||
|
config.security.token_management_admin_token = os.getenv("TOKEN_MANAGEMENT_ADMIN_TOKEN", config.security.token_management_admin_token)
|
||||||
|
|
||||||
|
# Parse allowed IPs from comma-separated string
|
||||||
|
allowed_ips_str = os.getenv("TOKEN_MANAGEMENT_ALLOWED_IPS", "")
|
||||||
|
if allowed_ips_str:
|
||||||
|
config.security.token_management_allowed_ips = [ip.strip() for ip in allowed_ips_str.split(",") if ip.strip()]
|
||||||
|
|
||||||
|
config.security.require_admin_auth = (
|
||||||
|
os.getenv("REQUIRE_ADMIN_AUTH", str(config.security.require_admin_auth).lower()).lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
# Performance configuration
|
# Performance configuration
|
||||||
config.performance.enable_query_cache = (
|
config.performance.enable_query_cache = (
|
||||||
@@ -265,6 +507,9 @@ class DorisConfig:
|
|||||||
config.performance.query_timeout = int(
|
config.performance.query_timeout = int(
|
||||||
os.getenv("QUERY_TIMEOUT", str(config.performance.query_timeout))
|
os.getenv("QUERY_TIMEOUT", str(config.performance.query_timeout))
|
||||||
)
|
)
|
||||||
|
config.performance.max_response_content_size = int(
|
||||||
|
os.getenv("MAX_RESPONSE_CONTENT_SIZE", str(config.performance.max_response_content_size))
|
||||||
|
)
|
||||||
|
|
||||||
# Logging configuration
|
# Logging configuration
|
||||||
config.logging.level = os.getenv("LOG_LEVEL", config.logging.level)
|
config.logging.level = os.getenv("LOG_LEVEL", config.logging.level)
|
||||||
@@ -273,6 +518,15 @@ class DorisConfig:
|
|||||||
os.getenv("ENABLE_AUDIT", str(config.logging.enable_audit).lower()).lower() == "true"
|
os.getenv("ENABLE_AUDIT", str(config.logging.enable_audit).lower()).lower() == "true"
|
||||||
)
|
)
|
||||||
config.logging.audit_file_path = os.getenv("AUDIT_FILE_PATH", config.logging.audit_file_path)
|
config.logging.audit_file_path = os.getenv("AUDIT_FILE_PATH", config.logging.audit_file_path)
|
||||||
|
config.logging.enable_cleanup = (
|
||||||
|
os.getenv("ENABLE_LOG_CLEANUP", str(config.logging.enable_cleanup).lower()).lower() == "true"
|
||||||
|
)
|
||||||
|
config.logging.max_age_days = int(
|
||||||
|
os.getenv("LOG_MAX_AGE_DAYS", str(config.logging.max_age_days))
|
||||||
|
)
|
||||||
|
config.logging.cleanup_interval_hours = int(
|
||||||
|
os.getenv("LOG_CLEANUP_INTERVAL_HOURS", str(config.logging.cleanup_interval_hours))
|
||||||
|
)
|
||||||
|
|
||||||
# Monitoring configuration
|
# Monitoring configuration
|
||||||
config.monitoring.enable_metrics = (
|
config.monitoring.enable_metrics = (
|
||||||
@@ -289,10 +543,60 @@ class DorisConfig:
|
|||||||
)
|
)
|
||||||
config.monitoring.alert_webhook_url = os.getenv("ALERT_WEBHOOK_URL", config.monitoring.alert_webhook_url)
|
config.monitoring.alert_webhook_url = os.getenv("ALERT_WEBHOOK_URL", config.monitoring.alert_webhook_url)
|
||||||
|
|
||||||
|
# ADBC configuration
|
||||||
|
config.adbc.default_max_rows = int(
|
||||||
|
os.getenv("ADBC_DEFAULT_MAX_ROWS", str(config.adbc.default_max_rows))
|
||||||
|
)
|
||||||
|
config.adbc.default_timeout = int(
|
||||||
|
os.getenv("ADBC_DEFAULT_TIMEOUT", str(config.adbc.default_timeout))
|
||||||
|
)
|
||||||
|
config.adbc.default_return_format = os.getenv("ADBC_DEFAULT_RETURN_FORMAT", config.adbc.default_return_format)
|
||||||
|
config.adbc.connection_timeout = int(
|
||||||
|
os.getenv("ADBC_CONNECTION_TIMEOUT", str(config.adbc.connection_timeout))
|
||||||
|
)
|
||||||
|
config.adbc.enabled = (
|
||||||
|
os.getenv("ADBC_ENABLED", str(config.adbc.enabled).lower()).lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Data quality configuration
|
||||||
|
config.data_quality.max_columns_per_batch = int(
|
||||||
|
os.getenv("DATA_QUALITY_MAX_COLUMNS_PER_BATCH", str(config.data_quality.max_columns_per_batch))
|
||||||
|
)
|
||||||
|
config.data_quality.default_sample_size = int(
|
||||||
|
os.getenv("DATA_QUALITY_DEFAULT_SAMPLE_SIZE", str(config.data_quality.default_sample_size))
|
||||||
|
)
|
||||||
|
config.data_quality.small_table_threshold = int(
|
||||||
|
os.getenv("DATA_QUALITY_SMALL_TABLE_THRESHOLD", str(config.data_quality.small_table_threshold))
|
||||||
|
)
|
||||||
|
config.data_quality.medium_table_threshold = int(
|
||||||
|
os.getenv("DATA_QUALITY_MEDIUM_TABLE_THRESHOLD", str(config.data_quality.medium_table_threshold))
|
||||||
|
)
|
||||||
|
config.data_quality.enable_batch_analysis = (
|
||||||
|
os.getenv("DATA_QUALITY_ENABLE_BATCH_ANALYSIS", str(config.data_quality.enable_batch_analysis).lower()).lower() == "true"
|
||||||
|
)
|
||||||
|
config.data_quality.batch_timeout = int(
|
||||||
|
os.getenv("DATA_QUALITY_BATCH_TIMEOUT", str(config.data_quality.batch_timeout))
|
||||||
|
)
|
||||||
|
config.data_quality.enable_fast_mode = (
|
||||||
|
os.getenv("DATA_QUALITY_ENABLE_FAST_MODE", str(config.data_quality.enable_fast_mode).lower()).lower() == "true"
|
||||||
|
)
|
||||||
|
config.data_quality.fast_mode_sample_size = int(
|
||||||
|
os.getenv("DATA_QUALITY_FAST_MODE_SAMPLE_SIZE", str(config.data_quality.fast_mode_sample_size))
|
||||||
|
)
|
||||||
|
config.data_quality.enable_distribution_analysis = (
|
||||||
|
os.getenv("DATA_QUALITY_ENABLE_DISTRIBUTION_ANALYSIS", str(config.data_quality.enable_distribution_analysis).lower()).lower() == "true"
|
||||||
|
)
|
||||||
|
config.data_quality.histogram_bins = int(
|
||||||
|
os.getenv("DATA_QUALITY_HISTOGRAM_BINS", str(config.data_quality.histogram_bins))
|
||||||
|
)
|
||||||
|
|
||||||
# Server configuration
|
# Server configuration
|
||||||
config.server_name = os.getenv("SERVER_NAME", config.server_name)
|
config.server_name = os.getenv("SERVER_NAME", config.server_name)
|
||||||
config.server_version = os.getenv("SERVER_VERSION", config.server_version)
|
config.server_version = os.getenv("SERVER_VERSION", config.server_version)
|
||||||
config.server_port = int(os.getenv("SERVER_PORT", str(config.server_port)))
|
server_port = os.getenv("SERVER_PORT", "").strip()
|
||||||
|
if server_port and server_port.isdigit():
|
||||||
|
config.server_port = int(server_port)
|
||||||
|
config.temp_files_dir = os.getenv("TEMP_FILES_DIR", config.temp_files_dir)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@@ -302,7 +606,7 @@ class DorisConfig:
|
|||||||
config = cls()
|
config = cls()
|
||||||
|
|
||||||
# Update basic configuration
|
# Update basic configuration
|
||||||
for key in ["server_name", "server_version", "server_port"]:
|
for key in ["server_name", "server_version", "server_port", "temp_files_dir"]:
|
||||||
if key in config_data:
|
if key in config_data:
|
||||||
setattr(config, key, config_data[key])
|
setattr(config, key, config_data[key])
|
||||||
|
|
||||||
@@ -327,6 +631,13 @@ class DorisConfig:
|
|||||||
if hasattr(config.performance, key):
|
if hasattr(config.performance, key):
|
||||||
setattr(config.performance, key, value)
|
setattr(config.performance, key, value)
|
||||||
|
|
||||||
|
# Update data quality configuration
|
||||||
|
if "data_quality" in config_data:
|
||||||
|
dq_config = config_data["data_quality"]
|
||||||
|
for key, value in dq_config.items():
|
||||||
|
if hasattr(config.data_quality, key):
|
||||||
|
setattr(config.data_quality, key, value)
|
||||||
|
|
||||||
# Update logging configuration
|
# Update logging configuration
|
||||||
if "logging" in config_data:
|
if "logging" in config_data:
|
||||||
log_config = config_data["logging"]
|
log_config = config_data["logging"]
|
||||||
@@ -341,6 +652,13 @@ class DorisConfig:
|
|||||||
if hasattr(config.monitoring, key):
|
if hasattr(config.monitoring, key):
|
||||||
setattr(config.monitoring, key, value)
|
setattr(config.monitoring, key, value)
|
||||||
|
|
||||||
|
# Update ADBC configuration
|
||||||
|
if "adbc" in config_data:
|
||||||
|
adbc_config = config_data["adbc"]
|
||||||
|
for key, value in adbc_config.items():
|
||||||
|
if hasattr(config.adbc, key):
|
||||||
|
setattr(config.adbc, key, value)
|
||||||
|
|
||||||
# Custom configuration
|
# Custom configuration
|
||||||
config.custom_config = config_data.get("custom", {})
|
config.custom_config = config_data.get("custom", {})
|
||||||
|
|
||||||
@@ -352,6 +670,7 @@ class DorisConfig:
|
|||||||
"server_name": self.server_name,
|
"server_name": self.server_name,
|
||||||
"server_version": self.server_version,
|
"server_version": self.server_version,
|
||||||
"server_port": self.server_port,
|
"server_port": self.server_port,
|
||||||
|
"temp_files_dir": self.temp_files_dir,
|
||||||
"database": {
|
"database": {
|
||||||
"host": self.database.host,
|
"host": self.database.host,
|
||||||
"port": self.database.port,
|
"port": self.database.port,
|
||||||
@@ -359,7 +678,12 @@ class DorisConfig:
|
|||||||
"password": "***", # Hide password
|
"password": "***", # Hide password
|
||||||
"database": self.database.database,
|
"database": self.database.database,
|
||||||
"charset": self.database.charset,
|
"charset": self.database.charset,
|
||||||
"min_connections": self.database.min_connections,
|
"fe_http_port": self.database.fe_http_port,
|
||||||
|
"be_hosts": self.database.be_hosts,
|
||||||
|
"be_webserver_port": self.database.be_webserver_port,
|
||||||
|
"fe_arrow_flight_sql_port": self.database.fe_arrow_flight_sql_port,
|
||||||
|
"be_arrow_flight_sql_port": self.database.be_arrow_flight_sql_port,
|
||||||
|
"min_connections": self.database.min_connections, # Always 0, shown for reference
|
||||||
"max_connections": self.database.max_connections,
|
"max_connections": self.database.max_connections,
|
||||||
"connection_timeout": self.database.connection_timeout,
|
"connection_timeout": self.database.connection_timeout,
|
||||||
"health_check_interval": self.database.health_check_interval,
|
"health_check_interval": self.database.health_check_interval,
|
||||||
@@ -369,6 +693,7 @@ class DorisConfig:
|
|||||||
"auth_type": self.security.auth_type,
|
"auth_type": self.security.auth_type,
|
||||||
"token_secret": "***", # Hide secret key
|
"token_secret": "***", # Hide secret key
|
||||||
"token_expiry": self.security.token_expiry,
|
"token_expiry": self.security.token_expiry,
|
||||||
|
"enable_security_check": self.security.enable_security_check,
|
||||||
"blocked_keywords": self.security.blocked_keywords,
|
"blocked_keywords": self.security.blocked_keywords,
|
||||||
"max_query_complexity": self.security.max_query_complexity,
|
"max_query_complexity": self.security.max_query_complexity,
|
||||||
"max_result_rows": self.security.max_result_rows,
|
"max_result_rows": self.security.max_result_rows,
|
||||||
@@ -384,6 +709,20 @@ class DorisConfig:
|
|||||||
"query_timeout": self.performance.query_timeout,
|
"query_timeout": self.performance.query_timeout,
|
||||||
"connection_pool_size": self.performance.connection_pool_size,
|
"connection_pool_size": self.performance.connection_pool_size,
|
||||||
"idle_timeout": self.performance.idle_timeout,
|
"idle_timeout": self.performance.idle_timeout,
|
||||||
|
"max_response_content_size": self.performance.max_response_content_size,
|
||||||
|
},
|
||||||
|
"data_quality": {
|
||||||
|
"max_columns_per_batch": self.data_quality.max_columns_per_batch,
|
||||||
|
"default_sample_size": self.data_quality.default_sample_size,
|
||||||
|
"small_table_threshold": self.data_quality.small_table_threshold,
|
||||||
|
"medium_table_threshold": self.data_quality.medium_table_threshold,
|
||||||
|
"enable_batch_analysis": self.data_quality.enable_batch_analysis,
|
||||||
|
"batch_timeout": self.data_quality.batch_timeout,
|
||||||
|
"enable_fast_mode": self.data_quality.enable_fast_mode,
|
||||||
|
"fast_mode_sample_size": self.data_quality.fast_mode_sample_size,
|
||||||
|
"enable_distribution_analysis": self.data_quality.enable_distribution_analysis,
|
||||||
|
"histogram_bins": self.data_quality.histogram_bins,
|
||||||
|
"percentile_levels": self.data_quality.percentile_levels,
|
||||||
},
|
},
|
||||||
"logging": {
|
"logging": {
|
||||||
"level": self.logging.level,
|
"level": self.logging.level,
|
||||||
@@ -393,6 +732,9 @@ class DorisConfig:
|
|||||||
"backup_count": self.logging.backup_count,
|
"backup_count": self.logging.backup_count,
|
||||||
"enable_audit": self.logging.enable_audit,
|
"enable_audit": self.logging.enable_audit,
|
||||||
"audit_file_path": self.logging.audit_file_path,
|
"audit_file_path": self.logging.audit_file_path,
|
||||||
|
"enable_cleanup": self.logging.enable_cleanup,
|
||||||
|
"max_age_days": self.logging.max_age_days,
|
||||||
|
"cleanup_interval_hours": self.logging.cleanup_interval_hours,
|
||||||
},
|
},
|
||||||
"monitoring": {
|
"monitoring": {
|
||||||
"enable_metrics": self.monitoring.enable_metrics,
|
"enable_metrics": self.monitoring.enable_metrics,
|
||||||
@@ -403,6 +745,13 @@ class DorisConfig:
|
|||||||
"enable_alerts": self.monitoring.enable_alerts,
|
"enable_alerts": self.monitoring.enable_alerts,
|
||||||
"alert_webhook_url": self.monitoring.alert_webhook_url,
|
"alert_webhook_url": self.monitoring.alert_webhook_url,
|
||||||
},
|
},
|
||||||
|
"adbc": {
|
||||||
|
"default_max_rows": self.adbc.default_max_rows,
|
||||||
|
"default_timeout": self.adbc.default_timeout,
|
||||||
|
"default_return_format": self.adbc.default_return_format,
|
||||||
|
"connection_timeout": self.adbc.connection_timeout,
|
||||||
|
"enabled": self.adbc.enabled,
|
||||||
|
},
|
||||||
"custom": self.custom_config,
|
"custom": self.custom_config,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -435,11 +784,8 @@ class DorisConfig:
|
|||||||
if not self.database.user:
|
if not self.database.user:
|
||||||
errors.append("Database username cannot be empty")
|
errors.append("Database username cannot be empty")
|
||||||
|
|
||||||
if self.database.min_connections <= 0:
|
if self.database.max_connections <= 0:
|
||||||
errors.append("Minimum connections must be greater than 0")
|
errors.append("Maximum connections must be greater than 0")
|
||||||
|
|
||||||
if self.database.max_connections <= self.database.min_connections:
|
|
||||||
errors.append("Maximum connections must be greater than minimum connections")
|
|
||||||
|
|
||||||
# Validate security configuration
|
# Validate security configuration
|
||||||
if self.security.auth_type not in ["token", "basic", "oauth"]:
|
if self.security.auth_type not in ["token", "basic", "oauth"]:
|
||||||
@@ -464,6 +810,31 @@ class DorisConfig:
|
|||||||
if self.performance.query_timeout <= 0:
|
if self.performance.query_timeout <= 0:
|
||||||
errors.append("Query timeout must be greater than 0")
|
errors.append("Query timeout must be greater than 0")
|
||||||
|
|
||||||
|
# Validate data quality configuration
|
||||||
|
if self.data_quality.max_columns_per_batch <= 0:
|
||||||
|
errors.append("Max columns per batch must be greater than 0")
|
||||||
|
|
||||||
|
if self.data_quality.default_sample_size <= 0:
|
||||||
|
errors.append("Default sample size must be greater than 0")
|
||||||
|
|
||||||
|
if self.data_quality.small_table_threshold <= 0:
|
||||||
|
errors.append("Small table threshold must be greater than 0")
|
||||||
|
|
||||||
|
if self.data_quality.medium_table_threshold <= 0:
|
||||||
|
errors.append("Medium table threshold must be greater than 0")
|
||||||
|
|
||||||
|
if self.data_quality.small_table_threshold >= self.data_quality.medium_table_threshold:
|
||||||
|
errors.append("Small table threshold must be less than medium table threshold")
|
||||||
|
|
||||||
|
if self.data_quality.batch_timeout <= 0:
|
||||||
|
errors.append("Batch timeout must be greater than 0")
|
||||||
|
|
||||||
|
if self.data_quality.fast_mode_sample_size <= 0:
|
||||||
|
errors.append("Fast mode sample size must be greater than 0")
|
||||||
|
|
||||||
|
if self.data_quality.histogram_bins <= 0:
|
||||||
|
errors.append("Histogram bins must be greater than 0")
|
||||||
|
|
||||||
# Validate logging configuration
|
# Validate logging configuration
|
||||||
if self.logging.level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
|
if self.logging.level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
|
||||||
errors.append("Log level must be one of DEBUG, INFO, WARNING, ERROR, or CRITICAL")
|
errors.append("Log level must be one of DEBUG, INFO, WARNING, ERROR, or CRITICAL")
|
||||||
@@ -473,6 +844,12 @@ class DorisConfig:
|
|||||||
|
|
||||||
if self.logging.backup_count < 0:
|
if self.logging.backup_count < 0:
|
||||||
errors.append("Log backup count cannot be negative")
|
errors.append("Log backup count cannot be negative")
|
||||||
|
|
||||||
|
if self.logging.max_age_days <= 0:
|
||||||
|
errors.append("Log max age days must be greater than 0")
|
||||||
|
|
||||||
|
if self.logging.cleanup_interval_hours <= 0:
|
||||||
|
errors.append("Log cleanup interval hours must be greater than 0")
|
||||||
|
|
||||||
# Validate monitoring configuration
|
# Validate monitoring configuration
|
||||||
if not (1 <= self.monitoring.metrics_port <= 65535):
|
if not (1 <= self.monitoring.metrics_port <= 65535):
|
||||||
@@ -481,6 +858,19 @@ class DorisConfig:
|
|||||||
if not (1 <= self.monitoring.health_check_port <= 65535):
|
if not (1 <= self.monitoring.health_check_port <= 65535):
|
||||||
errors.append("Health check port must be in the range 1-65535")
|
errors.append("Health check port must be in the range 1-65535")
|
||||||
|
|
||||||
|
# Validate ADBC configuration
|
||||||
|
if self.adbc.default_max_rows <= 0:
|
||||||
|
errors.append("ADBC default max rows must be greater than 0")
|
||||||
|
|
||||||
|
if self.adbc.default_timeout <= 0:
|
||||||
|
errors.append("ADBC default timeout must be greater than 0")
|
||||||
|
|
||||||
|
if self.adbc.default_return_format not in ["arrow", "pandas", "dict"]:
|
||||||
|
errors.append("ADBC default return format must be one of arrow, pandas, or dict")
|
||||||
|
|
||||||
|
if self.adbc.connection_timeout <= 0:
|
||||||
|
errors.append("ADBC connection timeout must be greater than 0")
|
||||||
|
|
||||||
return errors
|
return errors
|
||||||
|
|
||||||
def get_connection_string(self) -> str:
|
def get_connection_string(self) -> str:
|
||||||
@@ -492,7 +882,7 @@ class DorisConfig:
|
|||||||
return {
|
return {
|
||||||
"server": f"{self.server_name} v{self.server_version}",
|
"server": f"{self.server_name} v{self.server_version}",
|
||||||
"database": f"{self.database.host}:{self.database.port}/{self.database.database}",
|
"database": f"{self.database.host}:{self.database.port}/{self.database.database}",
|
||||||
"connection_pool": f"{self.database.min_connections}-{self.database.max_connections}",
|
"connection_pool": f"0-{self.database.max_connections} (min fixed at 0 for stability)",
|
||||||
"security": {
|
"security": {
|
||||||
"auth_type": self.security.auth_type,
|
"auth_type": self.security.auth_type,
|
||||||
"masking_enabled": self.security.enable_masking,
|
"masking_enabled": self.security.enable_masking,
|
||||||
@@ -518,56 +908,50 @@ class ConfigManager:
|
|||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def setup_logging(self):
|
def setup_logging(self):
|
||||||
"""Setup logging configuration"""
|
"""Setup logging configuration using enhanced logger"""
|
||||||
# Configure root logger
|
from .logger import setup_logging, get_logger
|
||||||
root_logger = logging.getLogger()
|
import sys
|
||||||
root_logger.setLevel(getattr(logging, self.config.logging.level.upper()))
|
|
||||||
|
# Determine log directory
|
||||||
# Clear existing handlers
|
log_dir = "logs"
|
||||||
for handler in root_logger.handlers[:]:
|
|
||||||
root_logger.removeHandler(handler)
|
|
||||||
|
|
||||||
# Create formatter
|
|
||||||
formatter = logging.Formatter(self.config.logging.format)
|
|
||||||
|
|
||||||
# Console handler
|
|
||||||
console_handler = logging.StreamHandler()
|
|
||||||
console_handler.setFormatter(formatter)
|
|
||||||
root_logger.addHandler(console_handler)
|
|
||||||
|
|
||||||
# File handler (if configured)
|
|
||||||
if self.config.logging.file_path:
|
if self.config.logging.file_path:
|
||||||
try:
|
# Extract directory from file path if provided
|
||||||
from logging.handlers import RotatingFileHandler
|
from pathlib import Path
|
||||||
|
log_dir = str(Path(self.config.logging.file_path).parent)
|
||||||
file_handler = RotatingFileHandler(
|
|
||||||
self.config.logging.file_path,
|
# Detect if we're in stdio mode by checking if this is likely MCP stdio communication
|
||||||
maxBytes=self.config.logging.max_file_size,
|
# In stdio mode, we shouldn't output to console as it interferes with JSON protocol
|
||||||
backupCount=self.config.logging.backup_count,
|
is_stdio_mode = (
|
||||||
encoding="utf-8",
|
self.config.transport == "stdio" or
|
||||||
)
|
"--transport" in sys.argv and "stdio" in sys.argv or
|
||||||
file_handler.setFormatter(formatter)
|
not sys.stdout.isatty() # Not a terminal (likely piped/redirected)
|
||||||
root_logger.addHandler(file_handler)
|
)
|
||||||
except Exception as e:
|
|
||||||
self.logger.warning(f"Failed to setup file logging: {e}")
|
# Setup enhanced logging with cleanup functionality
|
||||||
|
setup_logging(
|
||||||
# Audit log handler (if configured)
|
level=self.config.logging.level,
|
||||||
if self.config.logging.enable_audit and self.config.logging.audit_file_path:
|
log_dir=log_dir,
|
||||||
try:
|
enable_console=not is_stdio_mode, # Disable console logging in stdio mode
|
||||||
from logging.handlers import RotatingFileHandler
|
enable_file=True,
|
||||||
|
enable_audit=self.config.logging.enable_audit,
|
||||||
audit_logger = logging.getLogger("audit")
|
audit_file=self.config.logging.audit_file_path,
|
||||||
audit_handler = RotatingFileHandler(
|
max_file_size=self.config.logging.max_file_size,
|
||||||
self.config.logging.audit_file_path,
|
backup_count=self.config.logging.backup_count,
|
||||||
maxBytes=self.config.logging.max_file_size,
|
enable_cleanup=self.config.logging.enable_cleanup,
|
||||||
backupCount=self.config.logging.backup_count,
|
max_age_days=self.config.logging.max_age_days,
|
||||||
encoding="utf-8",
|
cleanup_interval_hours=self.config.logging.cleanup_interval_hours
|
||||||
)
|
)
|
||||||
audit_handler.setFormatter(formatter)
|
|
||||||
audit_logger.addHandler(audit_handler)
|
# Update logger to use new system
|
||||||
audit_logger.setLevel(logging.INFO)
|
self.logger = get_logger(__name__)
|
||||||
except Exception as e:
|
|
||||||
self.logger.warning(f"Failed to setup audit logging: {e}")
|
self.logger.info("Enhanced logging system with cleanup initialized successfully")
|
||||||
|
self.logger.info(f"Log directory: {log_dir}")
|
||||||
|
self.logger.info(f"Log level: {self.config.logging.level}")
|
||||||
|
self.logger.info(f"Audit logging: {'Enabled' if self.config.logging.enable_audit else 'Disabled'}")
|
||||||
|
self.logger.info(f"Log cleanup: {'Enabled' if self.config.logging.enable_cleanup else 'Disabled'}")
|
||||||
|
if self.config.logging.enable_cleanup:
|
||||||
|
self.logger.info(f"Cleanup config: Max age {self.config.logging.max_age_days} days, interval {self.config.logging.cleanup_interval_hours}h")
|
||||||
|
|
||||||
def validate_config(self) -> bool:
|
def validate_config(self) -> bool:
|
||||||
"""Validate configuration"""
|
"""Validate configuration"""
|
||||||
|
|||||||
771
doris_mcp_server/utils/data_exploration_tools.py
Normal file
@@ -0,0 +1,771 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
Data Exploration Tools Module
|
||||||
|
Provides table data distribution analysis and exploration capabilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import math
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from .db import DorisConnectionManager
|
||||||
|
from .logger import get_logger
|
||||||
|
from .sql_security_utils import (
|
||||||
|
SQLSecurityError,
|
||||||
|
validate_identifier,
|
||||||
|
quote_identifier,
|
||||||
|
build_table_reference,
|
||||||
|
get_auth_context
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DataExplorationTools:
|
||||||
|
"""Data exploration tools for table distribution analysis"""
|
||||||
|
|
||||||
|
def __init__(self, connection_manager: DorisConnectionManager):
|
||||||
|
self.connection_manager = connection_manager
|
||||||
|
logger.info("DataExplorationTools initialized")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Private Helper Methods ====================
|
||||||
|
|
||||||
|
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
|
||||||
|
"""Build full table name with catalog and database using three-part naming convention"""
|
||||||
|
# SECURITY FIX: Use build_table_reference for safe identifier handling
|
||||||
|
effective_catalog = catalog_name if catalog_name else "internal"
|
||||||
|
|
||||||
|
if db_name:
|
||||||
|
return build_table_reference(table_name, db_name, effective_catalog)
|
||||||
|
else:
|
||||||
|
return build_table_reference(table_name, catalog_name=effective_catalog)
|
||||||
|
|
||||||
|
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
|
||||||
|
"""Get basic table information including row count"""
|
||||||
|
try:
|
||||||
|
# SECURITY FIX: Get auth_context for security validation
|
||||||
|
# table_name should already be validated by _build_full_table_name
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
|
||||||
|
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
|
||||||
|
result = await connection.execute(count_sql, auth_context=auth_context)
|
||||||
|
|
||||||
|
if result.data:
|
||||||
|
return {"row_count": result.data[0]["row_count"]}
|
||||||
|
return None
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Security validation failed for table {table_name}: {str(e)}")
|
||||||
|
return {"row_count": 0}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
|
||||||
|
return {"row_count": 0}
|
||||||
|
|
||||||
|
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
|
||||||
|
"""Get detailed column information"""
|
||||||
|
try:
|
||||||
|
# SECURITY FIX: Validate identifiers and use parameterized query
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name, "table name")
|
||||||
|
if db_name:
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid identifier rejected: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Build parameterized query
|
||||||
|
params = [table_name]
|
||||||
|
where_conditions = ["table_name = %s"]
|
||||||
|
|
||||||
|
if db_name:
|
||||||
|
where_conditions.append("table_schema = %s")
|
||||||
|
params.append(db_name)
|
||||||
|
else:
|
||||||
|
where_conditions.append("table_schema = DATABASE()")
|
||||||
|
|
||||||
|
columns_sql = f"""
|
||||||
|
SELECT
|
||||||
|
column_name,
|
||||||
|
data_type,
|
||||||
|
is_nullable,
|
||||||
|
column_comment,
|
||||||
|
ordinal_position
|
||||||
|
FROM information_schema.columns
|
||||||
|
WHERE {' AND '.join(where_conditions)}
|
||||||
|
ORDER BY ordinal_position
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = await connection.execute(columns_sql, params=tuple(params), auth_context=auth_context)
|
||||||
|
return result.data if result.data else []
|
||||||
|
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Security validation failed: {str(e)}")
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _determine_sampling_strategy(self, connection, table_name: str, total_rows: int, sample_size: int) -> Dict[str, Any]:
|
||||||
|
"""Determine optimal sampling strategy based on table size"""
|
||||||
|
if total_rows <= sample_size:
|
||||||
|
# Use all data if table is small enough
|
||||||
|
return {
|
||||||
|
"total_rows": total_rows,
|
||||||
|
"sample_size": total_rows,
|
||||||
|
"sampling_method": "full_scan",
|
||||||
|
"sampling_ratio": 1.0,
|
||||||
|
"use_sampling": False,
|
||||||
|
"sample_table_expression": table_name
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Use random sampling for large tables
|
||||||
|
sampling_ratio = sample_size / total_rows
|
||||||
|
return {
|
||||||
|
"total_rows": total_rows,
|
||||||
|
"sample_size": sample_size,
|
||||||
|
"sampling_method": "random_sample",
|
||||||
|
"sampling_ratio": round(sampling_ratio, 4),
|
||||||
|
"use_sampling": True,
|
||||||
|
"sample_table_expression": f"(SELECT * FROM {table_name} ORDER BY RAND() LIMIT {sample_size}) as sample_table"
|
||||||
|
}
|
||||||
|
|
||||||
|
def _select_analysis_columns(self, columns_info: List[Dict], include_all: bool) -> List[Dict]:
|
||||||
|
"""Select columns for analysis based on strategy"""
|
||||||
|
if include_all:
|
||||||
|
return columns_info
|
||||||
|
|
||||||
|
# If not analyzing all columns, prioritize key columns
|
||||||
|
priority_keywords = ['id', 'key', 'code', 'status', 'type', 'amount', 'count', 'date', 'time']
|
||||||
|
|
||||||
|
priority_columns = []
|
||||||
|
other_columns = []
|
||||||
|
|
||||||
|
for col in columns_info:
|
||||||
|
col_name_lower = col["column_name"].lower()
|
||||||
|
if any(keyword in col_name_lower for keyword in priority_keywords):
|
||||||
|
priority_columns.append(col)
|
||||||
|
else:
|
||||||
|
other_columns.append(col)
|
||||||
|
|
||||||
|
# Return priority columns plus first 10 other columns
|
||||||
|
return priority_columns + other_columns[:10]
|
||||||
|
|
||||||
|
def _is_numeric_type(self, data_type: str) -> bool:
|
||||||
|
"""Check if column type is numeric"""
|
||||||
|
numeric_types = [
|
||||||
|
'tinyint', 'smallint', 'int', 'bigint', 'largeint',
|
||||||
|
'float', 'double', 'decimal', 'numeric'
|
||||||
|
]
|
||||||
|
return any(num_type in data_type.lower() for num_type in numeric_types)
|
||||||
|
|
||||||
|
def _is_categorical_type(self, data_type: str) -> bool:
|
||||||
|
"""Check if column type is categorical"""
|
||||||
|
categorical_types = ['varchar', 'char', 'string', 'text', 'enum']
|
||||||
|
return any(cat_type in data_type.lower() for cat_type in categorical_types)
|
||||||
|
|
||||||
|
def _is_temporal_type(self, data_type: str) -> bool:
|
||||||
|
"""Check if column type is temporal"""
|
||||||
|
temporal_types = ['date', 'datetime', 'timestamp', 'time']
|
||||||
|
return any(temp_type in data_type.lower() for temp_type in temporal_types)
|
||||||
|
|
||||||
|
async def _analyze_numeric_distributions(self, connection, table_name: str, numeric_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
|
||||||
|
"""Analyze distribution patterns for numeric columns"""
|
||||||
|
numeric_analysis = {}
|
||||||
|
|
||||||
|
for column in numeric_columns:
|
||||||
|
col_name = column["column_name"]
|
||||||
|
try:
|
||||||
|
# Basic statistics
|
||||||
|
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||||
|
stats_sql = f"""
|
||||||
|
SELECT
|
||||||
|
COUNT({col_name}) as count,
|
||||||
|
MIN({col_name}) as min_value,
|
||||||
|
MAX({col_name}) as max_value,
|
||||||
|
AVG({col_name}) as mean_value,
|
||||||
|
STDDEV({col_name}) as std_dev
|
||||||
|
FROM {table_expr}
|
||||||
|
WHERE {col_name} IS NOT NULL
|
||||||
|
"""
|
||||||
|
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
stats_result = await connection.execute(stats_sql, auth_context=auth_context)
|
||||||
|
|
||||||
|
if stats_result.data and stats_result.data[0]["count"] > 0:
|
||||||
|
stats = stats_result.data[0]
|
||||||
|
|
||||||
|
# Percentiles calculation
|
||||||
|
percentiles = await self._calculate_percentiles(connection, table_name, col_name, sampling_info)
|
||||||
|
|
||||||
|
# Outlier detection
|
||||||
|
outliers = await self._detect_numeric_outliers(connection, table_name, col_name, percentiles, sampling_info)
|
||||||
|
|
||||||
|
# Distribution shape analysis
|
||||||
|
distribution_shape = await self._analyze_distribution_shape(
|
||||||
|
connection, table_name, col_name, stats, percentiles, sampling_info
|
||||||
|
)
|
||||||
|
|
||||||
|
numeric_analysis[col_name] = {
|
||||||
|
"data_type": column["data_type"],
|
||||||
|
"statistics": {
|
||||||
|
"count": stats["count"],
|
||||||
|
"mean": round(float(stats["mean_value"]), 4) if stats["mean_value"] else None,
|
||||||
|
"std": round(float(stats["std_dev"]), 4) if stats["std_dev"] else None,
|
||||||
|
"min": float(stats["min_value"]) if stats["min_value"] else None,
|
||||||
|
"max": float(stats["max_value"]) if stats["max_value"] else None,
|
||||||
|
**percentiles
|
||||||
|
},
|
||||||
|
"distribution_shape": distribution_shape,
|
||||||
|
"outliers": outliers
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to analyze numeric column {col_name}: {str(e)}")
|
||||||
|
numeric_analysis[col_name] = {"error": str(e)}
|
||||||
|
|
||||||
|
return numeric_analysis
|
||||||
|
|
||||||
|
async def _calculate_percentiles(self, connection, table_name: str, col_name: str, sampling_info: Dict) -> Dict[str, float]:
|
||||||
|
"""Calculate percentiles for numeric column"""
|
||||||
|
try:
|
||||||
|
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||||
|
percentile_sql = f"""
|
||||||
|
SELECT
|
||||||
|
PERCENTILE({col_name}, 0.25) as p25,
|
||||||
|
PERCENTILE({col_name}, 0.50) as p50,
|
||||||
|
PERCENTILE({col_name}, 0.75) as p75,
|
||||||
|
PERCENTILE({col_name}, 0.90) as p90,
|
||||||
|
PERCENTILE({col_name}, 0.95) as p95,
|
||||||
|
PERCENTILE({col_name}, 0.99) as p99
|
||||||
|
FROM {table_expr}
|
||||||
|
WHERE {col_name} IS NOT NULL
|
||||||
|
"""
|
||||||
|
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(percentile_sql, auth_context=auth_context)
|
||||||
|
|
||||||
|
if result.data:
|
||||||
|
data = result.data[0]
|
||||||
|
return {
|
||||||
|
"25%": round(float(data["p25"]), 4) if data["p25"] else None,
|
||||||
|
"50%": round(float(data["p50"]), 4) if data["p50"] else None,
|
||||||
|
"75%": round(float(data["p75"]), 4) if data["p75"] else None,
|
||||||
|
"90%": round(float(data["p90"]), 4) if data["p90"] else None,
|
||||||
|
"95%": round(float(data["p95"]), 4) if data["p95"] else None,
|
||||||
|
"99%": round(float(data["p99"]), 4) if data["p99"] else None
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to calculate percentiles for {col_name}: {str(e)}")
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def _detect_numeric_outliers(self, connection, table_name: str, col_name: str, percentiles: Dict, sampling_info: Dict) -> Dict[str, Any]:
|
||||||
|
"""Detect outliers using IQR method"""
|
||||||
|
try:
|
||||||
|
if "25%" not in percentiles or "75%" not in percentiles:
|
||||||
|
return {"outlier_count": 0, "outlier_rate": 0.0}
|
||||||
|
|
||||||
|
q1 = percentiles["25%"]
|
||||||
|
q3 = percentiles["75%"]
|
||||||
|
iqr = q3 - q1
|
||||||
|
|
||||||
|
lower_bound = q1 - 1.5 * iqr
|
||||||
|
upper_bound = q3 + 1.5 * iqr
|
||||||
|
|
||||||
|
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||||
|
outlier_sql = f"""
|
||||||
|
SELECT
|
||||||
|
COUNT(*) as total_count,
|
||||||
|
SUM(CASE WHEN {col_name} < {lower_bound} OR {col_name} > {upper_bound} THEN 1 ELSE 0 END) as outlier_count
|
||||||
|
FROM {table_expr}
|
||||||
|
WHERE {col_name} IS NOT NULL
|
||||||
|
"""
|
||||||
|
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(outlier_sql, auth_context=auth_context)
|
||||||
|
|
||||||
|
if result.data:
|
||||||
|
data = result.data[0]
|
||||||
|
total_count = data["total_count"]
|
||||||
|
outlier_count = data["outlier_count"]
|
||||||
|
outlier_rate = outlier_count / total_count if total_count > 0 else 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"outlier_count": outlier_count,
|
||||||
|
"outlier_rate": round(outlier_rate, 4),
|
||||||
|
"outlier_threshold_lower": round(lower_bound, 4),
|
||||||
|
"outlier_threshold_upper": round(upper_bound, 4),
|
||||||
|
"iqr": round(iqr, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to detect outliers for {col_name}: {str(e)}")
|
||||||
|
|
||||||
|
return {"outlier_count": 0, "outlier_rate": 0.0}
|
||||||
|
|
||||||
|
async def _analyze_distribution_shape(self, connection, table_name: str, col_name: str, stats: Dict, percentiles: Dict, sampling_info: Dict) -> Dict[str, Any]:
|
||||||
|
"""Analyze the shape of data distribution"""
|
||||||
|
try:
|
||||||
|
mean = stats.get("mean_value", 0)
|
||||||
|
median = percentiles.get("50%", 0)
|
||||||
|
|
||||||
|
if mean is None or median is None:
|
||||||
|
return {"distribution_type": "unknown"}
|
||||||
|
|
||||||
|
# Calculate skewness indicator
|
||||||
|
if abs(mean - median) < 0.01:
|
||||||
|
skew_indicator = "symmetric"
|
||||||
|
elif mean > median:
|
||||||
|
skew_indicator = "right_skewed"
|
||||||
|
else:
|
||||||
|
skew_indicator = "left_skewed"
|
||||||
|
|
||||||
|
# Estimate kurtosis based on percentile spread
|
||||||
|
if "25%" in percentiles and "75%" in percentiles:
|
||||||
|
iqr = percentiles["75%"] - percentiles["25%"]
|
||||||
|
range_90 = percentiles.get("90%", percentiles["75%"]) - percentiles.get("10%", percentiles["25%"])
|
||||||
|
|
||||||
|
if iqr > 0:
|
||||||
|
kurtosis_indicator = "normal" if 2.5 <= range_90/iqr <= 3.5 else ("heavy_tailed" if range_90/iqr > 3.5 else "light_tailed")
|
||||||
|
else:
|
||||||
|
kurtosis_indicator = "unknown"
|
||||||
|
else:
|
||||||
|
kurtosis_indicator = "unknown"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"skewness_indicator": skew_indicator,
|
||||||
|
"kurtosis_indicator": kurtosis_indicator,
|
||||||
|
"distribution_type": self._classify_distribution_type(skew_indicator, kurtosis_indicator),
|
||||||
|
"mean_median_ratio": round(mean / median, 4) if median != 0 else None
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to analyze distribution shape for {col_name}: {str(e)}")
|
||||||
|
return {"distribution_type": "unknown"}
|
||||||
|
|
||||||
|
def _classify_distribution_type(self, skew: str, kurtosis: str) -> str:
|
||||||
|
"""Classify distribution type based on skewness and kurtosis"""
|
||||||
|
if skew == "symmetric" and kurtosis == "normal":
|
||||||
|
return "approximately_normal"
|
||||||
|
elif skew == "right_skewed":
|
||||||
|
return "right_skewed"
|
||||||
|
elif skew == "left_skewed":
|
||||||
|
return "left_skewed"
|
||||||
|
elif kurtosis == "heavy_tailed":
|
||||||
|
return "heavy_tailed"
|
||||||
|
else:
|
||||||
|
return "non_normal"
|
||||||
|
|
||||||
|
async def _analyze_categorical_distributions(self, connection, table_name: str, categorical_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
|
||||||
|
"""Analyze distribution patterns for categorical columns"""
|
||||||
|
categorical_analysis = {}
|
||||||
|
|
||||||
|
for column in categorical_columns:
|
||||||
|
col_name = column["column_name"]
|
||||||
|
try:
|
||||||
|
# Basic cardinality and distribution
|
||||||
|
cardinality_sql = f"""
|
||||||
|
SELECT
|
||||||
|
COUNT(DISTINCT {col_name}) as cardinality,
|
||||||
|
COUNT({col_name}) as non_null_count
|
||||||
|
FROM {table_name}
|
||||||
|
WHERE {col_name} IS NOT NULL
|
||||||
|
{sampling_info.get('sample_query_suffix', '')}
|
||||||
|
"""
|
||||||
|
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
cardinality_result = await connection.execute(cardinality_sql, auth_context=auth_context)
|
||||||
|
|
||||||
|
if cardinality_result.data:
|
||||||
|
cardinality_data = cardinality_result.data[0]
|
||||||
|
cardinality = cardinality_data["cardinality"]
|
||||||
|
non_null_count = cardinality_data["non_null_count"]
|
||||||
|
|
||||||
|
# Value distribution (top values)
|
||||||
|
value_distribution = await self._get_categorical_value_distribution(
|
||||||
|
connection, table_name, col_name, sampling_info, non_null_count
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate entropy and concentration
|
||||||
|
entropy = self._calculate_entropy(value_distribution)
|
||||||
|
concentration_ratio = value_distribution[0]["percentage"] if value_distribution else 0
|
||||||
|
|
||||||
|
categorical_analysis[col_name] = {
|
||||||
|
"data_type": column["data_type"],
|
||||||
|
"cardinality": cardinality,
|
||||||
|
"non_null_count": non_null_count,
|
||||||
|
"value_distribution": value_distribution,
|
||||||
|
"entropy": round(entropy, 3),
|
||||||
|
"concentration_ratio": round(concentration_ratio, 4),
|
||||||
|
"diversity_score": round(cardinality / non_null_count, 4) if non_null_count > 0 else 0
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to analyze categorical column {col_name}: {str(e)}")
|
||||||
|
categorical_analysis[col_name] = {"error": str(e)}
|
||||||
|
|
||||||
|
return categorical_analysis
|
||||||
|
|
||||||
|
async def _get_categorical_value_distribution(self, connection, table_name: str, col_name: str, sampling_info: Dict, total_count: int) -> List[Dict]:
|
||||||
|
"""Get value distribution for categorical column"""
|
||||||
|
try:
|
||||||
|
# Use sample table expression if sampling is enabled
|
||||||
|
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||||
|
|
||||||
|
distribution_sql = f"""
|
||||||
|
SELECT
|
||||||
|
{col_name} as value,
|
||||||
|
COUNT(*) as count
|
||||||
|
FROM {table_expr}
|
||||||
|
WHERE {col_name} IS NOT NULL
|
||||||
|
GROUP BY {col_name}
|
||||||
|
ORDER BY COUNT(*) DESC
|
||||||
|
LIMIT 20
|
||||||
|
"""
|
||||||
|
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(distribution_sql, auth_context=auth_context)
|
||||||
|
|
||||||
|
if result.data:
|
||||||
|
distribution = []
|
||||||
|
for row in result.data:
|
||||||
|
count = row["count"]
|
||||||
|
percentage = count / total_count if total_count > 0 else 0
|
||||||
|
distribution.append({
|
||||||
|
"value": str(row["value"]),
|
||||||
|
"count": count,
|
||||||
|
"percentage": round(percentage, 4)
|
||||||
|
})
|
||||||
|
return distribution
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get value distribution for {col_name}: {str(e)}")
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _calculate_entropy(self, value_distribution: List[Dict]) -> float:
|
||||||
|
"""Calculate Shannon entropy for categorical distribution"""
|
||||||
|
if not value_distribution:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
entropy = 0.0
|
||||||
|
for item in value_distribution:
|
||||||
|
p = item["percentage"]
|
||||||
|
if p > 0:
|
||||||
|
entropy -= p * math.log2(p)
|
||||||
|
|
||||||
|
return entropy
|
||||||
|
|
||||||
|
async def _analyze_temporal_distributions(self, connection, table_name: str, temporal_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
|
||||||
|
"""Analyze distribution patterns for temporal columns"""
|
||||||
|
temporal_analysis = {}
|
||||||
|
|
||||||
|
for column in temporal_columns:
|
||||||
|
col_name = column["column_name"]
|
||||||
|
try:
|
||||||
|
# Date range analysis
|
||||||
|
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||||
|
range_sql = f"""
|
||||||
|
SELECT
|
||||||
|
MIN({col_name}) as earliest,
|
||||||
|
MAX({col_name}) as latest,
|
||||||
|
COUNT({col_name}) as non_null_count
|
||||||
|
FROM {table_expr}
|
||||||
|
WHERE {col_name} IS NOT NULL
|
||||||
|
"""
|
||||||
|
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
range_result = await connection.execute(range_sql, auth_context=auth_context)
|
||||||
|
|
||||||
|
if range_result.data and range_result.data[0]["non_null_count"] > 0:
|
||||||
|
range_data = range_result.data[0]
|
||||||
|
earliest = range_data["earliest"]
|
||||||
|
latest = range_data["latest"]
|
||||||
|
|
||||||
|
# Calculate span
|
||||||
|
date_span_info = self._calculate_date_span(earliest, latest)
|
||||||
|
|
||||||
|
# Temporal patterns analysis
|
||||||
|
temporal_patterns = await self._analyze_temporal_patterns(
|
||||||
|
connection, table_name, col_name, sampling_info
|
||||||
|
)
|
||||||
|
|
||||||
|
temporal_analysis[col_name] = {
|
||||||
|
"data_type": column["data_type"],
|
||||||
|
"non_null_count": range_data["non_null_count"],
|
||||||
|
"date_range": {
|
||||||
|
"earliest": str(earliest),
|
||||||
|
"latest": str(latest),
|
||||||
|
**date_span_info
|
||||||
|
},
|
||||||
|
"temporal_patterns": temporal_patterns
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to analyze temporal column {col_name}: {str(e)}")
|
||||||
|
temporal_analysis[col_name] = {"error": str(e)}
|
||||||
|
|
||||||
|
return temporal_analysis
|
||||||
|
|
||||||
|
def _calculate_date_span(self, earliest, latest) -> Dict[str, Any]:
|
||||||
|
"""Calculate date span information"""
|
||||||
|
try:
|
||||||
|
if isinstance(earliest, str):
|
||||||
|
earliest = datetime.fromisoformat(earliest.replace('Z', '+00:00'))
|
||||||
|
if isinstance(latest, str):
|
||||||
|
latest = datetime.fromisoformat(latest.replace('Z', '+00:00'))
|
||||||
|
|
||||||
|
span = latest - earliest
|
||||||
|
span_days = span.days
|
||||||
|
|
||||||
|
return {
|
||||||
|
"span_days": span_days,
|
||||||
|
"span_years": round(span_days / 365.25, 2),
|
||||||
|
"span_description": self._describe_time_span(span_days)
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to calculate date span: {str(e)}")
|
||||||
|
return {"span_days": 0}
|
||||||
|
|
||||||
|
def _describe_time_span(self, days: int) -> str:
|
||||||
|
"""Describe time span in human readable format"""
|
||||||
|
if days < 1:
|
||||||
|
return "less_than_day"
|
||||||
|
elif days < 7:
|
||||||
|
return "days"
|
||||||
|
elif days < 30:
|
||||||
|
return "weeks"
|
||||||
|
elif days < 365:
|
||||||
|
return "months"
|
||||||
|
else:
|
||||||
|
return "years"
|
||||||
|
|
||||||
|
async def _analyze_temporal_patterns(self, connection, table_name: str, col_name: str, sampling_info: Dict) -> Dict[str, Any]:
|
||||||
|
"""Analyze temporal patterns like seasonality and trends"""
|
||||||
|
try:
|
||||||
|
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||||
|
# Weekly pattern analysis
|
||||||
|
weekly_pattern_sql = f"""
|
||||||
|
SELECT
|
||||||
|
DAYOFWEEK({col_name}) as day_of_week,
|
||||||
|
COUNT(*) as count
|
||||||
|
FROM {table_expr}
|
||||||
|
WHERE {col_name} IS NOT NULL
|
||||||
|
GROUP BY DAYOFWEEK({col_name})
|
||||||
|
ORDER BY day_of_week
|
||||||
|
"""
|
||||||
|
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
weekly_result = await connection.execute(weekly_pattern_sql, auth_context=auth_context)
|
||||||
|
|
||||||
|
weekly_pattern = []
|
||||||
|
if weekly_result.data:
|
||||||
|
total_records = sum(row["count"] for row in weekly_result.data)
|
||||||
|
for row in weekly_result.data:
|
||||||
|
percentage = row["count"] / total_records if total_records > 0 else 0
|
||||||
|
weekly_pattern.append(round(percentage, 3))
|
||||||
|
|
||||||
|
# Monthly trend analysis (simplified)
|
||||||
|
monthly_trend_sql = f"""
|
||||||
|
SELECT
|
||||||
|
YEAR({col_name}) as year,
|
||||||
|
MONTH({col_name}) as month,
|
||||||
|
COUNT(*) as count
|
||||||
|
FROM {table_expr}
|
||||||
|
WHERE {col_name} IS NOT NULL
|
||||||
|
GROUP BY YEAR({col_name}), MONTH({col_name})
|
||||||
|
ORDER BY year, month
|
||||||
|
LIMIT 12
|
||||||
|
"""
|
||||||
|
|
||||||
|
monthly_result = await connection.execute(monthly_trend_sql, auth_context=auth_context)
|
||||||
|
monthly_trend = "stable" # Simplified trend analysis
|
||||||
|
|
||||||
|
if monthly_result.data and len(monthly_result.data) > 3:
|
||||||
|
counts = [row["count"] for row in monthly_result.data]
|
||||||
|
if len(counts) > 1:
|
||||||
|
trend_direction = "increasing" if counts[-1] > counts[0] else "decreasing"
|
||||||
|
monthly_trend = trend_direction
|
||||||
|
|
||||||
|
return {
|
||||||
|
"weekly_pattern": weekly_pattern,
|
||||||
|
"monthly_trend": monthly_trend,
|
||||||
|
"seasonal_component": self._estimate_seasonality(weekly_pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to analyze temporal patterns for {col_name}: {str(e)}")
|
||||||
|
return {"weekly_pattern": [], "monthly_trend": "unknown"}
|
||||||
|
|
||||||
|
def _estimate_seasonality(self, weekly_pattern: List[float]) -> float:
|
||||||
|
"""Estimate seasonality strength based on weekly pattern variance"""
|
||||||
|
if len(weekly_pattern) < 7:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
mean_percentage = sum(weekly_pattern) / len(weekly_pattern)
|
||||||
|
variance = sum((x - mean_percentage) ** 2 for x in weekly_pattern) / len(weekly_pattern)
|
||||||
|
|
||||||
|
# Normalize variance to 0-1 scale as seasonality indicator
|
||||||
|
seasonality = min(variance * 10, 1.0) # Scaling factor
|
||||||
|
return round(seasonality, 3)
|
||||||
|
|
||||||
|
async def _generate_data_quality_insights(self, connection, table_name: str, columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
|
||||||
|
"""Generate overall data quality insights"""
|
||||||
|
try:
|
||||||
|
total_columns = len(columns)
|
||||||
|
|
||||||
|
# Calculate null rates across all columns
|
||||||
|
null_analysis = await self._analyze_overall_null_rates(connection, table_name, columns, sampling_info)
|
||||||
|
|
||||||
|
# Identify potential data quality issues
|
||||||
|
quality_issues = []
|
||||||
|
|
||||||
|
# High null rate columns
|
||||||
|
high_null_columns = [col for col, rate in null_analysis["column_null_rates"].items() if rate > 0.2]
|
||||||
|
if high_null_columns:
|
||||||
|
quality_issues.append({
|
||||||
|
"issue_type": "high_null_rates",
|
||||||
|
"severity": "medium",
|
||||||
|
"affected_columns": high_null_columns,
|
||||||
|
"description": f"{len(high_null_columns)} columns have null rates > 20%"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Calculate overall data quality score
|
||||||
|
avg_null_rate = sum(null_analysis["column_null_rates"].values()) / len(null_analysis["column_null_rates"]) if null_analysis["column_null_rates"] else 0
|
||||||
|
data_quality_score = max(0, 1 - avg_null_rate)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_columns_analyzed": total_columns,
|
||||||
|
"null_analysis": null_analysis,
|
||||||
|
"data_quality_score": round(data_quality_score, 3),
|
||||||
|
"quality_issues": quality_issues,
|
||||||
|
"recommendations": self._generate_quality_recommendations(quality_issues, null_analysis)
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to generate data quality insights: {str(e)}")
|
||||||
|
return {"data_quality_score": 0.0, "error": str(e)}
|
||||||
|
|
||||||
|
async def _analyze_overall_null_rates(self, connection, table_name: str, columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
|
||||||
|
"""Analyze null rates across all columns"""
|
||||||
|
column_null_rates = {}
|
||||||
|
total_null_count = 0
|
||||||
|
total_cell_count = 0
|
||||||
|
|
||||||
|
for column in columns:
|
||||||
|
col_name = column["column_name"]
|
||||||
|
try:
|
||||||
|
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||||
|
null_sql = f"""
|
||||||
|
SELECT
|
||||||
|
COUNT(*) as total_count,
|
||||||
|
COUNT({col_name}) as non_null_count
|
||||||
|
FROM {table_expr}
|
||||||
|
"""
|
||||||
|
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(null_sql, auth_context=auth_context)
|
||||||
|
if result.data:
|
||||||
|
data = result.data[0]
|
||||||
|
total_count = data["total_count"]
|
||||||
|
non_null_count = data["non_null_count"]
|
||||||
|
null_count = total_count - non_null_count
|
||||||
|
null_rate = null_count / total_count if total_count > 0 else 0
|
||||||
|
|
||||||
|
column_null_rates[col_name] = round(null_rate, 4)
|
||||||
|
total_null_count += null_count
|
||||||
|
total_cell_count += total_count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to analyze null rate for column {col_name}: {str(e)}")
|
||||||
|
column_null_rates[col_name] = 0.0
|
||||||
|
|
||||||
|
overall_null_rate = total_null_count / total_cell_count if total_cell_count > 0 else 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"column_null_rates": column_null_rates,
|
||||||
|
"overall_null_rate": round(overall_null_rate, 4),
|
||||||
|
"columns_with_nulls": len([rate for rate in column_null_rates.values() if rate > 0])
|
||||||
|
}
|
||||||
|
|
||||||
|
def _generate_quality_recommendations(self, quality_issues: List[Dict], null_analysis: Dict) -> List[Dict]:
|
||||||
|
"""Generate data quality improvement recommendations"""
|
||||||
|
recommendations = []
|
||||||
|
|
||||||
|
# Recommendations based on null analysis
|
||||||
|
overall_null_rate = null_analysis.get("overall_null_rate", 0)
|
||||||
|
if overall_null_rate > 0.1:
|
||||||
|
recommendations.append({
|
||||||
|
"type": "data_completeness",
|
||||||
|
"priority": "high" if overall_null_rate > 0.3 else "medium",
|
||||||
|
"description": f"Overall null rate is {overall_null_rate:.1%}",
|
||||||
|
"action": "Review data collection and validation processes"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Recommendations based on quality issues
|
||||||
|
for issue in quality_issues:
|
||||||
|
if issue["issue_type"] == "high_null_rates":
|
||||||
|
recommendations.append({
|
||||||
|
"type": "column_completeness",
|
||||||
|
"priority": issue["severity"],
|
||||||
|
"description": issue["description"],
|
||||||
|
"action": f"Focus on improving data completeness for: {', '.join(issue['affected_columns'][:3])}"
|
||||||
|
})
|
||||||
|
|
||||||
|
return recommendations
|
||||||
|
|
||||||
|
def _generate_analysis_summary(self, distribution_analysis: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Generate high-level summary of distribution analysis"""
|
||||||
|
summary = {
|
||||||
|
"numeric_columns_count": len(distribution_analysis.get("numeric_columns", {})),
|
||||||
|
"categorical_columns_count": len(distribution_analysis.get("categorical_columns", {})),
|
||||||
|
"temporal_columns_count": len(distribution_analysis.get("temporal_columns", {}))
|
||||||
|
}
|
||||||
|
|
||||||
|
# Identify interesting patterns
|
||||||
|
patterns = []
|
||||||
|
|
||||||
|
# Check for highly skewed numeric columns
|
||||||
|
numeric_cols = distribution_analysis.get("numeric_columns", {})
|
||||||
|
skewed_cols = [
|
||||||
|
col for col, info in numeric_cols.items()
|
||||||
|
if isinstance(info, dict) and
|
||||||
|
info.get("distribution_shape", {}).get("skewness_indicator") in ["right_skewed", "left_skewed"]
|
||||||
|
]
|
||||||
|
|
||||||
|
if skewed_cols:
|
||||||
|
patterns.append(f"Found {len(skewed_cols)} skewed numeric columns")
|
||||||
|
|
||||||
|
# Check for high cardinality categorical columns
|
||||||
|
categorical_cols = distribution_analysis.get("categorical_columns", {})
|
||||||
|
high_cardinality_cols = [
|
||||||
|
col for col, info in categorical_cols.items()
|
||||||
|
if isinstance(info, dict) and info.get("cardinality", 0) > 1000
|
||||||
|
]
|
||||||
|
|
||||||
|
if high_cardinality_cols:
|
||||||
|
patterns.append(f"Found {len(high_cardinality_cols)} high cardinality categorical columns")
|
||||||
|
|
||||||
|
summary["notable_patterns"] = patterns
|
||||||
|
|
||||||
|
return summary
|
||||||
1022
doris_mcp_server/utils/data_governance_tools.py
Normal file
1173
doris_mcp_server/utils/data_quality_tools.py
Normal file
1025
doris_mcp_server/utils/dependency_analysis_tools.py
Normal file
@@ -15,77 +15,573 @@
|
|||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
"""
|
"""
|
||||||
Logging configuration for Doris MCP Server.
|
Enhanced Logging configuration for Doris MCP Server.
|
||||||
|
Features:
|
||||||
|
- Log level-based file separation
|
||||||
|
- Timestamped log entries
|
||||||
|
- Automatic log rotation
|
||||||
|
- Comprehensive logging coverage
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import logging.config
|
import logging.config
|
||||||
|
import logging.handlers
|
||||||
import sys
|
import sys
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(
|
class TimestampedFormatter(logging.Formatter):
|
||||||
level: str = "INFO",
|
"""Custom formatter with enhanced timestamp and structured format"""
|
||||||
log_file: str | None = None,
|
|
||||||
log_format: str | None = None,
|
def __init__(self, fmt=None, datefmt=None, style='%'):
|
||||||
) -> None:
|
if fmt is None:
|
||||||
"""
|
fmt = "%(asctime)s.%(msecs)03d %(level_aligned)s %(name)s:%(lineno)d - %(message)s"
|
||||||
Setup logging configuration.
|
if datefmt is None:
|
||||||
|
datefmt = "%Y-%m-%d %H:%M:%S"
|
||||||
|
super().__init__(fmt, datefmt, style)
|
||||||
|
|
||||||
|
def format(self, record):
|
||||||
|
"""Format log record with enhanced information and proper alignment"""
|
||||||
|
# Add process info if available
|
||||||
|
if hasattr(record, 'process') and record.process:
|
||||||
|
record.process_info = f"[PID:{record.process}]"
|
||||||
|
else:
|
||||||
|
record.process_info = ""
|
||||||
|
|
||||||
|
# Add thread info if available
|
||||||
|
if hasattr(record, 'thread') and record.thread:
|
||||||
|
record.thread_info = f"[TID:{record.thread}]"
|
||||||
|
else:
|
||||||
|
record.thread_info = ""
|
||||||
|
|
||||||
|
# Format with proper alignment after the level name
|
||||||
|
# Calculate padding needed for alignment
|
||||||
|
level_name = record.levelname
|
||||||
|
max_level_length = 8 # Length of "CRITICAL"
|
||||||
|
padding = max_level_length - len(level_name)
|
||||||
|
record.level_aligned = f"[{level_name}]{' ' * padding}"
|
||||||
|
|
||||||
|
return super().format(record)
|
||||||
|
|
||||||
Args:
|
|
||||||
level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
|
||||||
log_file: Optional log file path
|
|
||||||
log_format: Optional custom log format
|
|
||||||
"""
|
|
||||||
if log_format is None:
|
|
||||||
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
||||||
|
|
||||||
# Base configuration
|
class LevelBasedFileHandler(logging.Handler):
|
||||||
config: dict[str, Any] = {
|
"""Custom handler that writes different log levels to different files"""
|
||||||
"version": 1,
|
|
||||||
"disable_existing_loggers": False,
|
def __init__(self, log_dir: str, base_name: str = "doris_mcp_server",
|
||||||
"formatters": {
|
max_bytes: int = 10*1024*1024, backup_count: int = 5):
|
||||||
"default": {"format": log_format, "datefmt": "%Y-%m-%d %H:%M:%S"}
|
super().__init__()
|
||||||
},
|
self.log_dir = Path(log_dir)
|
||||||
"handlers": {
|
self.base_name = base_name
|
||||||
"console": {
|
self.max_bytes = max_bytes
|
||||||
"class": "logging.StreamHandler",
|
self.backup_count = backup_count
|
||||||
"level": level,
|
|
||||||
"formatter": "default",
|
|
||||||
"stream": sys.stdout,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"root": {"level": level, "handlers": ["console"]},
|
|
||||||
"loggers": {
|
|
||||||
"doris_mcp_server": {
|
|
||||||
"level": level,
|
|
||||||
"handlers": ["console"],
|
|
||||||
"propagate": False,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add file handler if log_file is specified
|
|
||||||
if log_file:
|
|
||||||
# Ensure log directory exists
|
# Ensure log directory exists
|
||||||
log_path = Path(log_file)
|
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
# Create handlers for different log levels
|
||||||
config["handlers"]["file"] = {
|
self.handlers = {}
|
||||||
"class": "logging.handlers.RotatingFileHandler",
|
self._setup_level_handlers()
|
||||||
"level": level,
|
|
||||||
"formatter": "default",
|
def _setup_level_handlers(self):
|
||||||
"filename": log_file,
|
"""Setup rotating file handlers for different log levels"""
|
||||||
"maxBytes": 10485760, # 10MB
|
level_files = {
|
||||||
"backupCount": 5,
|
'DEBUG': 'debug.log',
|
||||||
|
'INFO': 'info.log',
|
||||||
|
'WARNING': 'warning.log',
|
||||||
|
'ERROR': 'error.log',
|
||||||
|
'CRITICAL': 'critical.log'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
formatter = TimestampedFormatter()
|
||||||
|
|
||||||
|
for level, filename in level_files.items():
|
||||||
|
file_path = self.log_dir / f"{self.base_name}_{filename}"
|
||||||
|
handler = logging.handlers.RotatingFileHandler(
|
||||||
|
file_path,
|
||||||
|
maxBytes=self.max_bytes,
|
||||||
|
backupCount=self.backup_count,
|
||||||
|
encoding='utf-8'
|
||||||
|
)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
handler.setLevel(getattr(logging, level))
|
||||||
|
self.handlers[level] = handler
|
||||||
|
|
||||||
|
def emit(self, record):
|
||||||
|
"""Emit log record to appropriate level-based file"""
|
||||||
|
level_name = record.levelname
|
||||||
|
if level_name in self.handlers:
|
||||||
|
try:
|
||||||
|
self.handlers[level_name].emit(record)
|
||||||
|
except Exception:
|
||||||
|
self.handleError(record)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close all handlers"""
|
||||||
|
for handler in self.handlers.values():
|
||||||
|
handler.close()
|
||||||
|
super().close()
|
||||||
|
|
||||||
# Add file handler to root and package loggers
|
|
||||||
config["root"]["handlers"].append("file")
|
|
||||||
config["loggers"]["doris_mcp_server"]["handlers"].append("file")
|
|
||||||
|
|
||||||
logging.config.dictConfig(config)
|
class LogCleanupManager:
|
||||||
|
"""Log file cleanup manager for automatic maintenance"""
|
||||||
|
|
||||||
|
def __init__(self, log_dir: str, max_age_days: int = 30, cleanup_interval_hours: int = 24):
|
||||||
|
"""
|
||||||
|
Initialize log cleanup manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_dir: Directory containing log files
|
||||||
|
max_age_days: Maximum age of log files in days (default: 30 days)
|
||||||
|
cleanup_interval_hours: Cleanup interval in hours (default: 24 hours)
|
||||||
|
"""
|
||||||
|
self.log_dir = Path(log_dir)
|
||||||
|
self.max_age_days = max_age_days
|
||||||
|
self.cleanup_interval_hours = cleanup_interval_hours
|
||||||
|
self.cleanup_thread = None
|
||||||
|
self.stop_event = threading.Event()
|
||||||
|
self.logger = None
|
||||||
|
|
||||||
|
def start_cleanup_scheduler(self):
|
||||||
|
"""Start the cleanup scheduler in a background thread"""
|
||||||
|
if self.cleanup_thread and self.cleanup_thread.is_alive():
|
||||||
|
return
|
||||||
|
|
||||||
|
self.stop_event.clear()
|
||||||
|
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
|
||||||
|
self.cleanup_thread.start()
|
||||||
|
|
||||||
|
# Get logger for this class
|
||||||
|
if not self.logger:
|
||||||
|
self.logger = logging.getLogger("doris_mcp_server.log_cleanup")
|
||||||
|
|
||||||
|
self.logger.info(f"Log cleanup scheduler started - cleanup every {self.cleanup_interval_hours}h, max age {self.max_age_days} days")
|
||||||
|
|
||||||
|
def stop_cleanup_scheduler(self):
|
||||||
|
"""Stop the cleanup scheduler"""
|
||||||
|
if self.cleanup_thread and self.cleanup_thread.is_alive():
|
||||||
|
self.stop_event.set()
|
||||||
|
self.cleanup_thread.join(timeout=5)
|
||||||
|
if self.logger:
|
||||||
|
self.logger.info("Log cleanup scheduler stopped")
|
||||||
|
|
||||||
|
def _cleanup_loop(self):
|
||||||
|
"""Background loop for periodic cleanup"""
|
||||||
|
while not self.stop_event.is_set():
|
||||||
|
try:
|
||||||
|
self.cleanup_old_logs()
|
||||||
|
# Sleep for the specified interval, but check stop event every 60 seconds
|
||||||
|
for _ in range(self.cleanup_interval_hours * 60): # Convert hours to minutes
|
||||||
|
if self.stop_event.wait(60): # Wait 60 seconds or until stop event
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.error(f"Error in log cleanup loop: {e}")
|
||||||
|
# Sleep for 5 minutes before retrying
|
||||||
|
self.stop_event.wait(300)
|
||||||
|
|
||||||
|
def cleanup_old_logs(self):
|
||||||
|
"""Clean up old log files based on age"""
|
||||||
|
if not self.log_dir.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
current_time = datetime.now()
|
||||||
|
cutoff_time = current_time - timedelta(days=self.max_age_days)
|
||||||
|
|
||||||
|
cleaned_files = []
|
||||||
|
cleaned_size = 0
|
||||||
|
|
||||||
|
# Pattern for log files (including backup files)
|
||||||
|
log_patterns = [
|
||||||
|
"doris_mcp_server_*.log",
|
||||||
|
"doris_mcp_server_*.log.*" # Backup files
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern in log_patterns:
|
||||||
|
for log_file in self.log_dir.glob(pattern):
|
||||||
|
try:
|
||||||
|
# Get file modification time
|
||||||
|
file_mtime = datetime.fromtimestamp(log_file.stat().st_mtime)
|
||||||
|
|
||||||
|
if file_mtime < cutoff_time:
|
||||||
|
file_size = log_file.stat().st_size
|
||||||
|
log_file.unlink() # Delete the file
|
||||||
|
cleaned_files.append(log_file.name)
|
||||||
|
cleaned_size += file_size
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"Failed to cleanup log file {log_file}: {e}")
|
||||||
|
|
||||||
|
if cleaned_files and self.logger:
|
||||||
|
size_mb = cleaned_size / (1024 * 1024)
|
||||||
|
self.logger.info(f"Cleaned up {len(cleaned_files)} old log files, freed {size_mb:.2f} MB")
|
||||||
|
self.logger.debug(f"Cleaned files: {', '.join(cleaned_files)}")
|
||||||
|
|
||||||
|
def get_cleanup_stats(self) -> dict:
|
||||||
|
"""Get statistics about log files and cleanup status"""
|
||||||
|
if not self.log_dir.exists():
|
||||||
|
return {"error": "Log directory does not exist"}
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
"log_directory": str(self.log_dir.absolute()),
|
||||||
|
"max_age_days": self.max_age_days,
|
||||||
|
"cleanup_interval_hours": self.cleanup_interval_hours,
|
||||||
|
"scheduler_running": self.cleanup_thread and self.cleanup_thread.is_alive(),
|
||||||
|
"total_files": 0,
|
||||||
|
"total_size_mb": 0,
|
||||||
|
"files_by_age": {"recent": 0, "old": 0},
|
||||||
|
"oldest_file": None,
|
||||||
|
"newest_file": None
|
||||||
|
}
|
||||||
|
|
||||||
|
current_time = datetime.now()
|
||||||
|
cutoff_time = current_time - timedelta(days=self.max_age_days)
|
||||||
|
oldest_time = None
|
||||||
|
newest_time = None
|
||||||
|
|
||||||
|
log_patterns = ["doris_mcp_server_*.log", "doris_mcp_server_*.log.*"]
|
||||||
|
|
||||||
|
for pattern in log_patterns:
|
||||||
|
for log_file in self.log_dir.glob(pattern):
|
||||||
|
try:
|
||||||
|
file_stat = log_file.stat()
|
||||||
|
file_mtime = datetime.fromtimestamp(file_stat.st_mtime)
|
||||||
|
|
||||||
|
stats["total_files"] += 1
|
||||||
|
stats["total_size_mb"] += file_stat.st_size / (1024 * 1024)
|
||||||
|
|
||||||
|
if file_mtime < cutoff_time:
|
||||||
|
stats["files_by_age"]["old"] += 1
|
||||||
|
else:
|
||||||
|
stats["files_by_age"]["recent"] += 1
|
||||||
|
|
||||||
|
if oldest_time is None or file_mtime < oldest_time:
|
||||||
|
oldest_time = file_mtime
|
||||||
|
stats["oldest_file"] = {"name": log_file.name, "age_days": (current_time - file_mtime).days}
|
||||||
|
|
||||||
|
if newest_time is None or file_mtime > newest_time:
|
||||||
|
newest_time = file_mtime
|
||||||
|
stats["newest_file"] = {"name": log_file.name, "age_days": (current_time - file_mtime).days}
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
stats["total_size_mb"] = round(stats["total_size_mb"], 2)
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
class DorisLoggerManager:
|
||||||
|
"""Centralized logger manager for Doris MCP Server"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.is_initialized = False
|
||||||
|
self.log_dir = None
|
||||||
|
self.config = None
|
||||||
|
self.loggers = {}
|
||||||
|
self.cleanup_manager = None
|
||||||
|
|
||||||
|
def setup_logging(self,
|
||||||
|
level: str = "INFO",
|
||||||
|
log_dir: str = "logs",
|
||||||
|
enable_console: bool = True,
|
||||||
|
enable_file: bool = True,
|
||||||
|
enable_audit: bool = True,
|
||||||
|
audit_file: Optional[str] = None,
|
||||||
|
max_file_size: int = 10*1024*1024,
|
||||||
|
backup_count: int = 5,
|
||||||
|
enable_cleanup: bool = True,
|
||||||
|
max_age_days: int = 30,
|
||||||
|
cleanup_interval_hours: int = 24) -> None:
|
||||||
|
"""
|
||||||
|
Setup comprehensive logging configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Base logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||||
|
log_dir: Directory for log files
|
||||||
|
enable_console: Enable console output
|
||||||
|
enable_file: Enable file logging
|
||||||
|
enable_audit: Enable audit logging
|
||||||
|
audit_file: Custom audit log file path
|
||||||
|
max_file_size: Maximum size per log file (bytes)
|
||||||
|
backup_count: Number of backup files to keep
|
||||||
|
enable_cleanup: Enable automatic log cleanup
|
||||||
|
max_age_days: Maximum age of log files in days (default: 30)
|
||||||
|
cleanup_interval_hours: Cleanup interval in hours (default: 24)
|
||||||
|
"""
|
||||||
|
if self.is_initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.log_dir = Path(log_dir)
|
||||||
|
log_dir_writable = True # Initialize the variable
|
||||||
|
|
||||||
|
# Try to create log directory, fallback to console-only if fails
|
||||||
|
try:
|
||||||
|
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
except (OSError, PermissionError) as e:
|
||||||
|
# If we can't create log directory (e.g., read-only filesystem in stdio mode),
|
||||||
|
# fall back to console-only logging
|
||||||
|
log_dir_writable = False
|
||||||
|
enable_file = False
|
||||||
|
enable_audit = False
|
||||||
|
enable_cleanup = False
|
||||||
|
# Don't use print() in stdio mode as it interferes with MCP JSON protocol
|
||||||
|
# Log the warning through the logging system instead, which will be handled after setup
|
||||||
|
|
||||||
|
# Clear existing handlers
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
for handler in root_logger.handlers[:]:
|
||||||
|
root_logger.removeHandler(handler)
|
||||||
|
|
||||||
|
# Set root logger level
|
||||||
|
root_logger.setLevel(logging.DEBUG) # Allow all levels, handlers will filter
|
||||||
|
|
||||||
|
handlers = []
|
||||||
|
|
||||||
|
# Console handler
|
||||||
|
if enable_console:
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
console_handler.setLevel(getattr(logging, level.upper()))
|
||||||
|
console_formatter = TimestampedFormatter(
|
||||||
|
fmt="%(asctime)s.%(msecs)03d %(level_aligned)s %(name)s - %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S"
|
||||||
|
)
|
||||||
|
console_handler.setFormatter(console_formatter)
|
||||||
|
handlers.append(console_handler)
|
||||||
|
|
||||||
|
# Level-based file handlers
|
||||||
|
if enable_file:
|
||||||
|
level_handler = LevelBasedFileHandler(
|
||||||
|
log_dir=str(self.log_dir),
|
||||||
|
base_name="doris_mcp_server",
|
||||||
|
max_bytes=max_file_size,
|
||||||
|
backup_count=backup_count
|
||||||
|
)
|
||||||
|
level_handler.setLevel(logging.DEBUG) # Accept all levels
|
||||||
|
handlers.append(level_handler)
|
||||||
|
|
||||||
|
# Combined application log (all levels in one file)
|
||||||
|
if enable_file:
|
||||||
|
app_log_file = self.log_dir / "doris_mcp_server_all.log"
|
||||||
|
app_handler = logging.handlers.RotatingFileHandler(
|
||||||
|
app_log_file,
|
||||||
|
maxBytes=max_file_size,
|
||||||
|
backupCount=backup_count,
|
||||||
|
encoding='utf-8'
|
||||||
|
)
|
||||||
|
app_handler.setLevel(getattr(logging, level.upper()))
|
||||||
|
app_formatter = TimestampedFormatter()
|
||||||
|
app_handler.setFormatter(app_formatter)
|
||||||
|
handlers.append(app_handler)
|
||||||
|
|
||||||
|
# Audit logger (separate from main logging)
|
||||||
|
if enable_audit:
|
||||||
|
audit_file_path = audit_file or str(self.log_dir / "doris_mcp_server_audit.log")
|
||||||
|
audit_logger = logging.getLogger("audit")
|
||||||
|
audit_logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
# Clear existing audit handlers
|
||||||
|
for handler in audit_logger.handlers[:]:
|
||||||
|
audit_logger.removeHandler(handler)
|
||||||
|
|
||||||
|
audit_handler = logging.handlers.RotatingFileHandler(
|
||||||
|
audit_file_path,
|
||||||
|
maxBytes=max_file_size,
|
||||||
|
backupCount=backup_count,
|
||||||
|
encoding='utf-8'
|
||||||
|
)
|
||||||
|
audit_formatter = TimestampedFormatter(
|
||||||
|
fmt="%(asctime)s.%(msecs)03d [AUDIT] %(name)s - %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S"
|
||||||
|
)
|
||||||
|
audit_handler.setFormatter(audit_formatter)
|
||||||
|
audit_logger.addHandler(audit_handler)
|
||||||
|
audit_logger.propagate = False # Don't propagate to root logger
|
||||||
|
|
||||||
|
# Add all handlers to root logger
|
||||||
|
for handler in handlers:
|
||||||
|
root_logger.addHandler(handler)
|
||||||
|
|
||||||
|
# Setup package-specific loggers
|
||||||
|
self._setup_package_loggers(level)
|
||||||
|
|
||||||
|
# Setup log cleanup manager
|
||||||
|
if enable_cleanup and enable_file:
|
||||||
|
self.cleanup_manager = LogCleanupManager(
|
||||||
|
log_dir=str(self.log_dir),
|
||||||
|
max_age_days=max_age_days,
|
||||||
|
cleanup_interval_hours=cleanup_interval_hours
|
||||||
|
)
|
||||||
|
self.cleanup_manager.start_cleanup_scheduler()
|
||||||
|
|
||||||
|
self.is_initialized = True
|
||||||
|
|
||||||
|
# Log initialization message
|
||||||
|
logger = self.get_logger("doris_mcp_server.logger")
|
||||||
|
logger.info("=" * 80)
|
||||||
|
logger.info("Doris MCP Server Logging System Initialized")
|
||||||
|
logger.info(f"Log Level: {level}")
|
||||||
|
if log_dir_writable:
|
||||||
|
logger.info(f"Log Directory: {self.log_dir.absolute()}")
|
||||||
|
else:
|
||||||
|
logger.info("Log Directory: Not available (console-only mode)")
|
||||||
|
logger.info(f"Console Logging: {'Enabled' if enable_console else 'Disabled'}")
|
||||||
|
logger.info(f"File Logging: {'Enabled' if enable_file else 'Disabled (fallback mode)'}")
|
||||||
|
logger.info(f"Audit Logging: {'Enabled' if enable_audit else 'Disabled (fallback mode)'}")
|
||||||
|
logger.info(f"Log Cleanup: {'Enabled' if enable_cleanup and enable_file else 'Disabled (fallback mode)'}")
|
||||||
|
if enable_cleanup and enable_file:
|
||||||
|
logger.info(f"Cleanup Settings: Max age {max_age_days} days, interval {cleanup_interval_hours}h")
|
||||||
|
if not log_dir_writable:
|
||||||
|
logger.warning("Running in console-only logging mode due to filesystem permissions")
|
||||||
|
logger.warning(f"Could not create log directory '{log_dir}' - stdio mode fallback enabled")
|
||||||
|
logger.info("=" * 80)
|
||||||
|
|
||||||
|
def _setup_package_loggers(self, level: str):
|
||||||
|
"""Setup specific loggers for different modules"""
|
||||||
|
package_loggers = [
|
||||||
|
"doris_mcp_server",
|
||||||
|
"doris_mcp_server.main",
|
||||||
|
"doris_mcp_server.utils",
|
||||||
|
"doris_mcp_server.tools",
|
||||||
|
"doris_mcp_client"
|
||||||
|
]
|
||||||
|
|
||||||
|
for logger_name in package_loggers:
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
logger.setLevel(getattr(logging, level.upper()))
|
||||||
|
# Don't add handlers here - they inherit from root logger
|
||||||
|
|
||||||
|
def get_logger(self, name: str) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
Get a logger instance with proper configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Logger name (usually __name__)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured logger instance
|
||||||
|
"""
|
||||||
|
if name not in self.loggers:
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
self.loggers[name] = logger
|
||||||
|
|
||||||
|
return self.loggers[name]
|
||||||
|
|
||||||
|
def get_audit_logger(self) -> logging.Logger:
|
||||||
|
"""Get the audit logger"""
|
||||||
|
return logging.getLogger("audit")
|
||||||
|
|
||||||
|
def log_system_info(self):
|
||||||
|
"""Log system information for debugging"""
|
||||||
|
logger = self.get_logger("doris_mcp_server.system")
|
||||||
|
logger.info("System Information:")
|
||||||
|
logger.info(f"Python Version: {sys.version}")
|
||||||
|
logger.info(f"Platform: {sys.platform}")
|
||||||
|
logger.info(f"Working Directory: {os.getcwd()}")
|
||||||
|
logger.info(f"Process ID: {os.getpid()}")
|
||||||
|
|
||||||
|
# Log environment variables (filtered)
|
||||||
|
env_vars = ["LOG_LEVEL", "LOG_FILE_PATH", "ENABLE_AUDIT", "AUDIT_FILE_PATH"]
|
||||||
|
for var in env_vars:
|
||||||
|
value = os.getenv(var, "Not Set")
|
||||||
|
logger.info(f"Environment {var}: {value}")
|
||||||
|
|
||||||
|
def get_cleanup_stats(self) -> dict:
|
||||||
|
"""Get log cleanup statistics"""
|
||||||
|
if self.cleanup_manager:
|
||||||
|
return self.cleanup_manager.get_cleanup_stats()
|
||||||
|
else:
|
||||||
|
return {"error": "Log cleanup is not enabled"}
|
||||||
|
|
||||||
|
def manual_cleanup(self) -> dict:
|
||||||
|
"""Manually trigger log cleanup and return statistics"""
|
||||||
|
if self.cleanup_manager:
|
||||||
|
self.cleanup_manager.cleanup_old_logs()
|
||||||
|
return self.cleanup_manager.get_cleanup_stats()
|
||||||
|
else:
|
||||||
|
return {"error": "Log cleanup is not enabled"}
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
"""Shutdown logging system"""
|
||||||
|
if not self.is_initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger = self.get_logger("doris_mcp_server.logger")
|
||||||
|
logger.info("Shutting down logging system...")
|
||||||
|
|
||||||
|
# Stop cleanup manager
|
||||||
|
if self.cleanup_manager:
|
||||||
|
self.cleanup_manager.stop_cleanup_scheduler()
|
||||||
|
|
||||||
|
# Close all handlers
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
for handler in root_logger.handlers[:]:
|
||||||
|
try:
|
||||||
|
handler.close()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error closing handler: {e}")
|
||||||
|
|
||||||
|
# Close audit logger handlers
|
||||||
|
audit_logger = logging.getLogger("audit")
|
||||||
|
for handler in audit_logger.handlers[:]:
|
||||||
|
try:
|
||||||
|
handler.close()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error closing audit handler: {e}")
|
||||||
|
|
||||||
|
self.is_initialized = False
|
||||||
|
|
||||||
|
|
||||||
|
# Global logger manager instance
|
||||||
|
_logger_manager = DorisLoggerManager()
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(level: str = "INFO",
|
||||||
|
log_dir: str = "logs",
|
||||||
|
enable_console: bool = True,
|
||||||
|
enable_file: bool = True,
|
||||||
|
enable_audit: bool = True,
|
||||||
|
audit_file: Optional[str] = None,
|
||||||
|
max_file_size: int = 10*1024*1024,
|
||||||
|
backup_count: int = 5,
|
||||||
|
enable_cleanup: bool = True,
|
||||||
|
max_age_days: int = 30,
|
||||||
|
cleanup_interval_hours: int = 24) -> None:
|
||||||
|
"""
|
||||||
|
Setup logging configuration (convenience function).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||||
|
log_dir: Directory for log files
|
||||||
|
enable_console: Enable console output
|
||||||
|
enable_file: Enable file logging
|
||||||
|
enable_audit: Enable audit logging
|
||||||
|
audit_file: Custom audit log file path
|
||||||
|
max_file_size: Maximum size per log file (bytes)
|
||||||
|
backup_count: Number of backup files to keep
|
||||||
|
enable_cleanup: Enable automatic log cleanup
|
||||||
|
max_age_days: Maximum age of log files in days (default: 30)
|
||||||
|
cleanup_interval_hours: Cleanup interval in hours (default: 24)
|
||||||
|
"""
|
||||||
|
_logger_manager.setup_logging(
|
||||||
|
level=level,
|
||||||
|
log_dir=log_dir,
|
||||||
|
enable_console=enable_console,
|
||||||
|
enable_file=enable_file,
|
||||||
|
enable_audit=enable_audit,
|
||||||
|
audit_file=audit_file,
|
||||||
|
max_file_size=max_file_size,
|
||||||
|
backup_count=backup_count,
|
||||||
|
enable_cleanup=enable_cleanup,
|
||||||
|
max_age_days=max_age_days,
|
||||||
|
cleanup_interval_hours=cleanup_interval_hours
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name: str) -> logging.Logger:
|
def get_logger(name: str) -> logging.Logger:
|
||||||
@@ -93,9 +589,60 @@ def get_logger(name: str) -> logging.Logger:
|
|||||||
Get a logger instance.
|
Get a logger instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Logger name
|
name: Logger name (usually __name__)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Logger instance
|
Configured logger instance
|
||||||
"""
|
"""
|
||||||
return logging.getLogger(name)
|
return _logger_manager.get_logger(name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_audit_logger() -> logging.Logger:
|
||||||
|
"""Get the audit logger"""
|
||||||
|
return _logger_manager.get_audit_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def log_system_info():
|
||||||
|
"""Log system information for debugging"""
|
||||||
|
_logger_manager.log_system_info()
|
||||||
|
|
||||||
|
|
||||||
|
def get_cleanup_stats() -> dict:
|
||||||
|
"""Get log cleanup statistics"""
|
||||||
|
return _logger_manager.get_cleanup_stats()
|
||||||
|
|
||||||
|
|
||||||
|
def manual_cleanup() -> dict:
|
||||||
|
"""Manually trigger log cleanup and return statistics"""
|
||||||
|
return _logger_manager.manual_cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
def shutdown_logging():
|
||||||
|
"""Shutdown logging system"""
|
||||||
|
_logger_manager.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
# Compatibility function for existing code
|
||||||
|
def setup_logging_old(level: str = "INFO",
|
||||||
|
log_file: str | None = None,
|
||||||
|
log_format: str | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Legacy setup function for backward compatibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||||
|
log_file: Optional log file path (deprecated - use log_dir instead)
|
||||||
|
log_format: Optional custom log format (deprecated)
|
||||||
|
"""
|
||||||
|
# Extract directory from log_file if provided
|
||||||
|
log_dir = "logs"
|
||||||
|
if log_file:
|
||||||
|
log_dir = str(Path(log_file).parent)
|
||||||
|
|
||||||
|
setup_logging(
|
||||||
|
level=level,
|
||||||
|
log_dir=log_dir,
|
||||||
|
enable_console=True,
|
||||||
|
enable_file=True,
|
||||||
|
enable_audit=True
|
||||||
|
)
|
||||||
|
|||||||
1611
doris_mcp_server/utils/monitoring_tools.py
Normal file
1810
doris_mcp_server/utils/performance_analytics_tools.py
Normal file
@@ -33,7 +33,11 @@ from datetime import datetime, timedelta, date
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
|
||||||
|
import sqlparse
|
||||||
|
|
||||||
from .db import DorisConnectionManager, QueryResult
|
from .db import DorisConnectionManager, QueryResult
|
||||||
|
from .logger import get_logger
|
||||||
|
from .sql_security_utils import get_auth_context
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -92,7 +96,7 @@ class QueryCache:
|
|||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
self.default_ttl = default_ttl
|
self.default_ttl = default_ttl
|
||||||
self.cache: dict[str, CachedQuery] = {}
|
self.cache: dict[str, CachedQuery] = {}
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = get_logger(__name__)
|
||||||
|
|
||||||
def _generate_cache_key(
|
def _generate_cache_key(
|
||||||
self, sql: str, parameters: dict[str, Any] | None = None
|
self, sql: str, parameters: dict[str, Any] | None = None
|
||||||
@@ -194,7 +198,7 @@ class QueryOptimizer:
|
|||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = get_logger(__name__)
|
||||||
self.optimization_rules = self._load_optimization_rules()
|
self.optimization_rules = self._load_optimization_rules()
|
||||||
|
|
||||||
def _load_optimization_rules(self) -> list[dict[str, Any]]:
|
def _load_optimization_rules(self) -> list[dict[str, Any]]:
|
||||||
@@ -318,7 +322,7 @@ class DorisQueryExecutor:
|
|||||||
def __init__(self, connection_manager: DorisConnectionManager, config=None):
|
def __init__(self, connection_manager: DorisConnectionManager, config=None):
|
||||||
self.connection_manager = connection_manager
|
self.connection_manager = connection_manager
|
||||||
self.config = config or self._create_default_config()
|
self.config = config or self._create_default_config()
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = get_logger(__name__)
|
||||||
|
|
||||||
# Initialize components
|
# Initialize components
|
||||||
cache_config = getattr(self.config, 'performance', None)
|
cache_config = getattr(self.config, 'performance', None)
|
||||||
@@ -425,27 +429,27 @@ class DorisQueryExecutor:
|
|||||||
self, query_request: QueryRequest, auth_context
|
self, query_request: QueryRequest, auth_context
|
||||||
) -> QueryResult:
|
) -> QueryResult:
|
||||||
"""Internal query execution"""
|
"""Internal query execution"""
|
||||||
|
|
||||||
|
# Database configuration should already be handled during authentication
|
||||||
|
# No need to configure again during query execution
|
||||||
|
|
||||||
# Optimize query
|
# Optimize query
|
||||||
optimized_sql = await self.query_optimizer.optimize_query(
|
optimized_sql = await self.query_optimizer.optimize_query(
|
||||||
query_request.sql, {"user_roles": getattr(auth_context, 'roles', [])}
|
query_request.sql, {"user_roles": getattr(auth_context, 'roles', [])}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute query
|
# Execute query
|
||||||
connection = await self.connection_manager.get_connection(
|
|
||||||
query_request.session_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set timeout if specified
|
# Set timeout if specified
|
||||||
if query_request.timeout:
|
if query_request.timeout:
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
connection.execute(optimized_sql, query_request.parameters, auth_context),
|
self.connection_manager.execute_query(query_request.session_id, optimized_sql, query_request.parameters, auth_context),
|
||||||
timeout=query_request.timeout
|
timeout=query_request.timeout
|
||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
raise Exception(f"Query timeout after {query_request.timeout} seconds")
|
raise Exception(f"Query timeout after {query_request.timeout} seconds")
|
||||||
else:
|
else:
|
||||||
result = await connection.execute(optimized_sql, query_request.parameters, auth_context)
|
result = await self.connection_manager.execute_query(query_request.session_id, optimized_sql, query_request.parameters, auth_context)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -466,6 +470,51 @@ class DorisQueryExecutor:
|
|||||||
f"Slow query detected: {execution_time:.2f}s (threshold: {self.slow_query_threshold}s)"
|
f"Slow query detected: {execution_time:.2f}s (threshold: {self.slow_query_threshold}s)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def execute_batch_sqls_for_mcp(
|
||||||
|
self, sqls: list[str],
|
||||||
|
timeout: int = 30,
|
||||||
|
session_id: str = "mcp_session",
|
||||||
|
user_id: str = "mcp_user",
|
||||||
|
auth_context=None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Execute multiple sqls in batch"""
|
||||||
|
if not sqls:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "SQL query is required",
|
||||||
|
"data": None
|
||||||
|
}
|
||||||
|
query_requests = [
|
||||||
|
QueryRequest(
|
||||||
|
sql=sql,
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
timeout=timeout,
|
||||||
|
cache_enabled=False
|
||||||
|
)
|
||||||
|
for sql in sqls
|
||||||
|
]
|
||||||
|
query_results = await self.execute_batch_queries(query_requests, auth_context)
|
||||||
|
# Serialize data for JSON response
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"data": [self._serialize_row_data(data) for data in result.data],
|
||||||
|
"row_count": result.row_count,
|
||||||
|
"execution_time": result.execution_time,
|
||||||
|
"metadata": {
|
||||||
|
"columns": result.metadata.get("columns", []),
|
||||||
|
"query": result.sql
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for result in query_results
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"multiple_results": True,
|
||||||
|
"results": results
|
||||||
|
}
|
||||||
|
|
||||||
async def execute_batch_queries(
|
async def execute_batch_queries(
|
||||||
self, query_requests: list[QueryRequest], auth_context=None
|
self, query_requests: list[QueryRequest], auth_context=None
|
||||||
) -> list[QueryResult]:
|
) -> list[QueryResult]:
|
||||||
@@ -483,20 +532,24 @@ class DorisQueryExecutor:
|
|||||||
self.execute_query(request, auth_context) for request in query_requests
|
self.execute_query(request, auth_context) for request in query_requests
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
query_results = []
|
||||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
except Exception as e:
|
for result in results:
|
||||||
self.logger.error(f"Batch query execution failed: {e}")
|
if isinstance(result, Exception):
|
||||||
raise
|
self.logger.error(f"Batch query execution failed: {result}")
|
||||||
|
raise result
|
||||||
|
else:
|
||||||
|
query_results.append(result)
|
||||||
|
|
||||||
return results
|
return query_results
|
||||||
|
|
||||||
async def explain_query(self, sql: str, session_id: str) -> dict[str, Any]:
|
async def explain_query(self, sql: str, session_id: str) -> dict[str, Any]:
|
||||||
"""Get query execution plan"""
|
"""Get query execution plan"""
|
||||||
explain_sql = f"EXPLAIN {sql}"
|
explain_sql = f"EXPLAIN {sql}"
|
||||||
|
|
||||||
connection = await self.connection_manager.get_connection(session_id)
|
connection = await self.connection_manager.get_connection(session_id)
|
||||||
result = await connection.execute(explain_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(explain_sql, auth_context=auth_context)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"query": sql,
|
"query": sql,
|
||||||
@@ -540,87 +593,192 @@ class DorisQueryExecutor:
|
|||||||
await self.query_cache.clear_all()
|
await self.query_cache.clear_all()
|
||||||
|
|
||||||
async def execute_sql_for_mcp(
|
async def execute_sql_for_mcp(
|
||||||
self,
|
self,
|
||||||
sql: str,
|
sql: str,
|
||||||
limit: int = 1000,
|
limit: int = 1000,
|
||||||
timeout: int = 30,
|
timeout: int = 30,
|
||||||
session_id: str = "mcp_session",
|
session_id: str = "mcp_session",
|
||||||
user_id: str = "mcp_user"
|
user_id: str = "mcp_user",
|
||||||
|
auth_context = None # FIX for Issue #62 Bug 1: Accept auth_context with token
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Execute SQL query for MCP interface - unified method"""
|
"""Execute SQL query for MCP interface - unified method
|
||||||
try:
|
|
||||||
if not sql:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "SQL query is required",
|
|
||||||
"data": None
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add LIMIT if not present and it's a SELECT query
|
FIX for Issue #62 Bug 1: Now accepts auth_context parameter to support token-bound database configuration
|
||||||
if sql.upper().startswith("SELECT") and "LIMIT" not in sql.upper():
|
"""
|
||||||
if sql.endswith(";"):
|
max_retries = 2
|
||||||
sql = sql[:-1]
|
retry_count = 0
|
||||||
sql = f"{sql} LIMIT {limit}"
|
|
||||||
|
|
||||||
# Create auth context for MCP calls
|
while retry_count <= max_retries:
|
||||||
class MockAuthContext:
|
try:
|
||||||
def __init__(self):
|
if not sql:
|
||||||
self.user_id = user_id
|
return {
|
||||||
self.roles = ["data_analyst"]
|
"success": False,
|
||||||
self.permissions = ["read_data", "execute_query"]
|
"error": "SQL query is required",
|
||||||
self.session_id = session_id
|
"data": None
|
||||||
self.security_level = "internal"
|
}
|
||||||
|
|
||||||
auth_context = MockAuthContext()
|
# Import required security modules
|
||||||
|
from .security import DorisSecurityManager, AuthContext, SecurityLevel
|
||||||
# Create query request
|
|
||||||
query_request = QueryRequest(
|
# FIX: Use provided auth_context if available (contains token for DB config)
|
||||||
sql=sql,
|
# Otherwise create default auth context for backward compatibility
|
||||||
session_id=session_id,
|
if auth_context is None:
|
||||||
user_id=user_id,
|
auth_context = AuthContext(
|
||||||
timeout=timeout,
|
user_id=user_id,
|
||||||
cache_enabled=True
|
roles=["read_only_user"], # Restrictive role for MCP interface
|
||||||
)
|
permissions=["read_data"], # Only read permissions
|
||||||
|
session_id=session_id,
|
||||||
# Execute query
|
security_level=SecurityLevel.INTERNAL,
|
||||||
result = await self.execute_query(query_request, auth_context)
|
token="" # No token in default context
|
||||||
|
)
|
||||||
# Process results
|
else:
|
||||||
processed_data = []
|
# Use provided auth_context (may contain token for database configuration)
|
||||||
if result.data:
|
self.logger.debug(f"Using provided auth_context with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
|
||||||
|
|
||||||
|
# Perform SQL security validation if enabled
|
||||||
|
if hasattr(self.connection_manager, 'config') and hasattr(self.connection_manager.config, 'security'):
|
||||||
|
if self.connection_manager.config.security.enable_security_check:
|
||||||
|
try:
|
||||||
|
# 🔧 FIX: Use existing security_manager to avoid creating multiple TokenManager instances
|
||||||
|
# Creating new DorisSecurityManager each time causes multiple hot reload monitors
|
||||||
|
security_manager = getattr(self.connection_manager, 'security_manager', None)
|
||||||
|
if not security_manager:
|
||||||
|
# Fallback: create new one only if not available (should rarely happen)
|
||||||
|
self.logger.warning("No existing security_manager, creating new instance")
|
||||||
|
security_manager = DorisSecurityManager(self.connection_manager.config)
|
||||||
|
validation_result = await security_manager.validate_sql_security(sql, auth_context)
|
||||||
|
|
||||||
|
if not validation_result.is_valid:
|
||||||
|
self.logger.warning(f"SQL security validation failed for query: {sql[:100]}...")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"SQL security validation failed: {validation_result.error_message}",
|
||||||
|
"error_type": "security_violation",
|
||||||
|
"blocked_operations": validation_result.blocked_operations,
|
||||||
|
"risk_level": validation_result.risk_level,
|
||||||
|
"data": None,
|
||||||
|
"metadata": {
|
||||||
|
"query": sql,
|
||||||
|
"validation_details": {
|
||||||
|
"blocked_operations": validation_result.blocked_operations,
|
||||||
|
"risk_level": validation_result.risk_level
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
self.logger.debug(f"SQL security validation passed for query: {sql[:100]}...")
|
||||||
|
except Exception as security_error:
|
||||||
|
self.logger.error(f"Security validation error: {str(security_error)}")
|
||||||
|
# In case of security validation error, fail safe
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Security validation system error: {str(security_error)}",
|
||||||
|
"error_type": "security_system_error",
|
||||||
|
"data": None,
|
||||||
|
"metadata": {
|
||||||
|
"query": sql,
|
||||||
|
"security_error": str(security_error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
self.logger.info("SQL security check is disabled in configuration")
|
||||||
|
else:
|
||||||
|
self.logger.warning("Security configuration not found, proceeding without validation")
|
||||||
|
|
||||||
|
# Add LIMIT if not present and it's a SELECT query
|
||||||
|
if sql.upper().startswith("SELECT") and "LIMIT" not in sql.upper():
|
||||||
|
if sql.endswith(";"):
|
||||||
|
sql = sql[:-1]
|
||||||
|
sql = f"{sql} LIMIT {limit}"
|
||||||
|
|
||||||
|
all_statements = [
|
||||||
|
s.strip()
|
||||||
|
for s in sqlparse.split(sql)
|
||||||
|
if s.strip()
|
||||||
|
]
|
||||||
|
if len(all_statements) > 1:
|
||||||
|
return await self.execute_batch_sqls_for_mcp(sqls=all_statements, timeout=timeout,
|
||||||
|
session_id=session_id, user_id=user_id,
|
||||||
|
auth_context=auth_context)
|
||||||
|
# Create query request
|
||||||
|
query_request = QueryRequest(
|
||||||
|
sql=sql,
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
timeout=timeout,
|
||||||
|
cache_enabled=False # Disable cache for MCP calls to ensure fresh data
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute query with retry logic
|
||||||
|
result = await self.execute_query(query_request, auth_context)
|
||||||
|
|
||||||
|
# Serialize data for JSON response
|
||||||
|
serialized_data = []
|
||||||
for row in result.data:
|
for row in result.data:
|
||||||
processed_row = self._serialize_row_data(row)
|
serialized_data.append(self._serialize_row_data(row))
|
||||||
processed_data.append(processed_row)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"data": processed_data,
|
"data": serialized_data,
|
||||||
"metadata": {
|
|
||||||
"row_count": result.row_count,
|
"row_count": result.row_count,
|
||||||
"execution_time": result.execution_time,
|
"execution_time": result.execution_time,
|
||||||
"columns": result.metadata.get("columns", []),
|
"metadata": {
|
||||||
"query": sql
|
"columns": result.metadata.get("columns", []),
|
||||||
},
|
"query": sql
|
||||||
"error": None
|
}
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = str(e)
|
|
||||||
self.logger.error(f"SQL execution error: {error_msg}")
|
|
||||||
|
|
||||||
# Analyze error for better user feedback
|
|
||||||
error_analysis = self._analyze_error(error_msg)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": error_analysis.get("user_message", error_msg),
|
|
||||||
"error_type": error_analysis.get("error_type", "execution_error"),
|
|
||||||
"data": None,
|
|
||||||
"metadata": {
|
|
||||||
"query": sql,
|
|
||||||
"error_details": error_msg
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
error_str = error_msg.lower()
|
||||||
|
|
||||||
|
# Check if it's a connection-related error that we should retry
|
||||||
|
connection_errors = [
|
||||||
|
"at_eof", "connection", "closed", "nonetype",
|
||||||
|
"transport", "reader", "broken pipe", "connection reset"
|
||||||
|
]
|
||||||
|
|
||||||
|
is_connection_error = any(err in error_str for err in connection_errors)
|
||||||
|
|
||||||
|
if is_connection_error and retry_count < max_retries:
|
||||||
|
retry_count += 1
|
||||||
|
self.logger.warning(f"Connection error detected, retrying ({retry_count}/{max_retries}): {e}")
|
||||||
|
|
||||||
|
# Release the problematic connection
|
||||||
|
try:
|
||||||
|
await self.connection_manager.release_connection(session_id)
|
||||||
|
except Exception:
|
||||||
|
pass # Ignore cleanup errors
|
||||||
|
|
||||||
|
# Wait a bit before retry
|
||||||
|
await asyncio.sleep(0.5 * retry_count)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# If we've exhausted retries or it's not a connection error, return error
|
||||||
|
error_analysis = self._analyze_error(error_msg)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": error_analysis.get("user_message", error_msg),
|
||||||
|
"error_type": error_analysis.get("error_type", "general_error"),
|
||||||
|
"data": None,
|
||||||
|
"metadata": {
|
||||||
|
"query": sql,
|
||||||
|
"error_details": error_msg,
|
||||||
|
"retry_count": retry_count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# This should never be reached, but just in case
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "Maximum retries exceeded",
|
||||||
|
"data": None,
|
||||||
|
"metadata": {
|
||||||
|
"query": sql,
|
||||||
|
"retry_count": retry_count
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def _serialize_row_data(self, row_data: Dict[str, Any]) -> Dict[str, Any]:
|
def _serialize_row_data(self, row_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Serialize row data for JSON response"""
|
"""Serialize row data for JSON response"""
|
||||||
@@ -649,7 +807,12 @@ class DorisQueryExecutor:
|
|||||||
"""Analyze error message and provide user-friendly feedback"""
|
"""Analyze error message and provide user-friendly feedback"""
|
||||||
error_msg_lower = error_message.lower()
|
error_msg_lower = error_message.lower()
|
||||||
|
|
||||||
if "table" in error_msg_lower and "doesn't exist" in error_msg_lower:
|
if "at_eof" in error_msg_lower or "nonetype" in error_msg_lower and "at_eof" in error_msg_lower:
|
||||||
|
return {
|
||||||
|
"error_type": "connection_lost",
|
||||||
|
"user_message": "Database connection was lost. The query has been automatically retried. If this persists, please restart the server."
|
||||||
|
}
|
||||||
|
elif "table" in error_msg_lower and "doesn't exist" in error_msg_lower:
|
||||||
return {
|
return {
|
||||||
"error_type": "table_not_found",
|
"error_type": "table_not_found",
|
||||||
"user_message": "The specified table does not exist. Please check the table name and database."
|
"user_message": "The specified table does not exist. Please check the table name and database."
|
||||||
@@ -674,6 +837,11 @@ class DorisQueryExecutor:
|
|||||||
"error_type": "timeout",
|
"error_type": "timeout",
|
||||||
"user_message": "Query execution timed out. Try simplifying your query or adding more specific filters."
|
"user_message": "Query execution timed out. Try simplifying your query or adding more specific filters."
|
||||||
}
|
}
|
||||||
|
elif "connection" in error_msg_lower and ("closed" in error_msg_lower or "reset" in error_msg_lower):
|
||||||
|
return {
|
||||||
|
"error_type": "connection_error",
|
||||||
|
"user_message": "Database connection was interrupted. The query has been automatically retried."
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
"error_type": "general_error",
|
"error_type": "general_error",
|
||||||
@@ -701,7 +869,7 @@ class QueryPerformanceMonitor:
|
|||||||
|
|
||||||
def __init__(self, query_executor: DorisQueryExecutor):
|
def __init__(self, query_executor: DorisQueryExecutor):
|
||||||
self.query_executor = query_executor
|
self.query_executor = query_executor
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = get_logger(__name__)
|
||||||
self.performance_records = []
|
self.performance_records = []
|
||||||
|
|
||||||
async def record_query_performance(
|
async def record_query_performance(
|
||||||
@@ -785,32 +953,51 @@ class QueryPerformanceMonitor:
|
|||||||
|
|
||||||
# Unified convenience function for MCP integration
|
# Unified convenience function for MCP integration
|
||||||
async def execute_sql_query(sql: str, connection_manager: DorisConnectionManager, **kwargs) -> Dict[str, Any]:
|
async def execute_sql_query(sql: str, connection_manager: DorisConnectionManager, **kwargs) -> Dict[str, Any]:
|
||||||
"""Execute SQL query - unified convenience function for MCP tools"""
|
"""Execute SQL query - unified convenience function for MCP tools
|
||||||
|
|
||||||
|
This function now includes security validation to ensure safe query execution.
|
||||||
|
All queries are validated against the configured security policies before execution.
|
||||||
|
|
||||||
|
FIX for Issue #62 Bug 1: Now supports auth_context parameter for token-bound database configuration
|
||||||
|
FIX for Issue #58 Problem 2: Removed executor.close() to prevent ClosedResourceError in multi-worker mode
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# Create query executor
|
# Create query executor with the connection manager's configuration
|
||||||
executor = DorisQueryExecutor(connection_manager)
|
executor = DorisQueryExecutor(connection_manager)
|
||||||
|
|
||||||
try:
|
# Extract parameters from kwargs or use defaults
|
||||||
# Extract parameters from kwargs or use defaults
|
limit = kwargs.get("limit", 1000)
|
||||||
limit = kwargs.get("limit", 1000)
|
timeout = kwargs.get("timeout", 30)
|
||||||
timeout = kwargs.get("timeout", 30)
|
session_id = kwargs.get("session_id", "mcp_session")
|
||||||
session_id = kwargs.get("session_id", "mcp_session")
|
user_id = kwargs.get("user_id", "mcp_user")
|
||||||
user_id = kwargs.get("user_id", "mcp_user")
|
auth_context = kwargs.get("auth_context", None) # FIX: Extract auth_context
|
||||||
|
|
||||||
result = await executor.execute_sql_for_mcp(
|
# The execute_sql_for_mcp method now includes security validation
|
||||||
sql=sql,
|
result = await executor.execute_sql_for_mcp(
|
||||||
limit=limit,
|
sql=sql,
|
||||||
timeout=timeout,
|
limit=limit,
|
||||||
session_id=session_id,
|
timeout=timeout,
|
||||||
user_id=user_id
|
session_id=session_id,
|
||||||
)
|
user_id=user_id,
|
||||||
return result
|
auth_context=auth_context # FIX: Pass auth_context with token
|
||||||
finally:
|
)
|
||||||
await executor.close()
|
|
||||||
|
# FIX for Issue #58 Problem 2: Do NOT close executor here
|
||||||
|
# In multi-worker mode, closing here causes ClosedResourceError
|
||||||
|
# The executor's resources (cache, background tasks) will be managed
|
||||||
|
# by the connection_manager lifecycle and Python's garbage collection
|
||||||
|
# This prevents premature cleanup while MCP session manager is still processing
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": f"Query execution failed: {str(e)}",
|
"error": f"Query execution failed: {str(e)}",
|
||||||
"data": None
|
"error_type": "execution_error",
|
||||||
|
"data": None,
|
||||||
|
"metadata": {
|
||||||
|
"query": sql,
|
||||||
|
"execution_error": str(e)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,18 +20,25 @@ Doris Security Management Module
|
|||||||
Implements enterprise-level authentication, authorization, SQL security validation and data masking functionality
|
Implements enterprise-level authentication, authorization, SQL security validation and data masking functionality
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from contextvars import ContextVar
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
import sqlparse
|
import sqlparse
|
||||||
from sqlparse.sql import Statement
|
from sqlparse.sql import Statement
|
||||||
from sqlparse.tokens import Keyword, Name
|
from sqlparse.tokens import Keyword, Name
|
||||||
|
|
||||||
|
from .logger import get_logger
|
||||||
|
from .config import DatabaseConfig
|
||||||
|
|
||||||
|
# Global ContextVar for auth_context - must be a single instance shared across all modules
|
||||||
|
# This allows token-bound database configuration to work correctly in concurrent requests
|
||||||
|
mcp_auth_context_var: ContextVar['AuthContext'] = ContextVar('mcp_auth_context', default=None)
|
||||||
|
|
||||||
|
|
||||||
class SecurityLevel(Enum):
|
class SecurityLevel(Enum):
|
||||||
"""Security level enumeration"""
|
"""Security level enumeration"""
|
||||||
@@ -44,15 +51,18 @@ class SecurityLevel(Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AuthContext:
|
class AuthContext:
|
||||||
"""Authentication context"""
|
"""Authentication context for audit and session tracking"""
|
||||||
|
|
||||||
user_id: str
|
token_id: str = "" # Token identifier for audit logging
|
||||||
roles: list[str]
|
user_id: str = "" # User identifier
|
||||||
permissions: list[str]
|
roles: list[str] = field(default_factory=list) # User roles
|
||||||
session_id: str
|
permissions: list[str] = field(default_factory=list) # User permissions
|
||||||
login_time: datetime | None = None
|
security_level: 'SecurityLevel' = field(default_factory=lambda: SecurityLevel.INTERNAL) # Security level
|
||||||
|
client_ip: str = "unknown" # Client IP address
|
||||||
|
session_id: str = "" # Session identifier
|
||||||
|
login_time: datetime = field(default_factory=datetime.utcnow)
|
||||||
last_activity: datetime | None = None
|
last_activity: datetime | None = None
|
||||||
security_level: SecurityLevel = SecurityLevel.INTERNAL
|
token: str = "" # Raw token for token-bound database configuration
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -85,12 +95,13 @@ class DorisSecurityManager:
|
|||||||
Provides complete security control functionality, including authentication, authorization, SQL security validation and data masking
|
Provides complete security control functionality, including authentication, authorization, SQL security validation and data masking
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config, connection_manager=None):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = get_logger(__name__)
|
||||||
|
self.connection_manager = connection_manager
|
||||||
|
|
||||||
# Initialize security components
|
# Initialize security components
|
||||||
self.auth_provider = AuthenticationProvider(config)
|
self.auth_provider = AuthenticationProvider(config, self)
|
||||||
self.authz_provider = AuthorizationProvider(config)
|
self.authz_provider = AuthorizationProvider(config)
|
||||||
self.sql_validator = SQLSecurityValidator(config)
|
self.sql_validator = SQLSecurityValidator(config)
|
||||||
self.masking_processor = DataMaskingProcessor(config)
|
self.masking_processor = DataMaskingProcessor(config)
|
||||||
@@ -99,32 +110,56 @@ class DorisSecurityManager:
|
|||||||
self.blocked_keywords = self._load_blocked_keywords()
|
self.blocked_keywords = self._load_blocked_keywords()
|
||||||
self.sensitive_tables = self._load_sensitive_tables()
|
self.sensitive_tables = self._load_sensitive_tables()
|
||||||
self.masking_rules = self._load_masking_rules()
|
self.masking_rules = self._load_masking_rules()
|
||||||
|
|
||||||
|
# Track initialization state
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""Initialize security manager components"""
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize authentication provider (for JWT setup)
|
||||||
|
await self.auth_provider.initialize()
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
self.logger.info("DorisSecurityManager initialized successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to initialize DorisSecurityManager: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""Shutdown security manager components"""
|
||||||
|
try:
|
||||||
|
await self.auth_provider.shutdown()
|
||||||
|
self._initialized = False
|
||||||
|
self.logger.info("DorisSecurityManager shutdown completed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error during DorisSecurityManager shutdown: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
def _load_blocked_keywords(self) -> set[str]:
|
def _load_blocked_keywords(self) -> set[str]:
|
||||||
"""Load blocked SQL keywords"""
|
"""Load blocked SQL keywords from configuration"""
|
||||||
default_blocked = {
|
# Load keywords from configuration, unified source of truth
|
||||||
"DROP",
|
|
||||||
"DELETE",
|
|
||||||
"TRUNCATE",
|
|
||||||
"ALTER",
|
|
||||||
"CREATE",
|
|
||||||
"INSERT",
|
|
||||||
"UPDATE",
|
|
||||||
"GRANT",
|
|
||||||
"REVOKE",
|
|
||||||
"EXEC",
|
|
||||||
"EXECUTE",
|
|
||||||
"SHUTDOWN",
|
|
||||||
"KILL",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Load custom rules from configuration file
|
|
||||||
if hasattr(self.config, 'get'):
|
if hasattr(self.config, 'get'):
|
||||||
custom_blocked = set(self.config.get("blocked_keywords", []))
|
# Dictionary-style configuration
|
||||||
|
blocked_keywords = self.config.get("blocked_keywords", [])
|
||||||
|
elif hasattr(self.config, 'security') and hasattr(self.config.security, 'blocked_keywords'):
|
||||||
|
# DorisConfig object, get through security.blocked_keywords
|
||||||
|
blocked_keywords = self.config.security.blocked_keywords
|
||||||
else:
|
else:
|
||||||
custom_blocked = set()
|
# Fallback to default if no configuration available
|
||||||
|
blocked_keywords = [
|
||||||
|
"DROP", "CREATE", "ALTER", "TRUNCATE",
|
||||||
|
"DELETE", "INSERT", "UPDATE",
|
||||||
|
"GRANT", "REVOKE",
|
||||||
|
"EXEC", "EXECUTE", "SHUTDOWN", "KILL"
|
||||||
|
]
|
||||||
|
|
||||||
return default_blocked.union(custom_blocked)
|
return set(blocked_keywords)
|
||||||
|
|
||||||
def _load_sensitive_tables(self) -> dict[str, SecurityLevel]:
|
def _load_sensitive_tables(self) -> dict[str, SecurityLevel]:
|
||||||
"""Load sensitive table configuration"""
|
"""Load sensitive table configuration"""
|
||||||
@@ -189,8 +224,59 @@ class DorisSecurityManager:
|
|||||||
return default_rules
|
return default_rules
|
||||||
|
|
||||||
async def authenticate_request(self, auth_info: dict[str, Any]) -> AuthContext:
|
async def authenticate_request(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||||
"""Validate request authentication information"""
|
"""Validate request authentication information
|
||||||
return await self.auth_provider.authenticate(auth_info)
|
|
||||||
|
Tries authentication methods in order: Token -> JWT -> OAuth
|
||||||
|
Any one method succeeding allows access
|
||||||
|
If all methods are disabled, returns anonymous context
|
||||||
|
"""
|
||||||
|
# Check if any authentication method is enabled
|
||||||
|
if not (self.config.security.enable_token_auth or
|
||||||
|
self.config.security.enable_jwt_auth or
|
||||||
|
self.config.security.enable_oauth_auth):
|
||||||
|
self.logger.debug("All authentication methods are disabled")
|
||||||
|
# Return anonymous context when no authentication is enabled
|
||||||
|
return AuthContext(
|
||||||
|
token_id="anonymous",
|
||||||
|
user_id="anonymous",
|
||||||
|
roles=["anonymous"],
|
||||||
|
permissions=["read"],
|
||||||
|
security_level=SecurityLevel.PUBLIC,
|
||||||
|
client_ip=auth_info.get("client_ip", "unknown"),
|
||||||
|
session_id="anonymous_session"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try authentication methods in order of preference
|
||||||
|
last_error = None
|
||||||
|
|
||||||
|
# 1. Try Token authentication first (most common)
|
||||||
|
if self.config.security.enable_token_auth:
|
||||||
|
try:
|
||||||
|
return await self.auth_provider.authenticate_token(auth_info)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.debug(f"Token authentication failed: {e}")
|
||||||
|
last_error = e
|
||||||
|
|
||||||
|
# 2. Try JWT authentication
|
||||||
|
if self.config.security.enable_jwt_auth:
|
||||||
|
try:
|
||||||
|
return await self.auth_provider.authenticate_jwt(auth_info)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.debug(f"JWT authentication failed: {e}")
|
||||||
|
last_error = e
|
||||||
|
|
||||||
|
# 3. Try OAuth authentication
|
||||||
|
if self.config.security.enable_oauth_auth:
|
||||||
|
try:
|
||||||
|
return await self.auth_provider.authenticate_oauth(auth_info)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.debug(f"OAuth authentication failed: {e}")
|
||||||
|
last_error = e
|
||||||
|
|
||||||
|
# All enabled authentication methods failed
|
||||||
|
error_message = f"Authentication failed: {str(last_error)}" if last_error else "No authentication method succeeded"
|
||||||
|
self.logger.warning(f"Authentication failed for client {auth_info.get('client_ip', 'unknown')}: {error_message}")
|
||||||
|
raise ValueError(error_message)
|
||||||
|
|
||||||
async def authorize_resource_access(
|
async def authorize_resource_access(
|
||||||
self, auth_context: AuthContext, resource_uri: str
|
self, auth_context: AuthContext, resource_uri: str
|
||||||
@@ -212,43 +298,362 @@ class DorisSecurityManager:
|
|||||||
"""Apply data masking processing"""
|
"""Apply data masking processing"""
|
||||||
return await self.masking_processor.process(data, auth_context)
|
return await self.masking_processor.process(data, auth_context)
|
||||||
|
|
||||||
|
# OAuth-specific methods
|
||||||
|
def get_oauth_authorization_url(self) -> tuple[str, str]:
|
||||||
|
"""Get OAuth authorization URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (authorization_url, state)
|
||||||
|
"""
|
||||||
|
if not self.auth_provider.oauth_provider:
|
||||||
|
raise ValueError("OAuth is not enabled")
|
||||||
|
return self.auth_provider.oauth_provider.get_authorization_url()
|
||||||
|
|
||||||
|
async def handle_oauth_callback(self, code: str, state: str) -> AuthContext:
|
||||||
|
"""Handle OAuth callback
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Authorization code from OAuth provider
|
||||||
|
state: State parameter for CSRF protection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AuthContext for authenticated user
|
||||||
|
"""
|
||||||
|
if not self.auth_provider.oauth_provider:
|
||||||
|
raise ValueError("OAuth is not enabled")
|
||||||
|
return await self.auth_provider.oauth_provider.handle_callback(code, state)
|
||||||
|
|
||||||
|
def get_oauth_provider_info(self) -> dict[str, Any]:
|
||||||
|
"""Get OAuth provider information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OAuth provider information
|
||||||
|
"""
|
||||||
|
if not self.auth_provider.oauth_provider:
|
||||||
|
return {"enabled": False}
|
||||||
|
return self.auth_provider.oauth_provider.get_provider_info()
|
||||||
|
|
||||||
|
# Token management methods
|
||||||
|
async def create_token(
|
||||||
|
self,
|
||||||
|
token_id: str,
|
||||||
|
expires_hours: Optional[int] = None,
|
||||||
|
description: str = "",
|
||||||
|
custom_token: Optional[str] = None,
|
||||||
|
database_config: Optional[DatabaseConfig] = None
|
||||||
|
) -> str:
|
||||||
|
"""Create a new API access token
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_id: Unique token identifier for audit and management
|
||||||
|
expires_hours: Token expiration in hours (None for no expiration)
|
||||||
|
description: Token description for management purposes
|
||||||
|
custom_token: Custom token string (if None, generates random token)
|
||||||
|
database_config: Optional database configuration for this token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated token string
|
||||||
|
"""
|
||||||
|
if not self.auth_provider.token_manager:
|
||||||
|
raise ValueError("Token manager not initialized")
|
||||||
|
|
||||||
|
return await self.auth_provider.token_manager.create_token(
|
||||||
|
token_id=token_id,
|
||||||
|
expires_hours=expires_hours,
|
||||||
|
description=description,
|
||||||
|
custom_token=custom_token,
|
||||||
|
database_config=database_config
|
||||||
|
)
|
||||||
|
|
||||||
|
async def revoke_token(self, token_id: str) -> bool:
|
||||||
|
"""Revoke a token by token ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_id: Token ID to revoke
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if token was revoked successfully
|
||||||
|
"""
|
||||||
|
if not self.auth_provider.token_manager:
|
||||||
|
raise ValueError("Token manager not initialized")
|
||||||
|
|
||||||
|
return await self.auth_provider.token_manager.revoke_token(token_id)
|
||||||
|
|
||||||
|
async def list_tokens(self) -> list[dict[str, Any]]:
|
||||||
|
"""List all tokens (without sensitive data)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of token information
|
||||||
|
"""
|
||||||
|
if not self.auth_provider.token_manager:
|
||||||
|
raise ValueError("Token manager not initialized")
|
||||||
|
|
||||||
|
return await self.auth_provider.token_manager.list_tokens()
|
||||||
|
|
||||||
|
async def cleanup_expired_tokens(self) -> int:
|
||||||
|
"""Remove expired tokens and return count
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of expired tokens removed
|
||||||
|
"""
|
||||||
|
if not self.auth_provider.token_manager:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return await self.auth_provider.token_manager.cleanup_expired_tokens()
|
||||||
|
|
||||||
|
def get_token_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get token statistics
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token statistics dictionary
|
||||||
|
"""
|
||||||
|
if not self.auth_provider.token_manager:
|
||||||
|
return {"error": "Token manager not initialized"}
|
||||||
|
|
||||||
|
return self.auth_provider.token_manager.get_token_stats()
|
||||||
|
|
||||||
|
async def _validate_token_database_config(self, token: str, token_info) -> None:
|
||||||
|
"""Validate database configuration for token immediately during authentication
|
||||||
|
|
||||||
|
This ensures database connectivity issues are caught at authentication time,
|
||||||
|
not during query execution, providing better user experience.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: Raw authentication token
|
||||||
|
token_info: TokenInfo object from token validation
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If database configuration is invalid or connection fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not self.connection_manager:
|
||||||
|
self.logger.warning("Connection manager not available for immediate database validation")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Configure and test database connection for this token
|
||||||
|
success, config_source = await self.connection_manager.configure_for_token(token)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
self.logger.info(f"Database configuration validated successfully for token {token_info.token_id} (source: {config_source})")
|
||||||
|
else:
|
||||||
|
raise ValueError("Database configuration validation failed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Database configuration validation failed for token {token_info.token_id}: {str(e)}"
|
||||||
|
self.logger.error(error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationProvider:
|
class AuthenticationProvider:
|
||||||
"""Authentication provider"""
|
"""Authentication provider"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config, security_manager=None):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = get_logger(__name__)
|
||||||
self.session_cache = {}
|
self.session_cache = {}
|
||||||
|
self.jwt_manager = None
|
||||||
async def authenticate(self, auth_info: dict[str, Any]) -> AuthContext:
|
self.oauth_provider = None
|
||||||
"""Perform identity authentication"""
|
self.token_manager = None
|
||||||
auth_type = auth_info.get("type", "token")
|
self.security_manager = security_manager
|
||||||
|
|
||||||
if auth_type == "token":
|
# Initialize authentication providers based on individual switches
|
||||||
return await self._authenticate_token(auth_info)
|
auth_methods_enabled = []
|
||||||
elif auth_type == "basic":
|
|
||||||
return await self._authenticate_basic(auth_info)
|
# Initialize Token manager if enabled
|
||||||
|
if config.security.enable_token_auth:
|
||||||
|
self._initialize_token_manager()
|
||||||
|
auth_methods_enabled.append("Token")
|
||||||
|
|
||||||
|
# Initialize JWT manager if enabled
|
||||||
|
if config.security.enable_jwt_auth:
|
||||||
|
self._initialize_jwt_manager()
|
||||||
|
auth_methods_enabled.append("JWT")
|
||||||
|
|
||||||
|
# Initialize OAuth provider if enabled
|
||||||
|
if config.security.enable_oauth_auth or (hasattr(config.security, 'oauth_enabled') and config.security.oauth_enabled):
|
||||||
|
self._initialize_oauth_provider()
|
||||||
|
auth_methods_enabled.append("OAuth")
|
||||||
|
|
||||||
|
if auth_methods_enabled:
|
||||||
|
self.logger.info(f"Authentication enabled with methods: {', '.join(auth_methods_enabled)}")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported authentication type: {auth_type}")
|
self.logger.info("All authentication methods are disabled - anonymous access allowed")
|
||||||
|
|
||||||
|
def _initialize_jwt_manager(self):
|
||||||
|
"""Initialize JWT manager"""
|
||||||
|
try:
|
||||||
|
from ..auth.jwt_manager import JWTManager
|
||||||
|
self.jwt_manager = JWTManager(self.config)
|
||||||
|
self.logger.info("JWT manager initialized")
|
||||||
|
except ImportError as e:
|
||||||
|
self.logger.error(f"Failed to import JWT manager: {e}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to initialize JWT manager: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _initialize_token_manager(self):
|
||||||
|
"""Initialize Token manager"""
|
||||||
|
try:
|
||||||
|
from ..auth.token_manager import TokenManager
|
||||||
|
self.token_manager = TokenManager(self.config)
|
||||||
|
self.logger.info("Token manager initialized")
|
||||||
|
except ImportError as e:
|
||||||
|
self.logger.error(f"Failed to import Token manager: {e}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to initialize Token manager: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _initialize_oauth_provider(self):
|
||||||
|
"""Initialize OAuth provider"""
|
||||||
|
try:
|
||||||
|
from ..auth.oauth_provider import OAuthAuthenticationProvider
|
||||||
|
self.oauth_provider = OAuthAuthenticationProvider(self.config)
|
||||||
|
self.logger.info("OAuth provider initialized")
|
||||||
|
except ImportError as e:
|
||||||
|
self.logger.error(f"Failed to import OAuth provider: {e}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to initialize OAuth provider: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""Initialize authentication provider asynchronously"""
|
||||||
|
if self.jwt_manager:
|
||||||
|
success = await self.jwt_manager.initialize()
|
||||||
|
if not success:
|
||||||
|
raise RuntimeError("Failed to initialize JWT manager")
|
||||||
|
self.logger.info("JWT authentication provider initialized successfully")
|
||||||
|
|
||||||
|
if self.token_manager:
|
||||||
|
# Token manager doesn't need async initialization, just log success
|
||||||
|
self.logger.info("Token authentication provider initialized successfully")
|
||||||
|
|
||||||
|
if self.oauth_provider:
|
||||||
|
success = await self.oauth_provider.initialize()
|
||||||
|
if not success:
|
||||||
|
raise RuntimeError("Failed to initialize OAuth provider")
|
||||||
|
self.logger.info("OAuth authentication provider initialized successfully")
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""Shutdown authentication provider"""
|
||||||
|
if self.jwt_manager:
|
||||||
|
await self.jwt_manager.shutdown()
|
||||||
|
self.logger.info("JWT authentication provider shutdown completed")
|
||||||
|
|
||||||
|
if self.token_manager:
|
||||||
|
# Token manager doesn't need async shutdown, just log
|
||||||
|
self.logger.info("Token authentication provider shutdown completed")
|
||||||
|
|
||||||
|
if self.oauth_provider:
|
||||||
|
await self.oauth_provider.shutdown()
|
||||||
|
self.logger.info("OAuth authentication provider shutdown completed")
|
||||||
|
|
||||||
|
async def authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||||
|
"""Perform token authentication"""
|
||||||
|
if not self.config.security.enable_token_auth:
|
||||||
|
raise ValueError("Token authentication is not enabled")
|
||||||
|
return await self._authenticate_token(auth_info)
|
||||||
|
|
||||||
|
async def authenticate_jwt(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||||
|
"""Perform JWT authentication"""
|
||||||
|
if not self.config.security.enable_jwt_auth:
|
||||||
|
raise ValueError("JWT authentication is not enabled")
|
||||||
|
return await self._authenticate_jwt(auth_info)
|
||||||
|
|
||||||
|
async def authenticate_oauth(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||||
|
"""Perform OAuth authentication"""
|
||||||
|
if not self.config.security.enable_oauth_auth:
|
||||||
|
raise ValueError("OAuth authentication is not enabled")
|
||||||
|
return await self._authenticate_oauth(auth_info)
|
||||||
|
|
||||||
|
async def _authenticate_jwt(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||||
|
"""JWT authentication"""
|
||||||
|
if not self.jwt_manager:
|
||||||
|
raise ValueError("JWT manager not initialized")
|
||||||
|
|
||||||
|
token = auth_info.get("token")
|
||||||
|
if not token:
|
||||||
|
# Try to extract from Authorization header
|
||||||
|
authorization = auth_info.get("authorization")
|
||||||
|
if authorization and authorization.startswith('Bearer '):
|
||||||
|
token = authorization[7:]
|
||||||
|
|
||||||
|
if not token:
|
||||||
|
raise ValueError("Missing JWT token")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use JWT middleware for authentication
|
||||||
|
from ..auth.auth_middleware import AuthMiddleware
|
||||||
|
middleware = AuthMiddleware(self.jwt_manager)
|
||||||
|
return await middleware.authenticate_request(auth_info)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"JWT authentication failed: {e}")
|
||||||
|
raise ValueError(f"JWT authentication failed: {str(e)}")
|
||||||
|
|
||||||
|
async def _authenticate_oauth(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||||
|
"""OAuth authentication"""
|
||||||
|
if not self.oauth_provider:
|
||||||
|
raise ValueError("OAuth provider not initialized")
|
||||||
|
|
||||||
|
# Handle different OAuth authentication scenarios
|
||||||
|
if "access_token" in auth_info:
|
||||||
|
# Direct OAuth access token authentication
|
||||||
|
return await self.oauth_provider.authenticate_with_token(auth_info["access_token"])
|
||||||
|
elif "code" in auth_info and "state" in auth_info:
|
||||||
|
# OAuth callback authentication
|
||||||
|
return await self.oauth_provider.handle_callback(auth_info["code"], auth_info["state"])
|
||||||
|
else:
|
||||||
|
raise ValueError("OAuth authentication requires either access_token or code+state")
|
||||||
|
|
||||||
async def _authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext:
|
async def _authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||||
"""Token authentication"""
|
"""Token authentication"""
|
||||||
|
if not self.token_manager:
|
||||||
|
raise ValueError("Token manager not initialized")
|
||||||
|
|
||||||
token = auth_info.get("token")
|
token = auth_info.get("token")
|
||||||
|
if not token:
|
||||||
|
# Try to extract from Authorization header
|
||||||
|
authorization = auth_info.get("authorization")
|
||||||
|
if authorization and authorization.startswith('Bearer '):
|
||||||
|
token = authorization[7:]
|
||||||
|
elif authorization and authorization.startswith('Token '):
|
||||||
|
token = authorization[6:]
|
||||||
|
|
||||||
if not token:
|
if not token:
|
||||||
raise ValueError("Missing authentication token")
|
raise ValueError("Missing authentication token")
|
||||||
|
|
||||||
# Validate token (simplified implementation, should validate JWT or query authentication service in practice)
|
try:
|
||||||
user_info = await self._validate_token(token)
|
# Validate token using TokenManager
|
||||||
|
validation_result = await self.token_manager.validate_token(token)
|
||||||
return AuthContext(
|
|
||||||
user_id=user_info["user_id"],
|
if not validation_result.is_valid:
|
||||||
roles=user_info["roles"],
|
raise ValueError(f"Token validation failed: {validation_result.error_message}")
|
||||||
permissions=user_info["permissions"],
|
|
||||||
session_id=auth_info.get("session_id", "default"),
|
token_info = validation_result.token_info
|
||||||
login_time=datetime.utcnow(),
|
|
||||||
security_level=SecurityLevel(user_info.get("security_level", "internal")),
|
# Immediately validate database configuration for this token
|
||||||
)
|
if self.security_manager:
|
||||||
|
await self.security_manager._validate_token_database_config(token, token_info)
|
||||||
|
|
||||||
|
return AuthContext(
|
||||||
|
token_id=token_info.token_id,
|
||||||
|
user_id=token_info.token_id, # Use token_id as user_id for token auth
|
||||||
|
roles=["token_user"], # Default role for token users
|
||||||
|
permissions=["read", "write"], # Default permissions for token users
|
||||||
|
security_level=SecurityLevel.INTERNAL,
|
||||||
|
client_ip=auth_info.get("client_ip", "unknown"),
|
||||||
|
session_id=auth_info.get("session_id", f"session_{token_info.token_id}"),
|
||||||
|
login_time=datetime.utcnow(),
|
||||||
|
last_activity=token_info.last_used,
|
||||||
|
token=token # Store raw token for token-bound database configuration
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Token authentication failed: {e}")
|
||||||
|
raise ValueError(f"Token authentication failed: {str(e)}")
|
||||||
|
|
||||||
async def _authenticate_basic(self, auth_info: dict[str, Any]) -> AuthContext:
|
async def _authenticate_basic(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||||
"""Basic authentication (username password)"""
|
"""Basic authentication (username password)"""
|
||||||
@@ -328,7 +733,7 @@ class AuthorizationProvider:
|
|||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = get_logger(__name__)
|
||||||
self.permission_cache = {}
|
self.permission_cache = {}
|
||||||
|
|
||||||
# Load sensitive tables configuration
|
# Load sensitive tables configuration
|
||||||
@@ -471,43 +876,80 @@ class SQLSecurityValidator:
|
|||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = get_logger(__name__)
|
||||||
|
|
||||||
# Handle DorisConfig object or dictionary configuration
|
# Handle DorisConfig object or dictionary configuration
|
||||||
if hasattr(config, 'get'):
|
if hasattr(config, 'get'):
|
||||||
# Dictionary configuration
|
# Dictionary configuration
|
||||||
self.blocked_keywords = set(config.get("blocked_keywords", []))
|
self.blocked_keywords = set(config.get("blocked_keywords", []))
|
||||||
self.max_query_complexity = config.get("max_query_complexity", 100)
|
self.max_query_complexity = config.get("max_query_complexity", 100)
|
||||||
|
self.enable_security_check = config.get("enable_security_check", True)
|
||||||
|
elif hasattr(config, 'security'):
|
||||||
|
# DorisConfig object with security attribute - unified source from config
|
||||||
|
self.blocked_keywords = set(config.security.blocked_keywords)
|
||||||
|
self.max_query_complexity = config.security.max_query_complexity
|
||||||
|
self.enable_security_check = getattr(config.security, 'enable_security_check', True)
|
||||||
else:
|
else:
|
||||||
# DorisConfig object, use default values
|
# Fallback to default if no configuration available
|
||||||
self.blocked_keywords = set(["DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE"])
|
self.blocked_keywords = set([
|
||||||
|
"DROP", "CREATE", "ALTER", "TRUNCATE",
|
||||||
|
"DELETE", "INSERT", "UPDATE",
|
||||||
|
"GRANT", "REVOKE",
|
||||||
|
"EXEC", "EXECUTE", "SHUTDOWN", "KILL"
|
||||||
|
])
|
||||||
self.max_query_complexity = 100
|
self.max_query_complexity = 100
|
||||||
|
self.enable_security_check = True
|
||||||
|
|
||||||
async def validate(self, sql: str, auth_context: AuthContext) -> ValidationResult:
|
async def validate(self, sql: str, auth_context: AuthContext) -> ValidationResult:
|
||||||
"""Validate SQL query security"""
|
"""Validate SQL query security"""
|
||||||
|
# If security check is disabled, always return valid
|
||||||
|
if not self.enable_security_check:
|
||||||
|
self.logger.debug("SQL security check is disabled, allowing all queries")
|
||||||
|
return ValidationResult(is_valid=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Parse SQL statement
|
# SECURITY FIX: Parse ALL SQL statements, not just the first one
|
||||||
parsed = sqlparse.parse(sql)[0]
|
# This prevents bypassing security checks by injecting additional statements
|
||||||
|
all_statements = sqlparse.parse(sql)
|
||||||
|
|
||||||
# Check blocked operations first (more specific)
|
if not all_statements:
|
||||||
keyword_result = await self._check_blocked_keywords(parsed)
|
return ValidationResult(
|
||||||
if not keyword_result.is_valid:
|
is_valid=False,
|
||||||
return keyword_result
|
error_message="Empty or invalid SQL statement",
|
||||||
|
risk_level="medium"
|
||||||
|
)
|
||||||
|
|
||||||
# Check SQL injection risks
|
# SECURITY FIX: Validate each statement individually
|
||||||
injection_result = await self._check_sql_injection(sql, parsed)
|
for idx, parsed in enumerate(all_statements):
|
||||||
if not injection_result.is_valid:
|
# Skip empty statements (e.g., from trailing semicolons)
|
||||||
return injection_result
|
if not parsed.tokens or str(parsed).strip() == '':
|
||||||
|
continue
|
||||||
|
|
||||||
# Check query complexity
|
self.logger.debug(f"Validating SQL statement {idx + 1}/{len(all_statements)}: {str(parsed)[:100]}...")
|
||||||
complexity_result = await self._check_query_complexity(parsed)
|
|
||||||
if not complexity_result.is_valid:
|
|
||||||
return complexity_result
|
|
||||||
|
|
||||||
# Check table access permissions
|
# Check blocked operations first (more specific)
|
||||||
table_result = await self._check_table_access(parsed, auth_context)
|
keyword_result = await self._check_blocked_keywords(parsed)
|
||||||
if not table_result.is_valid:
|
if not keyword_result.is_valid:
|
||||||
return table_result
|
keyword_result.error_message = f"Statement {idx + 1}: {keyword_result.error_message}"
|
||||||
|
return keyword_result
|
||||||
|
|
||||||
|
# Check SQL injection risks
|
||||||
|
injection_result = await self._check_sql_injection(sql, parsed)
|
||||||
|
if not injection_result.is_valid:
|
||||||
|
injection_result.error_message = f"Statement {idx + 1}: {injection_result.error_message}"
|
||||||
|
return injection_result
|
||||||
|
|
||||||
|
# Check query complexity
|
||||||
|
complexity_result = await self._check_query_complexity(parsed)
|
||||||
|
if not complexity_result.is_valid:
|
||||||
|
complexity_result.error_message = f"Statement {idx + 1}: {complexity_result.error_message}"
|
||||||
|
return complexity_result
|
||||||
|
|
||||||
|
# Check table access permissions
|
||||||
|
table_result = await self._check_table_access(parsed, auth_context)
|
||||||
|
if not table_result.is_valid:
|
||||||
|
table_result.error_message = f"Statement {idx + 1}: {table_result.error_message}"
|
||||||
|
return table_result
|
||||||
|
|
||||||
return ValidationResult(is_valid=True)
|
return ValidationResult(is_valid=True)
|
||||||
|
|
||||||
@@ -522,28 +964,69 @@ class SQLSecurityValidator:
|
|||||||
async def _check_sql_injection(
|
async def _check_sql_injection(
|
||||||
self, sql: str, parsed: Statement
|
self, sql: str, parsed: Statement
|
||||||
) -> ValidationResult:
|
) -> ValidationResult:
|
||||||
"""Check SQL injection risks"""
|
"""Check SQL injection risks with improved pattern detection
|
||||||
# Check common SQL injection patterns
|
|
||||||
|
FIX for Issue #62 Bug 2: Improved patterns to reduce false positives
|
||||||
|
Now better distinguishes between legitimate SQL (like BETWEEN...AND) and injection attempts
|
||||||
|
"""
|
||||||
|
# Improved injection patterns that are more specific and less prone to false positives
|
||||||
injection_patterns = [
|
injection_patterns = [
|
||||||
r"(\s|^)(union|select|insert|update|delete|drop|create|alter)\s+.*\s+(union|select|insert|update|delete|drop|create|alter)",
|
# Stacked queries with dangerous operations (true injection risk)
|
||||||
r"(\s|^)(or|and)\s+\d+\s*=\s*\d+",
|
r";\s*(DROP|DELETE|TRUNCATE|ALTER|CREATE|INSERT|UPDATE)\s+",
|
||||||
r"(\s|^)(or|and)\s+['\"].*['\"]",
|
|
||||||
r";\s*(drop|delete|truncate|alter|create)",
|
# UNION-based injection (but allow legitimate UNION queries)
|
||||||
r"(exec|execute|sp_|xp_)",
|
# Only flag if UNION is followed by suspicious patterns like SELECT with WHERE 1=1
|
||||||
r"(script|javascript|vbscript)",
|
r"UNION\s+(ALL\s+)?SELECT\s+.*\s+(WHERE|AND|OR)\s+\d+\s*=\s*\d+",
|
||||||
r"(char|ascii|substring|concat)\s*\(",
|
|
||||||
|
# Boolean-based blind injection with comments (true injection pattern)
|
||||||
|
r"(WHERE|AND|OR)\s+\d+\s*=\s*\d+\s*(--|#|/\*)",
|
||||||
|
|
||||||
|
# Quote-based injection attempts (but not in legitimate strings)
|
||||||
|
r"(WHERE|AND|OR)\s+(['\"])[^\2]*\2\s*=\s*\2[^\2]*\2",
|
||||||
|
|
||||||
|
# Time-based blind injection
|
||||||
|
r"(SLEEP|WAITFOR|BENCHMARK)\s*\(",
|
||||||
|
|
||||||
|
# System stored procedure injection
|
||||||
|
r"(EXEC|EXECUTE|SP_|XP_)\s*\(",
|
||||||
|
|
||||||
|
# Script injection attempts
|
||||||
|
r"<\s*(SCRIPT|JAVASCRIPT|VBSCRIPT)",
|
||||||
]
|
]
|
||||||
|
|
||||||
sql_lower = sql.lower()
|
# FIX: Don't flag legitimate SQL functions and keywords
|
||||||
|
# These patterns are too broad and cause false positives:
|
||||||
|
# - REMOVED: r"(char|ascii|substring|concat)\s*\(" - These are legitimate SQL functions
|
||||||
|
# - REMOVED: r"(\s|^)(or|and)\s+\d+\s*=\s*\d+" - This flags BETWEEN...AND constructs
|
||||||
|
# - REMOVED: r"(\s|^)(or|and)\s+['\"].*['\"]" - This is too broad
|
||||||
|
|
||||||
|
sql_upper = sql.upper()
|
||||||
|
|
||||||
|
# Special case: Allow BETWEEN...AND which is legitimate SQL
|
||||||
|
# This prevents false positives like "WHERE dt BETWEEN '2025-01-01' AND '2025-01-31'"
|
||||||
|
if "BETWEEN" in sql_upper and "AND" in sql_upper:
|
||||||
|
# This is likely a BETWEEN clause, not injection
|
||||||
|
# Check if AND appears in a BETWEEN context
|
||||||
|
between_pattern = r"BETWEEN\s+[^\s]+\s+AND\s+[^\s]+"
|
||||||
|
if re.search(between_pattern, sql_upper, re.IGNORECASE):
|
||||||
|
# Remove BETWEEN clauses before checking other patterns
|
||||||
|
sql_cleaned = re.sub(between_pattern, "BETWEEN_CLAUSE", sql_upper, flags=re.IGNORECASE)
|
||||||
|
sql_to_check = sql_cleaned
|
||||||
|
else:
|
||||||
|
sql_to_check = sql_upper
|
||||||
|
else:
|
||||||
|
sql_to_check = sql_upper
|
||||||
|
|
||||||
for pattern in injection_patterns:
|
for pattern in injection_patterns:
|
||||||
if re.search(pattern, sql_lower, re.IGNORECASE):
|
if re.search(pattern, sql_to_check, re.IGNORECASE):
|
||||||
|
self.logger.warning(f"Potential SQL injection pattern detected: {pattern}")
|
||||||
return ValidationResult(
|
return ValidationResult(
|
||||||
is_valid=False,
|
is_valid=False,
|
||||||
error_message="Potential SQL injection risk detected",
|
error_message="Potential SQL injection risk detected",
|
||||||
risk_level="high",
|
risk_level="high",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check suspicious quotes and comments
|
# Check suspicious quotes and comments (with improved detection)
|
||||||
if self._has_suspicious_quotes_or_comments(sql):
|
if self._has_suspicious_quotes_or_comments(sql):
|
||||||
return ValidationResult(
|
return ValidationResult(
|
||||||
is_valid=False,
|
is_valid=False,
|
||||||
@@ -554,19 +1037,67 @@ class SQLSecurityValidator:
|
|||||||
return ValidationResult(is_valid=True)
|
return ValidationResult(is_valid=True)
|
||||||
|
|
||||||
def _has_suspicious_quotes_or_comments(self, sql: str) -> bool:
|
def _has_suspicious_quotes_or_comments(self, sql: str) -> bool:
|
||||||
"""Check suspicious quote and comment patterns"""
|
"""Check suspicious quote and comment patterns with improved detection
|
||||||
# Check unmatched quotes
|
|
||||||
single_quotes = sql.count("'")
|
|
||||||
double_quotes = sql.count('"')
|
|
||||||
|
|
||||||
if single_quotes % 2 != 0 or double_quotes % 2 != 0:
|
FIX for Issue #62 Bug 2: Improved detection to reduce false positives
|
||||||
return True
|
Now distinguishes between legitimate comments/strings and injection attempts
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use sqlparse to parse the SQL and distinguish between code and comments/strings
|
||||||
|
import sqlparse
|
||||||
|
from sqlparse.tokens import Comment, String
|
||||||
|
|
||||||
# Check SQL comments
|
# Parse the SQL
|
||||||
if "--" in sql or "/*" in sql:
|
parsed = sqlparse.parse(sql)
|
||||||
return True
|
if not parsed:
|
||||||
|
# If parsing fails, be conservative
|
||||||
|
return True
|
||||||
|
|
||||||
return False
|
statement = parsed[0]
|
||||||
|
|
||||||
|
# Check for unmatched quotes ONLY in non-string tokens
|
||||||
|
# This prevents false positives from legitimate string content
|
||||||
|
non_string_content = []
|
||||||
|
has_string_tokens = False
|
||||||
|
|
||||||
|
for token in statement.flatten():
|
||||||
|
if token.ttype in (String.Single, String.Double):
|
||||||
|
has_string_tokens = True
|
||||||
|
# Skip string content - quotes inside strings are legitimate
|
||||||
|
continue
|
||||||
|
elif token.ttype in (Comment.Single, Comment.Multi):
|
||||||
|
# Comments are generally OK, but check for suspicious injection patterns
|
||||||
|
comment_value = str(token).lower()
|
||||||
|
# Check if comment contains dangerous SQL keywords
|
||||||
|
dangerous_in_comments = ['drop', 'delete', 'insert', 'update', 'union', 'exec', 'execute']
|
||||||
|
if any(keyword in comment_value for keyword in dangerous_in_comments):
|
||||||
|
self.logger.warning(f"Suspicious SQL keyword in comment: {token}")
|
||||||
|
return True
|
||||||
|
# Normal comments are OK
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Accumulate non-string, non-comment content
|
||||||
|
non_string_content.append(str(token))
|
||||||
|
|
||||||
|
# Check for unmatched quotes in non-string content
|
||||||
|
non_string_text = ''.join(non_string_content)
|
||||||
|
single_quotes = non_string_text.count("'")
|
||||||
|
double_quotes = non_string_text.count('"')
|
||||||
|
|
||||||
|
# Only flag if there are unmatched quotes in actual SQL code (not in strings)
|
||||||
|
if single_quotes % 2 != 0 or double_quotes % 2 != 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# FIX: Don't flag legitimate SQL comments
|
||||||
|
# Comments are OK as long as they don't contain dangerous patterns (already checked above)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.debug(f"SQL parsing error in quote/comment check: {e}")
|
||||||
|
# On parsing error, fall back to conservative check
|
||||||
|
# But be more lenient than before
|
||||||
|
return False # Don't flag on parse errors to reduce false positives
|
||||||
|
|
||||||
async def _check_blocked_keywords(self, parsed: Statement) -> ValidationResult:
|
async def _check_blocked_keywords(self, parsed: Statement) -> ValidationResult:
|
||||||
"""Check blocked keywords"""
|
"""Check blocked keywords"""
|
||||||
@@ -628,6 +1159,10 @@ class SQLSecurityValidator:
|
|||||||
self, parsed: Statement, auth_context: AuthContext
|
self, parsed: Statement, auth_context: AuthContext
|
||||||
) -> ValidationResult:
|
) -> ValidationResult:
|
||||||
"""Check table access permissions"""
|
"""Check table access permissions"""
|
||||||
|
# If no auth_context, skip table access checks (rely on other security checks)
|
||||||
|
if auth_context is None:
|
||||||
|
return ValidationResult(is_valid=True)
|
||||||
|
|
||||||
# Extract table names from query
|
# Extract table names from query
|
||||||
tables = self._extract_table_names(parsed)
|
tables = self._extract_table_names(parsed)
|
||||||
|
|
||||||
@@ -676,7 +1211,7 @@ class DataMaskingProcessor:
|
|||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = get_logger(__name__)
|
||||||
self.masking_algorithms = self._init_masking_algorithms()
|
self.masking_algorithms = self._init_masking_algorithms()
|
||||||
self.masking_rules = self._load_masking_rules()
|
self.masking_rules = self._load_masking_rules()
|
||||||
|
|
||||||
|
|||||||
788
doris_mcp_server/utils/security_analytics_tools.py
Normal file
@@ -0,0 +1,788 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
Security Analytics Tools Module
|
||||||
|
Provides data access analysis, user behavior monitoring, and security insights
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from collections import Counter, defaultdict
|
||||||
|
|
||||||
|
from .db import DorisConnectionManager
|
||||||
|
from .logger import get_logger
|
||||||
|
from .sql_security_utils import get_auth_context
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SecurityAnalyticsTools:
|
||||||
|
"""Security analytics tools for access pattern analysis and user monitoring"""
|
||||||
|
|
||||||
|
def __init__(self, connection_manager: DorisConnectionManager):
|
||||||
|
self.connection_manager = connection_manager
|
||||||
|
logger.info("SecurityAnalyticsTools initialized")
|
||||||
|
|
||||||
|
async def analyze_data_access_patterns(
|
||||||
|
self,
|
||||||
|
days: int = 7,
|
||||||
|
include_system_users: bool = False,
|
||||||
|
min_query_threshold: int = 5
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Analyze data access patterns for users and roles
|
||||||
|
|
||||||
|
Args:
|
||||||
|
days: Number of days to analyze
|
||||||
|
include_system_users: Whether to include system/service users
|
||||||
|
min_query_threshold: Minimum queries for a user to be included in analysis
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Comprehensive access pattern analysis
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 🚀 PROGRESS: Initialize security analysis
|
||||||
|
logger.info("=" * 70)
|
||||||
|
logger.info(f"🔒 Starting Data Access Pattern Analysis")
|
||||||
|
logger.info(f"📅 Analysis period: {days} days")
|
||||||
|
logger.info(f"👥 Include system users: {include_system_users}")
|
||||||
|
logger.info(f"🎯 Min query threshold: {min_query_threshold}")
|
||||||
|
logger.info("=" * 70)
|
||||||
|
|
||||||
|
connection = await self.connection_manager.get_connection("query")
|
||||||
|
|
||||||
|
# Define analysis period
|
||||||
|
end_date = datetime.now()
|
||||||
|
start_date = end_date - timedelta(days=days)
|
||||||
|
|
||||||
|
logger.info(f"📊 Period: {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')}")
|
||||||
|
|
||||||
|
# 🚀 PROGRESS: Step 1 - Get audit log data
|
||||||
|
logger.info("📋 Step 1/5: Retrieving audit log data...")
|
||||||
|
audit_start = time.time()
|
||||||
|
audit_data = await self._get_audit_log_data(connection, start_date, end_date, include_system_users)
|
||||||
|
audit_time = time.time() - audit_start
|
||||||
|
|
||||||
|
if not audit_data:
|
||||||
|
logger.warning("⚠️ No audit data available for the specified period")
|
||||||
|
return {
|
||||||
|
"error": "No audit data available for the specified period",
|
||||||
|
"analysis_period": {
|
||||||
|
"start_date": start_date.isoformat(),
|
||||||
|
"end_date": end_date.isoformat(),
|
||||||
|
"days": days
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"✅ Retrieved {len(audit_data)} audit records in {audit_time:.2f}s")
|
||||||
|
|
||||||
|
# 🚀 PROGRESS: Step 2 - Analyze user access patterns
|
||||||
|
logger.info("👤 Step 2/5: Analyzing user access patterns...")
|
||||||
|
user_start = time.time()
|
||||||
|
user_access_analysis = await self._analyze_user_access_patterns(
|
||||||
|
audit_data, min_query_threshold
|
||||||
|
)
|
||||||
|
user_time = time.time() - user_start
|
||||||
|
logger.info(f"✅ Analyzed {len(user_access_analysis)} users in {user_time:.2f}s")
|
||||||
|
|
||||||
|
# 🚀 PROGRESS: Step 3 - Analyze role-based access
|
||||||
|
logger.info("🎭 Step 3/5: Analyzing role-based access patterns...")
|
||||||
|
role_start = time.time()
|
||||||
|
role_access_analysis = await self._analyze_role_access_patterns(
|
||||||
|
connection, user_access_analysis
|
||||||
|
)
|
||||||
|
role_time = time.time() - role_start
|
||||||
|
logger.info(f"✅ Role analysis completed in {role_time:.2f}s")
|
||||||
|
|
||||||
|
# 🚀 PROGRESS: Step 4 - Detect security anomalies
|
||||||
|
logger.info("🚨 Step 4/5: Detecting security anomalies...")
|
||||||
|
anomaly_start = time.time()
|
||||||
|
security_alerts = await self._detect_security_anomalies(
|
||||||
|
audit_data, user_access_analysis
|
||||||
|
)
|
||||||
|
anomaly_time = time.time() - anomaly_start
|
||||||
|
logger.info(f"✅ Found {len(security_alerts)} security alerts in {anomaly_time:.2f}s")
|
||||||
|
|
||||||
|
# Log alert summary
|
||||||
|
if security_alerts:
|
||||||
|
high_alerts = sum(1 for alert in security_alerts if alert.get("severity") == "high")
|
||||||
|
medium_alerts = sum(1 for alert in security_alerts if alert.get("severity") == "medium")
|
||||||
|
logger.info(f"🚨 Alert breakdown: {high_alerts} high, {medium_alerts} medium")
|
||||||
|
|
||||||
|
# 🚀 PROGRESS: Step 5 - Generate access insights
|
||||||
|
logger.info("💡 Step 5/5: Generating access insights...")
|
||||||
|
insights_start = time.time()
|
||||||
|
access_insights = await self._generate_access_insights(
|
||||||
|
user_access_analysis, role_access_analysis
|
||||||
|
)
|
||||||
|
insights_time = time.time() - insights_start
|
||||||
|
logger.info(f"✅ Access insights generated in {insights_time:.2f}s")
|
||||||
|
|
||||||
|
execution_time = time.time() - start_time
|
||||||
|
|
||||||
|
return {
|
||||||
|
"analysis_period": {
|
||||||
|
"start_date": start_date.isoformat(),
|
||||||
|
"end_date": end_date.isoformat(),
|
||||||
|
"days": days
|
||||||
|
},
|
||||||
|
"analysis_timestamp": datetime.now().isoformat(),
|
||||||
|
"execution_time_seconds": round(execution_time, 3),
|
||||||
|
"user_access_summary": self._generate_user_access_summary(user_access_analysis),
|
||||||
|
"user_access_details": user_access_analysis,
|
||||||
|
"role_analysis": role_access_analysis,
|
||||||
|
"security_alerts": security_alerts,
|
||||||
|
"access_insights": access_insights,
|
||||||
|
"recommendations": self._generate_security_recommendations(security_alerts, access_insights)
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Data access pattern analysis failed: {str(e)}")
|
||||||
|
return {
|
||||||
|
"error": str(e),
|
||||||
|
"analysis_timestamp": datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
# ==================== Private Helper Methods ====================
|
||||||
|
|
||||||
|
async def _get_audit_log_data(self, connection, start_date: datetime, end_date: datetime, include_system_users: bool) -> List[Dict]:
|
||||||
|
"""Retrieve audit log data for the specified period"""
|
||||||
|
try:
|
||||||
|
# System users filter
|
||||||
|
system_user_filter = ""
|
||||||
|
if not include_system_users:
|
||||||
|
system_users = ['root', 'admin', 'system', 'doris', 'information_schema']
|
||||||
|
user_list = ','.join([f'"{user}"' for user in system_users])
|
||||||
|
system_user_filter = f"AND `user` NOT IN ({user_list})"
|
||||||
|
|
||||||
|
audit_sql = f"""
|
||||||
|
SELECT
|
||||||
|
`user` as user_name,
|
||||||
|
`client_ip` as host,
|
||||||
|
`time` as query_time,
|
||||||
|
`stmt` as sql_statement,
|
||||||
|
`state` as query_status,
|
||||||
|
`scan_bytes` as scan_bytes,
|
||||||
|
`scan_rows` as scan_rows,
|
||||||
|
`return_rows` as return_rows,
|
||||||
|
`query_time` as execution_time_ms
|
||||||
|
FROM internal.__internal_schema.audit_log
|
||||||
|
WHERE `time` >= '{start_date.strftime('%Y-%m-%d %H:%M:%S')}'
|
||||||
|
AND `time` <= '{end_date.strftime('%Y-%m-%d %H:%M:%S')}'
|
||||||
|
AND `stmt` IS NOT NULL
|
||||||
|
AND `stmt` != ''
|
||||||
|
{system_user_filter}
|
||||||
|
ORDER BY `time` DESC
|
||||||
|
LIMIT 10000
|
||||||
|
"""
|
||||||
|
|
||||||
|
# SECURITY FIX: Pass auth_context to execute
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(audit_sql, auth_context=auth_context)
|
||||||
|
return result.data if result.data else []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get audit log data: {str(e)}")
|
||||||
|
# Try alternative method without detailed metrics
|
||||||
|
try:
|
||||||
|
simple_audit_sql = f"""
|
||||||
|
SELECT
|
||||||
|
`user` as user_name,
|
||||||
|
`client_ip` as host,
|
||||||
|
`time` as query_time,
|
||||||
|
`stmt` as sql_statement,
|
||||||
|
`state` as query_status
|
||||||
|
FROM internal.__internal_schema.audit_log
|
||||||
|
WHERE `time` >= '{start_date.strftime('%Y-%m-%d %H:%M:%S')}'
|
||||||
|
AND `time` <= '{end_date.strftime('%Y-%m-%d %H:%M:%S')}'
|
||||||
|
AND `stmt` IS NOT NULL
|
||||||
|
{system_user_filter}
|
||||||
|
ORDER BY `time` DESC
|
||||||
|
LIMIT 10000
|
||||||
|
"""
|
||||||
|
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(simple_audit_sql, auth_context=auth_context)
|
||||||
|
return result.data if result.data else []
|
||||||
|
|
||||||
|
except Exception as e2:
|
||||||
|
logger.error(f"Failed to get simplified audit log data: {str(e2)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _analyze_user_access_patterns(self, audit_data: List[Dict], min_query_threshold: int) -> List[Dict]:
|
||||||
|
"""Analyze access patterns for individual users"""
|
||||||
|
user_stats = defaultdict(lambda: {
|
||||||
|
"total_queries": 0,
|
||||||
|
"unique_tables_accessed": set(),
|
||||||
|
"hosts": set(),
|
||||||
|
"query_types": Counter(),
|
||||||
|
"query_times": [],
|
||||||
|
"failed_queries": 0,
|
||||||
|
"data_volume_read_bytes": 0,
|
||||||
|
"data_volume_read_rows": 0,
|
||||||
|
"hourly_pattern": [0] * 24,
|
||||||
|
"daily_pattern": [0] * 7,
|
||||||
|
"query_statements": []
|
||||||
|
})
|
||||||
|
|
||||||
|
# Process audit data
|
||||||
|
for entry in audit_data:
|
||||||
|
user_name = entry.get("user_name", "unknown")
|
||||||
|
query_time = entry.get("query_time")
|
||||||
|
sql_statement = entry.get("sql_statement", "")
|
||||||
|
query_status = entry.get("query_status", "")
|
||||||
|
|
||||||
|
stats = user_stats[user_name]
|
||||||
|
stats["total_queries"] += 1
|
||||||
|
|
||||||
|
# Extract table names from SQL
|
||||||
|
tables = self._extract_table_names_from_sql(sql_statement)
|
||||||
|
stats["unique_tables_accessed"].update(tables)
|
||||||
|
|
||||||
|
# Host tracking
|
||||||
|
if entry.get("host"):
|
||||||
|
stats["hosts"].add(entry["host"])
|
||||||
|
|
||||||
|
# Query type analysis
|
||||||
|
query_type = self._classify_query_type(sql_statement)
|
||||||
|
stats["query_types"][query_type] += 1
|
||||||
|
|
||||||
|
# Query time patterns
|
||||||
|
if query_time:
|
||||||
|
try:
|
||||||
|
if isinstance(query_time, str):
|
||||||
|
query_dt = datetime.fromisoformat(query_time.replace('Z', '+00:00'))
|
||||||
|
else:
|
||||||
|
query_dt = query_time
|
||||||
|
|
||||||
|
stats["query_times"].append(query_dt)
|
||||||
|
stats["hourly_pattern"][query_dt.hour] += 1
|
||||||
|
stats["daily_pattern"][query_dt.weekday()] += 1
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Error tracking
|
||||||
|
if query_status and "error" in query_status.lower():
|
||||||
|
stats["failed_queries"] += 1
|
||||||
|
|
||||||
|
# Data volume tracking
|
||||||
|
if entry.get("scan_bytes"):
|
||||||
|
try:
|
||||||
|
stats["data_volume_read_bytes"] += int(entry["scan_bytes"])
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if entry.get("scan_rows"):
|
||||||
|
try:
|
||||||
|
stats["data_volume_read_rows"] += int(entry["scan_rows"])
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Store sample queries
|
||||||
|
if len(stats["query_statements"]) < 10:
|
||||||
|
stats["query_statements"].append({
|
||||||
|
"sql": sql_statement[:200] + "..." if len(sql_statement) > 200 else sql_statement,
|
||||||
|
"timestamp": str(query_time),
|
||||||
|
"type": query_type
|
||||||
|
})
|
||||||
|
|
||||||
|
# Convert to analysis results
|
||||||
|
user_analysis = []
|
||||||
|
for user_name, stats in user_stats.items():
|
||||||
|
if stats["total_queries"] >= min_query_threshold:
|
||||||
|
# Calculate patterns and insights
|
||||||
|
access_pattern = self._classify_access_pattern(stats["hourly_pattern"])
|
||||||
|
table_access_frequency = dict(Counter(
|
||||||
|
table for entry in audit_data
|
||||||
|
if entry.get("user_name") == user_name
|
||||||
|
for table in self._extract_table_names_from_sql(entry.get("sql_statement", ""))
|
||||||
|
).most_common(10))
|
||||||
|
|
||||||
|
user_analysis.append({
|
||||||
|
"user_name": user_name,
|
||||||
|
"access_stats": {
|
||||||
|
"total_queries": stats["total_queries"],
|
||||||
|
"unique_tables_accessed": len(stats["unique_tables_accessed"]),
|
||||||
|
"unique_hosts": len(stats["hosts"]),
|
||||||
|
"data_volume_read_gb": round(stats["data_volume_read_bytes"] / (1024**3), 3),
|
||||||
|
"data_volume_read_rows": stats["data_volume_read_rows"],
|
||||||
|
"failed_queries": stats["failed_queries"],
|
||||||
|
"success_rate": round((stats["total_queries"] - stats["failed_queries"]) / stats["total_queries"], 3) if stats["total_queries"] > 0 else 0,
|
||||||
|
"peak_access_hour": stats["hourly_pattern"].index(max(stats["hourly_pattern"])) if max(stats["hourly_pattern"]) > 0 else None,
|
||||||
|
"access_pattern": access_pattern
|
||||||
|
},
|
||||||
|
"query_type_distribution": dict(stats["query_types"]),
|
||||||
|
"table_access_frequency": table_access_frequency,
|
||||||
|
"hosts_used": list(stats["hosts"]),
|
||||||
|
"sample_queries": stats["query_statements"],
|
||||||
|
"temporal_patterns": {
|
||||||
|
"hourly_distribution": stats["hourly_pattern"],
|
||||||
|
"daily_distribution": stats["daily_pattern"]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return sorted(user_analysis, key=lambda x: x["access_stats"]["total_queries"], reverse=True)
|
||||||
|
|
||||||
|
def _extract_table_names_from_sql(self, sql: str) -> List[str]:
|
||||||
|
"""Extract table names from SQL statement (simplified implementation)"""
|
||||||
|
if not sql:
|
||||||
|
return []
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Simple regex patterns to match table names
|
||||||
|
patterns = [
|
||||||
|
r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
|
||||||
|
r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
|
||||||
|
r'\bINTO\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
|
||||||
|
r'\bUPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
|
||||||
|
r'\bDELETE\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)'
|
||||||
|
]
|
||||||
|
|
||||||
|
tables = []
|
||||||
|
for pattern in patterns:
|
||||||
|
matches = re.findall(pattern, sql, re.IGNORECASE)
|
||||||
|
tables.extend(matches)
|
||||||
|
|
||||||
|
# Clean up table names (remove quotes, aliases, etc.)
|
||||||
|
cleaned_tables = []
|
||||||
|
for table in tables:
|
||||||
|
# Remove backticks, quotes, and get just the table name
|
||||||
|
clean_table = table.strip('`"\'').split(' ')[0]
|
||||||
|
if clean_table and not clean_table.upper() in ['SELECT', 'WHERE', 'AND', 'OR']:
|
||||||
|
cleaned_tables.append(clean_table)
|
||||||
|
|
||||||
|
return list(set(cleaned_tables))
|
||||||
|
|
||||||
|
def _classify_query_type(self, sql: str) -> str:
|
||||||
|
"""Classify SQL query type"""
|
||||||
|
if not sql:
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
sql_upper = sql.upper().strip()
|
||||||
|
|
||||||
|
if sql_upper.startswith('SELECT'):
|
||||||
|
return "SELECT"
|
||||||
|
elif sql_upper.startswith('INSERT'):
|
||||||
|
return "INSERT"
|
||||||
|
elif sql_upper.startswith('UPDATE'):
|
||||||
|
return "UPDATE"
|
||||||
|
elif sql_upper.startswith('DELETE'):
|
||||||
|
return "DELETE"
|
||||||
|
elif sql_upper.startswith('CREATE'):
|
||||||
|
return "CREATE"
|
||||||
|
elif sql_upper.startswith('ALTER'):
|
||||||
|
return "ALTER"
|
||||||
|
elif sql_upper.startswith('DROP'):
|
||||||
|
return "DROP"
|
||||||
|
elif sql_upper.startswith('SHOW'):
|
||||||
|
return "SHOW"
|
||||||
|
elif sql_upper.startswith('DESCRIBE') or sql_upper.startswith('DESC'):
|
||||||
|
return "DESCRIBE"
|
||||||
|
else:
|
||||||
|
return "OTHER"
|
||||||
|
|
||||||
|
def _classify_access_pattern(self, hourly_pattern: List[int]) -> str:
|
||||||
|
"""Classify user access pattern based on hourly distribution"""
|
||||||
|
if not hourly_pattern or max(hourly_pattern) == 0:
|
||||||
|
return "no_pattern"
|
||||||
|
|
||||||
|
# Find peak hours
|
||||||
|
max_queries = max(hourly_pattern)
|
||||||
|
peak_hours = [i for i, count in enumerate(hourly_pattern) if count == max_queries]
|
||||||
|
|
||||||
|
# Business hours: 9-17
|
||||||
|
business_hours = set(range(9, 18))
|
||||||
|
peak_in_business_hours = any(hour in business_hours for hour in peak_hours)
|
||||||
|
|
||||||
|
# Night hours: 22-6
|
||||||
|
night_hours = set(list(range(22, 24)) + list(range(0, 7)))
|
||||||
|
peak_in_night_hours = any(hour in night_hours for hour in peak_hours)
|
||||||
|
|
||||||
|
if peak_in_business_hours and not peak_in_night_hours:
|
||||||
|
return "regular_business_hours"
|
||||||
|
elif peak_in_night_hours:
|
||||||
|
return "night_shift_or_batch"
|
||||||
|
elif len(peak_hours) > 6: # Distributed throughout day
|
||||||
|
return "distributed_access"
|
||||||
|
else:
|
||||||
|
return "irregular_pattern"
|
||||||
|
|
||||||
|
async def _analyze_role_access_patterns(self, connection, user_access_analysis: List[Dict]) -> Dict[str, Any]:
|
||||||
|
"""Analyze access patterns by role"""
|
||||||
|
try:
|
||||||
|
# Get user roles information
|
||||||
|
user_roles = await self._get_user_roles(connection)
|
||||||
|
|
||||||
|
# Group users by roles
|
||||||
|
role_stats = defaultdict(lambda: {
|
||||||
|
"user_count": 0,
|
||||||
|
"total_queries": 0,
|
||||||
|
"unique_tables": set(),
|
||||||
|
"query_types": Counter(),
|
||||||
|
"avg_queries_per_user": 0,
|
||||||
|
"users": []
|
||||||
|
})
|
||||||
|
|
||||||
|
# Process user access data
|
||||||
|
for user_data in user_access_analysis:
|
||||||
|
user_name = user_data["user_name"]
|
||||||
|
user_stats = user_data["access_stats"]
|
||||||
|
query_types = user_data["query_type_distribution"]
|
||||||
|
|
||||||
|
# Get user roles (default to 'unknown' if not found)
|
||||||
|
roles = user_roles.get(user_name, ["unknown"])
|
||||||
|
|
||||||
|
for role in roles:
|
||||||
|
stats = role_stats[role]
|
||||||
|
stats["user_count"] += 1
|
||||||
|
stats["total_queries"] += user_stats["total_queries"]
|
||||||
|
stats["users"].append(user_name)
|
||||||
|
|
||||||
|
# Aggregate query types
|
||||||
|
for query_type, count in query_types.items():
|
||||||
|
stats["query_types"][query_type] += count
|
||||||
|
|
||||||
|
# Calculate role analysis
|
||||||
|
role_analysis = {}
|
||||||
|
for role, stats in role_stats.items():
|
||||||
|
if stats["user_count"] > 0:
|
||||||
|
avg_queries = stats["total_queries"] / stats["user_count"]
|
||||||
|
|
||||||
|
# Calculate privilege usage (simplified)
|
||||||
|
total_role_queries = sum(stats["query_types"].values())
|
||||||
|
privilege_usage = {}
|
||||||
|
if total_role_queries > 0:
|
||||||
|
privilege_usage = {
|
||||||
|
query_type: round(count / total_role_queries, 3)
|
||||||
|
for query_type, count in stats["query_types"].items()
|
||||||
|
}
|
||||||
|
|
||||||
|
role_analysis[role] = {
|
||||||
|
"user_count": stats["user_count"],
|
||||||
|
"users": stats["users"],
|
||||||
|
"total_queries": stats["total_queries"],
|
||||||
|
"avg_queries_per_user": round(avg_queries, 1),
|
||||||
|
"query_type_distribution": dict(stats["query_types"]),
|
||||||
|
"privilege_usage": privilege_usage,
|
||||||
|
"activity_level": self._classify_role_activity_level(avg_queries)
|
||||||
|
}
|
||||||
|
|
||||||
|
return role_analysis
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to analyze role access patterns: {str(e)}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def _get_user_roles(self, connection) -> Dict[str, List[str]]:
|
||||||
|
"""Get user roles mapping"""
|
||||||
|
try:
|
||||||
|
# Try to get user role information
|
||||||
|
roles_sql = """
|
||||||
|
SELECT
|
||||||
|
User as user_name,
|
||||||
|
COALESCE(Default_role, 'default') as role_name
|
||||||
|
FROM mysql.user
|
||||||
|
"""
|
||||||
|
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(roles_sql, auth_context=auth_context)
|
||||||
|
|
||||||
|
user_roles = defaultdict(list)
|
||||||
|
if result.data:
|
||||||
|
for row in result.data:
|
||||||
|
user_name = row.get("user_name", "")
|
||||||
|
role_name = row.get("role_name", "default")
|
||||||
|
if user_name:
|
||||||
|
user_roles[user_name].append(role_name)
|
||||||
|
|
||||||
|
return dict(user_roles)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get user roles: {str(e)}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _classify_role_activity_level(self, avg_queries: float) -> str:
|
||||||
|
"""Classify role activity level based on average queries"""
|
||||||
|
if avg_queries > 100:
|
||||||
|
return "high"
|
||||||
|
elif avg_queries > 20:
|
||||||
|
return "medium"
|
||||||
|
elif avg_queries > 5:
|
||||||
|
return "low"
|
||||||
|
else:
|
||||||
|
return "minimal"
|
||||||
|
|
||||||
|
async def _detect_security_anomalies(self, audit_data: List[Dict], user_access_analysis: List[Dict]) -> List[Dict]:
|
||||||
|
"""Detect potential security anomalies"""
|
||||||
|
alerts = []
|
||||||
|
|
||||||
|
# 1. Detect unusual access times
|
||||||
|
for user_data in user_access_analysis:
|
||||||
|
user_name = user_data["user_name"]
|
||||||
|
hourly_pattern = user_data["temporal_patterns"]["hourly_distribution"]
|
||||||
|
|
||||||
|
# Check for significant night-time activity
|
||||||
|
night_queries = sum(hourly_pattern[22:24]) + sum(hourly_pattern[0:6])
|
||||||
|
total_queries = sum(hourly_pattern)
|
||||||
|
|
||||||
|
if total_queries > 0 and night_queries / total_queries > 0.3: # >30% night activity
|
||||||
|
alerts.append({
|
||||||
|
"alert_type": "unusual_access_time",
|
||||||
|
"severity": "medium",
|
||||||
|
"user": user_name,
|
||||||
|
"description": f"User {user_name} has {night_queries/total_queries:.1%} of queries during night hours",
|
||||||
|
"night_query_percentage": round(night_queries/total_queries, 3),
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
})
|
||||||
|
|
||||||
|
# 2. Detect users with high failure rates
|
||||||
|
for user_data in user_access_analysis:
|
||||||
|
user_name = user_data["user_name"]
|
||||||
|
success_rate = user_data["access_stats"]["success_rate"]
|
||||||
|
total_queries = user_data["access_stats"]["total_queries"]
|
||||||
|
|
||||||
|
if total_queries > 10 and success_rate < 0.8: # <80% success rate
|
||||||
|
alerts.append({
|
||||||
|
"alert_type": "high_failure_rate",
|
||||||
|
"severity": "medium",
|
||||||
|
"user": user_name,
|
||||||
|
"description": f"User {user_name} has low query success rate ({success_rate:.1%})",
|
||||||
|
"success_rate": success_rate,
|
||||||
|
"total_queries": total_queries,
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
})
|
||||||
|
|
||||||
|
# 3. Detect unusual data volume access
|
||||||
|
data_volumes = [user["access_stats"]["data_volume_read_gb"] for user in user_access_analysis]
|
||||||
|
if data_volumes:
|
||||||
|
avg_volume = sum(data_volumes) / len(data_volumes)
|
||||||
|
std_dev = (sum((x - avg_volume) ** 2 for x in data_volumes) / len(data_volumes)) ** 0.5
|
||||||
|
threshold = avg_volume + 2 * std_dev # 2 standard deviations above mean
|
||||||
|
|
||||||
|
for user_data in user_access_analysis:
|
||||||
|
user_name = user_data["user_name"]
|
||||||
|
volume = user_data["access_stats"]["data_volume_read_gb"]
|
||||||
|
|
||||||
|
if volume > threshold and volume > 1.0: # >1GB and above threshold
|
||||||
|
alerts.append({
|
||||||
|
"alert_type": "unusual_data_volume",
|
||||||
|
"severity": "high" if volume > threshold * 2 else "medium",
|
||||||
|
"user": user_name,
|
||||||
|
"description": f"User {user_name} read {volume:.2f}GB (threshold: {threshold:.2f}GB)",
|
||||||
|
"data_volume_gb": volume,
|
||||||
|
"threshold_gb": round(threshold, 2),
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
})
|
||||||
|
|
||||||
|
# 4. Detect users accessing many different tables
|
||||||
|
for user_data in user_access_analysis:
|
||||||
|
user_name = user_data["user_name"]
|
||||||
|
unique_tables = user_data["access_stats"]["unique_tables_accessed"]
|
||||||
|
total_queries = user_data["access_stats"]["total_queries"]
|
||||||
|
|
||||||
|
# High table diversity might indicate privilege escalation or data mining
|
||||||
|
if unique_tables > 20 and total_queries > 50:
|
||||||
|
alerts.append({
|
||||||
|
"alert_type": "broad_table_access",
|
||||||
|
"severity": "medium",
|
||||||
|
"user": user_name,
|
||||||
|
"description": f"User {user_name} accessed {unique_tables} different tables",
|
||||||
|
"unique_tables_count": unique_tables,
|
||||||
|
"total_queries": total_queries,
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
})
|
||||||
|
|
||||||
|
return sorted(alerts, key=lambda x: {"high": 3, "medium": 2, "low": 1}.get(x["severity"], 0), reverse=True)
|
||||||
|
|
||||||
|
async def _generate_access_insights(self, user_access_analysis: List[Dict], role_analysis: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Generate access insights and patterns"""
|
||||||
|
insights = {
|
||||||
|
"user_behavior_patterns": {},
|
||||||
|
"role_effectiveness": {},
|
||||||
|
"security_posture": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
# User behavior patterns
|
||||||
|
if user_access_analysis:
|
||||||
|
total_users = len(user_access_analysis)
|
||||||
|
active_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 10])
|
||||||
|
power_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 100])
|
||||||
|
|
||||||
|
# Access pattern distribution
|
||||||
|
pattern_distribution = Counter(
|
||||||
|
user["access_stats"]["access_pattern"] for user in user_access_analysis
|
||||||
|
)
|
||||||
|
|
||||||
|
insights["user_behavior_patterns"] = {
|
||||||
|
"total_users_analyzed": total_users,
|
||||||
|
"active_users": active_users,
|
||||||
|
"power_users": power_users,
|
||||||
|
"access_pattern_distribution": dict(pattern_distribution),
|
||||||
|
"avg_queries_per_user": round(
|
||||||
|
sum(u["access_stats"]["total_queries"] for u in user_access_analysis) / total_users, 1
|
||||||
|
) if total_users > 0 else 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# Role effectiveness
|
||||||
|
if role_analysis:
|
||||||
|
most_active_role = max(role_analysis.items(), key=lambda x: x[1]["total_queries"])
|
||||||
|
least_active_role = min(role_analysis.items(), key=lambda x: x[1]["total_queries"])
|
||||||
|
|
||||||
|
insights["role_effectiveness"] = {
|
||||||
|
"total_roles": len(role_analysis),
|
||||||
|
"most_active_role": {
|
||||||
|
"role": most_active_role[0],
|
||||||
|
"total_queries": most_active_role[1]["total_queries"],
|
||||||
|
"user_count": most_active_role[1]["user_count"]
|
||||||
|
},
|
||||||
|
"least_active_role": {
|
||||||
|
"role": least_active_role[0],
|
||||||
|
"total_queries": least_active_role[1]["total_queries"],
|
||||||
|
"user_count": least_active_role[1]["user_count"]
|
||||||
|
},
|
||||||
|
"avg_users_per_role": round(
|
||||||
|
sum(role_info["user_count"] for role_info in role_analysis.values()) / len(role_analysis), 1
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Security posture assessment
|
||||||
|
if user_access_analysis:
|
||||||
|
users_with_failures = len([u for u in user_access_analysis if u["access_stats"]["failed_queries"] > 0])
|
||||||
|
users_night_access = len([
|
||||||
|
u for u in user_access_analysis
|
||||||
|
if any(u["temporal_patterns"]["hourly_distribution"][hour] > 0 for hour in list(range(22, 24)) + list(range(0, 6)))
|
||||||
|
])
|
||||||
|
|
||||||
|
insights["security_posture"] = {
|
||||||
|
"users_with_query_failures": users_with_failures,
|
||||||
|
"users_with_night_access": users_night_access,
|
||||||
|
"security_score": self._calculate_security_score(user_access_analysis),
|
||||||
|
"risk_level": self._assess_overall_risk_level(user_access_analysis)
|
||||||
|
}
|
||||||
|
|
||||||
|
return insights
|
||||||
|
|
||||||
|
def _calculate_security_score(self, user_access_analysis: List[Dict]) -> float:
|
||||||
|
"""Calculate overall security score (0-1, higher is better)"""
|
||||||
|
if not user_access_analysis:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
total_users = len(user_access_analysis)
|
||||||
|
|
||||||
|
# Factors that contribute to security score
|
||||||
|
users_with_high_success_rate = len([u for u in user_access_analysis if u["access_stats"]["success_rate"] > 0.9])
|
||||||
|
users_with_normal_patterns = len([u for u in user_access_analysis if u["access_stats"]["access_pattern"] == "regular_business_hours"])
|
||||||
|
|
||||||
|
success_rate_score = users_with_high_success_rate / total_users
|
||||||
|
pattern_score = users_with_normal_patterns / total_users
|
||||||
|
|
||||||
|
# Combined score
|
||||||
|
overall_score = (success_rate_score * 0.6 + pattern_score * 0.4)
|
||||||
|
return round(overall_score, 3)
|
||||||
|
|
||||||
|
def _assess_overall_risk_level(self, user_access_analysis: List[Dict]) -> str:
|
||||||
|
"""Assess overall security risk level"""
|
||||||
|
security_score = self._calculate_security_score(user_access_analysis)
|
||||||
|
|
||||||
|
if security_score > 0.8:
|
||||||
|
return "low"
|
||||||
|
elif security_score > 0.6:
|
||||||
|
return "medium"
|
||||||
|
else:
|
||||||
|
return "high"
|
||||||
|
|
||||||
|
def _generate_user_access_summary(self, user_access_analysis: List[Dict]) -> Dict[str, Any]:
|
||||||
|
"""Generate summary statistics for user access"""
|
||||||
|
if not user_access_analysis:
|
||||||
|
return {
|
||||||
|
"total_users": 0,
|
||||||
|
"active_users": 0,
|
||||||
|
"high_activity_users": 0,
|
||||||
|
"dormant_users": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
total_users = len(user_access_analysis)
|
||||||
|
active_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 10])
|
||||||
|
high_activity_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 100])
|
||||||
|
dormant_users = total_users - active_users
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_users": total_users,
|
||||||
|
"active_users": active_users,
|
||||||
|
"high_activity_users": high_activity_users,
|
||||||
|
"dormant_users": dormant_users,
|
||||||
|
"activity_distribution": {
|
||||||
|
"high": high_activity_users,
|
||||||
|
"medium": active_users - high_activity_users,
|
||||||
|
"low": dormant_users
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def _generate_security_recommendations(self, security_alerts: List[Dict], access_insights: Dict[str, Any]) -> List[Dict]:
|
||||||
|
"""Generate security recommendations based on analysis"""
|
||||||
|
recommendations = []
|
||||||
|
|
||||||
|
# Recommendations based on alerts
|
||||||
|
if security_alerts:
|
||||||
|
high_severity_alerts = [alert for alert in security_alerts if alert["severity"] == "high"]
|
||||||
|
if high_severity_alerts:
|
||||||
|
recommendations.append({
|
||||||
|
"type": "urgent_security_review",
|
||||||
|
"priority": "high",
|
||||||
|
"description": f"Found {len(high_severity_alerts)} high-severity security alerts",
|
||||||
|
"action": "Immediate review of flagged users and access patterns required",
|
||||||
|
"affected_users": list(set(alert["user"] for alert in high_severity_alerts if "user" in alert))
|
||||||
|
})
|
||||||
|
|
||||||
|
# Night access recommendations
|
||||||
|
night_access_alerts = [alert for alert in security_alerts if alert["alert_type"] == "unusual_access_time"]
|
||||||
|
if night_access_alerts:
|
||||||
|
recommendations.append({
|
||||||
|
"type": "access_time_policy",
|
||||||
|
"priority": "medium",
|
||||||
|
"description": f"{len(night_access_alerts)} users have significant night-time access",
|
||||||
|
"action": "Review access time policies and consider time-based restrictions",
|
||||||
|
"affected_users": [alert["user"] for alert in night_access_alerts]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Recommendations based on insights
|
||||||
|
security_posture = access_insights.get("security_posture", {})
|
||||||
|
risk_level = security_posture.get("risk_level", "unknown")
|
||||||
|
|
||||||
|
if risk_level == "high":
|
||||||
|
recommendations.append({
|
||||||
|
"type": "overall_security_improvement",
|
||||||
|
"priority": "high",
|
||||||
|
"description": "Overall security posture indicates high risk",
|
||||||
|
"action": "Comprehensive security audit and policy review recommended"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Role-based recommendations
|
||||||
|
role_effectiveness = access_insights.get("role_effectiveness", {})
|
||||||
|
if role_effectiveness and role_effectiveness.get("total_roles", 0) < 3:
|
||||||
|
recommendations.append({
|
||||||
|
"type": "role_management",
|
||||||
|
"priority": "medium",
|
||||||
|
"description": "Limited role diversity detected",
|
||||||
|
"action": "Consider implementing more granular role-based access control"
|
||||||
|
})
|
||||||
|
|
||||||
|
return recommendations
|
||||||
301
doris_mcp_server/utils/sql_security_utils.py
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
SQL Security Utilities Module
|
||||||
|
|
||||||
|
Provides SQL identifier validation, escaping, and safe query building utilities
|
||||||
|
to prevent SQL injection attacks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Optional, Tuple, List, Any
|
||||||
|
|
||||||
|
from .logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Context variable for auth_context (set by HTTP middleware)
|
||||||
|
auth_context_var: ContextVar = ContextVar('mcp_auth_context', default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class SQLSecurityError(Exception):
|
||||||
|
"""Exception raised for SQL security validation failures"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SQLSecurityUtils:
|
||||||
|
"""
|
||||||
|
SQL Security Utilities for preventing SQL injection attacks.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
- Identifier validation (database names, table names, column names)
|
||||||
|
- Safe identifier quoting with backticks
|
||||||
|
- Safe table reference building
|
||||||
|
- Auth context retrieval from context variables
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Valid SQL identifier pattern: letters, numbers, underscores
|
||||||
|
# Must start with letter or underscore, not a number
|
||||||
|
# Supports Unicode letters for international database/table names
|
||||||
|
IDENTIFIER_PATTERN = re.compile(r'^[a-zA-Z_\u4e00-\u9fff][a-zA-Z0-9_\u4e00-\u9fff]*$')
|
||||||
|
|
||||||
|
# Maximum identifier length (MySQL/Doris standard)
|
||||||
|
MAX_IDENTIFIER_LENGTH = 64
|
||||||
|
|
||||||
|
# SQL reserved keywords that should be quoted
|
||||||
|
SQL_KEYWORDS = {
|
||||||
|
'SELECT', 'FROM', 'WHERE', 'INSERT', 'UPDATE', 'DELETE', 'DROP',
|
||||||
|
'CREATE', 'ALTER', 'TABLE', 'DATABASE', 'INDEX', 'VIEW', 'AND',
|
||||||
|
'OR', 'NOT', 'NULL', 'TRUE', 'FALSE', 'IN', 'LIKE', 'BETWEEN',
|
||||||
|
'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'ON', 'AS', 'ORDER',
|
||||||
|
'BY', 'GROUP', 'HAVING', 'LIMIT', 'OFFSET', 'UNION', 'ALL',
|
||||||
|
'DISTINCT', 'INTO', 'VALUES', 'SET', 'DEFAULT', 'PRIMARY', 'KEY',
|
||||||
|
'FOREIGN', 'REFERENCES', 'CHECK', 'UNIQUE', 'CONSTRAINT'
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_identifier(cls, name: str, identifier_type: str = "identifier") -> str:
|
||||||
|
"""
|
||||||
|
Validate a SQL identifier (database name, table name, column name, etc.)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The identifier to validate
|
||||||
|
identifier_type: Type description for error messages (e.g., "database name", "table name")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The validated identifier (unchanged if valid)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SQLSecurityError: If the identifier is invalid
|
||||||
|
"""
|
||||||
|
if not name:
|
||||||
|
raise SQLSecurityError(f"Empty {identifier_type} is not allowed")
|
||||||
|
|
||||||
|
if not isinstance(name, str):
|
||||||
|
raise SQLSecurityError(f"Invalid {identifier_type}: must be a string, got {type(name).__name__}")
|
||||||
|
|
||||||
|
# Strip whitespace
|
||||||
|
name = name.strip()
|
||||||
|
|
||||||
|
if not name:
|
||||||
|
raise SQLSecurityError(f"Empty {identifier_type} is not allowed")
|
||||||
|
|
||||||
|
# Check length
|
||||||
|
if len(name) > cls.MAX_IDENTIFIER_LENGTH:
|
||||||
|
raise SQLSecurityError(
|
||||||
|
f"Invalid {identifier_type}: '{name[:20]}...' exceeds maximum length of {cls.MAX_IDENTIFIER_LENGTH} characters"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for dangerous characters that could be SQL injection
|
||||||
|
dangerous_chars = ["'", '"', ';', '--', '/*', '*/', '\\', '\x00']
|
||||||
|
for char in dangerous_chars:
|
||||||
|
if char in name:
|
||||||
|
raise SQLSecurityError(
|
||||||
|
f"Invalid {identifier_type}: '{name}' contains forbidden character '{char}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate pattern
|
||||||
|
if not cls.IDENTIFIER_PATTERN.match(name):
|
||||||
|
raise SQLSecurityError(
|
||||||
|
f"Invalid {identifier_type}: '{name}' contains invalid characters. "
|
||||||
|
f"Only letters, numbers, and underscores are allowed, and must start with a letter or underscore."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Validated {identifier_type}: {name}")
|
||||||
|
return name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def quote_identifier(cls, name: str, identifier_type: str = "identifier") -> str:
|
||||||
|
"""
|
||||||
|
Safely quote a SQL identifier using backticks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The identifier to quote
|
||||||
|
identifier_type: Type description for error messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The quoted identifier (e.g., `table_name`)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SQLSecurityError: If the identifier is invalid
|
||||||
|
"""
|
||||||
|
# First validate the identifier
|
||||||
|
validated_name = cls.validate_identifier(name, identifier_type)
|
||||||
|
|
||||||
|
# Escape any backticks within the name (double them)
|
||||||
|
escaped_name = validated_name.replace('`', '``')
|
||||||
|
|
||||||
|
return f"`{escaped_name}`"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_table_reference(
|
||||||
|
cls,
|
||||||
|
table_name: str,
|
||||||
|
db_name: Optional[str] = None,
|
||||||
|
catalog_name: Optional[str] = None,
|
||||||
|
quote: bool = True
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Build a safe, fully-qualified table reference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: The table name (required)
|
||||||
|
db_name: The database name (optional)
|
||||||
|
catalog_name: The catalog name (optional)
|
||||||
|
quote: Whether to quote identifiers with backticks (default: True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A safe table reference string (e.g., `catalog`.`db`.`table`)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SQLSecurityError: If any identifier is invalid
|
||||||
|
"""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
if catalog_name:
|
||||||
|
if quote:
|
||||||
|
parts.append(cls.quote_identifier(catalog_name, "catalog name"))
|
||||||
|
else:
|
||||||
|
parts.append(cls.validate_identifier(catalog_name, "catalog name"))
|
||||||
|
|
||||||
|
if db_name:
|
||||||
|
if quote:
|
||||||
|
parts.append(cls.quote_identifier(db_name, "database name"))
|
||||||
|
else:
|
||||||
|
parts.append(cls.validate_identifier(db_name, "database name"))
|
||||||
|
|
||||||
|
if quote:
|
||||||
|
parts.append(cls.quote_identifier(table_name, "table name"))
|
||||||
|
else:
|
||||||
|
parts.append(cls.validate_identifier(table_name, "table name"))
|
||||||
|
|
||||||
|
return '.'.join(parts)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_column_reference(
|
||||||
|
cls,
|
||||||
|
column_name: str,
|
||||||
|
table_name: Optional[str] = None,
|
||||||
|
quote: bool = True
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Build a safe column reference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column_name: The column name (required)
|
||||||
|
table_name: The table name (optional, for qualified references)
|
||||||
|
quote: Whether to quote identifiers with backticks (default: True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A safe column reference string (e.g., `table`.`column`)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SQLSecurityError: If any identifier is invalid
|
||||||
|
"""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
if table_name:
|
||||||
|
if quote:
|
||||||
|
parts.append(cls.quote_identifier(table_name, "table name"))
|
||||||
|
else:
|
||||||
|
parts.append(cls.validate_identifier(table_name, "table name"))
|
||||||
|
|
||||||
|
if quote:
|
||||||
|
parts.append(cls.quote_identifier(column_name, "column name"))
|
||||||
|
else:
|
||||||
|
parts.append(cls.validate_identifier(column_name, "column name"))
|
||||||
|
|
||||||
|
return '.'.join(parts)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_and_build_where_condition(
|
||||||
|
cls,
|
||||||
|
column_name: str,
|
||||||
|
operator: str = "=",
|
||||||
|
use_param: bool = True
|
||||||
|
) -> Tuple[str, bool]:
|
||||||
|
"""
|
||||||
|
Build a safe WHERE condition for a column.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column_name: The column name
|
||||||
|
operator: The comparison operator (=, !=, <, >, <=, >=, LIKE, IN)
|
||||||
|
use_param: Whether to use parameterized placeholder (%s)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (condition_string, needs_param)
|
||||||
|
e.g., ("`column` = %s", True) or ("`column` = DATABASE()", False)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SQLSecurityError: If column name is invalid or operator is not allowed
|
||||||
|
"""
|
||||||
|
# Validate column name
|
||||||
|
quoted_column = cls.quote_identifier(column_name, "column name")
|
||||||
|
|
||||||
|
# Validate operator
|
||||||
|
allowed_operators = {'=', '!=', '<>', '<', '>', '<=', '>=', 'LIKE', 'IN', 'IS'}
|
||||||
|
if operator.upper() not in allowed_operators:
|
||||||
|
raise SQLSecurityError(f"Invalid operator: '{operator}'. Allowed: {allowed_operators}")
|
||||||
|
|
||||||
|
if use_param:
|
||||||
|
return f"{quoted_column} {operator} %s", True
|
||||||
|
else:
|
||||||
|
return f"{quoted_column} {operator}", False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_auth_context():
|
||||||
|
"""
|
||||||
|
Get auth_context from the context variable.
|
||||||
|
|
||||||
|
This retrieves the auth_context that was set by the HTTP middleware
|
||||||
|
during request processing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The auth_context object, or None if not available
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
auth_context = auth_context_var.get()
|
||||||
|
if auth_context:
|
||||||
|
logger.debug(f"Retrieved auth_context from context variable")
|
||||||
|
return auth_context
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not retrieve auth_context: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_auth_context(auth_context):
|
||||||
|
"""
|
||||||
|
Set auth_context in the context variable.
|
||||||
|
|
||||||
|
This is typically called by the HTTP middleware during request processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth_context: The auth_context object to set
|
||||||
|
"""
|
||||||
|
auth_context_var.set(auth_context)
|
||||||
|
logger.debug("Set auth_context in context variable")
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience functions for direct use
|
||||||
|
validate_identifier = SQLSecurityUtils.validate_identifier
|
||||||
|
quote_identifier = SQLSecurityUtils.quote_identifier
|
||||||
|
build_table_reference = SQLSecurityUtils.build_table_reference
|
||||||
|
build_column_reference = SQLSecurityUtils.build_column_reference
|
||||||
|
get_auth_context = SQLSecurityUtils.get_auth_context
|
||||||
|
set_auth_context = SQLSecurityUtils.set_auth_context
|
||||||
|
|
||||||
147
examples/cursor/README.md
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
<!--
|
||||||
|
Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
or more contributor license agreements. See the NOTICE file
|
||||||
|
distributed with this work for additional information
|
||||||
|
regarding copyright ownership. The ASF licenses this file
|
||||||
|
to you under the Apache License, Version 2.0 (the
|
||||||
|
"License"); you may not use this file except in compliance
|
||||||
|
with the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing,
|
||||||
|
software distributed under the License is distributed on an
|
||||||
|
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations
|
||||||
|
under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
|
||||||
|
# Cursor Example: Integrating Doris MCP Server
|
||||||
|
|
||||||
|
This guide provides step-by-step instructions on how to integrate the `doris-mcp-server` with the [Cursor](https://cursor.sh/) IDE. This integration allows you to interact with your Apache Doris database using natural language queries directly within Cursor's AI chat.
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
* [Prerequisites](#prerequisites)
|
||||||
|
* [Step 1: Set Up the Project](#step-1-set-up-the-project)
|
||||||
|
* [Step 2: Configure the MCP Server in Cursor](#step-2-configure-the-mcp-server-in-cursor)
|
||||||
|
* [Step 3: Verify the Integration](#step-3-verify-the-integration)
|
||||||
|
* [Step 4: Query Your Database](#step-4-query-your-database)
|
||||||
|
|
||||||
|
* [Example 1: List Tables](#example-1-list-tables)
|
||||||
|
* [Example 2: Analyze Sales Trends](#example-2-analyze-sales-trends)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
Before you begin, ensure you have the following installed and configured:
|
||||||
|
|
||||||
|
* The **Cursor** IDE
|
||||||
|
* **Git** for cloning the repository
|
||||||
|
* Access to an **Apache Doris** cluster (FE host, port, username, and password)
|
||||||
|
* **uv**, a fast Python package installer and runner
|
||||||
|
|
||||||
|
You can install `uv` with one of the following commands:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# For macOS (recommended)
|
||||||
|
brew install uv
|
||||||
|
|
||||||
|
# For other systems using pipx
|
||||||
|
pipx install uv
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Step 1: Set Up the Project
|
||||||
|
|
||||||
|
First, clone the `doris-mcp-server` repository to your local machine:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/apache/doris-mcp-server.git
|
||||||
|
|
||||||
|
cd doris-mcp-server
|
||||||
|
```
|
||||||
|
|
||||||
|
The necessary dependencies are listed in `requirements.txt` and will be managed automatically by `uv` in the next step.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Step 2: Configure the MCP Server in Cursor
|
||||||
|
|
||||||
|
1. Open the cloned `doris-mcp-server` directory in Cursor.
|
||||||
|
2. Click the ⚙️ icon (top-right), then go to **Tools & Integrations**.
|
||||||
|

|
||||||
|
3. Click **Add a custom MCP Server**.
|
||||||
|
4. Paste the following JSON configuration:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"doris-mcp": {
|
||||||
|
"command": "uv",
|
||||||
|
"args": [
|
||||||
|
"run",
|
||||||
|
"--project",
|
||||||
|
"/path/to/your/doris-mcp-server",
|
||||||
|
"doris-mcp-server"
|
||||||
|
],
|
||||||
|
"env": {
|
||||||
|
"DORIS_HOST": "your_doris_fe_host",
|
||||||
|
"DORIS_PORT": "9030",
|
||||||
|
"DORIS_USER": "your_username",
|
||||||
|
"DORIS_PASSWORD": "your_password",
|
||||||
|
"DORIS_DATABASE": "ssb"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
> ⚠️ **Important:**
|
||||||
|
>
|
||||||
|
> * Replace `"/path/to/your/doris-mcp-server"` with the **absolute path** to your local project directory.
|
||||||
|
> * Fill in your actual Doris FE host, username, password, and database name.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Step 3: Verify the Integration
|
||||||
|
|
||||||
|
Once saved, go back to the **Settings** panel. If everything is configured correctly, you’ll see a green status dot next to `doris-mcp-server`, along with available tools like `exec_query`.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Step 4: Query Your Database
|
||||||
|
|
||||||
|
You can now chat with Cursor Agent to run SQL queries against your Doris database.
|
||||||
|
|
||||||
|
1. Open the chat panel using `Cmd + K` (macOS) or `Ctrl + K` (Windows/Linux), or click the chat icon in the top-right.
|
||||||
|
2. Switch to **Agent Mode**.
|
||||||
|
3. Start asking questions using natural language.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
#### Example 1: List Tables
|
||||||
|
|
||||||
|
> **Prompt:** What tables are in the `ssb` database?
|
||||||
|
|
||||||
|
The agent will call the `get_db_table_list` tool and return the results.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
#### Example 2: Analyze Sales Trends
|
||||||
|
|
||||||
|
> **Prompt:** What has been the sales trend over the past ten years in the `ssb` database, and which year had the fastest growth?
|
||||||
|
|
||||||
|
The agent will generate an appropriate SQL query, send it to the MCP server, and interpret the results to give you growth trends and highlights.
|
||||||
|
|
||||||
|

|
||||||
156
examples/dify/README.md
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
<!--
|
||||||
|
Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
or more contributor license agreements. See the NOTICE file
|
||||||
|
distributed with this work for additional information
|
||||||
|
regarding copyright ownership. The ASF licenses this file
|
||||||
|
to you under the Apache License, Version 2.0 (the
|
||||||
|
"License"); you may not use this file except in compliance
|
||||||
|
with the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing,
|
||||||
|
software distributed under the License is distributed on an
|
||||||
|
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations
|
||||||
|
under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
|
||||||
|
# Dify Example: Integrating Doris MCP Server
|
||||||
|
|
||||||
|
This document demonstrates how to integrate and use `doris-mcp-server` in Dify to perform Doris SQL calls via MCP.
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Prerequisites](#prerequisites)
|
||||||
|
- [Starting the MCP Server](#starting-the-mcp-server)
|
||||||
|
- [Ngrok Tunnel (Optional)](#ngrok-tunnel-optional)
|
||||||
|
- [Installing & Configuring the Plugin in Dify](#installing--configuring-the-plugin-in-dify)
|
||||||
|
- [Creating a Dify App](#creating-a-dify-app)
|
||||||
|
- [Adding MCP Tools](#adding-mcp-tools)
|
||||||
|
- [Example Calls](#example-calls)
|
||||||
|
|
||||||
|
|
||||||
|
-----
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
First, install `mcp-doris-server`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mcp-doris-server
|
||||||
|
```
|
||||||
|
|
||||||
|
## Starting the MCP Server
|
||||||
|
|
||||||
|
Run the startup script:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Full configuration with database connection
|
||||||
|
doris-mcp-server \
|
||||||
|
--transport http \
|
||||||
|
--host 0.0.0.0 \
|
||||||
|
--port 3000 \
|
||||||
|
--db-host 127.0.0.1 \
|
||||||
|
--db-port 9030 \
|
||||||
|
--db-user root \
|
||||||
|
--db-password your_password
|
||||||
|
```
|
||||||
|
|
||||||
|
If successful, you'll see logs similar to this:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
-----
|
||||||
|
|
||||||
|
## Ngrok Tunnel (Optional)
|
||||||
|
|
||||||
|
If your Dify deployment requires a publicly accessible endpoint, you can use the **ngrok** tool. Ngrok is a third-party service that securely exposes local servers to the internet.
|
||||||
|
|
||||||
|
|
||||||
|
-----
|
||||||
|
|
||||||
|
## Installing & Configuring the Plugin in Dify
|
||||||
|
|
||||||
|
1. In the Dify console, go to **Plugin Marketplace**, search for, and install **MCP‑SSE / StreamableHTTP**:
|
||||||
|

|
||||||
|
|
||||||
|
2. After installation, click **Configure** and set the URL to your public or local address. For example, if you're using `ngrok`, this should be the public URL `ngrok` provides, in the format `https://<your-domain>/mcp`. If Dify can directly access your local server, use `http://localhost:3000/mcp`.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"doris_mcp_server": {
|
||||||
|
"transport": "streamable_http",
|
||||||
|
"url": "https://<your-domain>/mcp"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|

|
||||||
|
|
||||||
|
3. Click **Save**. If configured correctly, you'll see a green **Authorized** indicator:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
-----
|
||||||
|
|
||||||
|
## Creating a Dify App
|
||||||
|
|
||||||
|
1. In the Dify console, click **New App** → **Blank App**.
|
||||||
|

|
||||||
|
|
||||||
|
2. Select **Agent** as the template and set the **App Name** (e.g., `Doris ChatBI`).
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
3. Import from DSL,[dify_doris_dsl.yml](dify_doris_dsl.yml)
|
||||||
|
|
||||||
|
-----
|
||||||
|
|
||||||
|
## Instructions & Tool Configuration
|
||||||
|
|
||||||
|
### Instruction Block
|
||||||
|
|
||||||
|
Paste the following into the **Instruction** field:
|
||||||
|
|
||||||
|
```
|
||||||
|
<instruction>
|
||||||
|
Use MCP tools to complete tasks as much as possible. Carefully read the annotations, method names, and parameter descriptions of each tool. Please follow these steps:
|
||||||
|
1. Analyze the user's question and match the most appropriate tool.
|
||||||
|
2. Use tool names and parameters exactly as defined; do not invent new ones.
|
||||||
|
3. Pass parameters in the required JSON format.
|
||||||
|
4. When calling tools, use:
|
||||||
|
{"mcp_sse_call_tool": {"tool_name": "<tool_name>", "arguments": "{}"}}
|
||||||
|
5. Output plain text only—no XML tags.
|
||||||
|
<input>
|
||||||
|
User question: user_query
|
||||||
|
</input>
|
||||||
|
<output>
|
||||||
|
Return tool results or a final answer, including analysis.
|
||||||
|
</output>
|
||||||
|
</instruction>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Adding MCP Tools
|
||||||
|
|
||||||
|
In the **Tools** pane, click **Add** twice to add two entries, both named `mcp_sse` (they will inherit the transport and URL from the plugin):
|
||||||
|

|
||||||
|
|
||||||
|
-----
|
||||||
|
|
||||||
|
## Example Calls
|
||||||
|
|
||||||
|
### List Tables in Database
|
||||||
|
|
||||||
|
* **User**: What tables are in the database?
|
||||||
|
|
||||||
|
* **Result**: Dify will call the MCP tool to run `SHOW TABLES` and return the list.
|
||||||
|

|
||||||
|
|
||||||
|
### Sales Trend Over Ten Years
|
||||||
|
|
||||||
|
* **User**: What has been the sales trend over the past ten years in the ssb database, and which year had the fastest growth?
|
||||||
|
|
||||||
|
* **Result**: The tool will execute the SQL, calculate growth rates, and return data.
|
||||||
|

|
||||||
127
examples/dify/dify_doris_dsl.yml
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
app:
|
||||||
|
description: ''
|
||||||
|
icon: 🤖
|
||||||
|
icon_background: '#FFEAD5'
|
||||||
|
mode: agent-chat
|
||||||
|
name: doris
|
||||||
|
use_icon_as_answer_icon: false
|
||||||
|
dependencies:
|
||||||
|
- current_identifier: null
|
||||||
|
type: marketplace
|
||||||
|
value:
|
||||||
|
marketplace_plugin_unique_identifier: langgenius/deepseek:0.0.5@21408d5c48cd9f18d66b08883d0999fe89e6d049c891324c2229dea23b9665d5
|
||||||
|
- current_identifier: null
|
||||||
|
type: marketplace
|
||||||
|
value:
|
||||||
|
marketplace_plugin_unique_identifier: junjiem/mcp_sse:0.2.1@53cc613667fcf91dd7208dd5f6d2c8df3c7ff0af8b79e8f3c0a430f1b39bda4c
|
||||||
|
kind: app
|
||||||
|
model_config:
|
||||||
|
agent_mode:
|
||||||
|
enabled: true
|
||||||
|
max_iteration: 10
|
||||||
|
prompt: null
|
||||||
|
strategy: function_call
|
||||||
|
tools:
|
||||||
|
- enabled: true
|
||||||
|
isDeleted: false
|
||||||
|
notAuthor: false
|
||||||
|
provider_id: junjiem/mcp_sse/mcp_sse
|
||||||
|
provider_name: junjiem/mcp_sse/mcp_sse
|
||||||
|
provider_type: builtin
|
||||||
|
tool_label: 获取 MCP 工具列表
|
||||||
|
tool_name: mcp_sse_list_tools
|
||||||
|
tool_parameters:
|
||||||
|
prompts_as_tools: 1
|
||||||
|
resources_as_tools: 1
|
||||||
|
servers_config: null
|
||||||
|
- enabled: true
|
||||||
|
isDeleted: false
|
||||||
|
notAuthor: false
|
||||||
|
provider_id: junjiem/mcp_sse/mcp_sse
|
||||||
|
provider_name: junjiem/mcp_sse/mcp_sse
|
||||||
|
provider_type: builtin
|
||||||
|
tool_label: 调用 MCP 工具
|
||||||
|
tool_name: mcp_sse_call_tool
|
||||||
|
tool_parameters:
|
||||||
|
arguments: ''
|
||||||
|
prompts_as_tools: ''
|
||||||
|
resources_as_tools: ''
|
||||||
|
servers_config: ''
|
||||||
|
tool_name: ''
|
||||||
|
annotation_reply:
|
||||||
|
enabled: false
|
||||||
|
chat_prompt_config: {}
|
||||||
|
completion_prompt_config: {}
|
||||||
|
dataset_configs:
|
||||||
|
datasets:
|
||||||
|
datasets: []
|
||||||
|
reranking_enable: true
|
||||||
|
reranking_mode: reranking_model
|
||||||
|
reranking_model:
|
||||||
|
reranking_model_name: ''
|
||||||
|
reranking_provider_name: ''
|
||||||
|
retrieval_model: multiple
|
||||||
|
top_k: 4
|
||||||
|
dataset_query_variable: ''
|
||||||
|
external_data_tools: []
|
||||||
|
file_upload:
|
||||||
|
allowed_file_extensions:
|
||||||
|
- .JPG
|
||||||
|
- .JPEG
|
||||||
|
- .PNG
|
||||||
|
- .GIF
|
||||||
|
- .WEBP
|
||||||
|
- .SVG
|
||||||
|
- .MP4
|
||||||
|
- .MOV
|
||||||
|
- .MPEG
|
||||||
|
- .WEBM
|
||||||
|
allowed_file_types: []
|
||||||
|
allowed_file_upload_methods:
|
||||||
|
- remote_url
|
||||||
|
- local_file
|
||||||
|
enabled: false
|
||||||
|
image:
|
||||||
|
detail: high
|
||||||
|
enabled: false
|
||||||
|
number_limits: 3
|
||||||
|
transfer_methods:
|
||||||
|
- remote_url
|
||||||
|
- local_file
|
||||||
|
number_limits: 3
|
||||||
|
model:
|
||||||
|
completion_params:
|
||||||
|
stop: []
|
||||||
|
mode: chat
|
||||||
|
name: deepseek-chat
|
||||||
|
provider: langgenius/deepseek/deepseek
|
||||||
|
more_like_this:
|
||||||
|
enabled: false
|
||||||
|
opening_statement: ''
|
||||||
|
pre_prompt: "<instruction>\nUse MCP tools to complete tasks as much as possible.\
|
||||||
|
\ Carefully read the annotations, method names, and parameter descriptions of\
|
||||||
|
\ each tool. Please follow these steps:\n1. Analyze the user's question and match\
|
||||||
|
\ the most appropriate tool.\n2. Use tool names and parameters exactly as defined;\
|
||||||
|
\ do not invent new ones.\n3. Pass parameters in the required JSON format.\n4.\
|
||||||
|
\ When calling tools, use:\n {\"mcp_sse_call_tool\": {\"tool_name\": \"<tool_name>\"\
|
||||||
|
, \"arguments\": \"{}\"}}\n5. Output plain text only—no XML tags.\n<input>\nUser\
|
||||||
|
\ question: user_query\n</input>\n<output>\nReturn tool results or a final answer,\
|
||||||
|
\ including analysis.\n</output>\n</instruction>"
|
||||||
|
prompt_type: simple
|
||||||
|
retriever_resource:
|
||||||
|
enabled: true
|
||||||
|
sensitive_word_avoidance:
|
||||||
|
configs: []
|
||||||
|
enabled: false
|
||||||
|
type: ''
|
||||||
|
speech_to_text:
|
||||||
|
enabled: false
|
||||||
|
suggested_questions: []
|
||||||
|
suggested_questions_after_answer:
|
||||||
|
enabled: false
|
||||||
|
text_to_speech:
|
||||||
|
enabled: false
|
||||||
|
language: ''
|
||||||
|
voice: ''
|
||||||
|
user_input_form: []
|
||||||
|
version: 0.3.0
|
||||||
BIN
examples/images/cursor_add_mcp.png
Normal file
|
After Width: | Height: | Size: 323 KiB |
BIN
examples/images/cursor_agent.png
Normal file
|
After Width: | Height: | Size: 673 KiB |
BIN
examples/images/cursor_ask1.png
Normal file
|
After Width: | Height: | Size: 118 KiB |
BIN
examples/images/cursor_ask2.png
Normal file
|
After Width: | Height: | Size: 232 KiB |
BIN
examples/images/cursor_doris-mcp.png
Normal file
|
After Width: | Height: | Size: 80 KiB |
BIN
examples/images/dify_add_tools.png
Normal file
|
After Width: | Height: | Size: 17 KiB |
BIN
examples/images/dify_agent_setup.png
Normal file
|
After Width: | Height: | Size: 258 KiB |
BIN
examples/images/dify_authorized.png
Normal file
|
After Width: | Height: | Size: 44 KiB |
BIN
examples/images/dify_config_mcp.png
Normal file
|
After Width: | Height: | Size: 66 KiB |
BIN
examples/images/dify_create_app.png
Normal file
|
After Width: | Height: | Size: 127 KiB |
BIN
examples/images/dify_install_plugin.png
Normal file
|
After Width: | Height: | Size: 317 KiB |
BIN
examples/images/dify_query_tabels.png
Normal file
|
After Width: | Height: | Size: 369 KiB |
BIN
examples/images/dify_sale_trend.png
Normal file
|
After Width: | Height: | Size: 272 KiB |
BIN
examples/images/dify_start_server.png
Normal file
|
After Width: | Height: | Size: 73 KiB |
@@ -20,7 +20,7 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "doris-mcp-server"
|
name = "doris-mcp-server"
|
||||||
version = "0.3.0"
|
version = "0.6.1"
|
||||||
description = "Enterprise-grade Model Context Protocol (MCP) server implementation for Apache Doris"
|
description = "Enterprise-grade Model Context Protocol (MCP) server implementation for Apache Doris"
|
||||||
authors = [
|
authors = [
|
||||||
{name = "Yijia Su", email = "freeoneplus@apache.org"}
|
{name = "Yijia Su", email = "freeoneplus@apache.org"}
|
||||||
@@ -42,10 +42,14 @@ classifiers = [
|
|||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
# Core MCP dependencies
|
# Core MCP dependencies
|
||||||
"mcp>=1.0.0",
|
"mcp>=1.8.0,<2.0.0",
|
||||||
# Database drivers
|
# Database drivers
|
||||||
"aiomysql>=0.2.0",
|
"aiomysql>=0.2.0",
|
||||||
"PyMySQL>=1.1.0",
|
"PyMySQL>=1.1.0",
|
||||||
|
# ADBC (Arrow Flight SQL) dependencies
|
||||||
|
"adbc-driver-manager>=0.8.0",
|
||||||
|
"adbc-driver-flightsql>=0.8.0",
|
||||||
|
"pyarrow>=14.0.0",
|
||||||
# Async and utility libraries
|
# Async and utility libraries
|
||||||
"asyncio-mqtt>=0.16.0",
|
"asyncio-mqtt>=0.16.0",
|
||||||
"aiofiles>=23.0.0",
|
"aiofiles>=23.0.0",
|
||||||
|
|||||||
@@ -1,21 +1,5 @@
|
|||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
# Development dependencies - auto-generated from pyproject.toml
|
||||||
# or more contributor license agreements. See the NOTICE file
|
# Installation command: pip install -r requirements-dev.txt
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
# 开发依赖 - 从 pyproject.toml 自动生成
|
|
||||||
# 安装命令: pip install -r requirements-dev.txt
|
|
||||||
|
|
||||||
pytest>=7.4.0
|
pytest>=7.4.0
|
||||||
pytest-asyncio>=0.23.0
|
pytest-asyncio>=0.23.0
|
||||||
|
|||||||
@@ -1,26 +1,13 @@
|
|||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
# Main dependencies - auto-generated from pyproject.toml
|
||||||
# or more contributor license agreements. See the NOTICE file
|
# Do not edit this file manually, use 'python generate_requirements.py' to regenerate
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
# 主要依赖 - 从 pyproject.toml 自动生成
|
|
||||||
# 请不要手动编辑此文件,使用 python generate_requirements.py 重新生成
|
|
||||||
|
|
||||||
# === 核心依赖 ===
|
# === Core Dependencies ===
|
||||||
mcp>=1.0.0
|
mcp>=1.8.0,<2.0.0
|
||||||
aiomysql>=0.2.0
|
aiomysql>=0.2.0
|
||||||
PyMySQL>=1.1.0
|
PyMySQL>=1.1.0
|
||||||
|
adbc-driver-manager>=0.8.0
|
||||||
|
adbc-driver-flightsql>=0.8.0
|
||||||
|
pyarrow>=14.0.0
|
||||||
asyncio-mqtt>=0.16.0
|
asyncio-mqtt>=0.16.0
|
||||||
aiofiles>=23.0.0
|
aiofiles>=23.0.0
|
||||||
aiohttp>=3.9.0
|
aiohttp>=3.9.0
|
||||||
@@ -53,8 +40,11 @@ click>=8.1.0
|
|||||||
typer>=0.9.0
|
typer>=0.9.0
|
||||||
requests>=2.31.0
|
requests>=2.31.0
|
||||||
tqdm>=4.66.0
|
tqdm>=4.66.0
|
||||||
|
pytest>=8.4.0
|
||||||
|
pytest-asyncio>=1.0.0
|
||||||
|
pytest-cov>=6.1.1
|
||||||
|
|
||||||
# === 开发依赖 ===
|
# === Development Dependencies ===
|
||||||
pytest>=7.4.0
|
pytest>=7.4.0
|
||||||
pytest-asyncio>=0.23.0
|
pytest-asyncio>=0.23.0
|
||||||
pytest-cov>=4.1.0
|
pytest-cov>=4.1.0
|
||||||
|
|||||||
@@ -64,9 +64,11 @@ else
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# Set HTTP-specific environment variables
|
# Set HTTP-specific environment variables
|
||||||
|
# FIX for Issue #62 Bug 4: Use SERVER_PORT instead of MCP_PORT for consistency with code
|
||||||
export MCP_TRANSPORT_TYPE="http"
|
export MCP_TRANSPORT_TYPE="http"
|
||||||
export MCP_HOST="${MCP_HOST:-0.0.0.0}"
|
export MCP_HOST="${MCP_HOST:-0.0.0.0}"
|
||||||
export MCP_PORT="${MCP_PORT:-3000}"
|
export SERVER_PORT="${SERVER_PORT:-3000}" # Changed from MCP_PORT to SERVER_PORT
|
||||||
|
export WORKERS="${WORKERS:-1}"
|
||||||
export ALLOWED_ORIGINS="${ALLOWED_ORIGINS:-*}"
|
export ALLOWED_ORIGINS="${ALLOWED_ORIGINS:-*}"
|
||||||
export LOG_LEVEL="${LOG_LEVEL:-info}"
|
export LOG_LEVEL="${LOG_LEVEL:-info}"
|
||||||
export MCP_ALLOW_CREDENTIALS="${MCP_ALLOW_CREDENTIALS:-false}"
|
export MCP_ALLOW_CREDENTIALS="${MCP_ALLOW_CREDENTIALS:-false}"
|
||||||
@@ -76,14 +78,15 @@ export MCP_DEBUG_ADAPTER="true"
|
|||||||
export PYTHONPATH="$(pwd):$PYTHONPATH"
|
export PYTHONPATH="$(pwd):$PYTHONPATH"
|
||||||
|
|
||||||
echo -e "${GREEN}Starting MCP server (Streamable HTTP mode)...${NC}"
|
echo -e "${GREEN}Starting MCP server (Streamable HTTP mode)...${NC}"
|
||||||
echo -e "${YELLOW}Service will run on http://${MCP_HOST}:${MCP_PORT}/mcp${NC}"
|
echo -e "${YELLOW}Service will run on http://${MCP_HOST}:${SERVER_PORT}/mcp${NC}"
|
||||||
echo -e "${YELLOW}Health Check: http://${MCP_HOST}:${MCP_PORT}/health${NC}"
|
echo -e "${YELLOW}Health Check: http://${MCP_HOST}:${SERVER_PORT}/health${NC}"
|
||||||
echo -e "${YELLOW}MCP Endpoint: http://${MCP_HOST}:${MCP_PORT}/mcp${NC}"
|
echo -e "${YELLOW}MCP Endpoint: http://${MCP_HOST}:${SERVER_PORT}/mcp${NC}"
|
||||||
echo -e "${YELLOW}Local access: http://localhost:${MCP_PORT}/mcp${NC}"
|
echo -e "${YELLOW}Local access: http://localhost:${SERVER_PORT}/mcp${NC}"
|
||||||
|
echo -e "${YELLOW}Workers: ${WORKERS}${NC}"
|
||||||
echo -e "${YELLOW}Use Ctrl+C to stop the service${NC}"
|
echo -e "${YELLOW}Use Ctrl+C to stop the service${NC}"
|
||||||
|
|
||||||
# Start the server in HTTP mode (Streamable HTTP)
|
# Start the server in HTTP mode (Streamable HTTP)
|
||||||
python -m doris_mcp_server.main --transport http --host ${MCP_HOST} --port ${MCP_PORT}
|
python -m doris_mcp_server.main --transport http --host ${MCP_HOST} --port ${SERVER_PORT} --workers ${WORKERS}
|
||||||
|
|
||||||
# Check exit status
|
# Check exit status
|
||||||
if [ $? -ne 0 ]; then
|
if [ $? -ne 0 ]; then
|
||||||
@@ -95,4 +98,4 @@ fi
|
|||||||
echo -e "${YELLOW}Tip: If the page displays abnormally, please clear your browser cache or use incognito mode${NC}"
|
echo -e "${YELLOW}Tip: If the page displays abnormally, please clear your browser cache or use incognito mode${NC}"
|
||||||
echo -e "${YELLOW}Chrome browser clear cache shortcut: Ctrl+Shift+Del (Windows) or Cmd+Shift+Del (Mac)${NC}"
|
echo -e "${YELLOW}Chrome browser clear cache shortcut: Ctrl+Shift+Del (Windows) or Cmd+Shift+Del (Mac)${NC}"
|
||||||
echo -e "${CYAN}For testing HTTP endpoints, you can use:${NC}"
|
echo -e "${CYAN}For testing HTTP endpoints, you can use:${NC}"
|
||||||
echo -e "${CYAN} curl -X POST http://localhost:${MCP_PORT}/mcp -H 'Content-Type: application/json' -d '{\"method\":\"tools/list\"}'${NC}"
|
echo -e "${CYAN} curl -X POST http://localhost:${SERVER_PORT}/mcp -H 'Content-Type: application/json' -d '{\"method\":\"tools/list\"}'${NC}"
|
||||||
@@ -47,22 +47,29 @@ def event_loop():
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_config():
|
def test_config():
|
||||||
"""Provide test configuration"""
|
"""Test configuration fixture"""
|
||||||
return {
|
from doris_mcp_server.utils.config import DorisConfig, DatabaseConfig, SecurityConfig
|
||||||
"doris_host": "localhost",
|
|
||||||
"doris_port": 9030,
|
config = DorisConfig()
|
||||||
"doris_user": "test_user",
|
|
||||||
"doris_password": "test_password",
|
# Database configuration
|
||||||
"doris_database": "test_db",
|
config.database.host = "localhost"
|
||||||
"blocked_keywords": ["DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE"],
|
config.database.port = 9030
|
||||||
"sensitive_tables": {
|
config.database.user = "test_user"
|
||||||
"user_info": "confidential",
|
config.database.password = "test_password"
|
||||||
"payment_records": "secret",
|
config.database.database = "test_db"
|
||||||
"employee_data": "confidential",
|
config.database.health_check_interval = 60
|
||||||
"public_reports": "public"
|
config.database.max_connections = 20
|
||||||
},
|
config.database.connection_timeout = 30
|
||||||
"max_query_complexity": 100
|
config.database.max_connection_age = 3600
|
||||||
}
|
|
||||||
|
# Security configuration
|
||||||
|
config.security.enable_masking = True
|
||||||
|
config.security.auth_type = "token"
|
||||||
|
config.security.token_secret = "test_secret"
|
||||||
|
config.security.token_expiry = 3600
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@@ -34,17 +34,9 @@ class TestEndToEndIntegration:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_config(self):
|
def mock_config(self):
|
||||||
"""Create mock configuration"""
|
"""Create mock configuration"""
|
||||||
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
|
from doris_mcp_server.utils.config import ADBCConfig, DatabaseConfig, SecurityConfig
|
||||||
|
|
||||||
config = Mock(spec=DorisConfig)
|
config = Mock(spec=DorisConfig)
|
||||||
config.doris_host = "localhost"
|
|
||||||
config.doris_port = 9030
|
|
||||||
config.doris_user = "test_user"
|
|
||||||
config.doris_password = "test_password"
|
|
||||||
config.doris_database = "test_db"
|
|
||||||
config.server_host = "localhost"
|
|
||||||
config.server_port = 8000
|
|
||||||
config.enable_security = True
|
|
||||||
|
|
||||||
# Add database config
|
# Add database config
|
||||||
config.database = Mock(spec=DatabaseConfig)
|
config.database = Mock(spec=DatabaseConfig)
|
||||||
@@ -54,7 +46,6 @@ class TestEndToEndIntegration:
|
|||||||
config.database.password = "test_password"
|
config.database.password = "test_password"
|
||||||
config.database.database = "test_db"
|
config.database.database = "test_db"
|
||||||
config.database.health_check_interval = 60
|
config.database.health_check_interval = 60
|
||||||
config.database.min_connections = 5
|
|
||||||
config.database.max_connections = 20
|
config.database.max_connections = 20
|
||||||
config.database.connection_timeout = 30
|
config.database.connection_timeout = 30
|
||||||
config.database.max_connection_age = 3600
|
config.database.max_connection_age = 3600
|
||||||
@@ -65,7 +56,12 @@ class TestEndToEndIntegration:
|
|||||||
config.security.auth_type = "token"
|
config.security.auth_type = "token"
|
||||||
config.security.token_secret = "test_secret"
|
config.security.token_secret = "test_secret"
|
||||||
config.security.token_expiry = 3600
|
config.security.token_expiry = 3600
|
||||||
|
config.security.blocked_keywords = ["DROP"]
|
||||||
|
|
||||||
|
# Add adbc config
|
||||||
|
config.adbc = Mock(spec=ADBCConfig)
|
||||||
|
config.adbc.enabled = True
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -239,7 +235,7 @@ class TestEndToEndIntegration:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_tool_execution_with_security(self, doris_server):
|
async def test_tool_execution_with_security(self, doris_server):
|
||||||
"""Test tool execution with security checks"""
|
"""Test tool execution with security checks"""
|
||||||
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
|
with patch.object(doris_server.tools_manager.connection_manager, 'execute_query') as mock_execute:
|
||||||
mock_execute.return_value = [{"Database": "test_db"}]
|
mock_execute.return_value = [{"Database": "test_db"}]
|
||||||
|
|
||||||
# Test tool execution through tools manager
|
# Test tool execution through tools manager
|
||||||
@@ -266,7 +262,7 @@ class TestEndToEndIntegration:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_performance_monitoring_integration(self, doris_server):
|
async def test_performance_monitoring_integration(self, doris_server):
|
||||||
"""Test performance monitoring integration"""
|
"""Test performance monitoring integration"""
|
||||||
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
|
with patch.object(doris_server.tools_manager.connection_manager, 'execute_query') as mock_execute:
|
||||||
mock_execute.return_value = [
|
mock_execute.return_value = [
|
||||||
{
|
{
|
||||||
"query_count": 1500,
|
"query_count": 1500,
|
||||||
@@ -277,10 +273,7 @@ class TestEndToEndIntegration:
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Test performance stats tool
|
# Test performance stats tool
|
||||||
result = await doris_server.tools_manager.call_tool("performance_stats", {
|
result = await doris_server.tools_manager.call_tool("get_db_list", {})
|
||||||
"metric_type": "queries",
|
|
||||||
"time_range": "1h"
|
|
||||||
})
|
|
||||||
result_data = json.loads(result)
|
result_data = json.loads(result)
|
||||||
|
|
||||||
# Accept either success result or error (due to mock environment)
|
# Accept either success result or error (due to mock environment)
|
||||||
@@ -296,4 +289,4 @@ class TestEndToEndIntegration:
|
|||||||
# Verify tools are available - use list_tools instead
|
# Verify tools are available - use list_tools instead
|
||||||
import asyncio
|
import asyncio
|
||||||
tools = asyncio.run(doris_server.tools_manager.list_tools())
|
tools = asyncio.run(doris_server.tools_manager.list_tools())
|
||||||
assert len(tools) > 0
|
assert len(tools) > 0
|
||||||
|
|||||||
367
test/security/test_sql_injection.py
Normal file
@@ -0,0 +1,367 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
SQL Security Test Suite for Apache Doris MCP Server
|
||||||
|
|
||||||
|
Tests for:
|
||||||
|
1. SQL injection prevention via identifier validation
|
||||||
|
2. Multi-statement SQL parsing in security validator
|
||||||
|
3. auth_context enforcement
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
class TestSQLSecurityUtils:
|
||||||
|
"""Test cases for sql_security_utils module"""
|
||||||
|
|
||||||
|
def test_validate_identifier_accepts_valid_names(self):
|
||||||
|
"""Test that valid identifiers are accepted"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import validate_identifier
|
||||||
|
|
||||||
|
valid_names = [
|
||||||
|
"users",
|
||||||
|
"my_table",
|
||||||
|
"Table123",
|
||||||
|
"_private_table",
|
||||||
|
"CamelCaseTable",
|
||||||
|
"table_with_numbers_123",
|
||||||
|
]
|
||||||
|
|
||||||
|
for name in valid_names:
|
||||||
|
result = validate_identifier(name, "table")
|
||||||
|
assert result == name, f"Valid identifier '{name}' should be accepted"
|
||||||
|
|
||||||
|
def test_validate_identifier_rejects_sql_injection(self):
|
||||||
|
"""Test that SQL injection attempts are rejected"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import (
|
||||||
|
validate_identifier,
|
||||||
|
SQLSecurityError
|
||||||
|
)
|
||||||
|
|
||||||
|
injection_attempts = [
|
||||||
|
# Basic SQL injection
|
||||||
|
"'; DROP TABLE users; --",
|
||||||
|
"table' OR '1'='1",
|
||||||
|
"table'; DELETE FROM users; --",
|
||||||
|
|
||||||
|
# Union-based injection
|
||||||
|
"table' UNION SELECT * FROM passwords --",
|
||||||
|
|
||||||
|
# Comment injection
|
||||||
|
"table/**/OR/**/1=1",
|
||||||
|
"table--comment",
|
||||||
|
|
||||||
|
# Special characters
|
||||||
|
"table`; DROP TABLE users;",
|
||||||
|
'table"; DROP TABLE users;',
|
||||||
|
"table\"; DELETE FROM",
|
||||||
|
|
||||||
|
# Backtick escape attempt
|
||||||
|
"analytics`; SELECT * FROM sensitive_table;--",
|
||||||
|
|
||||||
|
# Whitespace injection
|
||||||
|
"table name with spaces",
|
||||||
|
"table\ttab",
|
||||||
|
"table\nnewline",
|
||||||
|
]
|
||||||
|
|
||||||
|
for injection in injection_attempts:
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
validate_identifier(injection, "table")
|
||||||
|
|
||||||
|
def test_validate_identifier_rejects_empty(self):
|
||||||
|
"""Test that empty identifiers are rejected"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import (
|
||||||
|
validate_identifier,
|
||||||
|
SQLSecurityError
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
validate_identifier("", "table")
|
||||||
|
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
validate_identifier(None, "table")
|
||||||
|
|
||||||
|
def test_validate_identifier_rejects_too_long(self):
|
||||||
|
"""Test that identifiers exceeding max length are rejected"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import (
|
||||||
|
validate_identifier,
|
||||||
|
SQLSecurityError
|
||||||
|
)
|
||||||
|
|
||||||
|
# Doris identifier max length is typically 64 characters
|
||||||
|
long_name = "a" * 100
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
validate_identifier(long_name, "table")
|
||||||
|
|
||||||
|
def test_quote_identifier_adds_backticks(self):
|
||||||
|
"""Test that quote_identifier properly escapes identifiers"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import quote_identifier
|
||||||
|
|
||||||
|
assert quote_identifier("my_table", "table") == "`my_table`"
|
||||||
|
assert quote_identifier("users", "table") == "`users`"
|
||||||
|
assert quote_identifier("Table123", "table") == "`Table123`"
|
||||||
|
|
||||||
|
def test_quote_identifier_validates_first(self):
|
||||||
|
"""Test that quote_identifier validates before quoting"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import (
|
||||||
|
quote_identifier,
|
||||||
|
SQLSecurityError
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
quote_identifier("'; DROP TABLE users; --", "table")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSQLSecurityValidator:
|
||||||
|
"""Test cases for SQLSecurityValidator multi-statement parsing"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dict_config(self):
|
||||||
|
"""Create dictionary configuration"""
|
||||||
|
return {
|
||||||
|
"blocked_keywords": [
|
||||||
|
"DROP", "CREATE", "ALTER", "TRUNCATE",
|
||||||
|
"DELETE", "INSERT", "UPDATE",
|
||||||
|
"GRANT", "REVOKE", "EXEC", "EXECUTE"
|
||||||
|
],
|
||||||
|
"max_query_complexity": 100,
|
||||||
|
"enable_security_check": True
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_auth_context(self):
|
||||||
|
"""Create mock auth context"""
|
||||||
|
from doris_mcp_server.utils.security import AuthContext, SecurityLevel
|
||||||
|
return AuthContext(
|
||||||
|
user_id="test_user",
|
||||||
|
roles=["user"],
|
||||||
|
security_level=SecurityLevel.INTERNAL
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validates_all_statements(self, dict_config, mock_auth_context):
|
||||||
|
"""Test that validator checks ALL SQL statements, not just the first"""
|
||||||
|
from doris_mcp_server.utils.security import SQLSecurityValidator
|
||||||
|
|
||||||
|
validator = SQLSecurityValidator(dict_config)
|
||||||
|
|
||||||
|
# Multi-statement with injection in second statement
|
||||||
|
# This should be BLOCKED
|
||||||
|
malicious_sql = "SELECT 1; DROP TABLE users; SELECT 2"
|
||||||
|
|
||||||
|
result = await validator.validate(malicious_sql, mock_auth_context)
|
||||||
|
|
||||||
|
assert not result.is_valid, "Multi-statement injection should be blocked"
|
||||||
|
# Check for either DROP keyword detection or SQL injection detection
|
||||||
|
error_upper = result.error_message.upper()
|
||||||
|
assert ("DROP" in error_upper or
|
||||||
|
"INJECTION" in error_upper or
|
||||||
|
"BLOCKED" in error_upper), f"Expected DROP/injection/blocked in: {result.error_message}"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_blocks_hidden_dangerous_statement(self, dict_config, mock_auth_context):
|
||||||
|
"""Test that dangerous statements hidden after safe ones are blocked"""
|
||||||
|
from doris_mcp_server.utils.security import SQLSecurityValidator
|
||||||
|
|
||||||
|
validator = SQLSecurityValidator(dict_config)
|
||||||
|
|
||||||
|
# Safe statement followed by dangerous one
|
||||||
|
malicious_sql = """
|
||||||
|
SELECT * FROM users WHERE id = 1;
|
||||||
|
DELETE FROM audit_log;
|
||||||
|
SELECT 1;
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = await validator.validate(malicious_sql, mock_auth_context)
|
||||||
|
|
||||||
|
assert not result.is_valid, "Hidden DELETE statement should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_allows_safe_multi_statement(self, dict_config, mock_auth_context):
|
||||||
|
"""Test that multiple safe SELECT statements are allowed"""
|
||||||
|
from doris_mcp_server.utils.security import SQLSecurityValidator
|
||||||
|
|
||||||
|
validator = SQLSecurityValidator(dict_config)
|
||||||
|
|
||||||
|
safe_sql = """
|
||||||
|
SELECT * FROM users;
|
||||||
|
SELECT COUNT(*) FROM orders;
|
||||||
|
SELECT id, name FROM products;
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = await validator.validate(safe_sql, mock_auth_context)
|
||||||
|
|
||||||
|
assert result.is_valid, f"Multiple safe SELECT statements should be allowed, got: {result.error_message}"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_switch_injection_blocked(self, dict_config, mock_auth_context):
|
||||||
|
"""Test that context switch SQL injection is blocked"""
|
||||||
|
from doris_mcp_server.utils.security import SQLSecurityValidator
|
||||||
|
|
||||||
|
validator = SQLSecurityValidator(dict_config)
|
||||||
|
|
||||||
|
# Simulating the exec_query_for_mcp attack vector
|
||||||
|
injected_sql = """
|
||||||
|
USE `analytics`; SELECT * FROM sensitive_table;-- `;
|
||||||
|
SELECT * FROM public_table;
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = await validator.validate(injected_sql, mock_auth_context)
|
||||||
|
|
||||||
|
# The validator should process all statements
|
||||||
|
# Even if USE is allowed, subsequent unauthorized access should be caught
|
||||||
|
# by table access checks (if configured)
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecQueryForMCP:
|
||||||
|
"""Test cases for exec_query_for_mcp function"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_malicious_db_name(self):
|
||||||
|
"""Test that malicious db_name is rejected"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import (
|
||||||
|
validate_identifier,
|
||||||
|
SQLSecurityError
|
||||||
|
)
|
||||||
|
|
||||||
|
# The attack vector from security report
|
||||||
|
malicious_db_name = "analytics`; SELECT * FROM sensitive_table;--"
|
||||||
|
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
validate_identifier(malicious_db_name, "database name")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_malicious_catalog_name(self):
|
||||||
|
"""Test that malicious catalog_name is rejected"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import (
|
||||||
|
validate_identifier,
|
||||||
|
SQLSecurityError
|
||||||
|
)
|
||||||
|
|
||||||
|
malicious_catalog_name = "internal'; DROP DATABASE production;--"
|
||||||
|
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
validate_identifier(malicious_catalog_name, "catalog name")
|
||||||
|
|
||||||
|
|
||||||
|
class TestDependencyAnalysisTools:
|
||||||
|
"""Test cases for dependency_analysis_tools security fixes"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_tables_metadata_rejects_injection(self):
|
||||||
|
"""Test that _get_tables_metadata rejects SQL injection"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import (
|
||||||
|
validate_identifier,
|
||||||
|
SQLSecurityError
|
||||||
|
)
|
||||||
|
|
||||||
|
# The attack vector from security report
|
||||||
|
injection_db_name = "test_db' OR '1'='1' --"
|
||||||
|
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
validate_identifier(injection_db_name, "database name")
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthContextEnforcement:
|
||||||
|
"""Test cases for auth_context enforcement"""
|
||||||
|
|
||||||
|
def test_execute_requires_auth_context_for_security(self):
|
||||||
|
"""Test that security checks require auth_context"""
|
||||||
|
# This test documents the expected behavior:
|
||||||
|
# When auth_context is None, security checks are skipped
|
||||||
|
# When auth_context is provided, security checks are performed
|
||||||
|
|
||||||
|
# The fix ensures all execute() calls pass auth_context
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_auth_context_returns_context(self):
|
||||||
|
"""Test that get_auth_context retrieves context from ContextVar"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import get_auth_context
|
||||||
|
|
||||||
|
# When no context is set, should return None
|
||||||
|
result = get_auth_context()
|
||||||
|
# This is expected - context is set by HTTP middleware
|
||||||
|
assert result is None or hasattr(result, 'user_id')
|
||||||
|
|
||||||
|
|
||||||
|
class TestIntegrationScenarios:
|
||||||
|
"""Integration test scenarios for security fixes"""
|
||||||
|
|
||||||
|
def test_attack_scenario_1_permission_bypass(self):
|
||||||
|
"""
|
||||||
|
Attack Scenario 1: Permission Bypass in Multi-Tenant Environment
|
||||||
|
|
||||||
|
Expected: User can only query their own database (db_name="tenant_a_db")
|
||||||
|
Attack: Inject "tenant_a_db' OR '1'='1' --" to query ALL databases
|
||||||
|
Result: Should be BLOCKED by validate_identifier()
|
||||||
|
"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import (
|
||||||
|
validate_identifier,
|
||||||
|
SQLSecurityError
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
validate_identifier("tenant_a_db' OR '1'='1' --", "database name")
|
||||||
|
|
||||||
|
def test_attack_scenario_2_union_injection(self):
|
||||||
|
"""
|
||||||
|
Attack Scenario 2: UNION-based Information Disclosure
|
||||||
|
|
||||||
|
Attack: Inject UNION SELECT to extract sensitive data
|
||||||
|
Result: Should be BLOCKED
|
||||||
|
"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import (
|
||||||
|
validate_identifier,
|
||||||
|
SQLSecurityError
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
validate_identifier(
|
||||||
|
"test' UNION SELECT password FROM users --",
|
||||||
|
"database name"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_attack_scenario_3_backtick_escape(self):
|
||||||
|
"""
|
||||||
|
Attack Scenario 3: Backtick Escape Attempt
|
||||||
|
|
||||||
|
Attack: Use backticks to break out of quoted identifier
|
||||||
|
Result: Should be BLOCKED
|
||||||
|
"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import (
|
||||||
|
validate_identifier,
|
||||||
|
SQLSecurityError
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
validate_identifier(
|
||||||
|
"analytics`; SELECT * FROM sensitive_table;--",
|
||||||
|
"database name"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Run tests with: pytest tests/test_sql_security.py -v
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v", "--tb=short"])
|
||||||
|
|
||||||
871
test/security/test_sql_injection_api.py
Normal file
@@ -0,0 +1,871 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
SQL Injection API Integration Tests
|
||||||
|
|
||||||
|
This module tests SQL injection prevention through the MCP HTTP API.
|
||||||
|
It sends malicious payloads and verifies they are properly blocked.
|
||||||
|
|
||||||
|
Prerequisites:
|
||||||
|
- MCP server running on localhost:3000
|
||||||
|
- Run with: pytest test/security/test_sql_injection_api.py -v
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Start server first
|
||||||
|
bash start_server.sh
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
pytest test/security/test_sql_injection_api.py -v --no-cov
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import httpx
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
# Server configuration
|
||||||
|
MCP_BASE_URL = "http://localhost:3000"
|
||||||
|
MCP_ENDPOINT = f"{MCP_BASE_URL}/mcp"
|
||||||
|
HEALTH_ENDPOINT = f"{MCP_BASE_URL}/health"
|
||||||
|
TIMEOUT = 30.0
|
||||||
|
|
||||||
|
|
||||||
|
class MCPClient:
|
||||||
|
"""Simple MCP HTTP client for testing"""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str = MCP_BASE_URL):
|
||||||
|
self.base_url = base_url
|
||||||
|
self.mcp_endpoint = f"{base_url}/mcp"
|
||||||
|
self.session_id: Optional[str] = None
|
||||||
|
self.request_id = 0
|
||||||
|
self.client = httpx.AsyncClient(timeout=TIMEOUT)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
await self.client.aclose()
|
||||||
|
|
||||||
|
def _next_id(self) -> int:
|
||||||
|
self.request_id += 1
|
||||||
|
return self.request_id
|
||||||
|
|
||||||
|
async def initialize(self) -> dict:
|
||||||
|
"""Initialize MCP session"""
|
||||||
|
response = await self.client.post(
|
||||||
|
self.mcp_endpoint,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json, text/event-stream"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "initialize",
|
||||||
|
"params": {
|
||||||
|
"protocolVersion": "2024-11-05",
|
||||||
|
"capabilities": {},
|
||||||
|
"clientInfo": {
|
||||||
|
"name": "sql-injection-test",
|
||||||
|
"version": "1.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": self._next_id()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract session ID from response header
|
||||||
|
self.session_id = response.headers.get("mcp-session-id")
|
||||||
|
return self._parse_response(response.text)
|
||||||
|
|
||||||
|
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
|
||||||
|
"""Call an MCP tool"""
|
||||||
|
if not self.session_id:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
response = await self.client.post(
|
||||||
|
self.mcp_endpoint,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
"mcp-session-id": self.session_id
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {
|
||||||
|
"name": tool_name,
|
||||||
|
"arguments": arguments
|
||||||
|
},
|
||||||
|
"id": self._next_id()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._parse_response(response.text)
|
||||||
|
|
||||||
|
def _parse_response(self, text: str) -> dict:
|
||||||
|
"""Parse JSON response"""
|
||||||
|
try:
|
||||||
|
return json.loads(text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Try SSE format
|
||||||
|
lines = text.strip().split("\n")
|
||||||
|
for line in lines:
|
||||||
|
if line.startswith("data: "):
|
||||||
|
try:
|
||||||
|
return json.loads(line[6:])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
return {"raw": text}
|
||||||
|
|
||||||
|
|
||||||
|
def print_result(test_name: str, payload: dict, result: dict):
|
||||||
|
"""Print test result in a readable format"""
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"TEST: {test_name}")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f"PAYLOAD: {json.dumps(payload, ensure_ascii=False)}")
|
||||||
|
print(f"{'-'*60}")
|
||||||
|
|
||||||
|
# Extract inner result content
|
||||||
|
if "result" in result and "content" in result.get("result", {}):
|
||||||
|
for item in result["result"]["content"]:
|
||||||
|
if item.get("type") == "text":
|
||||||
|
try:
|
||||||
|
inner = json.loads(item["text"])
|
||||||
|
print("RESPONSE:")
|
||||||
|
print(f" success: {inner.get('success')}")
|
||||||
|
if inner.get('error'):
|
||||||
|
print(f" error: {inner.get('error')}")
|
||||||
|
if inner.get('error_type'):
|
||||||
|
print(f" error_type: {inner.get('error_type')}")
|
||||||
|
if inner.get('risk_level'):
|
||||||
|
print(f" risk_level: {inner.get('risk_level')}")
|
||||||
|
if inner.get('message'):
|
||||||
|
print(f" message: {inner.get('message')}")
|
||||||
|
if inner.get('data') is not None and inner.get('success'):
|
||||||
|
data_str = json.dumps(inner.get('data'), ensure_ascii=False)
|
||||||
|
if len(data_str) > 200:
|
||||||
|
data_str = data_str[:200] + "..."
|
||||||
|
print(f" data: {data_str}")
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
print(f"RESPONSE (raw): {item.get('text', '')[:500]}")
|
||||||
|
elif "error" in result:
|
||||||
|
print(f"RESPONSE ERROR: {result['error']}")
|
||||||
|
else:
|
||||||
|
print(f"RESPONSE (raw): {json.dumps(result, ensure_ascii=False)[:500]}")
|
||||||
|
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSQLInjectionAPI:
|
||||||
|
"""Test SQL injection prevention through MCP API"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mcp_client(self):
|
||||||
|
"""Create MCP client instance"""
|
||||||
|
client = MCPClient()
|
||||||
|
yield client
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def is_server_running(self):
|
||||||
|
"""Check if MCP server is running"""
|
||||||
|
import httpx
|
||||||
|
try:
|
||||||
|
response = httpx.get(HEALTH_ENDPOINT, timeout=5.0)
|
||||||
|
return response.status_code == 200
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_server_health(self):
|
||||||
|
"""Test that MCP server is running and healthy"""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(HEALTH_ENDPOINT)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "healthy"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_query_with_drop_injection(self, mcp_client):
|
||||||
|
"""Test exec_query rejects DROP TABLE injection"""
|
||||||
|
# Classic SQL injection: append DROP TABLE
|
||||||
|
payload = {"sql": "SELECT * FROM users; DROP TABLE users; --"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("DROP TABLE Injection", payload, result)
|
||||||
|
|
||||||
|
# Should return error, not execute the DROP
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
f"DROP TABLE injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_query_with_union_injection(self, mcp_client):
|
||||||
|
"""Test exec_query blocks UNION-based injection attempts"""
|
||||||
|
# UNION injection to extract data from other tables
|
||||||
|
payload = {"sql": "SELECT id FROM users UNION SELECT password FROM admin_users"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("UNION Injection", payload, result)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_query_with_delete_injection(self, mcp_client):
|
||||||
|
"""Test exec_query rejects DELETE injection"""
|
||||||
|
payload = {"sql": "SELECT 1; DELETE FROM users WHERE 1=1; SELECT 2"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("DELETE Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
f"DELETE injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_query_with_update_injection(self, mcp_client):
|
||||||
|
"""Test exec_query rejects UPDATE injection"""
|
||||||
|
payload = {"sql": "SELECT 1; UPDATE users SET role='admin' WHERE id=1; --"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("UPDATE Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
f"UPDATE injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_query_db_name_injection(self, mcp_client):
|
||||||
|
"""Test exec_query rejects SQL injection via db_name parameter"""
|
||||||
|
# Attack vector: inject SQL via db_name parameter
|
||||||
|
payload = {"sql": "SELECT 1", "db_name": "test'; DROP TABLE users; --"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("db_name Parameter Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
f"db_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_query_catalog_name_injection(self, mcp_client):
|
||||||
|
"""Test exec_query rejects SQL injection via catalog_name parameter"""
|
||||||
|
# Attack vector: inject SQL via catalog_name parameter
|
||||||
|
payload = {"sql": "SELECT 1", "catalog_name": "internal`; SELECT * FROM mysql.user; --"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("catalog_name Parameter Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
f"catalog_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_table_schema_injection(self, mcp_client):
|
||||||
|
"""Test get_table_schema rejects SQL injection via table_name"""
|
||||||
|
# Attack vector: inject SQL via table_name parameter
|
||||||
|
payload = {"table_name": "users'; DROP TABLE users; --"}
|
||||||
|
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||||
|
print_result("table_name Injection (get_table_schema)", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
f"table_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_table_schema_db_injection(self, mcp_client):
|
||||||
|
"""Test get_table_schema rejects SQL injection via db_name"""
|
||||||
|
payload = {"table_name": "users", "db_name": "test' OR '1'='1"}
|
||||||
|
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||||
|
print_result("db_name Injection (get_table_schema)", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
f"db_name injection in get_table_schema should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analyze_dependencies_injection(self, mcp_client):
|
||||||
|
"""Test analyze_dependencies rejects SQL injection"""
|
||||||
|
# This was the original vulnerability reported
|
||||||
|
payload = {"table_name": "users", "db_name": "test_db' OR '1'='1' --"}
|
||||||
|
result = await mcp_client.call_tool("analyze_dependencies", payload)
|
||||||
|
print_result("analyze_dependencies Injection (Original Report)", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
f"analyze_dependencies db_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stacked_queries_injection(self, mcp_client):
|
||||||
|
"""Test that stacked queries (multiple statements) are blocked"""
|
||||||
|
# Multiple statements injection
|
||||||
|
payload = {"sql": "SELECT * FROM users WHERE id = 1; INSERT INTO audit_log VALUES (NULL, 'hacked', NOW()); SELECT 1;"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("Stacked Queries (INSERT) Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
f"Stacked queries with INSERT should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_comment_based_injection(self, mcp_client):
|
||||||
|
"""Test that comment-based injection is blocked"""
|
||||||
|
# Using comments to bypass filters
|
||||||
|
payload = {"sql": "SELECT * FROM users WHERE id = 1/**/OR/**/1=1"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("Comment-based Injection", payload, result)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hex_encoded_injection(self, mcp_client):
|
||||||
|
"""Test that hex-encoded injection attempts are handled"""
|
||||||
|
# Hex-encoded 'DROP' attempt
|
||||||
|
payload = {"sql": "SELECT 0x44524F50205441424C4520757365727320"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("Hex Encoded Injection", payload, result)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_backtick_escape_injection(self, mcp_client):
|
||||||
|
"""Test backtick escape injection is blocked"""
|
||||||
|
# Attempt to escape backtick quoting
|
||||||
|
payload = {"sql": "SELECT 1", "db_name": "analytics`; SELECT * FROM sensitive_table;--"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("Backtick Escape Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
f"Backtick escape injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_valid_query_succeeds(self, mcp_client):
|
||||||
|
"""Test that valid queries still work"""
|
||||||
|
# Simple valid query should work
|
||||||
|
payload = {"sql": "SELECT 1 AS test_value"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("Valid Query (should succeed)", payload, result)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_valid_show_databases(self, mcp_client):
|
||||||
|
"""Test that SHOW DATABASES works"""
|
||||||
|
payload = {"sql": "SHOW DATABASES"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("SHOW DATABASES (should succeed)", payload, result)
|
||||||
|
|
||||||
|
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||||
|
"""Check if result indicates blocked or error"""
|
||||||
|
if not result:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for JSON-RPC error
|
||||||
|
if "error" in result:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for error in result content
|
||||||
|
if "result" in result:
|
||||||
|
result_content = result.get("result", {})
|
||||||
|
if isinstance(result_content, dict):
|
||||||
|
# Check for isError flag
|
||||||
|
if result_content.get("isError"):
|
||||||
|
return True
|
||||||
|
# Check content array for error messages
|
||||||
|
content = result_content.get("content", [])
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
text = item.get("text", "")
|
||||||
|
# Parse the JSON text content
|
||||||
|
try:
|
||||||
|
text_data = json.loads(text)
|
||||||
|
# Check for success: false
|
||||||
|
if text_data.get("success") is False:
|
||||||
|
return True
|
||||||
|
# Check for error field
|
||||||
|
if text_data.get("error"):
|
||||||
|
return True
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
# Check text for security keywords
|
||||||
|
if any(keyword in text.lower() for keyword in [
|
||||||
|
"error", "blocked", "invalid", "security",
|
||||||
|
"injection", "denied", "forbidden", "not allowed",
|
||||||
|
"security_violation", "risk_level"
|
||||||
|
]):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check raw text response
|
||||||
|
raw = result.get("raw", "")
|
||||||
|
if isinstance(raw, str) and any(keyword in raw.lower() for keyword in [
|
||||||
|
"error", "blocked", "invalid", "security"
|
||||||
|
]):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class TestIdentifierInjectionAPI:
|
||||||
|
"""Test identifier-based SQL injection prevention"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mcp_client(self):
|
||||||
|
"""Create MCP client instance"""
|
||||||
|
client = MCPClient()
|
||||||
|
yield client
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_table_name_with_semicolon(self, mcp_client):
|
||||||
|
"""Test table name containing semicolon is rejected"""
|
||||||
|
payload = {"table_name": "users; DROP TABLE users"}
|
||||||
|
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||||
|
print_result("Table Name with Semicolon", payload, result)
|
||||||
|
|
||||||
|
# Should be blocked by identifier validation
|
||||||
|
assert self._contains_error_indicator(result), \
|
||||||
|
f"Table name with semicolon should be rejected"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_table_name_with_quotes(self, mcp_client):
|
||||||
|
"""Test table name containing quotes is rejected"""
|
||||||
|
payload = {"table_name": "users' OR '1'='1"}
|
||||||
|
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||||
|
print_result("Table Name with Quotes", payload, result)
|
||||||
|
|
||||||
|
assert self._contains_error_indicator(result), \
|
||||||
|
f"Table name with quotes should be rejected"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_db_name_with_special_chars(self, mcp_client):
|
||||||
|
"""Test database name with special characters is rejected"""
|
||||||
|
special_chars = [
|
||||||
|
"test;db",
|
||||||
|
"test'db",
|
||||||
|
"test\"db",
|
||||||
|
"test`db",
|
||||||
|
"test--db",
|
||||||
|
"test/*db*/",
|
||||||
|
]
|
||||||
|
|
||||||
|
for db_name in special_chars:
|
||||||
|
payload = {"table_name": "users", "db_name": db_name}
|
||||||
|
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||||
|
print_result(f"Special Char in db_name: {db_name}", payload, result)
|
||||||
|
|
||||||
|
assert self._contains_error_indicator(result), \
|
||||||
|
f"db_name '{db_name}' should be rejected"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_valid_identifiers_accepted(self, mcp_client):
|
||||||
|
"""Test that valid identifiers are accepted"""
|
||||||
|
valid_names = [
|
||||||
|
"users",
|
||||||
|
"my_table",
|
||||||
|
"Table123",
|
||||||
|
"_internal_table",
|
||||||
|
]
|
||||||
|
|
||||||
|
for table_name in valid_names:
|
||||||
|
payload = {"table_name": table_name}
|
||||||
|
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||||
|
print_result(f"Valid Identifier: {table_name}", payload, result)
|
||||||
|
|
||||||
|
def _contains_error_indicator(self, result: dict) -> bool:
|
||||||
|
"""Check if result contains error indicators"""
|
||||||
|
if not result:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for JSON-RPC error
|
||||||
|
if "error" in result:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check result content
|
||||||
|
result_str = json.dumps(result).lower()
|
||||||
|
error_keywords = [
|
||||||
|
"error", "invalid", "illegal", "blocked",
|
||||||
|
"security", "injection", "denied", "forbidden"
|
||||||
|
]
|
||||||
|
|
||||||
|
return any(keyword in result_str for keyword in error_keywords)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiStatementInjectionAPI:
|
||||||
|
"""Test multi-statement SQL injection prevention"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mcp_client(self):
|
||||||
|
"""Create MCP client instance"""
|
||||||
|
client = MCPClient()
|
||||||
|
yield client
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hidden_drop_after_select(self, mcp_client):
|
||||||
|
"""Test DROP hidden after legitimate SELECT is blocked"""
|
||||||
|
payload = {"sql": "SELECT id, name FROM users WHERE status = 'active'; DROP TABLE audit_log; SELECT 1;"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("Hidden DROP after SELECT", payload, result)
|
||||||
|
|
||||||
|
assert self._is_dangerous_blocked(result), \
|
||||||
|
f"Hidden DROP statement should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hidden_truncate_after_select(self, mcp_client):
|
||||||
|
"""Test TRUNCATE hidden after SELECT is blocked"""
|
||||||
|
payload = {"sql": "SELECT 1; TRUNCATE TABLE users"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("Hidden TRUNCATE after SELECT", payload, result)
|
||||||
|
|
||||||
|
assert self._is_dangerous_blocked(result), \
|
||||||
|
f"Hidden TRUNCATE should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hidden_grant_after_select(self, mcp_client):
|
||||||
|
"""Test GRANT hidden after SELECT is blocked"""
|
||||||
|
payload = {"sql": "SELECT 1; GRANT ALL ON *.* TO 'hacker'@'%'"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("Hidden GRANT after SELECT", payload, result)
|
||||||
|
|
||||||
|
assert self._is_dangerous_blocked(result), \
|
||||||
|
f"Hidden GRANT should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_safe_selects_allowed(self, mcp_client):
|
||||||
|
"""Test that multiple SELECT statements may be allowed"""
|
||||||
|
payload = {"sql": "SELECT 1; SELECT 2; SELECT 3;"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("Multiple Safe SELECTs", payload, result)
|
||||||
|
|
||||||
|
def _is_dangerous_blocked(self, result: dict) -> bool:
|
||||||
|
"""Check if dangerous operation was blocked"""
|
||||||
|
if not result:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for error
|
||||||
|
if "error" in result:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check result content for blocking indicators
|
||||||
|
result_str = json.dumps(result).lower()
|
||||||
|
block_indicators = [
|
||||||
|
"drop", "truncate", "grant", "revoke",
|
||||||
|
"blocked", "denied", "forbidden", "not allowed",
|
||||||
|
"security", "error"
|
||||||
|
]
|
||||||
|
|
||||||
|
return any(indicator in result_str for indicator in block_indicators)
|
||||||
|
|
||||||
|
|
||||||
|
class TestADBCQueryInjectionAPI:
|
||||||
|
"""Test ADBC query SQL injection prevention"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mcp_client(self):
|
||||||
|
"""Create MCP client instance"""
|
||||||
|
client = MCPClient()
|
||||||
|
yield client
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_adbc_query_drop_injection(self, mcp_client):
|
||||||
|
"""Test exec_adbc_query rejects DROP TABLE injection"""
|
||||||
|
payload = {"sql": "SELECT * FROM users; DROP TABLE users; --"}
|
||||||
|
result = await mcp_client.call_tool("exec_adbc_query", payload)
|
||||||
|
print_result("ADBC DROP TABLE Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"ADBC DROP TABLE injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_adbc_query_delete_injection(self, mcp_client):
|
||||||
|
"""Test exec_adbc_query rejects DELETE injection"""
|
||||||
|
payload = {"sql": "SELECT 1; DELETE FROM users; --"}
|
||||||
|
result = await mcp_client.call_tool("exec_adbc_query", payload)
|
||||||
|
print_result("ADBC DELETE Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"ADBC DELETE injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_adbc_query_valid(self, mcp_client):
|
||||||
|
"""Test exec_adbc_query allows valid queries"""
|
||||||
|
payload = {"sql": "SELECT 1 AS test"}
|
||||||
|
result = await mcp_client.call_tool("exec_adbc_query", payload)
|
||||||
|
print_result("ADBC Valid Query", payload, result)
|
||||||
|
|
||||||
|
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||||
|
"""Check if result indicates blocked or error"""
|
||||||
|
if not result:
|
||||||
|
return True
|
||||||
|
if "error" in result:
|
||||||
|
return True
|
||||||
|
result_str = json.dumps(result).lower()
|
||||||
|
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetadataToolsInjectionAPI:
|
||||||
|
"""Test metadata tools SQL injection prevention"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mcp_client(self):
|
||||||
|
"""Create MCP client instance"""
|
||||||
|
client = MCPClient()
|
||||||
|
yield client
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_db_table_list_db_injection(self, mcp_client):
|
||||||
|
"""Test get_db_table_list rejects db_name injection"""
|
||||||
|
payload = {"db_name": "test'; DROP TABLE users; --"}
|
||||||
|
result = await mcp_client.call_tool("get_db_table_list", payload)
|
||||||
|
print_result("get_db_table_list db_name Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"db_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_db_table_list_catalog_injection(self, mcp_client):
|
||||||
|
"""Test get_db_table_list rejects catalog_name injection"""
|
||||||
|
payload = {"catalog_name": "internal`; SELECT * FROM mysql.user; --"}
|
||||||
|
result = await mcp_client.call_tool("get_db_table_list", payload)
|
||||||
|
print_result("get_db_table_list catalog_name Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"catalog_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_table_comment_injection(self, mcp_client):
|
||||||
|
"""Test get_table_comment rejects table_name injection"""
|
||||||
|
payload = {"table_name": "users'; DROP TABLE users; --"}
|
||||||
|
result = await mcp_client.call_tool("get_table_comment", payload)
|
||||||
|
print_result("get_table_comment table_name Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"table_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_table_column_comments_injection(self, mcp_client):
|
||||||
|
"""Test get_table_column_comments rejects injection"""
|
||||||
|
payload = {"table_name": "users'; DROP TABLE users; --", "db_name": "test"}
|
||||||
|
result = await mcp_client.call_tool("get_table_column_comments", payload)
|
||||||
|
print_result("get_table_column_comments Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"table_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_table_indexes_injection(self, mcp_client):
|
||||||
|
"""Test get_table_indexes rejects table_name injection"""
|
||||||
|
payload = {"table_name": "users; DROP TABLE users", "db_name": "test"}
|
||||||
|
result = await mcp_client.call_tool("get_table_indexes", payload)
|
||||||
|
print_result("get_table_indexes Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"table_name injection should be blocked"
|
||||||
|
|
||||||
|
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||||
|
"""Check if result indicates blocked or error"""
|
||||||
|
if not result:
|
||||||
|
return True
|
||||||
|
if "error" in result:
|
||||||
|
return True
|
||||||
|
result_str = json.dumps(result).lower()
|
||||||
|
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalyticsToolsInjectionAPI:
|
||||||
|
"""Test analytics tools SQL injection prevention"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mcp_client(self):
|
||||||
|
"""Create MCP client instance"""
|
||||||
|
client = MCPClient()
|
||||||
|
yield client
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analyze_columns_table_injection(self, mcp_client):
|
||||||
|
"""Test analyze_columns rejects table_name injection"""
|
||||||
|
payload = {"table_name": "users'; DROP TABLE users; --"}
|
||||||
|
result = await mcp_client.call_tool("analyze_columns", payload)
|
||||||
|
print_result("analyze_columns table_name Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"table_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analyze_columns_db_injection(self, mcp_client):
|
||||||
|
"""Test analyze_columns rejects db_name injection"""
|
||||||
|
payload = {"table_name": "users", "db_name": "test' OR '1'='1"}
|
||||||
|
result = await mcp_client.call_tool("analyze_columns", payload)
|
||||||
|
print_result("analyze_columns db_name Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"db_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_table_basic_info_injection(self, mcp_client):
|
||||||
|
"""Test get_table_basic_info rejects injection"""
|
||||||
|
payload = {"table_name": "users; DROP TABLE audit_log"}
|
||||||
|
result = await mcp_client.call_tool("get_table_basic_info", payload)
|
||||||
|
print_result("get_table_basic_info Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"table_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analyze_table_storage_injection(self, mcp_client):
|
||||||
|
"""Test analyze_table_storage rejects injection"""
|
||||||
|
payload = {"table_name": "users`; SELECT * FROM sensitive; --"}
|
||||||
|
result = await mcp_client.call_tool("analyze_table_storage", payload)
|
||||||
|
print_result("analyze_table_storage Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"table_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_sql_explain_injection(self, mcp_client):
|
||||||
|
"""Test get_sql_explain rejects SQL injection"""
|
||||||
|
payload = {"sql": "SELECT 1; DROP TABLE users; --"}
|
||||||
|
result = await mcp_client.call_tool("get_sql_explain", payload)
|
||||||
|
print_result("get_sql_explain SQL Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"SQL injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_sql_profile_injection(self, mcp_client):
|
||||||
|
"""Test get_sql_profile rejects SQL injection"""
|
||||||
|
payload = {"sql": "SELECT 1; DELETE FROM audit_log; --"}
|
||||||
|
result = await mcp_client.call_tool("get_sql_profile", payload)
|
||||||
|
print_result("get_sql_profile SQL Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"SQL injection should be blocked"
|
||||||
|
|
||||||
|
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||||
|
"""Check if result indicates blocked or error"""
|
||||||
|
if not result:
|
||||||
|
return True
|
||||||
|
if "error" in result:
|
||||||
|
return True
|
||||||
|
result_str = json.dumps(result).lower()
|
||||||
|
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestGovernanceToolsInjectionAPI:
|
||||||
|
"""Test data governance tools SQL injection prevention"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mcp_client(self):
|
||||||
|
"""Create MCP client instance"""
|
||||||
|
client = MCPClient()
|
||||||
|
yield client
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trace_column_lineage_table_injection(self, mcp_client):
|
||||||
|
"""Test trace_column_lineage rejects table_name injection"""
|
||||||
|
payload = {"table_name": "users'; DROP TABLE users; --", "column_name": "id"}
|
||||||
|
result = await mcp_client.call_tool("trace_column_lineage", payload)
|
||||||
|
print_result("trace_column_lineage table_name Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"table_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trace_column_lineage_column_injection(self, mcp_client):
|
||||||
|
"""Test trace_column_lineage rejects column_name injection"""
|
||||||
|
payload = {"table_name": "users", "column_name": "id; DROP TABLE users"}
|
||||||
|
result = await mcp_client.call_tool("trace_column_lineage", payload)
|
||||||
|
print_result("trace_column_lineage column_name Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"column_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_monitor_data_freshness_injection(self, mcp_client):
|
||||||
|
"""Test monitor_data_freshness rejects table_name injection"""
|
||||||
|
payload = {"table_name": "users`; SELECT * FROM passwords; --"}
|
||||||
|
result = await mcp_client.call_tool("monitor_data_freshness", payload)
|
||||||
|
print_result("monitor_data_freshness Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"table_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analyze_data_access_patterns_injection(self, mcp_client):
|
||||||
|
"""Test analyze_data_access_patterns rejects injection"""
|
||||||
|
payload = {"table_name": "users' UNION SELECT password FROM admin --"}
|
||||||
|
result = await mcp_client.call_tool("analyze_data_access_patterns", payload)
|
||||||
|
print_result("analyze_data_access_patterns Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"table_name injection should be blocked"
|
||||||
|
|
||||||
|
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||||
|
"""Check if result indicates blocked or error"""
|
||||||
|
if not result:
|
||||||
|
return True
|
||||||
|
if "error" in result:
|
||||||
|
return True
|
||||||
|
result_str = json.dumps(result).lower()
|
||||||
|
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerformanceToolsInjectionAPI:
|
||||||
|
"""Test performance analytics tools SQL injection prevention"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mcp_client(self):
|
||||||
|
"""Create MCP client instance"""
|
||||||
|
client = MCPClient()
|
||||||
|
yield client
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analyze_slow_queries_db_injection(self, mcp_client):
|
||||||
|
"""Test analyze_slow_queries_topn rejects db_name injection"""
|
||||||
|
payload = {"db_name": "test'; DROP TABLE audit_log; --"}
|
||||||
|
result = await mcp_client.call_tool("analyze_slow_queries_topn", payload)
|
||||||
|
print_result("analyze_slow_queries_topn db_name Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"db_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analyze_resource_growth_db_injection(self, mcp_client):
|
||||||
|
"""Test analyze_resource_growth_curves rejects db_name injection"""
|
||||||
|
payload = {"db_name": "test`; GRANT ALL ON *.* TO 'hacker'; --"}
|
||||||
|
result = await mcp_client.call_tool("analyze_resource_growth_curves", payload)
|
||||||
|
print_result("analyze_resource_growth_curves db_name Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"db_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_table_data_size_injection(self, mcp_client):
|
||||||
|
"""Test get_table_data_size rejects table_name injection"""
|
||||||
|
payload = {"table_name": "users; TRUNCATE TABLE logs"}
|
||||||
|
result = await mcp_client.call_tool("get_table_data_size", payload)
|
||||||
|
print_result("get_table_data_size Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"table_name injection should be blocked"
|
||||||
|
|
||||||
|
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||||
|
"""Check if result indicates blocked or error"""
|
||||||
|
if not result:
|
||||||
|
return True
|
||||||
|
if "error" in result:
|
||||||
|
return True
|
||||||
|
result_str = json.dumps(result).lower()
|
||||||
|
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||||
|
|
||||||
|
|
||||||
|
# Pytest configuration for async tests
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def event_loop():
|
||||||
|
"""Create event loop for async tests"""
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v", "--tb=short", "-x"])
|
||||||
|
|
||||||
@@ -44,17 +44,31 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"expected_tools": [
|
"expected_tools": [
|
||||||
|
"analyze_columns",
|
||||||
|
"analyze_data_access_patterns",
|
||||||
|
"analyze_data_flow_dependencies",
|
||||||
|
"analyze_resource_growth_curves",
|
||||||
|
"analyze_slow_queries_topn",
|
||||||
|
"analyze_table_storage",
|
||||||
|
"exec_adbc_query",
|
||||||
"exec_query",
|
"exec_query",
|
||||||
"get_db_list",
|
"get_adbc_connection_info",
|
||||||
|
"get_catalog_list",
|
||||||
|
"get_db_list",
|
||||||
"get_db_table_list",
|
"get_db_table_list",
|
||||||
"get_table_schema",
|
"get_memory_stats",
|
||||||
"get_table_comment",
|
"get_monitoring_metrics",
|
||||||
"get_table_column_comments",
|
|
||||||
"get_table_indexes",
|
|
||||||
"column_analysis",
|
|
||||||
"performance_stats",
|
|
||||||
"get_recent_audit_logs",
|
"get_recent_audit_logs",
|
||||||
"get_catalog_list"
|
"get_sql_explain",
|
||||||
|
"get_sql_profile",
|
||||||
|
"get_table_basic_info",
|
||||||
|
"get_table_column_comments",
|
||||||
|
"get_table_comment",
|
||||||
|
"get_table_data_size",
|
||||||
|
"get_table_indexes",
|
||||||
|
"get_table_schema",
|
||||||
|
"monitor_data_freshness",
|
||||||
|
"trace_column_lineage"
|
||||||
],
|
],
|
||||||
"expected_resources": [
|
"expected_resources": [
|
||||||
"database",
|
"database",
|
||||||
@@ -66,4 +80,4 @@
|
|||||||
"data_analysis_helper",
|
"data_analysis_helper",
|
||||||
"schema_explorer"
|
"schema_explorer"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -185,8 +185,9 @@ async def test_server_connectivity(transport: Optional[str] = None) -> bool:
|
|||||||
logger.error(f"Connectivity test failed: {e}")
|
logger.error(f"Connectivity test failed: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
result = await client.connect_and_run(test_connection)
|
await client.connect_and_run(test_connection)
|
||||||
return result
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to test server connectivity: {e}")
|
logger.error(f"Failed to test server connectivity: {e}")
|
||||||
return False
|
return False
|
||||||
@@ -211,4 +212,4 @@ if __name__ == "__main__":
|
|||||||
stdio_ok = await test_server_connectivity("stdio")
|
stdio_ok = await test_server_connectivity("stdio")
|
||||||
print(f" Stdio connectivity: {'✓' if stdio_ok else '✗'}")
|
print(f" Stdio connectivity: {'✓' if stdio_ok else '✗'}")
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|||||||
@@ -72,8 +72,7 @@ class TestToolsClientServer:
|
|||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
await client.connect_and_run(test_callback)
|
||||||
assert len(result) > 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_call_tool_exec_query_via_client(self, client, test_config):
|
async def test_call_tool_exec_query_via_client(self, client, test_config):
|
||||||
@@ -91,14 +90,13 @@ class TestToolsClientServer:
|
|||||||
assert "success" in result, "Result should contain 'success' field"
|
assert "success" in result, "Result should contain 'success' field"
|
||||||
|
|
||||||
if result["success"]:
|
if result["success"]:
|
||||||
assert "result" in result, "Successful result should contain 'result' field"
|
assert "data" in result, "Successful result should contain 'data' field"
|
||||||
else:
|
else:
|
||||||
assert "error" in result, "Failed result should contain 'error' field"
|
assert "error" in result, "Failed result should contain 'error' field"
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
await client.connect_and_run(test_callback)
|
||||||
# Don't assert success=True as it depends on actual server state
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_call_tool_get_db_list_via_client(self, client, test_config):
|
async def test_call_tool_get_db_list_via_client(self, client, test_config):
|
||||||
@@ -115,8 +113,7 @@ class TestToolsClientServer:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
await client.connect_and_run(test_callback)
|
||||||
assert "success" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_call_tool_get_table_schema_via_client(self, client, test_config):
|
async def test_call_tool_get_table_schema_via_client(self, client, test_config):
|
||||||
@@ -133,27 +130,7 @@ class TestToolsClientServer:
|
|||||||
assert "success" in result, "Result should contain 'success' field"
|
assert "success" in result, "Result should contain 'success' field"
|
||||||
return result
|
return result
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
await client.connect_and_run(test_callback)
|
||||||
assert "success" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_tool_performance_stats_via_client(self, client, test_config):
|
|
||||||
"""Test calling performance_stats tool through client"""
|
|
||||||
if not test_config.is_performance_tests_enabled():
|
|
||||||
pytest.skip("Performance tests are disabled")
|
|
||||||
|
|
||||||
async def test_callback(client_instance):
|
|
||||||
result = await client_instance.call_tool("performance_stats", {
|
|
||||||
"metric_type": "queries",
|
|
||||||
"time_range": "1h"
|
|
||||||
})
|
|
||||||
|
|
||||||
# Verify result structure
|
|
||||||
assert "success" in result, "Result should contain 'success' field"
|
|
||||||
return result
|
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
|
||||||
assert "success" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_tool_error_handling_via_client(self, client, test_config):
|
async def test_tool_error_handling_via_client(self, client, test_config):
|
||||||
@@ -168,8 +145,7 @@ class TestToolsClientServer:
|
|||||||
assert "success" in result, "Result should contain 'success' field"
|
assert "success" in result, "Result should contain 'success' field"
|
||||||
return result
|
return result
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
await client.connect_and_run(test_callback)
|
||||||
assert "success" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_tool_with_auth_token_via_client(self, client, test_config):
|
async def test_tool_with_auth_token_via_client(self, client, test_config):
|
||||||
@@ -188,5 +164,4 @@ class TestToolsClientServer:
|
|||||||
assert "success" in result, "Result should contain 'success' field"
|
assert "success" in result, "Result should contain 'success' field"
|
||||||
return result
|
return result
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
await client.connect_and_run(test_callback)
|
||||||
assert "success" in result
|
|
||||||
|
|||||||
@@ -36,11 +36,6 @@ class TestDorisToolsManager:
|
|||||||
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
|
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
|
||||||
|
|
||||||
config = Mock(spec=DorisConfig)
|
config = Mock(spec=DorisConfig)
|
||||||
config.doris_host = "localhost"
|
|
||||||
config.doris_port = 9030
|
|
||||||
config.doris_user = "test_user"
|
|
||||||
config.doris_password = "test_password"
|
|
||||||
config.doris_database = "test_db"
|
|
||||||
|
|
||||||
# Add database config
|
# Add database config
|
||||||
config.database = Mock(spec=DatabaseConfig)
|
config.database = Mock(spec=DatabaseConfig)
|
||||||
@@ -50,7 +45,6 @@ class TestDorisToolsManager:
|
|||||||
config.database.password = "test_password"
|
config.database.password = "test_password"
|
||||||
config.database.database = "test_db"
|
config.database.database = "test_db"
|
||||||
config.database.health_check_interval = 60
|
config.database.health_check_interval = 60
|
||||||
config.database.min_connections = 5
|
|
||||||
config.database.max_connections = 20
|
config.database.max_connections = 20
|
||||||
config.database.connection_timeout = 30
|
config.database.connection_timeout = 30
|
||||||
config.database.max_connection_age = 3600
|
config.database.max_connection_age = 3600
|
||||||
@@ -235,62 +229,7 @@ class TestDorisToolsManager:
|
|||||||
elif "result" in result_data:
|
elif "result" in result_data:
|
||||||
assert len(result_data["result"]) >= 0 # May be empty if no catalogs
|
assert len(result_data["result"]) >= 0 # May be empty if no catalogs
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_column_analysis_tool(self, tools_manager):
|
|
||||||
"""Test column_analysis tool"""
|
|
||||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
|
||||||
# Mock basic analysis result
|
|
||||||
mock_execute.return_value = [
|
|
||||||
{
|
|
||||||
"total_count": 1000,
|
|
||||||
"null_count": 10,
|
|
||||||
"distinct_count": 950,
|
|
||||||
"min_value": 1,
|
|
||||||
"max_value": 1000
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
arguments = {
|
|
||||||
"table_name": "users",
|
|
||||||
"column_name": "id",
|
|
||||||
"analysis_type": "basic"
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await tools_manager.call_tool("column_analysis", arguments)
|
|
||||||
result_data = json.loads(result) if isinstance(result, str) else result
|
|
||||||
|
|
||||||
# Check if result has analysis field or result field
|
|
||||||
if "analysis" in result_data:
|
|
||||||
assert result_data["analysis"]["total_count"] == 1000
|
|
||||||
elif "result" in result_data:
|
|
||||||
assert "result" in result_data # Just check result exists
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_performance_stats_tool(self, tools_manager):
|
|
||||||
"""Test performance_stats tool"""
|
|
||||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
|
||||||
mock_execute.return_value = [
|
|
||||||
{
|
|
||||||
"query_count": 1500,
|
|
||||||
"avg_execution_time": 0.25,
|
|
||||||
"slow_query_count": 5,
|
|
||||||
"error_count": 2
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
arguments = {
|
|
||||||
"metric_type": "queries",
|
|
||||||
"time_range": "1h"
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await tools_manager.call_tool("performance_stats", arguments)
|
|
||||||
result_data = json.loads(result) if isinstance(result, str) else result
|
|
||||||
|
|
||||||
# Check if result has stats field or result field
|
|
||||||
if "stats" in result_data:
|
|
||||||
assert result_data["stats"]["query_count"] == 1500
|
|
||||||
elif "result" in result_data:
|
|
||||||
assert "result" in result_data # Just check result exists
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invalid_tool_name(self, tools_manager):
|
async def test_invalid_tool_name(self, tools_manager):
|
||||||
@@ -328,4 +267,4 @@ class TestDorisToolsManager:
|
|||||||
|
|
||||||
# Required fields should be defined
|
# Required fields should be defined
|
||||||
if 'required' in tool.inputSchema:
|
if 'required' in tool.inputSchema:
|
||||||
assert isinstance(tool.inputSchema['required'], list)
|
assert isinstance(tool.inputSchema['required'], list)
|
||||||
|
|||||||
78
test/utils/test_db.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from doris_mcp_server.utils.db import DorisConnection, DorisSessionCache
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def session_cache():
|
||||||
|
"""Provides a DorisSessionCache instance with a mock connection manager."""
|
||||||
|
connection_manager = MagicMock()
|
||||||
|
cache = DorisSessionCache(connection_manager=connection_manager)
|
||||||
|
yield cache, connection_manager
|
||||||
|
|
||||||
|
|
||||||
|
class TestDorisSessionCache:
|
||||||
|
|
||||||
|
def test_initialization(self, session_cache):
|
||||||
|
cache, _ = session_cache
|
||||||
|
assert cache.cache_system_session is True
|
||||||
|
assert cache.cache_user_session is False
|
||||||
|
assert not cache.cached
|
||||||
|
|
||||||
|
def test_should_cache(self, session_cache):
|
||||||
|
cache, _ = session_cache
|
||||||
|
assert cache._should_cache("query") is True
|
||||||
|
assert cache._should_cache("system") is True
|
||||||
|
assert cache._should_cache("user-test-session-id") is False
|
||||||
|
|
||||||
|
cache.cache_user_session = True
|
||||||
|
assert cache._should_cache("user-test-session-id") is True
|
||||||
|
|
||||||
|
def test_save_and_get_session(self, session_cache):
|
||||||
|
cache, _ = session_cache
|
||||||
|
mock_connection = MagicMock(spec=DorisConnection)
|
||||||
|
mock_connection.session_id = "query"
|
||||||
|
|
||||||
|
cache.save(mock_connection)
|
||||||
|
retrieved_conn = cache.get("query")
|
||||||
|
assert retrieved_conn is mock_connection
|
||||||
|
|
||||||
|
mock_user_connection = MagicMock(spec=DorisConnection)
|
||||||
|
mock_user_connection.session_id = "user-test-session-id"
|
||||||
|
cache.save(mock_user_connection)
|
||||||
|
assert cache.get("user-test-session-id") is None
|
||||||
|
|
||||||
|
cache.cache_user_session = True
|
||||||
|
cache.save(mock_user_connection)
|
||||||
|
retrieved_user_conn = cache.get("user-test-session-id")
|
||||||
|
assert retrieved_user_conn is mock_user_connection
|
||||||
|
|
||||||
|
def test_remove_session(self, session_cache):
|
||||||
|
cache, _ = session_cache
|
||||||
|
mock_connection = MagicMock(spec=DorisConnection)
|
||||||
|
mock_connection.session_id = "system"
|
||||||
|
|
||||||
|
cache.save(mock_connection)
|
||||||
|
assert cache.get("system") is not None
|
||||||
|
|
||||||
|
cache.remove("system")
|
||||||
|
assert cache.get("system") is None
|
||||||
|
|
||||||
|
def test_clear_cache(self, session_cache):
|
||||||
|
cache, connection_manager = session_cache
|
||||||
|
mock_conn1 = MagicMock(spec=DorisConnection)
|
||||||
|
mock_conn1.session_id = "query"
|
||||||
|
mock_conn2 = MagicMock(spec=DorisConnection)
|
||||||
|
mock_conn2.session_id = "system"
|
||||||
|
|
||||||
|
cache.save(mock_conn1)
|
||||||
|
cache.save(mock_conn2)
|
||||||
|
assert len(cache.cached) == 2
|
||||||
|
|
||||||
|
cache.clear()
|
||||||
|
|
||||||
|
assert not cache.cached
|
||||||
|
connection_manager.release_connection.assert_any_call("query", mock_conn1)
|
||||||
|
connection_manager.release_connection.assert_any_call("system", mock_conn2)
|
||||||
|
assert connection_manager.release_connection.call_count == 2
|
||||||
@@ -35,11 +35,6 @@ class TestDorisQueryExecutor:
|
|||||||
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
|
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
|
||||||
|
|
||||||
config = Mock(spec=DorisConfig)
|
config = Mock(spec=DorisConfig)
|
||||||
config.doris_host = "localhost"
|
|
||||||
config.doris_port = 9030
|
|
||||||
config.doris_user = "test_user"
|
|
||||||
config.doris_password = "test_password"
|
|
||||||
config.doris_database = "test_db"
|
|
||||||
|
|
||||||
# Add database config
|
# Add database config
|
||||||
config.database = Mock(spec=DatabaseConfig)
|
config.database = Mock(spec=DatabaseConfig)
|
||||||
@@ -49,11 +44,17 @@ class TestDorisQueryExecutor:
|
|||||||
config.database.password = "test_password"
|
config.database.password = "test_password"
|
||||||
config.database.database = "test_db"
|
config.database.database = "test_db"
|
||||||
config.database.health_check_interval = 60
|
config.database.health_check_interval = 60
|
||||||
config.database.min_connections = 5
|
|
||||||
config.database.max_connections = 20
|
config.database.max_connections = 20
|
||||||
config.database.connection_timeout = 30
|
config.database.connection_timeout = 30
|
||||||
config.database.max_connection_age = 3600
|
config.database.max_connection_age = 3600
|
||||||
|
|
||||||
|
# Add security config
|
||||||
|
config.security = Mock(spec=SecurityConfig)
|
||||||
|
config.security.enable_masking = True
|
||||||
|
config.security.auth_type = "token"
|
||||||
|
config.security.token_secret = "test_secret"
|
||||||
|
config.security.token_expiry = 3600
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -199,4 +200,74 @@ class TestDorisQueryExecutor:
|
|||||||
assert "success" in result
|
assert "success" in result
|
||||||
if result["success"]:
|
if result["success"]:
|
||||||
assert "data" in result
|
assert "data" in result
|
||||||
assert "row_count" in result
|
assert "row_count" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_multi_sql_statements(self, query_executor):
|
||||||
|
"""Test execution of multiple SQL statements"""
|
||||||
|
from doris_mcp_server.utils.query_executor import QueryResult
|
||||||
|
|
||||||
|
# Disable security check for this test
|
||||||
|
query_executor.connection_manager.config.security.enable_security_check = False
|
||||||
|
|
||||||
|
with patch.object(query_executor, 'execute_query') as mock_execute:
|
||||||
|
# Mock results for three SQL statements
|
||||||
|
mock_execute.side_effect = [
|
||||||
|
QueryResult(
|
||||||
|
data=[{"id": 1, "name": "张三"}],
|
||||||
|
row_count=1,
|
||||||
|
execution_time=0.1,
|
||||||
|
sql="SELECT id, name FROM users WHERE id = 1",
|
||||||
|
metadata={"columns": ["id", "name"]}
|
||||||
|
),
|
||||||
|
QueryResult(
|
||||||
|
data=[{"id": 2, "name": "李四"}],
|
||||||
|
row_count=1,
|
||||||
|
execution_time=0.12,
|
||||||
|
sql="SELECT id, name FROM users WHERE id = 2",
|
||||||
|
metadata={"columns": ["id", "name"]}
|
||||||
|
),
|
||||||
|
QueryResult(
|
||||||
|
data=[{"count": 100}],
|
||||||
|
row_count=1,
|
||||||
|
execution_time=0.08,
|
||||||
|
sql="SELECT COUNT(*) as count FROM users",
|
||||||
|
metadata={"columns": ["count"]}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Execute multiple SQL statements separated by semicolons
|
||||||
|
multi_sql = """
|
||||||
|
SELECT id, name FROM users WHERE id = 1;
|
||||||
|
SELECT id, name FROM users WHERE id = 2;
|
||||||
|
SELECT COUNT(*) as count FROM users;
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = await query_executor.execute_sql_for_mcp(multi_sql)
|
||||||
|
|
||||||
|
# Verify the result structure for multiple statements
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["multiple_results"] is True
|
||||||
|
assert "results" in result
|
||||||
|
assert len(result["results"]) == 3
|
||||||
|
|
||||||
|
# Verify first query result
|
||||||
|
assert result["results"][0]["data"] == [{"id": 1, "name": "张三"}]
|
||||||
|
assert result["results"][0]["row_count"] == 1
|
||||||
|
assert result["results"][0]["metadata"]["columns"] == ["id", "name"]
|
||||||
|
assert result["results"][0]["metadata"]["query"] == "SELECT id, name FROM users WHERE id = 1"
|
||||||
|
|
||||||
|
# Verify second query result
|
||||||
|
assert result["results"][1]["data"] == [{"id": 2, "name": "李四"}]
|
||||||
|
assert result["results"][1]["row_count"] == 1
|
||||||
|
assert result["results"][1]["metadata"]["columns"] == ["id", "name"]
|
||||||
|
assert result["results"][1]["metadata"]["query"] == "SELECT id, name FROM users WHERE id = 2"
|
||||||
|
|
||||||
|
# Verify third query result
|
||||||
|
assert result["results"][2]["data"] == [{"count": 100}]
|
||||||
|
assert result["results"][2]["row_count"] == 1
|
||||||
|
assert result["results"][2]["metadata"]["columns"] == ["count"]
|
||||||
|
assert result["results"][2]["metadata"]["query"] == "SELECT COUNT(*) as count FROM users"
|
||||||
|
|
||||||
|
# Verify execute_query was called three times
|
||||||
|
assert mock_execute.call_count == 3
|
||||||
|
|||||||
@@ -21,8 +21,6 @@ Tests the query execution functionality through actual MCP client-server communi
|
|||||||
Assumes the server is already running and configured properly
|
Assumes the server is already running and configured properly
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import pytest
|
import pytest
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -66,14 +64,13 @@ class TestQueryExecutorClientServer:
|
|||||||
assert "success" in result, "Result should contain 'success' field"
|
assert "success" in result, "Result should contain 'success' field"
|
||||||
|
|
||||||
if result["success"]:
|
if result["success"]:
|
||||||
assert "result" in result, "Successful result should contain 'result' field"
|
assert "data" in result, "Successful result should contain 'data' field"
|
||||||
else:
|
else:
|
||||||
assert "error" in result, "Failed result should contain 'error' field"
|
assert "error" in result, "Failed result should contain 'error' field"
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
await client.connect_and_run(test_callback)
|
||||||
assert "success" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_show_databases_query_via_client(self, client, test_config):
|
async def test_show_databases_query_via_client(self, client, test_config):
|
||||||
@@ -87,8 +84,7 @@ class TestQueryExecutorClientServer:
|
|||||||
assert "success" in result, "Result should contain 'success' field"
|
assert "success" in result, "Result should contain 'success' field"
|
||||||
return result
|
return result
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
await client.connect_and_run(test_callback)
|
||||||
assert "success" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_information_schema_query_via_client(self, client, test_config):
|
async def test_information_schema_query_via_client(self, client, test_config):
|
||||||
@@ -102,8 +98,7 @@ class TestQueryExecutorClientServer:
|
|||||||
assert "success" in result, "Result should contain 'success' field"
|
assert "success" in result, "Result should contain 'success' field"
|
||||||
return result
|
return result
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
await client.connect_and_run(test_callback)
|
||||||
assert "success" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_with_max_rows_parameter_via_client(self, client, test_config):
|
async def test_query_with_max_rows_parameter_via_client(self, client, test_config):
|
||||||
@@ -118,8 +113,7 @@ class TestQueryExecutorClientServer:
|
|||||||
assert "success" in result, "Result should contain 'success' field"
|
assert "success" in result, "Result should contain 'success' field"
|
||||||
return result
|
return result
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
await client.connect_and_run(test_callback)
|
||||||
assert "success" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_error_handling_via_client(self, client, test_config):
|
async def test_query_error_handling_via_client(self, client, test_config):
|
||||||
@@ -131,8 +125,7 @@ class TestQueryExecutorClientServer:
|
|||||||
assert "success" in result, "Result should contain 'success' field"
|
assert "success" in result, "Result should contain 'success' field"
|
||||||
return result
|
return result
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
await client.connect_and_run(test_callback)
|
||||||
assert "success" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_with_auth_token_via_client(self, client, test_config):
|
async def test_query_with_auth_token_via_client(self, client, test_config):
|
||||||
@@ -152,5 +145,4 @@ class TestQueryExecutorClientServer:
|
|||||||
assert "success" in result, "Result should contain 'success' field"
|
assert "success" in result, "Result should contain 'success' field"
|
||||||
return result
|
return result
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
await client.connect_and_run(test_callback)
|
||||||
assert "success" in result
|
|
||||||
|
|||||||
64
tokens.json
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
{
|
||||||
|
"version": "1.0",
|
||||||
|
"description": "Doris MCP Server Token configuration file",
|
||||||
|
"created_at": "2025-09-01T00:00:00Z",
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"token_id": "admin-token",
|
||||||
|
"token": "doris_admin_token_123456",
|
||||||
|
"description": "Doris admin API access token",
|
||||||
|
"expires_hours": null,
|
||||||
|
"is_active": true,
|
||||||
|
"database_config": {
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": 9030,
|
||||||
|
"user": "root",
|
||||||
|
"password": "",
|
||||||
|
"database": "information_schema",
|
||||||
|
"charset": "UTF8",
|
||||||
|
"fe_http_port": 8030
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"token_id": "analyst-token",
|
||||||
|
"token": "doris_analyst_token_123456",
|
||||||
|
"description": "Doris analyst API access token",
|
||||||
|
"expires_hours": 8760,
|
||||||
|
"is_active": true,
|
||||||
|
"database_config": {
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": 9030,
|
||||||
|
"user": "root",
|
||||||
|
"password": "",
|
||||||
|
"database": "information_schema",
|
||||||
|
"charset": "UTF8",
|
||||||
|
"fe_http_port": 8030
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"token_id": "readonly-token",
|
||||||
|
"token": "doris_readonly_token_123456",
|
||||||
|
"description": "Doris readonly API access token",
|
||||||
|
"expires_hours": 4320,
|
||||||
|
"is_active": true,
|
||||||
|
"database_config": {
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": 9030,
|
||||||
|
"user": "root",
|
||||||
|
"password": "",
|
||||||
|
"database": "information_schema",
|
||||||
|
"charset": "UTF8",
|
||||||
|
"fe_http_port": 8030
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"notes": [
|
||||||
|
"The admin_token, analyst_token, readonly_token is default token,Please change the token before using in production!",
|
||||||
|
"The token_id is the key of the token,Please use the token_id to identify the token",
|
||||||
|
"The token is the value of the token,Please use the token to identify the token",
|
||||||
|
"The description is the description of the token,Please use the description to identify the token",
|
||||||
|
"The expires_hours is the expires hours of the token,Please use the expires_hours to identify the token",
|
||||||
|
"The is_active is the is active of the token,Please use the is_active to identify the token",
|
||||||
|
"The token_id, token, description, expires_hours, is_active is the metadata of the token,Please use the metadata to identify the token"
|
||||||
|
]
|
||||||
|
}
|
||||||
112
uv.lock
generated
@@ -1,19 +1,3 @@
|
|||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
# or more contributor license agreements. See the NOTICE file
|
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
version = 1
|
version = 1
|
||||||
revision = 1
|
revision = 1
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
@@ -22,6 +6,48 @@ resolution-markers = [
|
|||||||
"python_full_version < '3.13'",
|
"python_full_version < '3.13'",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "adbc-driver-flightsql"
|
||||||
|
version = "1.7.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "adbc-driver-manager" },
|
||||||
|
{ name = "importlib-resources" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/b8/d4/ebd3eed981c771565677084474cdf465141455b5deb1ca409c616609bfd7/adbc_driver_flightsql-1.7.0.tar.gz", hash = "sha256:5dca460a2c66e45b29208eaf41a7206f252177435fa48b16f19833b12586f7a0", size = 21247 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/36/20/807fca9d904b7e0d3020439828d6410db7fd7fd635824a80cab113d9fad1/adbc_driver_flightsql-1.7.0-py3-none-macosx_10_15_x86_64.whl", hash = "sha256:a5658f9bc3676bd122b26138e9b9ce56b8bf37387efe157b4c66d56f942361c6", size = 7749664 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/cd/e6/9e50f6497819c911b9cc1962ffde610b60f7d8e951d6bb3fa145dcfb50a7/adbc_driver_flightsql-1.7.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:65e21df86b454d8db422c8ee22db31be217d88c42d9d6dd89119f06813037c91", size = 7302476 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/27/82/e51af85e7cc8c87bc8ce4fae8ca7ee1d3cf39c926be0aeab789cedc93f0a/adbc_driver_flightsql-1.7.0-py3-none-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3282fdc7b73c712780cc777975288c88b1e3a555355bbe09df101aa954f8f105", size = 7686056 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8b/c9/591c8ecbaf010ba3f4b360db602050ee5880cd077a573c9e90fcb270ab71/adbc_driver_flightsql-1.7.0-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e0c5737ae6ee3bbfba44dcbc28ba1ff8cf3ab6521888c4b0f10dd6a482482161", size = 7050275 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/10/14/f339e9a5d8dbb3e3040215514cea9cca0a58640964aaccc6532f18003a03/adbc_driver_flightsql-1.7.0-py3-none-win_amd64.whl", hash = "sha256:f8b5290b322304b7d944ca823754e6354c1868dbbe94ddf84236f3e0329545da", size = 14312858 },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "adbc-driver-manager"
|
||||||
|
version = "1.7.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "typing-extensions" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/bb/bf/2986a2cd3e1af658d2597f7e2308564e5c11e036f9736d5c256f1e00d578/adbc_driver_manager-1.7.0.tar.gz", hash = "sha256:e3edc5d77634b5925adf6eb4fbcd01676b54acb2f5b1d6864b6a97c6a899591a", size = 198128 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/74/3a/72bd9c45d55f1f5f4c549e206de8cfe3313b31f7b95fbcb180da05c81044/adbc_driver_manager-1.7.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:8da1ac4c19bcbf30b3bd54247ec889dfacc9b44147c70b4da79efe2e9ba93600", size = 524210 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/33/29/e1a8d8dde713a287f8021f3207127f133ddce578711a4575218bdf78ef27/adbc_driver_manager-1.7.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:408bc23bad1a6823b364e2388f85f96545e82c3b2db97d7828a4b94839d3f29e", size = 505902 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/59/00/773ece64a58c0ade797ab4577e7cdc4c71ebf800b86d2d5637e3bfe605e9/adbc_driver_manager-1.7.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf38294320c23e47ed3455348e910031ad8289c3f9167ae35519ac957b7add01", size = 2974883 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/7c/ad/1568da6ae9ab70983f1438503d3906c6b1355601230e891d16e272376a04/adbc_driver_manager-1.7.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:689f91b62c18a9f86f892f112786fb157cacc4729b4d81666db4ca778eade2a8", size = 2997781 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/19/66/2b6ea5afded25a3fa009873c2bbebcd9283910877cc10b9453d680c00b9a/adbc_driver_manager-1.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:f936cfc8d098898a47ef60396bd7a73926ec3068f2d6d92a2be4e56e4aaf3770", size = 690041 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b2/3b/91154c83a98f103a3d97c9e2cb838c3842aef84ca4f4b219164b182d9516/adbc_driver_manager-1.7.0-cp313-cp313-macosx_10_15_x86_64.whl", hash = "sha256:ab9ee36683fd54f61b0db0f4a96f70fe1932223e61df9329290370b145abb0a9", size = 522737 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9c/52/4bc80c3388d5e2a3b6e504ba9656dd9eb3d8dbe822d07af38db1b8c96fb1/adbc_driver_manager-1.7.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4ec03d94177f71a8d3a149709f4111e021f9950229b35c0a803aadb1a1855a4b", size = 503896 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e1/f3/46052ca11224f661cef4721e19138bc73e750ba6aea54f22606950491606/adbc_driver_manager-1.7.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:700c79dac08a620018c912ede45a6dc7851819bc569a53073ab652dc0bd0c92f", size = 2972586 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a2/22/44738b41bb5ca30f94b5f4c00c71c20be86d7eb4ddc389d4cf3c7b8b69ef/adbc_driver_manager-1.7.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:98db0f5d0aa1635475f63700a7b6f677390beb59c69c7ba9d388bc8ce3779388", size = 2992001 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1b/2b/5184fe5a529feb019582cc90d0f65e0021d52c34ca20620551532340645a/adbc_driver_manager-1.7.0-cp313-cp313-win_amd64.whl", hash = "sha256:4b7e5e9a163acb21804647cc7894501df51cdcd780ead770557112a26ca01ca6", size = 688789 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3f/e0/b283544e1bb7864bf5a5ac9cd330f111009eff9180ec5000420510cf9342/adbc_driver_manager-1.7.0-cp313-cp313t-macosx_10_15_x86_64.whl", hash = "sha256:ac83717965b83367a8ad6c0536603acdcfa66e0592d783f8940f55fda47d963e", size = 538625 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/77/5a/dc244264bd8d0c331a418d2bdda5cb6e26c30493ff075d706aa81d4e3b30/adbc_driver_manager-1.7.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:4c234cf81b00eaf7e7c65dbd0f0ddf7bdae93dfcf41e9d8543f9ecf4b10590f6", size = 523627 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e9/ff/a499a00367fd092edb20dc6e36c81e3c7a437671c70481cae97f46c8156a/adbc_driver_manager-1.7.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ad8aa4b039cc50722a700b544773388c6b1dea955781a01f79cd35d0a1e6edbf", size = 3037517 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/25/6e/9dfdb113294dcb24b4f53924cd4a9c9af3fbe45a9790c1327048df731246/adbc_driver_manager-1.7.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4409ff53578e01842a8f57787ebfbfee790c1da01a6bd57fcb7701ed5d4dd4f7", size = 3016543 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiofiles"
|
name = "aiofiles"
|
||||||
version = "24.1.0"
|
version = "24.1.0"
|
||||||
@@ -536,9 +562,11 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "doris-mcp-server"
|
name = "doris-mcp-server"
|
||||||
version = "0.3.0"
|
version = "0.6.1"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
{ name = "adbc-driver-flightsql" },
|
||||||
|
{ name = "adbc-driver-manager" },
|
||||||
{ name = "aiofiles" },
|
{ name = "aiofiles" },
|
||||||
{ name = "aiohttp" },
|
{ name = "aiohttp" },
|
||||||
{ name = "aiomysql" },
|
{ name = "aiomysql" },
|
||||||
@@ -555,6 +583,7 @@ dependencies = [
|
|||||||
{ name = "pandas" },
|
{ name = "pandas" },
|
||||||
{ name = "passlib", extra = ["bcrypt"] },
|
{ name = "passlib", extra = ["bcrypt"] },
|
||||||
{ name = "prometheus-client" },
|
{ name = "prometheus-client" },
|
||||||
|
{ name = "pyarrow" },
|
||||||
{ name = "pydantic" },
|
{ name = "pydantic" },
|
||||||
{ name = "pydantic-settings" },
|
{ name = "pydantic-settings" },
|
||||||
{ name = "pyjwt" },
|
{ name = "pyjwt" },
|
||||||
@@ -625,6 +654,8 @@ dev = [
|
|||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
|
{ name = "adbc-driver-flightsql", specifier = ">=0.8.0" },
|
||||||
|
{ name = "adbc-driver-manager", specifier = ">=0.8.0" },
|
||||||
{ name = "aiofiles", specifier = ">=23.0.0" },
|
{ name = "aiofiles", specifier = ">=23.0.0" },
|
||||||
{ name = "aiohttp", specifier = ">=3.9.0" },
|
{ name = "aiohttp", specifier = ">=3.9.0" },
|
||||||
{ name = "aiomysql", specifier = ">=0.2.0" },
|
{ name = "aiomysql", specifier = ">=0.2.0" },
|
||||||
@@ -642,7 +673,7 @@ requires-dist = [
|
|||||||
{ name = "httpx", specifier = ">=0.26.0" },
|
{ name = "httpx", specifier = ">=0.26.0" },
|
||||||
{ name = "isort", marker = "extra == 'dev'", specifier = ">=5.13.0" },
|
{ name = "isort", marker = "extra == 'dev'", specifier = ">=5.13.0" },
|
||||||
{ name = "jaeger-client", marker = "extra == 'monitoring'", specifier = ">=4.8.0" },
|
{ name = "jaeger-client", marker = "extra == 'monitoring'", specifier = ">=4.8.0" },
|
||||||
{ name = "mcp", specifier = ">=1.0.0" },
|
{ name = "mcp", specifier = ">=1.8.0,<2.0.0" },
|
||||||
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" },
|
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" },
|
||||||
{ name = "myst-parser", marker = "extra == 'dev'", specifier = ">=2.0.0" },
|
{ name = "myst-parser", marker = "extra == 'dev'", specifier = ">=2.0.0" },
|
||||||
{ name = "myst-parser", marker = "extra == 'docs'", specifier = ">=2.0.0" },
|
{ name = "myst-parser", marker = "extra == 'docs'", specifier = ">=2.0.0" },
|
||||||
@@ -656,6 +687,7 @@ requires-dist = [
|
|||||||
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.6.0" },
|
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.6.0" },
|
||||||
{ name = "prometheus-client", specifier = ">=0.19.0" },
|
{ name = "prometheus-client", specifier = ">=0.19.0" },
|
||||||
{ name = "prometheus-client", marker = "extra == 'monitoring'", specifier = ">=0.19.0" },
|
{ name = "prometheus-client", marker = "extra == 'monitoring'", specifier = ">=0.19.0" },
|
||||||
|
{ name = "pyarrow", specifier = ">=14.0.0" },
|
||||||
{ name = "pydantic", specifier = ">=2.5.0" },
|
{ name = "pydantic", specifier = ">=2.5.0" },
|
||||||
{ name = "pydantic-settings", specifier = ">=2.1.0" },
|
{ name = "pydantic-settings", specifier = ">=2.1.0" },
|
||||||
{ name = "pyjwt", specifier = ">=2.8.0" },
|
{ name = "pyjwt", specifier = ">=2.8.0" },
|
||||||
@@ -948,6 +980,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/20/b0/36bd937216ec521246249be3bf9855081de4c5e06a0c9b4219dbeda50373/importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd", size = 27656 },
|
{ url = "https://files.pythonhosted.org/packages/20/b0/36bd937216ec521246249be3bf9855081de4c5e06a0c9b4219dbeda50373/importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd", size = 27656 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "importlib-resources"
|
||||||
|
version = "6.5.2"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/cf/8c/f834fbf984f691b4f7ff60f50b514cc3de5cc08abfc3295564dd89c5e2e7/importlib_resources-6.5.2.tar.gz", hash = "sha256:185f87adef5bcc288449d98fb4fba07cea78bc036455dd44c5fc4a2fe78fed2c", size = 44693 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "iniconfig"
|
name = "iniconfig"
|
||||||
version = "2.1.0"
|
version = "2.1.0"
|
||||||
@@ -1621,6 +1662,41 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/7b/d7/7831438e6c3ebbfa6e01a927127a6cb42ad3ab844247f3c5b96bea25d73d/psutil-6.1.1-cp37-abi3-win_amd64.whl", hash = "sha256:f35cfccb065fff93529d2afb4a2e89e363fe63ca1e4a5da22b603a85833c2649", size = 254444 },
|
{ url = "https://files.pythonhosted.org/packages/7b/d7/7831438e6c3ebbfa6e01a927127a6cb42ad3ab844247f3c5b96bea25d73d/psutil-6.1.1-cp37-abi3-win_amd64.whl", hash = "sha256:f35cfccb065fff93529d2afb4a2e89e363fe63ca1e4a5da22b603a85833c2649", size = 254444 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyarrow"
|
||||||
|
version = "20.0.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/a2/ee/a7810cb9f3d6e9238e61d312076a9859bf3668fd21c69744de9532383912/pyarrow-20.0.0.tar.gz", hash = "sha256:febc4a913592573c8d5805091a6c2b5064c8bd6e002131f01061797d91c783c1", size = 1125187 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a1/d6/0c10e0d54f6c13eb464ee9b67a68b8c71bcf2f67760ef5b6fbcddd2ab05f/pyarrow-20.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:75a51a5b0eef32727a247707d4755322cb970be7e935172b6a3a9f9ae98404ba", size = 30815067 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/7e/e2/04e9874abe4094a06fd8b0cbb0f1312d8dd7d707f144c2ec1e5e8f452ffa/pyarrow-20.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:211d5e84cecc640c7a3ab900f930aaff5cd2702177e0d562d426fb7c4f737781", size = 32297128 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/31/fd/c565e5dcc906a3b471a83273039cb75cb79aad4a2d4a12f76cc5ae90a4b8/pyarrow-20.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ba3cf4182828be7a896cbd232aa8dd6a31bd1f9e32776cc3796c012855e1199", size = 41334890 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/af/a9/3bdd799e2c9b20c1ea6dc6fa8e83f29480a97711cf806e823f808c2316ac/pyarrow-20.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c3a01f313ffe27ac4126f4c2e5ea0f36a5fc6ab51f8726cf41fee4b256680bd", size = 42421775 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/10/f7/da98ccd86354c332f593218101ae56568d5dcedb460e342000bd89c49cc1/pyarrow-20.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:a2791f69ad72addd33510fec7bb14ee06c2a448e06b649e264c094c5b5f7ce28", size = 40687231 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/bb/1b/2168d6050e52ff1e6cefc61d600723870bf569cbf41d13db939c8cf97a16/pyarrow-20.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4250e28a22302ce8692d3a0e8ec9d9dde54ec00d237cff4dfa9c1fbf79e472a8", size = 42295639 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b2/66/2d976c0c7158fd25591c8ca55aee026e6d5745a021915a1835578707feb3/pyarrow-20.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:89e030dc58fc760e4010148e6ff164d2f44441490280ef1e97a542375e41058e", size = 42908549 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/31/a9/dfb999c2fc6911201dcbf348247f9cc382a8990f9ab45c12eabfd7243a38/pyarrow-20.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6102b4864d77102dbbb72965618e204e550135a940c2534711d5ffa787df2a5a", size = 44557216 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a0/8e/9adee63dfa3911be2382fb4d92e4b2e7d82610f9d9f668493bebaa2af50f/pyarrow-20.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:96d6a0a37d9c98be08f5ed6a10831d88d52cac7b13f5287f1e0f625a0de8062b", size = 25660496 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9b/aa/daa413b81446d20d4dad2944110dcf4cf4f4179ef7f685dd5a6d7570dc8e/pyarrow-20.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:a15532e77b94c61efadde86d10957950392999503b3616b2ffcef7621a002893", size = 30798501 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ff/75/2303d1caa410925de902d32ac215dc80a7ce7dd8dfe95358c165f2adf107/pyarrow-20.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:dd43f58037443af715f34f1322c782ec463a3c8a94a85fdb2d987ceb5658e061", size = 32277895 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/92/41/fe18c7c0b38b20811b73d1bdd54b1fccba0dab0e51d2048878042d84afa8/pyarrow-20.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aa0d288143a8585806e3cc7c39566407aab646fb9ece164609dac1cfff45f6ae", size = 41327322 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/da/ab/7dbf3d11db67c72dbf36ae63dcbc9f30b866c153b3a22ef728523943eee6/pyarrow-20.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6953f0114f8d6f3d905d98e987d0924dabce59c3cda380bdfaa25a6201563b4", size = 42411441 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/90/c3/0c7da7b6dac863af75b64e2f827e4742161128c350bfe7955b426484e226/pyarrow-20.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:991f85b48a8a5e839b2128590ce07611fae48a904cae6cab1f089c5955b57eb5", size = 40677027 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/be/27/43a47fa0ff9053ab5203bb3faeec435d43c0d8bfa40179bfd076cdbd4e1c/pyarrow-20.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:97c8dc984ed09cb07d618d57d8d4b67a5100a30c3818c2fb0b04599f0da2de7b", size = 42281473 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/bc/0b/d56c63b078876da81bbb9ba695a596eabee9b085555ed12bf6eb3b7cab0e/pyarrow-20.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9b71daf534f4745818f96c214dbc1e6124d7daf059167330b610fc69b6f3d3e3", size = 42893897 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/92/ac/7d4bd020ba9145f354012838692d48300c1b8fe5634bfda886abcada67ed/pyarrow-20.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e8b88758f9303fa5a83d6c90e176714b2fd3852e776fc2d7e42a22dd6c2fb368", size = 44543847 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9d/07/290f4abf9ca702c5df7b47739c1b2c83588641ddfa2cc75e34a301d42e55/pyarrow-20.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:30b3051b7975801c1e1d387e17c588d8ab05ced9b1e14eec57915f79869b5031", size = 25653219 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/95/df/720bb17704b10bd69dde086e1400b8eefb8f58df3f8ac9cff6c425bf57f1/pyarrow-20.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:ca151afa4f9b7bc45bcc791eb9a89e90a9eb2772767d0b1e5389609c7d03db63", size = 30853957 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d9/72/0d5f875efc31baef742ba55a00a25213a19ea64d7176e0fe001c5d8b6e9a/pyarrow-20.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:4680f01ecd86e0dd63e39eb5cd59ef9ff24a9d166db328679e36c108dc993d4c", size = 32247972 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d5/bc/e48b4fa544d2eea72f7844180eb77f83f2030b84c8dad860f199f94307ed/pyarrow-20.0.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f4c8534e2ff059765647aa69b75d6543f9fef59e2cd4c6d18015192565d2b70", size = 41256434 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c3/01/974043a29874aa2cf4f87fb07fd108828fc7362300265a2a64a94965e35b/pyarrow-20.0.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e1f8a47f4b4ae4c69c4d702cfbdfe4d41e18e5c7ef6f1bb1c50918c1e81c57b", size = 42353648 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/68/95/cc0d3634cde9ca69b0e51cbe830d8915ea32dda2157560dda27ff3b3337b/pyarrow-20.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:a1f60dc14658efaa927f8214734f6a01a806d7690be4b3232ba526836d216122", size = 40619853 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/29/c2/3ad40e07e96a3e74e7ed7cc8285aadfa84eb848a798c98ec0ad009eb6bcc/pyarrow-20.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:204a846dca751428991346976b914d6d2a82ae5b8316a6ed99789ebf976551e6", size = 42241743 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/eb/cb/65fa110b483339add6a9bc7b6373614166b14e20375d4daa73483755f830/pyarrow-20.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:f3b117b922af5e4c6b9a9115825726cac7d8b1421c37c2b5e24fbacc8930612c", size = 42839441 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/98/7b/f30b1954589243207d7a0fbc9997401044bf9a033eec78f6cb50da3f304a/pyarrow-20.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:e724a3fd23ae5b9c010e7be857f4405ed5e679db5c93e66204db1a69f733936a", size = 44503279 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/37/40/ad395740cd641869a13bcf60851296c89624662575621968dcfafabaa7f6/pyarrow-20.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:82f1ee5133bd8f49d31be1299dc07f585136679666b502540db854968576faf9", size = 25944982 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyasn1"
|
name = "pyasn1"
|
||||||
version = "0.6.1"
|
version = "0.6.1"
|
||||||
|
|||||||