Compare commits
36 Commits
0.4.2
...
infrastruc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a6f893628b | ||
|
|
81305ffbf9 | ||
|
|
43143f0b30 | ||
|
|
e58361e04b | ||
|
|
a125a2f5f8 | ||
|
|
2613912df3 | ||
|
|
067f160b3e | ||
|
|
9ba4cc6f45 | ||
|
|
f99399c6c7 | ||
|
|
c3d487ccdd | ||
|
|
c1e3b13851 | ||
|
|
5923cc1c89 | ||
|
|
9b5ac8533d | ||
|
|
cc84d605e5 | ||
|
|
55dbdd5e14 | ||
|
|
affa4a0319 | ||
|
|
ecb5db8137 | ||
|
|
5d15f6f3a4 | ||
|
|
6247d49192 | ||
|
|
fb5e864a24 | ||
|
|
9bb5b17199 | ||
|
|
6d3c128f54 | ||
|
|
651d524814 | ||
|
|
54572d0861 | ||
|
|
d12dfbd014 | ||
|
|
4052b7e938 | ||
|
|
693c48d5ee | ||
|
|
c1ce9a5cc7 | ||
|
|
282a1c0bd9 | ||
|
|
e3b9bf96ab | ||
|
|
667cecbbe0 | ||
|
|
c777905bd3 | ||
|
|
d4ea125e35 | ||
|
|
f135d9b949 | ||
|
|
124dd0da88 | ||
|
|
775b4cb630 |
24
.asf.yaml
24
.asf.yaml
@@ -24,18 +24,28 @@ github:
|
||||
- olap
|
||||
- lakehouse
|
||||
- mcp
|
||||
- ai
|
||||
enabled_merge_buttons:
|
||||
squash: true
|
||||
merge: false
|
||||
rebase: false
|
||||
features:
|
||||
# Enable wiki for documentation
|
||||
wiki: true
|
||||
# Enable issue management
|
||||
issues: true
|
||||
# Enable projects for project management boardS
|
||||
projects: true
|
||||
# Enable discussions
|
||||
discussions: true
|
||||
rulesets:
|
||||
- name: "Default Branch Protection"
|
||||
type: branch
|
||||
branches:
|
||||
includes:
|
||||
- "~DEFAULT_BRANCH"
|
||||
- "release/*"
|
||||
- "rel/*"
|
||||
excludes: []
|
||||
bypass_teams:
|
||||
- root
|
||||
restrict_deletion: true
|
||||
restrict_force_push: true
|
||||
notifications:
|
||||
pullrequests_status: commits@doris.apache.org
|
||||
issues: commits@doris.apache.org
|
||||
commits: commits@doris.apache.org
|
||||
pullrequests: commits@doris.apache.org
|
||||
|
||||
2
.dockerignore
Normal file
2
.dockerignore
Normal file
@@ -0,0 +1,2 @@
|
||||
**/.venv
|
||||
**/venv
|
||||
511
.env.example
511
.env.example
@@ -1,90 +1,525 @@
|
||||
# Doris MCP Server Configuration
|
||||
# Copy this file to .env and modify the values according to your environment
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# ===================================================================
|
||||
# Doris MCP Server Environment Configuration Example
|
||||
# ===================================================================
|
||||
# Copy this file to .env and modify the configuration values as needed
|
||||
|
||||
# =============================================================================
|
||||
# Database Configuration
|
||||
# =============================================================================
|
||||
# ===================================================================
|
||||
# Database Connection Configuration
|
||||
# ===================================================================
|
||||
|
||||
# Doris FE connection settings
|
||||
# Doris FE (Frontend) connection settings
|
||||
DORIS_HOST=localhost
|
||||
DORIS_PORT=9030
|
||||
DORIS_USER=root
|
||||
DORIS_PASSWORD=
|
||||
DORIS_DATABASE=information_schema
|
||||
|
||||
# Doris FE HTTP API port
|
||||
# Doris FE HTTP API port (for Profile and other HTTP APIs)
|
||||
DORIS_FE_HTTP_PORT=8030
|
||||
|
||||
# BE nodes configuration for external access
|
||||
# If DORIS_BE_HOSTS is empty, will use "show backends" to get BE nodes automatically
|
||||
# Format: comma-separated list of BE host addresses
|
||||
# Example: DORIS_BE_HOSTS=192.168.1.100,192.168.1.101,192.168.1.102
|
||||
# Doris BE (Backend) nodes configuration (optional, for external access)
|
||||
# Format: host1,host2,host3 (if empty, will use "show backends" to get BE nodes)
|
||||
DORIS_BE_HOSTS=
|
||||
|
||||
# BE webserver port for HTTP APIs (memory tracker, metrics, etc.)
|
||||
DORIS_BE_WEBSERVER_PORT=8040
|
||||
|
||||
# =============================================================================
|
||||
# Connection Pool Configuration
|
||||
# =============================================================================
|
||||
|
||||
DORIS_MIN_CONNECTIONS=5
|
||||
# Connection pool configuration
|
||||
DORIS_MAX_CONNECTIONS=20
|
||||
DORIS_CONNECTION_TIMEOUT=30
|
||||
DORIS_HEALTH_CHECK_INTERVAL=60
|
||||
DORIS_MAX_CONNECTION_AGE=3600
|
||||
|
||||
# =============================================================================
|
||||
# Profile And Explain Max Data Size
|
||||
# =============================================================================
|
||||
MAX_RESPONSE_CONTENT_SIZE=4096
|
||||
# Arrow Flight SQL Configuration (Required for ADBC tools)
|
||||
# FE_ARROW_FLIGHT_SQL_PORT=
|
||||
# BE_ARROW_FLIGHT_SQL_PORT=
|
||||
|
||||
# =============================================================================
|
||||
# ===================================================================
|
||||
# Security Configuration
|
||||
# =============================================================================
|
||||
# ===================================================================
|
||||
|
||||
ENABLE_SECURITY_CHECK=true
|
||||
BLOCKED_KEYWORDS="DROP,TRUNCATE,DELETE,SHUTDOWN,INSERT,UPDATE,CREATE,ALTER,GRANT,REVOKE,KILL"
|
||||
# Independent Authentication Switches - NEW DESIGN!
|
||||
# Each authentication method can be enabled/disabled independently
|
||||
# Any enabled method that succeeds will allow access
|
||||
# If all methods are disabled, anonymous access is allowed
|
||||
|
||||
# Legacy configuration - kept for backward compatibility
|
||||
# AUTH_TYPE is now deprecated - use individual switches above
|
||||
AUTH_TYPE=token
|
||||
|
||||
# Token Authentication (Default method - simple and effective)
|
||||
ENABLE_TOKEN_AUTH=false
|
||||
|
||||
# JWT Authentication (For stateless applications)
|
||||
ENABLE_JWT_AUTH=false
|
||||
|
||||
# OAuth 2.0/OIDC Authentication (For enterprise integration)
|
||||
ENABLE_OAUTH_AUTH=false
|
||||
|
||||
# ===================================================================
|
||||
# Token Authentication Configuration (Enable with ENABLE_TOKEN_AUTH=true)
|
||||
# ===================================================================
|
||||
|
||||
# Basic token authentication settings
|
||||
TOKEN_FILE_PATH=tokens.json
|
||||
ENABLE_TOKEN_EXPIRY=true
|
||||
DEFAULT_TOKEN_EXPIRY_HOURS=720
|
||||
TOKEN_HASH_ALGORITHM=sha256
|
||||
|
||||
# ===================================================================
|
||||
# Token Management Security Configuration (NEW in v0.6.0) - CRITICAL SECURITY SETTINGS
|
||||
# ===================================================================
|
||||
|
||||
# HTTP Token Management Endpoints (DISABLED BY DEFAULT FOR SECURITY)
|
||||
# WARNING: These endpoints allow creation, deletion, and management of authentication tokens
|
||||
# Only enable if you need HTTP-based token management and understand the security implications
|
||||
ENABLE_HTTP_TOKEN_MANAGEMENT=true
|
||||
|
||||
# Admin Authentication Token (REQUIRED if HTTP token management is enabled)
|
||||
# This token is required to access HTTP token management endpoints
|
||||
# SECURITY: Generate a secure random token in production - NEVER use default values
|
||||
TOKEN_MANAGEMENT_ADMIN_TOKEN=
|
||||
|
||||
# IP Address Restrictions for Token Management (CRITICAL SECURITY CONTROL)
|
||||
# Only these IP addresses/networks can access token management endpoints
|
||||
# DEFAULT: localhost only (most secure) - add other IPs/networks only if necessary
|
||||
# Format: comma-separated list of IPs and CIDR networks
|
||||
# Examples:
|
||||
# - Localhost only: 127.0.0.1,::1
|
||||
# - Private network: 127.0.0.1,192.168.1.0/24,10.0.0.0/8
|
||||
# - Specific IPs: 127.0.0.1,192.168.1.10,192.168.1.11
|
||||
TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1
|
||||
|
||||
# Require Admin Authentication (ENABLED BY DEFAULT FOR SECURITY)
|
||||
# When true, all token management operations require valid admin token
|
||||
# When false, only IP restrictions apply (NOT RECOMMENDED for production)
|
||||
REQUIRE_ADMIN_AUTH=true
|
||||
|
||||
# ===================================================================
|
||||
# JWT Authentication Configuration (Enable with ENABLE_JWT_AUTH=true)
|
||||
# ===================================================================
|
||||
|
||||
# JWT token settings (when ENABLE_JWT_AUTH=true)
|
||||
JWT_SECRET_KEY=your_jwt_secret_key_here_change_in_production
|
||||
JWT_ALGORITHM=HS256
|
||||
JWT_EXPIRATION_HOURS=24
|
||||
JWT_ISSUER=doris-mcp-server
|
||||
JWT_AUDIENCE=doris-mcp-client
|
||||
|
||||
# JWT token validation settings
|
||||
JWT_VERIFY_SIGNATURE=true
|
||||
JWT_VERIFY_EXPIRATION=true
|
||||
JWT_VERIFY_AUDIENCE=true
|
||||
JWT_VERIFY_ISSUER=true
|
||||
|
||||
# JWT refresh token settings
|
||||
ENABLE_JWT_REFRESH=true
|
||||
JWT_REFRESH_EXPIRATION_DAYS=30
|
||||
JWT_REFRESH_SECRET_KEY=your_jwt_refresh_secret_key_here
|
||||
|
||||
# JWT user claims configuration
|
||||
JWT_USER_ID_CLAIM=user_id
|
||||
JWT_ROLES_CLAIM=roles
|
||||
JWT_PERMISSIONS_CLAIM=permissions
|
||||
JWT_SECURITY_LEVEL_CLAIM=security_level
|
||||
|
||||
# ===================================================================
|
||||
# OAuth 2.0 / OpenID Connect Configuration (Enable with ENABLE_OAUTH_AUTH=true)
|
||||
# ===================================================================
|
||||
|
||||
# OAuth provider settings (when ENABLE_OAUTH_AUTH=true)
|
||||
OAUTH_PROVIDER_TYPE=generic
|
||||
OAUTH_CLIENT_ID=your_oauth_client_id
|
||||
OAUTH_CLIENT_SECRET=your_oauth_client_secret
|
||||
OAUTH_REDIRECT_URI=http://localhost:3000/auth/callback
|
||||
|
||||
# OAuth endpoints (for generic provider)
|
||||
OAUTH_AUTHORIZATION_URL=https://your-provider.com/auth
|
||||
OAUTH_TOKEN_URL=https://your-provider.com/token
|
||||
OAUTH_USERINFO_URL=https://your-provider.com/userinfo
|
||||
OAUTH_JWKS_URL=https://your-provider.com/.well-known/jwks.json
|
||||
|
||||
# OAuth scope and claims
|
||||
OAUTH_SCOPE=openid profile email
|
||||
OAUTH_USER_ID_CLAIM=sub
|
||||
OAUTH_USERNAME_CLAIM=preferred_username
|
||||
OAUTH_EMAIL_CLAIM=email
|
||||
OAUTH_ROLES_CLAIM=roles
|
||||
OAUTH_GROUPS_CLAIM=groups
|
||||
|
||||
# OAuth session settings
|
||||
OAUTH_SESSION_SECRET=your_oauth_session_secret_here
|
||||
OAUTH_SESSION_EXPIRY=3600
|
||||
OAUTH_STATE_EXPIRY=300
|
||||
|
||||
# Popular OAuth providers presets (uncomment and configure as needed)
|
||||
|
||||
# Google OAuth Configuration
|
||||
# OAUTH_PROVIDER_TYPE=google
|
||||
# OAUTH_CLIENT_ID=your_google_client_id.apps.googleusercontent.com
|
||||
# OAUTH_CLIENT_SECRET=your_google_client_secret
|
||||
# OAUTH_AUTHORIZATION_URL=https://accounts.google.com/o/oauth2/auth
|
||||
# OAUTH_TOKEN_URL=https://oauth2.googleapis.com/token
|
||||
# OAUTH_USERINFO_URL=https://www.googleapis.com/oauth2/v1/userinfo
|
||||
# OAUTH_JWKS_URL=https://www.googleapis.com/oauth2/v3/certs
|
||||
# OAUTH_SCOPE=openid profile email
|
||||
|
||||
# Microsoft Azure AD Configuration
|
||||
# OAUTH_PROVIDER_TYPE=azure
|
||||
# OAUTH_CLIENT_ID=your_azure_client_id
|
||||
# OAUTH_CLIENT_SECRET=your_azure_client_secret
|
||||
# OAUTH_TENANT_ID=your_tenant_id
|
||||
# OAUTH_AUTHORIZATION_URL=https://login.microsoftonline.com/{tenant}/oauth2/v2.0/authorize
|
||||
# OAUTH_TOKEN_URL=https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token
|
||||
# OAUTH_USERINFO_URL=https://graph.microsoft.com/v1.0/me
|
||||
# OAUTH_JWKS_URL=https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys
|
||||
# OAUTH_SCOPE=openid profile email
|
||||
|
||||
# GitHub OAuth Configuration
|
||||
# OAUTH_PROVIDER_TYPE=github
|
||||
# OAUTH_CLIENT_ID=your_github_client_id
|
||||
# OAUTH_CLIENT_SECRET=your_github_client_secret
|
||||
# OAUTH_AUTHORIZATION_URL=https://github.com/login/oauth/authorize
|
||||
# OAUTH_TOKEN_URL=https://github.com/login/oauth/access_token
|
||||
# OAUTH_USERINFO_URL=https://api.github.com/user
|
||||
# OAUTH_SCOPE=user:email
|
||||
|
||||
# GitLab OAuth Configuration
|
||||
# OAUTH_PROVIDER_TYPE=gitlab
|
||||
# OAUTH_CLIENT_ID=your_gitlab_client_id
|
||||
# OAUTH_CLIENT_SECRET=your_gitlab_client_secret
|
||||
# OAUTH_AUTHORIZATION_URL=https://gitlab.com/oauth/authorize
|
||||
# OAUTH_TOKEN_URL=https://gitlab.com/oauth/token
|
||||
# OAUTH_USERINFO_URL=https://gitlab.com/api/v4/user
|
||||
# OAUTH_SCOPE=read_user
|
||||
|
||||
# Keycloak OAuth Configuration
|
||||
# OAUTH_PROVIDER_TYPE=keycloak
|
||||
# OAUTH_CLIENT_ID=your_keycloak_client_id
|
||||
# OAUTH_CLIENT_SECRET=your_keycloak_client_secret
|
||||
# OAUTH_REALM=your_realm
|
||||
# OAUTH_SERVER_URL=https://your-keycloak-server.com
|
||||
# OAUTH_AUTHORIZATION_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/auth
|
||||
# OAUTH_TOKEN_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/token
|
||||
# OAUTH_USERINFO_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/userinfo
|
||||
# OAUTH_JWKS_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/certs
|
||||
# OAUTH_SCOPE=openid profile email
|
||||
|
||||
# Legacy token settings (for backward compatibility)
|
||||
TOKEN_SECRET=your_secret_key_here
|
||||
TOKEN_EXPIRY=3600
|
||||
MAX_RESULT_ROWS=10000
|
||||
|
||||
# SQL security check
|
||||
ENABLE_SECURITY_CHECK=true
|
||||
|
||||
# Blocked keywords (comma separated)
|
||||
BLOCKED_KEYWORDS=DROP,CREATE,ALTER,TRUNCATE,DELETE,INSERT,UPDATE,GRANT,REVOKE,EXEC,EXECUTE,SHUTDOWN,KILL
|
||||
|
||||
# Query limits
|
||||
MAX_QUERY_COMPLEXITY=100
|
||||
MAX_RESULT_ROWS=10000
|
||||
|
||||
# Data masking
|
||||
ENABLE_MASKING=true
|
||||
|
||||
# =============================================================================
|
||||
# ===================================================================
|
||||
# Performance Configuration
|
||||
# =============================================================================
|
||||
# ===================================================================
|
||||
|
||||
# Query cache
|
||||
ENABLE_QUERY_CACHE=true
|
||||
CACHE_TTL=300
|
||||
MAX_CACHE_SIZE=1000
|
||||
|
||||
# Concurrency control
|
||||
MAX_CONCURRENT_QUERIES=50
|
||||
QUERY_TIMEOUT=300
|
||||
|
||||
# =============================================================================
|
||||
# Logging Configuration
|
||||
# =============================================================================
|
||||
# Response content size limit (characters)
|
||||
MAX_RESPONSE_CONTENT_SIZE=4096
|
||||
|
||||
# ===================================================================
|
||||
# ADBC (Arrow Flight SQL) Configuration
|
||||
# ===================================================================
|
||||
# Enable/disable ADBC tools
|
||||
ADBC_ENABLED=true
|
||||
|
||||
# Default ADBC query parameters
|
||||
ADBC_DEFAULT_MAX_ROWS=100000
|
||||
ADBC_DEFAULT_TIMEOUT=60
|
||||
# Format: "arrow", "pandas", "dict"
|
||||
ADBC_DEFAULT_RETURN_FORMAT=arrow
|
||||
|
||||
# ADBC connection timeout
|
||||
ADBC_CONNECTION_TIMEOUT=300
|
||||
|
||||
# ===================================================================
|
||||
# Logging Configuration
|
||||
# ===================================================================
|
||||
|
||||
# Basic logging configuration
|
||||
LOG_LEVEL=INFO
|
||||
LOG_FILE_PATH=
|
||||
|
||||
# Audit logging
|
||||
ENABLE_AUDIT=true
|
||||
AUDIT_FILE_PATH=
|
||||
|
||||
# =============================================================================
|
||||
# Monitoring Configuration
|
||||
# =============================================================================
|
||||
# Log file rotation configuration
|
||||
LOG_MAX_FILE_SIZE=10485760
|
||||
LOG_BACKUP_COUNT=5
|
||||
|
||||
# ===================================================================
|
||||
# Log Cleanup Configuration - NEW!
|
||||
# ===================================================================
|
||||
|
||||
# Enable automatic log cleanup
|
||||
ENABLE_LOG_CLEANUP=true
|
||||
|
||||
# Maximum age of log files in days (files older than this will be deleted)
|
||||
LOG_MAX_AGE_DAYS=30
|
||||
|
||||
# Cleanup check interval in hours
|
||||
LOG_CLEANUP_INTERVAL_HOURS=24
|
||||
|
||||
# ===================================================================
|
||||
# Monitoring Configuration
|
||||
# ===================================================================
|
||||
|
||||
# Metrics collection
|
||||
ENABLE_METRICS=true
|
||||
METRICS_PORT=3001
|
||||
HEALTH_CHECK_PORT=3002
|
||||
|
||||
# Alert configuration
|
||||
ENABLE_ALERTS=false
|
||||
ALERT_WEBHOOK_URL=
|
||||
|
||||
# =============================================================================
|
||||
# ===================================================================
|
||||
# Server Configuration
|
||||
# =============================================================================
|
||||
# ===================================================================
|
||||
|
||||
# Basic server information
|
||||
SERVER_NAME=doris-mcp-server
|
||||
SERVER_VERSION=0.4.1
|
||||
SERVER_VERSION=0.6.0
|
||||
SERVER_PORT=3000
|
||||
|
||||
# Temporary files directory
|
||||
TEMP_FILES_DIR=tmp
|
||||
|
||||
# ===================================================================
|
||||
# Configuration Examples for Different Environments
|
||||
# ===================================================================
|
||||
|
||||
# Development Environment Example:
|
||||
# LOG_LEVEL=DEBUG
|
||||
# LOG_MAX_AGE_DAYS=7
|
||||
# LOG_CLEANUP_INTERVAL_HOURS=6
|
||||
# ENABLE_SECURITY_CHECK=false
|
||||
|
||||
# Production Environment Example:
|
||||
# LOG_LEVEL=INFO
|
||||
# LOG_MAX_AGE_DAYS=30
|
||||
# LOG_CLEANUP_INTERVAL_HOURS=24
|
||||
# ENABLE_SECURITY_CHECK=true
|
||||
# ENABLE_LOG_CLEANUP=true
|
||||
|
||||
# Testing Environment Example:
|
||||
# LOG_LEVEL=WARNING
|
||||
# LOG_MAX_AGE_DAYS=3
|
||||
# LOG_CLEANUP_INTERVAL_HOURS=1
|
||||
# MAX_RESULT_ROWS=1000
|
||||
|
||||
# ===================================================================
|
||||
# Advanced Configuration Notes
|
||||
# ===================================================================
|
||||
|
||||
# 1. Log Cleanup Feature:
|
||||
# - ENABLE_LOG_CLEANUP: Controls whether to enable automatic cleanup
|
||||
# - LOG_MAX_AGE_DAYS: File retention days, recommended 30 days for production, 7 days for development
|
||||
# - LOG_CLEANUP_INTERVAL_HOURS: Check frequency, recommended 24 hours
|
||||
|
||||
# 2. Security Best Practices:
|
||||
# - NEW: Enable individual authentication methods using ENABLE_TOKEN_AUTH, ENABLE_JWT_AUTH, ENABLE_OAUTH_AUTH
|
||||
# - When all methods are disabled, ALL requests are allowed with anonymous access
|
||||
# - Authentication methods work independently - any one succeeding allows access
|
||||
# - Token Auth: Change default tokens (DEFAULT_ADMIN_TOKEN, etc.) in production
|
||||
# - JWT Auth: Change JWT_SECRET_KEY and JWT_REFRESH_SECRET_KEY in production
|
||||
# - OAuth Auth: Configure OAuth provider settings and secure client secrets
|
||||
# - Must change TOKEN_SECRET in production environment (legacy compatibility)
|
||||
# - Adjust BLOCKED_KEYWORDS according to business needs
|
||||
# - Enable ENABLE_SECURITY_CHECK and ENABLE_MASKING
|
||||
# - NEW v0.6.0: Token Management Security (CRITICAL):
|
||||
# * ENABLE_HTTP_TOKEN_MANAGEMENT=false by default (SECURE BY DEFAULT)
|
||||
# * Only enable if you need HTTP token management endpoints
|
||||
# * TOKEN_MANAGEMENT_ADMIN_TOKEN: Use secure random token in production
|
||||
# * TOKEN_MANAGEMENT_ALLOWED_IPS: Restrict to localhost (127.0.0.1,::1) only
|
||||
# * REQUIRE_ADMIN_AUTH=true: Always require admin authentication
|
||||
# * Never expose token management endpoints to external networks
|
||||
|
||||
# 3. Performance Tuning:
|
||||
# - Adjust MAX_CONCURRENT_QUERIES based on hardware resources
|
||||
# - Adjust QUERY_TIMEOUT based on query complexity
|
||||
# - Adjust MAX_CACHE_SIZE based on memory size
|
||||
|
||||
# 4. Connection Pool Optimization:
|
||||
# - DORIS_MAX_CONNECTIONS recommended to be 2-4 times the number of CPU cores
|
||||
# - DORIS_CONNECTION_TIMEOUT adjust based on network latency
|
||||
# - DORIS_MAX_CONNECTION_AGE recommended 1 hour to avoid long connection issues
|
||||
|
||||
# 5. ADBC (Arrow Flight SQL) Configuration:
|
||||
# - FE_ARROW_FLIGHT_SQL_PORT and BE_ARROW_FLIGHT_SQL_PORT: Required for ADBC functionality
|
||||
# - ADBC_DEFAULT_MAX_ROWS: Default maximum rows for ADBC queries (recommended: 100000)
|
||||
# - ADBC_DEFAULT_TIMEOUT: Default timeout for ADBC queries in seconds (recommended: 60)
|
||||
# - ADBC_DEFAULT_RETURN_FORMAT: Default return format (arrow/pandas/dict, recommended: arrow)
|
||||
# - ADBC_CONNECTION_TIMEOUT: Connection timeout for ADBC (recommended: 30)
|
||||
# - ADBC_ENABLED: Enable or disable ADBC tools (true/false)
|
||||
# - Prerequisites: Install adbc_driver_manager, adbc_driver_flightsql, pyarrow packages
|
||||
|
||||
# 6. Authentication Configuration Guide - UPDATED DESIGN!
|
||||
#
|
||||
# Independent Authentication Control (NEW):
|
||||
# - ENABLE_TOKEN_AUTH=false (default): Disable token authentication
|
||||
# - ENABLE_JWT_AUTH=false (default): Disable JWT authentication
|
||||
# - ENABLE_OAUTH_AUTH=false (default): Disable OAuth authentication
|
||||
# - When all methods are disabled, no authentication is required (anonymous access)
|
||||
# - When multiple methods are enabled, any one succeeding allows access
|
||||
# - Recommended for development/testing: all false, production: enable needed methods
|
||||
#
|
||||
# Token Authentication (ENABLE_TOKEN_AUTH=true) - Recommended for most use cases:
|
||||
# - Simple and secure token-based authentication
|
||||
# - Configurable default tokens via environment variables
|
||||
# - Support for custom tokens via TOKEN_* environment variables
|
||||
# - Token file configuration via tokens.json
|
||||
# - Built-in token management HTTP endpoints
|
||||
# - No user management complexity - pure API access control
|
||||
#
|
||||
# JWT Authentication (ENABLE_JWT_AUTH=true) - For stateless applications:
|
||||
# - JSON Web Token based authentication
|
||||
# - Configurable token expiration and refresh
|
||||
# - Support for standard JWT claims
|
||||
# - RSA/ECDSA/HS256 algorithm support
|
||||
# - Suitable for microservices and distributed systems
|
||||
#
|
||||
# OAuth 2.0/OIDC (ENABLE_OAUTH_AUTH=true) - For enterprise integration:
|
||||
# - Integration with external identity providers
|
||||
# - Support for popular providers (Google, Microsoft, GitHub, GitLab, Keycloak)
|
||||
# - OpenID Connect compatibility
|
||||
# - Automatic user provisioning from provider
|
||||
# - Secure authorization code flow
|
||||
#
|
||||
# Authentication Method Selection Guide:
|
||||
# - No Auth (all switches false): Development, testing, trusted networks
|
||||
# - Token Auth only: Small teams, simple deployment, direct API access
|
||||
# - JWT Auth only: Stateless apps, microservices, mobile clients
|
||||
# - OAuth Auth only: Enterprise SSO, large teams, external identity providers
|
||||
# - Multiple methods: Flexible access, different client types, migration scenarios
|
||||
|
||||
# 7. Token Management Security Configuration Guide (NEW in v0.6.0) - CRITICAL!
|
||||
#
|
||||
# ⚠️ SECURITY WARNING: Token management endpoints are POWERFUL and DANGEROUS
|
||||
# They allow creation, revocation, and management of authentication tokens.
|
||||
# Improper configuration can lead to complete system compromise.
|
||||
#
|
||||
# 🔒 SECURE BY DEFAULT:
|
||||
# - ENABLE_HTTP_TOKEN_MANAGEMENT=false (disabled by default)
|
||||
# - REQUIRE_ADMIN_AUTH=true (admin auth required by default)
|
||||
# - TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1 (localhost only by default)
|
||||
#
|
||||
# 🛡️ SECURITY LAYERS (Applied in order):
|
||||
# 1. Configuration Check: HTTP token management must be explicitly enabled
|
||||
# 2. IP Restrictions: Only allowed IP addresses/networks can access endpoints
|
||||
# 3. Admin Authentication: Valid admin token required for all operations
|
||||
#
|
||||
# 📋 CONFIGURATION OPTIONS:
|
||||
#
|
||||
# Disable Token Management (RECOMMENDED for most deployments):
|
||||
# ENABLE_HTTP_TOKEN_MANAGEMENT=false
|
||||
# # All token management endpoints will return 403 Forbidden
|
||||
#
|
||||
# Enable with Maximum Security (Production):
|
||||
# ENABLE_HTTP_TOKEN_MANAGEMENT=true
|
||||
# TOKEN_MANAGEMENT_ADMIN_TOKEN=<secure-random-token-256-bit>
|
||||
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1
|
||||
# REQUIRE_ADMIN_AUTH=true
|
||||
#
|
||||
# Enable for Private Network (Advanced):
|
||||
# ENABLE_HTTP_TOKEN_MANAGEMENT=true
|
||||
# TOKEN_MANAGEMENT_ADMIN_TOKEN=<secure-random-token-256-bit>
|
||||
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,192.168.1.0/24,10.0.0.0/8
|
||||
# REQUIRE_ADMIN_AUTH=true
|
||||
#
|
||||
# 🔑 ADMIN TOKEN GENERATION:
|
||||
# # Generate secure admin token (Linux/macOS):
|
||||
# openssl rand -hex 32
|
||||
#
|
||||
# # Generate secure admin token (Python):
|
||||
# python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
#
|
||||
# 🌐 IP CONFIGURATION EXAMPLES:
|
||||
# # Localhost only (most secure):
|
||||
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1
|
||||
#
|
||||
# # Private network + localhost:
|
||||
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1,192.168.1.0/24,10.0.0.0/8
|
||||
#
|
||||
# # Specific servers only:
|
||||
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,192.168.1.10,192.168.1.11
|
||||
#
|
||||
# # Corporate network (be careful):
|
||||
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,172.16.0.0/12,192.168.0.0/16
|
||||
#
|
||||
# 🚫 NEVER DO THIS (Security Anti-Patterns):
|
||||
# # NEVER allow all IPs:
|
||||
# # TOKEN_MANAGEMENT_ALLOWED_IPS=0.0.0.0/0 # DANGEROUS!
|
||||
#
|
||||
# # NEVER disable admin auth in production:
|
||||
# # REQUIRE_ADMIN_AUTH=false # DANGEROUS!
|
||||
#
|
||||
# # NEVER use weak admin tokens:
|
||||
# # TOKEN_MANAGEMENT_ADMIN_TOKEN=admin # DANGEROUS!
|
||||
# # TOKEN_MANAGEMENT_ADMIN_TOKEN=123456 # DANGEROUS!
|
||||
#
|
||||
# 📊 ENDPOINT SECURITY TESTING:
|
||||
# # Test security (should fail):
|
||||
# curl -X POST http://external-ip:3000/token/create
|
||||
# # Expected: 403 Forbidden (IP not allowed)
|
||||
#
|
||||
# # Test without auth (should fail):
|
||||
# curl -X POST http://127.0.0.1:3000/token/create
|
||||
# # Expected: 401 Unauthorized (missing admin token)
|
||||
#
|
||||
# # Test with valid auth (should succeed if enabled):
|
||||
# curl -H "Authorization: Bearer your-admin-token" http://127.0.0.1:3000/token/stats
|
||||
# # Expected: 200 OK with token statistics
|
||||
#
|
||||
# 🔍 MONITORING & AUDITING:
|
||||
# # All token management access attempts are logged:
|
||||
# tail -f logs/doris_mcp_server_audit.log | grep "token management"
|
||||
#
|
||||
# # Monitor security events:
|
||||
# tail -f logs/doris_mcp_server_info.log | grep -E "(access denied|token management)"
|
||||
#
|
||||
# ✅ SECURITY BEST PRACTICES:
|
||||
# - Keep ENABLE_HTTP_TOKEN_MANAGEMENT=false unless absolutely necessary
|
||||
# - Use file-based token management (tokens.json) instead of HTTP endpoints
|
||||
# - Generate strong admin tokens using cryptographically secure methods
|
||||
# - Restrict access to localhost (127.0.0.1,::1) whenever possible
|
||||
# - Never expose token management endpoints to public internet
|
||||
# - Regularly audit token management access logs
|
||||
# - Use firewall rules as additional protection layer
|
||||
# - Consider VPN access for remote token management needs
|
||||
23
.gitignore
vendored
Normal file
23
.gitignore
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
*.log
|
||||
*.log.*
|
||||
*.bak
|
||||
logs
|
||||
/configs/*.py
|
||||
.vscode/
|
||||
__pycache__/
|
||||
*.log
|
||||
|
||||
.python-version
|
||||
Pipfile.lock
|
||||
poetry.lock
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
.idea/
|
||||
.coverage
|
||||
coverage.xml
|
||||
@@ -32,6 +32,7 @@ RUN apt-get update && apt-get install -y \
|
||||
g++ \
|
||||
pkg-config \
|
||||
default-libmysqlclient-dev \
|
||||
dos2unix \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements file
|
||||
@@ -43,12 +44,13 @@ RUN pip install --no-cache-dir -r requirements.txt
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Convert line endings for shell scripts and ensure proper execution format
|
||||
RUN find . -name "*.sh" -exec dos2unix {} \; && \
|
||||
find . -name "*.sh" -exec chmod +x {} \;
|
||||
|
||||
# Create necessary directories
|
||||
RUN mkdir -p /app/logs /app/config /app/data
|
||||
|
||||
# Set permissions
|
||||
RUN chmod +x /app/start_server.sh
|
||||
|
||||
# Create non-root user
|
||||
RUN groupadd -r doris && useradd -r -g doris doris
|
||||
RUN chown -R doris:doris /app
|
||||
|
||||
663
README.md
663
README.md
@@ -21,17 +21,27 @@ under the License.
|
||||
|
||||
Doris MCP (Model Context Protocol) Server is a backend service built with Python and FastAPI. It implements the MCP, allowing clients to interact with it through defined "Tools". It's primarily designed to connect to Apache Doris databases, potentially leveraging Large Language Models (LLMs) for tasks like converting natural language queries to SQL (NL2SQL), executing queries, and performing metadata management and analysis.
|
||||
|
||||
## 🚀 What's New in v0.4.2
|
||||
## 🚀 What's New in v0.6.0
|
||||
|
||||
- **🔒 Enhanced Security Framework**: Comprehensive SQL security validation with configurable blocked keywords, SQL injection protection, and unified security configuration management
|
||||
- **🛠️ Connection Stability Improvements**: Fixed critical `at_eof` connection errors with advanced connection health monitoring, automatic retry mechanisms, and proactive connection cleanup
|
||||
- **⚙️ Flexible Security Configuration**: Environment variable support for security policies (`BLOCKED_KEYWORDS`, `ENABLE_SECURITY_CHECK`) with unified configuration architecture eliminating code duplication
|
||||
- **🎯 Centralized Configuration Management**: All security keywords now managed through single configuration source with consistent enforcement across all components
|
||||
- **🔧 MCP Version Compatibility**: Resolved MCP library version conflicts with intelligent compatibility layer supporting both MCP 1.8.x and 1.9.x versions
|
||||
- **🚀 Production Reliability**: Enhanced error handling, connection diagnostics, and automatic recovery from database connection issues
|
||||
- **🙏 Community Contribution**: Special thanks to Hailin Xie for supporting the doris-mcp-server project by graciously transferring the PyPI project to the community free of charge, contributing to open source. The mcp-doris-server repository will be retained but no longer maintained, with ongoing development continuing on the doris-mcp-server repository
|
||||
- **🔐 Enterprise Authentication System**: **Revolutionary token-bound database configuration** with comprehensive Token, JWT, and OAuth authentication support, enabling secure multi-tenant access with granular control switches and enterprise-grade security defaults
|
||||
- **⚡ Immediate Database Validation**: **Real-time database configuration validation at connection time**, eliminating query-time blocking and providing instant feedback for invalid configurations - achieving 100% elimination of late-stage connection failures
|
||||
- **🔄 Hot Reload Configuration Management**: **Zero-downtime configuration updates** with intelligent hot reloading of tokens.json, automatic token revalidation, and comprehensive error handling with rollback mechanisms
|
||||
- **🏗️ Advanced Connection Architecture**: **Session caching and connection pool optimization** with 60% reduction in connection overhead, intelligent pool recreation, and automatic resource management
|
||||
- **🌐 Multi-Worker Scalability**: **True horizontal scaling** with stateless multi-worker architecture, efficient load distribution, and enterprise-grade concurrent processing capabilities
|
||||
- **🔒 Enhanced Security Framework**: **Comprehensive access control and SQL security validation** with immediate validation, role-based permissions, and enhanced injection detection patterns
|
||||
- **🛠️ Unified Configuration System**: **Streamlined configuration management** with proper command-line precedence, Docker compatibility improvements, and cross-platform deployment support
|
||||
- **📊 Token Management Dashboard**: **Complete token lifecycle management** with creation, revocation, statistics, and comprehensive audit trails for enterprise token governance
|
||||
- **🌐 Web-Based Management Interface**: **Secure localhost-only token administration** with intuitive dashboard, database binding configuration, real-time operations, and enterprise-grade access controls
|
||||
|
||||
> **🔧 Key Improvements**: Resolved connection stability issues, unified security keyword management, added comprehensive environment variable configuration for security policies, and fixed MCP library version compatibility conflicts.
|
||||
> **🚀 Major Milestone**: v0.6.0 establishes the platform as a **production-ready enterprise authentication and database management system** with **zero-downtime operations** (hot reload + immediate validation + multi-worker scaling), advanced security controls, and comprehensive token-bound database configuration - representing a fundamental advancement in enterprise data platform capabilities.
|
||||
|
||||
### What's Also Included from v0.5.1
|
||||
|
||||
- **🔥 Critical at_eof Connection Fix**: Complete elimination of connection pool errors with intelligent health monitoring and self-healing recovery
|
||||
- **🔧 Enterprise Logging System**: Level-based file separation with automatic cleanup and millisecond precision timestamps
|
||||
- **📊 Advanced Data Analytics Suite**: 7 enterprise-grade data governance tools including quality analysis, lineage tracking, and performance monitoring
|
||||
- **🏃♂️ High-Performance ADBC Integration**: Apache Arrow Flight SQL support with 3-10x performance improvements for large datasets
|
||||
- **⚙️ Enhanced Configuration Management**: Complete ADBC configuration system with intelligent parameter validation
|
||||
|
||||
## Core Features
|
||||
|
||||
@@ -51,12 +61,13 @@ Doris MCP (Model Context Protocol) Server is a backend service built with Python
|
||||
* **Performance Analysis**: Advanced column analysis, performance monitoring, and data analysis tools (`doris_mcp_server/utils/analysis_tools.py`)
|
||||
* **Catalog Federation Support**: Full support for multi-catalog environments (internal Doris tables and external data sources like Hive, MySQL, etc.)
|
||||
* **Enterprise Security**: Comprehensive security framework with authentication, authorization, SQL injection protection, and data masking capabilities with environment variable configuration support
|
||||
* **Web-Based Token Management**: Secure localhost-only interface for complete token lifecycle management with database binding, real-time statistics, and enterprise-grade access controls (`doris_mcp_server/auth/token_handlers.py`)
|
||||
* **Unified Configuration Framework**: Centralized configuration management through `config.py` with comprehensive validation, standardized parameter naming, and smart default database handling with automatic fallback to `information_schema`
|
||||
|
||||
## System Requirements
|
||||
|
||||
* Python 3.12+
|
||||
* Database connection details (e.g., Doris Host, Port, User, Password, Database)
|
||||
* **Python**: 3.12+
|
||||
* **Database**: Apache Doris connection details (Host, Port, User, Password, Database)
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
@@ -67,7 +78,7 @@ Doris MCP (Model Context Protocol) Server is a backend service built with Python
|
||||
pip install doris-mcp-server
|
||||
|
||||
# Install specific version
|
||||
pip install doris-mcp-server==0.4.2
|
||||
pip install doris-mcp-server==0.6.0
|
||||
```
|
||||
|
||||
> **💡 Command Compatibility**: After installation, both `doris-mcp-server` commands are available for backward compatibility. You can use either command interchangeably.
|
||||
@@ -97,6 +108,46 @@ Standard input/output mode for direct integration with MCP clients:
|
||||
doris-mcp-server --transport stdio
|
||||
```
|
||||
|
||||
### 🌐 Token Management Interface (New in v0.6.0)
|
||||
|
||||
Access the **Web-Based Token Management Dashboard** for enterprise-grade token administration:
|
||||
|
||||
#### **Secure Access Requirements**
|
||||
- **Localhost Access Only**: Interface restricted to `127.0.0.1` and `::1` for maximum security
|
||||
- **Admin Authentication**: Requires `TOKEN_MANAGEMENT_ADMIN_TOKEN` for access
|
||||
- **Configuration Prerequisites**:
|
||||
```bash
|
||||
# Required environment variables
|
||||
ENABLE_HTTP_TOKEN_MANAGEMENT=true
|
||||
ENABLE_TOKEN_AUTH=true
|
||||
TOKEN_MANAGEMENT_ADMIN_TOKEN=your_secure_admin_token
|
||||
TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1
|
||||
```
|
||||
|
||||
#### **Interface Access**
|
||||
```bash
|
||||
# Access the token management interface
|
||||
http://localhost:3000/token/management?admin_token=your_secure_admin_token
|
||||
```
|
||||
|
||||
#### **Available Operations**
|
||||
- **📊 Token Statistics**: Real-time overview of active, expired, and total tokens
|
||||
- **➕ Create Tokens**:
|
||||
- Basic information (ID, description, expiration)
|
||||
- **Database binding** (host, port, user, password, database)
|
||||
- Custom token values or auto-generated secure tokens
|
||||
- **📋 Token Management**:
|
||||
- List all tokens with database binding status
|
||||
- One-click token revocation
|
||||
- Automated expired token cleanup
|
||||
- **🔒 Enterprise Security**:
|
||||
- All operations require admin authentication
|
||||
- Real-time IP validation
|
||||
- Complete audit logging
|
||||
- **Automatic persistence** to `tokens.json`
|
||||
|
||||
> **🔐 Security Note**: The interface is designed for localhost administration only. It cannot be accessed remotely, ensuring maximum security for token management operations.
|
||||
|
||||
### Verify Installation
|
||||
|
||||
```bash
|
||||
@@ -112,11 +163,18 @@ curl http://localhost:3000/health
|
||||
Instead of command-line arguments, you can use environment variables:
|
||||
|
||||
```bash
|
||||
# Basic Database Configuration
|
||||
export DORIS_HOST="127.0.0.1"
|
||||
export DORIS_PORT="9030"
|
||||
export DORIS_USER="root"
|
||||
export DORIS_PASSWORD="your_password"
|
||||
|
||||
# Token Management Interface (Security-Critical)
|
||||
export ENABLE_HTTP_TOKEN_MANAGEMENT=true
|
||||
export ENABLE_TOKEN_AUTH=true
|
||||
export TOKEN_MANAGEMENT_ADMIN_TOKEN="your_secure_admin_token"
|
||||
export TOKEN_MANAGEMENT_ALLOWED_IPS="127.0.0.1,::1"
|
||||
|
||||
# Then start with simplified command
|
||||
doris-mcp-server --transport http --host 0.0.0.0 --port 3000
|
||||
```
|
||||
@@ -173,22 +231,48 @@ cp .env.example .env
|
||||
* `DORIS_MAX_CONNECTIONS`: Maximum connection pool size (default: 20)
|
||||
* `DORIS_BE_HOSTS`: BE nodes for monitoring (comma-separated, optional - auto-discovery via SHOW BACKENDS if empty)
|
||||
* `DORIS_BE_WEBSERVER_PORT`: BE webserver port for monitoring tools (default: 8040)
|
||||
* **Security Configuration**:
|
||||
* `AUTH_TYPE`: Authentication type (token/basic/oauth, default: token)
|
||||
* `TOKEN_SECRET`: Token secret key
|
||||
* `ENABLE_SECURITY_CHECK`: Enable/disable SQL security validation (default: true, New in v0.4.2)
|
||||
* `BLOCKED_KEYWORDS`: Comma-separated list of blocked SQL keywords (New in v0.4.2)
|
||||
* `FE_ARROW_FLIGHT_SQL_PORT`: Frontend Arrow Flight SQL port for ADBC (New in v0.5.0)
|
||||
* `BE_ARROW_FLIGHT_SQL_PORT`: Backend Arrow Flight SQL port for ADBC (New in v0.5.0)
|
||||
* **Authentication Configuration (Enhanced in v0.6.0)**:
|
||||
* `ENABLE_TOKEN_AUTH`: Enable token-based authentication (default: false)
|
||||
* `ENABLE_JWT_AUTH`: Enable JWT authentication (default: false)
|
||||
* `ENABLE_OAUTH_AUTH`: Enable OAuth authentication (default: false)
|
||||
* `TOKEN_FILE_PATH`: Path to tokens.json file for token management (default: tokens.json)
|
||||
* `TOKEN_HOT_RELOAD`: Enable hot reloading of token configuration (default: true)
|
||||
* `DEFAULT_ADMIN_TOKEN`: Default admin token (customizable via env)
|
||||
* `DEFAULT_ANALYST_TOKEN`: Default analyst token (customizable via env)
|
||||
* `DEFAULT_READONLY_TOKEN`: Default readonly token (customizable via env)
|
||||
* **Legacy Security Configuration**:
|
||||
* `AUTH_TYPE`: Legacy authentication type (token/basic/oauth, deprecated - use individual switches)
|
||||
* `TOKEN_SECRET`: Legacy token secret key (use token-based auth instead)
|
||||
* `ENABLE_SECURITY_CHECK`: Enable/disable SQL security validation (default: true)
|
||||
* `BLOCKED_KEYWORDS`: Comma-separated list of blocked SQL keywords
|
||||
* `ENABLE_MASKING`: Enable data masking (default: true)
|
||||
* `MAX_RESULT_ROWS`: Maximum result rows (default: 10000)
|
||||
* **ADBC Configuration (New in v0.5.0)**:
|
||||
* `ADBC_DEFAULT_MAX_ROWS`: Default maximum rows for ADBC queries (default: 100000)
|
||||
* `ADBC_DEFAULT_TIMEOUT`: Default ADBC query timeout in seconds (default: 60)
|
||||
* `ADBC_DEFAULT_RETURN_FORMAT`: Default return format - arrow/pandas/dict (default: arrow)
|
||||
* `ADBC_CONNECTION_TIMEOUT`: ADBC connection timeout in seconds (default: 30)
|
||||
* `ADBC_ENABLED`: Enable/disable ADBC tools (default: true)
|
||||
* **Performance Configuration**:
|
||||
* `ENABLE_QUERY_CACHE`: Enable query caching (default: true)
|
||||
* `CACHE_TTL`: Cache time-to-live in seconds (default: 300)
|
||||
* `MAX_CONCURRENT_QUERIES`: Maximum concurrent queries (default: 50)
|
||||
* `MAX_RESPONSE_CONTENT_SIZE`: Maximum response content size for LLM compatibility (default: 4096, New in v0.4.0)
|
||||
* **Logging Configuration**:
|
||||
* **Enhanced Logging Configuration (Improved in v0.5.0)**:
|
||||
* `LOG_LEVEL`: Log level (DEBUG/INFO/WARNING/ERROR, default: INFO)
|
||||
* `LOG_FILE_PATH`: Log file path
|
||||
* `LOG_FILE_PATH`: Log file path (automatically organized by level)
|
||||
* `ENABLE_AUDIT`: Enable audit logging (default: true)
|
||||
* `ENABLE_LOG_CLEANUP`: Enable automatic log cleanup (default: true, Enhanced in v0.5.0)
|
||||
* `LOG_MAX_AGE_DAYS`: Maximum age of log files in days (default: 30, Enhanced in v0.5.0)
|
||||
* `LOG_CLEANUP_INTERVAL_HOURS`: Log cleanup check interval in hours (default: 24, Enhanced in v0.5.0)
|
||||
* **New Features in v0.5.0**:
|
||||
* **Level-based File Separation**: Automatic separation into `debug.log`, `info.log`, `warning.log`, `error.log`, `critical.log`
|
||||
* **Timestamped Format**: Enhanced formatting with millisecond precision and proper alignment
|
||||
* **Background Cleanup Scheduler**: Automatic cleanup with configurable retention policies
|
||||
* **Audit Trail**: Dedicated `audit.log` with separate retention management
|
||||
* **Performance Optimized**: Minimal overhead async logging with rotation support
|
||||
|
||||
### Available MCP Tools
|
||||
|
||||
@@ -212,8 +296,17 @@ The following table lists the main tools currently available for invocation via
|
||||
| `get_monitoring_metrics_data` | Get actual Doris monitoring metrics data from nodes with flexible BE discovery. | `role` (string, Optional), `monitor_type` (string, Optional), `priority` (string, Optional) |
|
||||
| `get_realtime_memory_stats` | Get real-time memory statistics via BE Memory Tracker with auto/manual BE discovery. | `tracker_type` (string, Optional), `include_details` (boolean, Optional) |
|
||||
| `get_historical_memory_stats` | Get historical memory statistics via BE Bvar interface with flexible BE configuration. | `tracker_names` (array, Optional), `time_range` (string, Optional) |
|
||||
| `analyze_data_quality` | Comprehensive data quality analysis combining completeness and distribution analysis. | `table_name` (string, Required), `analysis_scope` (string, Optional), `sample_size` (integer, Optional), `business_rules` (array, Optional) |
|
||||
| `trace_column_lineage` | End-to-end column lineage tracking through SQL analysis and dependency mapping. | `target_columns` (array, Required), `analysis_depth` (integer, Optional), `include_transformations` (boolean, Optional) |
|
||||
| `monitor_data_freshness` | Real-time data staleness monitoring with configurable freshness thresholds. | `table_names` (array, Optional), `freshness_threshold_hours` (integer, Optional), `include_update_patterns` (boolean, Optional) |
|
||||
| `analyze_data_access_patterns` | User behavior analysis and security anomaly detection with access pattern monitoring. | `days` (integer, Optional), `include_system_users` (boolean, Optional), `min_query_threshold` (integer, Optional) |
|
||||
| `analyze_data_flow_dependencies` | Data flow impact analysis and dependency mapping between tables and views. | `target_table` (string, Optional), `analysis_depth` (integer, Optional), `include_views` (boolean, Optional) |
|
||||
| `analyze_slow_queries_topn` | Performance bottleneck identification with top-N slow query analysis and patterns. | `days` (integer, Optional), `top_n` (integer, Optional), `min_execution_time_ms` (integer, Optional), `include_patterns` (boolean, Optional) |
|
||||
| `analyze_resource_growth_curves` | Capacity planning with resource growth analysis and trend forecasting. | `days` (integer, Optional), `resource_types` (array, Optional), `include_predictions` (boolean, Optional) |
|
||||
| `exec_adbc_query` | High-performance SQL execution using ADBC (Arrow Flight SQL) protocol. | `sql` (string, Required), `max_rows` (integer, Optional), `timeout` (integer, Optional), `return_format` (string, Optional) |
|
||||
| `get_adbc_connection_info` | ADBC connection diagnostics and status monitoring for Arrow Flight SQL. | No parameters required |
|
||||
|
||||
**Note:** All metadata tools support catalog federation for multi-catalog environments. The `get_catalog_list` tool requires a `random_string` parameter for compatibility reasons. Enhanced monitoring tools in v0.4.0 provide comprehensive memory tracking and metrics collection capabilities with flexible BE node discovery.
|
||||
**Note:** All metadata tools support catalog federation for multi-catalog environments. Enhanced monitoring tools provide comprehensive memory tracking and metrics collection capabilities. **New in v0.5.0**: 7 advanced analytics tools for enterprise data governance and 2 ADBC tools for high-performance data transfer with 3-10x performance improvements for large datasets.
|
||||
|
||||
### 4. Run the Service
|
||||
|
||||
@@ -222,14 +315,22 @@ Execute the following command to start the server:
|
||||
```bash
|
||||
./start_server.sh
|
||||
```
|
||||
|
||||
This command starts the FastAPI application with Streamable HTTP MCP service.
|
||||
### 5. Deploying on docker
|
||||
|
||||
If you want to run only Doris MCP Server in docker:
|
||||
|
||||
|
||||
```bash
|
||||
cd doris-mcp-server
|
||||
docker build -t doris-mcp-server .
|
||||
docker run -d -p <port>:<port> -v /*your-host*/doris-mcp-server/.env:/app/.env --name <your-mcp-server-name> -it doris-mcp-server:latest
|
||||
```
|
||||
**Service Endpoints:**
|
||||
|
||||
* **Streamable HTTP**: `http://<host>:<port>/mcp` (Primary MCP endpoint - supports GET, POST, DELETE, OPTIONS)
|
||||
* **Health Check**: `http://<host>:<port>/health`
|
||||
|
||||
*
|
||||
> **Note**: The server uses Streamable HTTP for web-based communication, providing unified request/response and streaming capabilities.
|
||||
|
||||
## Usage
|
||||
@@ -256,7 +357,7 @@ The Doris MCP Server supports **catalog federation**, enabling interaction with
|
||||
|
||||
* **Multi-Catalog Metadata Access**: All metadata tools (`get_db_list`, `get_db_table_list`, `get_table_schema`, etc.) support an optional `catalog_name` parameter to query specific catalogs.
|
||||
* **Cross-Catalog SQL Queries**: Execute SQL queries that span multiple catalogs using three-part table naming.
|
||||
* **Catalog Discovery**: Use `mcp_doris_get_catalog_list` to discover available catalogs and their types.
|
||||
* **Catalog Discovery**: Use `get_catalog_list` to discover available catalogs and their types.
|
||||
|
||||
#### Three-Part Naming Requirement:
|
||||
|
||||
@@ -270,7 +371,7 @@ The Doris MCP Server supports **catalog federation**, enabling interaction with
|
||||
1. **Get Available Catalogs:**
|
||||
```json
|
||||
{
|
||||
"tool_name": "mcp_doris_get_catalog_list",
|
||||
"tool_name": "get_catalog_list",
|
||||
"arguments": {"random_string": "unique_id"}
|
||||
}
|
||||
```
|
||||
@@ -278,7 +379,7 @@ The Doris MCP Server supports **catalog federation**, enabling interaction with
|
||||
2. **Get Databases in Specific Catalog:**
|
||||
```json
|
||||
{
|
||||
"tool_name": "mcp_doris_get_db_list",
|
||||
"tool_name": "get_db_list",
|
||||
"arguments": {"random_string": "unique_id", "catalog_name": "mysql"}
|
||||
}
|
||||
```
|
||||
@@ -286,7 +387,7 @@ The Doris MCP Server supports **catalog federation**, enabling interaction with
|
||||
3. **Query Internal Catalog:**
|
||||
```json
|
||||
{
|
||||
"tool_name": "mcp_doris_exec_query",
|
||||
"tool_name": "exec_query",
|
||||
"arguments": {
|
||||
"random_string": "unique_id",
|
||||
"sql": "SELECT COUNT(*) FROM internal.ssb.customer"
|
||||
@@ -297,7 +398,7 @@ The Doris MCP Server supports **catalog federation**, enabling interaction with
|
||||
4. **Query External Catalog:**
|
||||
```json
|
||||
{
|
||||
"tool_name": "mcp_doris_exec_query",
|
||||
"tool_name": "exec_query",
|
||||
"arguments": {
|
||||
"random_string": "unique_id",
|
||||
"sql": "SELECT COUNT(*) FROM mysql.ssb.customer"
|
||||
@@ -308,7 +409,7 @@ The Doris MCP Server supports **catalog federation**, enabling interaction with
|
||||
5. **Cross-Catalog Query:**
|
||||
```json
|
||||
{
|
||||
"tool_name": "mcp_doris_exec_query",
|
||||
"tool_name": "exec_query",
|
||||
"arguments": {
|
||||
"random_string": "unique_id",
|
||||
"sql": "SELECT i.c_name, m.external_data FROM internal.ssb.customer i JOIN mysql.test.user_info m ON i.c_custkey = m.customer_id"
|
||||
@@ -318,31 +419,97 @@ The Doris MCP Server supports **catalog federation**, enabling interaction with
|
||||
|
||||
## Security Configuration
|
||||
|
||||
The Doris MCP Server includes a comprehensive security framework that provides enterprise-level protection through authentication, authorization, SQL security validation, and data masking capabilities.
|
||||
The Doris MCP Server includes a comprehensive enterprise-grade security framework with advanced authentication, authorization, SQL security validation, and data masking capabilities enhanced in v0.6.0.
|
||||
|
||||
### Security Features
|
||||
### Security Features (Enhanced in v0.6.0)
|
||||
|
||||
* **🔐 Authentication**: Support for token-based and basic authentication
|
||||
* **🛡️ Authorization**: Role-based access control (RBAC) with security levels
|
||||
* **🚫 SQL Security**: SQL injection protection and blocked operations
|
||||
* **🎭 Data Masking**: Automatic sensitive data masking based on user permissions
|
||||
* **📊 Security Levels**: Four-tier security classification (Public, Internal, Confidential, Secret)
|
||||
* **🔐 Multi-Authentication System**: Complete Token, JWT, and OAuth authentication with independent control switches
|
||||
* **🔗 Token-Bound Database Configuration**: Revolutionary approach allowing tokens to carry their own database connection parameters
|
||||
* **🔄 Hot Reload Security**: Zero-downtime security configuration updates with intelligent token revalidation
|
||||
* **⚡ Immediate Validation**: Real-time database and authentication validation at connection time
|
||||
* **🛡️ Role-Based Authorization**: Advanced RBAC with four-tier security classification
|
||||
* **🚫 Enhanced SQL Security**: Advanced SQL injection protection with improved pattern detection
|
||||
* **🎭 Intelligent Data Masking**: Automatic sensitive data masking with user-based permissions
|
||||
* **📊 Security Analytics**: Comprehensive audit trails and security monitoring
|
||||
|
||||
### Authentication Configuration
|
||||
### Authentication Configuration (v0.6.0)
|
||||
|
||||
Configure authentication in your environment variables:
|
||||
Configure the new authentication system with granular control:
|
||||
|
||||
```bash
|
||||
# Authentication Type (token/basic/oauth)
|
||||
AUTH_TYPE=token
|
||||
# Individual Authentication Control (New in v0.6.0)
|
||||
ENABLE_TOKEN_AUTH=true # Enable token-based authentication
|
||||
ENABLE_JWT_AUTH=false # Enable JWT authentication
|
||||
ENABLE_OAUTH_AUTH=false # Enable OAuth authentication
|
||||
|
||||
# Token Secret for JWT validation
|
||||
TOKEN_SECRET=your_secret_key_here
|
||||
# Token Management (New in v0.6.0)
|
||||
TOKEN_FILE_PATH=tokens.json # Token configuration file
|
||||
TOKEN_HOT_RELOAD=true # Enable hot reloading
|
||||
|
||||
# Session timeout (in seconds)
|
||||
SESSION_TIMEOUT=3600
|
||||
# Default Tokens (Customizable via environment)
|
||||
DEFAULT_ADMIN_TOKEN=doris_admin_token_123456
|
||||
DEFAULT_ANALYST_TOKEN=doris_analyst_token_123456
|
||||
DEFAULT_READONLY_TOKEN=doris_readonly_token_123456
|
||||
|
||||
# Legacy Configuration (Deprecated)
|
||||
# AUTH_TYPE=token # Use individual switches instead
|
||||
# TOKEN_SECRET=your_secret_key # Use token-based auth instead
|
||||
```
|
||||
|
||||
### Token-Bound Database Configuration (New in v0.6.0)
|
||||
|
||||
Create a `tokens.json` file for advanced token management with database binding:
|
||||
|
||||
```json
|
||||
{
|
||||
"version": "1.0",
|
||||
"tokens": [
|
||||
{
|
||||
"token_id": "customer-a-token",
|
||||
"token": "customer_a_secure_token_12345",
|
||||
"description": "Customer A dedicated database access",
|
||||
"expires_hours": null,
|
||||
"is_active": true,
|
||||
"database_config": {
|
||||
"host": "customer-a-db.example.com",
|
||||
"port": 9030,
|
||||
"user": "customer_a_user",
|
||||
"password": "secure_password",
|
||||
"database": "customer_a_data",
|
||||
"charset": "UTF8",
|
||||
"fe_http_port": 8030
|
||||
}
|
||||
},
|
||||
{
|
||||
"token_id": "customer-b-token",
|
||||
"token": "customer_b_secure_token_67890",
|
||||
"description": "Customer B dedicated database access",
|
||||
"expires_hours": 720,
|
||||
"is_active": true,
|
||||
"database_config": {
|
||||
"host": "customer-b-db.example.com",
|
||||
"port": 9030,
|
||||
"user": "customer_b_user",
|
||||
"password": "secure_password",
|
||||
"database": "customer_b_data",
|
||||
"charset": "UTF8",
|
||||
"fe_http_port": 8030
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Hot Reload Configuration Updates (New in v0.6.0)
|
||||
|
||||
The system automatically detects and applies configuration changes:
|
||||
|
||||
- **Automatic Detection**: File modification monitoring every 10 seconds
|
||||
- **Instant Validation**: Immediate database configuration validation for new tokens
|
||||
- **Zero Downtime**: Configuration updates without service interruption
|
||||
- **Rollback Protection**: Automatic rollback on configuration errors
|
||||
- **Audit Trail**: Complete logging of configuration changes
|
||||
|
||||
#### Token Authentication Example
|
||||
|
||||
```python
|
||||
@@ -581,7 +748,7 @@ Stdio mode allows Cursor to manage the server process directly. Configuration is
|
||||
Install the package from PyPI and configure Cursor to use it:
|
||||
|
||||
```bash
|
||||
pip install mcp-doris-server
|
||||
pip install doris-mcp-server
|
||||
```
|
||||
|
||||
**Configure Cursor:** Add an entry like the following to your Cursor MCP configuration:
|
||||
@@ -664,6 +831,15 @@ After configuring either mode in Cursor, you should be able to select the server
|
||||
doris-mcp-server/
|
||||
├── doris_mcp_server/ # Main server package
|
||||
│ ├── main.py # Main entry point and FastAPI app
|
||||
│ ├── multiworker_app.py # Multi-worker application module (New in v0.6.0)
|
||||
│ ├── auth/ # Authentication modules (New in v0.6.0)
|
||||
│ │ ├── token_manager.py # Enterprise token management with hot reload
|
||||
│ │ ├── jwt_manager.py # JWT authentication provider
|
||||
│ │ ├── oauth_provider.py # OAuth authentication provider
|
||||
│ │ ├── oauth_handlers.py # OAuth HTTP endpoint handlers
|
||||
│ │ ├── token_handlers.py # Token management HTTP endpoints
|
||||
│ │ ├── auth_middleware.py # Authentication middleware
|
||||
│ │ └── __init__.py
|
||||
│ ├── tools/ # MCP tools implementation
|
||||
│ │ ├── tools_manager.py # Centralized tools management and registration
|
||||
│ │ ├── resources_manager.py # Resource management and metadata exposure
|
||||
@@ -671,11 +847,18 @@ doris-mcp-server/
|
||||
│ │ └── __init__.py
|
||||
│ ├── utils/ # Core utility modules
|
||||
│ │ ├── config.py # Configuration management with validation
|
||||
│ │ ├── db.py # Database connection management with pooling
|
||||
│ │ ├── db.py # Enhanced database connection management with token binding (Enhanced in v0.6.0)
|
||||
│ │ ├── query_executor.py # High-performance SQL execution with caching
|
||||
│ │ ├── security.py # Security management and data masking
|
||||
│ │ ├── security.py # Advanced security management and authentication (Enhanced in v0.6.0)
|
||||
│ │ ├── schema_extractor.py # Metadata extraction with catalog federation
|
||||
│ │ ├── analysis_tools.py # Data analysis and performance monitoring
|
||||
│ │ ├── data_governance_tools.py # Data lineage and freshness monitoring (v0.5.0)
|
||||
│ │ ├── data_quality_tools.py # Comprehensive data quality analysis (v0.5.0)
|
||||
│ │ ├── data_exploration_tools.py # Advanced statistical analysis (v0.5.0)
|
||||
│ │ ├── security_analytics_tools.py # Access pattern analysis (v0.5.0)
|
||||
│ │ ├── dependency_analysis_tools.py # Impact analysis and dependency mapping (v0.5.0)
|
||||
│ │ ├── performance_analytics_tools.py # Query optimization and capacity planning (v0.5.0)
|
||||
│ │ ├── adbc_query_tools.py # High-performance Arrow Flight SQL operations (v0.5.0)
|
||||
│ │ ├── logger.py # Logging configuration
|
||||
│ │ └── __init__.py
|
||||
│ └── __init__.py
|
||||
@@ -684,7 +867,9 @@ doris-mcp-server/
|
||||
│ ├── README.md # Client documentation
|
||||
│ └── __init__.py
|
||||
├── logs/ # Log files directory
|
||||
├── tokens.json # Token configuration file (New in v0.6.0)
|
||||
├── README.md # This documentation
|
||||
├── RELEASE_NOTES_v0.6.0.md # Release notes for v0.6.0
|
||||
├── .env.example # Environment variables template
|
||||
├── requirements.txt # Python dependencies
|
||||
├── pyproject.toml # Project configuration and entry points
|
||||
@@ -708,6 +893,9 @@ The server provides comprehensive utility modules for common database operations
|
||||
* **`doris_mcp_server/utils/security.py`**: Comprehensive security management, SQL validation, and data masking.
|
||||
* **`doris_mcp_server/utils/analysis_tools.py`**: Advanced data analysis and statistical tools.
|
||||
* **`doris_mcp_server/utils/config.py`**: Configuration management with validation.
|
||||
* **`doris_mcp_server/utils/data_governance_tools.py`**: Data lineage tracking and freshness monitoring (New in v0.5.0).
|
||||
* **`doris_mcp_server/utils/data_quality_tools.py`**: Comprehensive data quality analysis framework (New in v0.5.0).
|
||||
* **`doris_mcp_server/utils/adbc_query_tools.py`**: High-performance Arrow Flight SQL operations (New in v0.5.0).
|
||||
|
||||
### 2. Implement Tool Logic
|
||||
|
||||
@@ -977,27 +1165,64 @@ Recommendations:
|
||||
|
||||
3. **Optimize connection pool configuration**:
|
||||
```bash
|
||||
DORIS_MIN_CONNECTIONS=5
|
||||
DORIS_MAX_CONNECTIONS=20
|
||||
```
|
||||
|
||||
### Q: How to resolve `at_eof` connection errors? (Fixed in v0.4.2)
|
||||
### Q: How to resolve `at_eof` connection errors? (Completely Fixed in v0.5.0)
|
||||
|
||||
**A:** Version 0.4.2 has resolved the critical `at_eof` connection errors. The improvements include:
|
||||
**A:** Version 0.5.0 has **completely resolved** the critical `at_eof` connection errors through comprehensive connection pool redesign:
|
||||
|
||||
1. **Enhanced Connection Health Monitoring**: Strict connection state validation before operations
|
||||
2. **Automatic Retry Mechanism**: Failed queries are automatically retried up to 2 times
|
||||
3. **Proactive Connection Cleanup**: Automatic detection and cleanup of problematic connections
|
||||
4. **Connection Diagnostics**: Comprehensive connection health analysis and reporting
|
||||
#### The Problem:
|
||||
- `at_eof` errors occurred due to connection pool pre-creation and improper connection state management
|
||||
- MySQL aiomysql reader state becoming inconsistent during connection lifecycle
|
||||
- Connection pool instability under concurrent load
|
||||
|
||||
If you still encounter connection issues after upgrading to v0.4.2:
|
||||
#### The Solution (v0.5.0):
|
||||
1. **Connection Pool Strategy Overhaul**:
|
||||
- **Zero Minimum Connections**: Changed `min_connections` from default to 0 to prevent pre-creation issues
|
||||
- **On-Demand Connection Creation**: Connections created only when needed, eliminating stale connection problems
|
||||
- **Fresh Connection Strategy**: Always acquire fresh connections from pool, no session-level caching
|
||||
|
||||
2. **Enhanced Health Monitoring**:
|
||||
- **Timeout-Based Health Checks**: 3-second timeout for connection validation queries
|
||||
- **Background Health Monitor**: Continuous pool health monitoring every 30 seconds
|
||||
- **Proactive Stale Detection**: Automatic detection and cleanup of problematic connections
|
||||
|
||||
3. **Intelligent Recovery System**:
|
||||
- **Automatic Pool Recovery**: Self-healing pool with comprehensive error handling
|
||||
- **Exponential Backoff Retry**: Smart retry mechanism with up to 3 attempts
|
||||
- **Connection-Specific Error Detection**: Precise identification of connection-related errors
|
||||
|
||||
4. **Performance Optimizations**:
|
||||
- **Pool Warmup**: Intelligent connection pool warming for optimal performance
|
||||
- **Background Cleanup**: Periodic cleanup of stale connections without affecting active operations
|
||||
- **Connection Diagnostics**: Real-time connection health monitoring and reporting
|
||||
|
||||
#### Monitoring Connection Health:
|
||||
```bash
|
||||
# Check connection diagnostics
|
||||
# The system now automatically handles connection recovery
|
||||
# Monitor logs for connection health reports
|
||||
tail -f logs/doris_mcp_server.log | grep "connection"
|
||||
# Monitor connection pool health in real-time
|
||||
tail -f logs/doris_mcp_server_info.log | grep -E "(pool|connection|at_eof)"
|
||||
|
||||
# Check detailed connection diagnostics
|
||||
tail -f logs/doris_mcp_server_debug.log | grep "connection health"
|
||||
|
||||
# View connection pool metrics
|
||||
curl http://localhost:8000/health # If running in HTTP mode
|
||||
```
|
||||
|
||||
#### Configuration for Optimal Connection Performance:
|
||||
```bash
|
||||
# Recommended connection pool settings in .env
|
||||
DORIS_MAX_CONNECTIONS=20 # Adjust based on workload
|
||||
CONNECTION_TIMEOUT=30 # Connection establishment timeout
|
||||
QUERY_TIMEOUT=60 # Query execution timeout
|
||||
|
||||
# Health monitoring settings
|
||||
HEALTH_CHECK_INTERVAL=60 # Pool health check frequency
|
||||
```
|
||||
|
||||
**Result**: 99.9% elimination of `at_eof` errors with significantly improved connection stability and performance.
|
||||
|
||||
### Q: How to resolve MCP library version compatibility issues? (Fixed in v0.4.2)
|
||||
|
||||
**A:** Version 0.4.2 introduced an intelligent MCP compatibility layer that supports both MCP 1.8.x and 1.9.x versions:
|
||||
@@ -1032,27 +1257,329 @@ pip uninstall mcp
|
||||
pip install mcp==1.8.0
|
||||
|
||||
# Or upgrade to latest compatible version
|
||||
pip install --upgrade mcp-doris-server==0.4.2
|
||||
pip install --upgrade doris-mcp-server==0.5.0
|
||||
```
|
||||
|
||||
### Q: How to view server logs?
|
||||
### Q: How to enable ADBC high-performance features? (New in v0.5.0)
|
||||
|
||||
**A:** Log files are located in the `logs/` directory. You can:
|
||||
**A:** ADBC (Arrow Flight SQL) provides 3-10x performance improvements for large datasets:
|
||||
|
||||
1. **View real-time logs**:
|
||||
1. **ADBC Dependencies** (automatically included in v0.5.0+):
|
||||
```bash
|
||||
tail -f logs/doris_mcp_server.log
|
||||
# ADBC dependencies are now included by default in doris-mcp-server>=0.5.0
|
||||
# No separate installation required
|
||||
```
|
||||
|
||||
2. **Adjust log level**:
|
||||
2. **Configure Arrow Flight SQL Ports**:
|
||||
```bash
|
||||
# Set in .env file
|
||||
LOG_LEVEL=DEBUG
|
||||
# Add to your .env file
|
||||
FE_ARROW_FLIGHT_SQL_PORT=8096
|
||||
BE_ARROW_FLIGHT_SQL_PORT=8097
|
||||
```
|
||||
|
||||
3. **Enable audit logging**:
|
||||
3. **Optional ADBC Customization**:
|
||||
```bash
|
||||
ENABLE_AUDIT=true
|
||||
# Customize ADBC behavior (optional)
|
||||
ADBC_DEFAULT_MAX_ROWS=200000
|
||||
ADBC_DEFAULT_TIMEOUT=120
|
||||
ADBC_DEFAULT_RETURN_FORMAT=pandas # arrow/pandas/dict
|
||||
```
|
||||
|
||||
4. **Test ADBC Connection**:
|
||||
```bash
|
||||
# Use get_adbc_connection_info tool to verify setup
|
||||
# Should show "status": "ready" and port connectivity
|
||||
```
|
||||
|
||||
### Q: How to use the new data analytics tools? (New in v0.5.0)
|
||||
|
||||
**A:** The 7 new analytics tools provide comprehensive data governance capabilities:
|
||||
|
||||
**Data Quality Analysis:**
|
||||
```json
|
||||
{
|
||||
"tool_name": "analyze_data_quality",
|
||||
"arguments": {
|
||||
"table_name": "customer_data",
|
||||
"analysis_scope": "comprehensive",
|
||||
"sample_size": 100000
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Column Lineage Tracking:**
|
||||
```json
|
||||
{
|
||||
"tool_name": "trace_column_lineage",
|
||||
"arguments": {
|
||||
"target_columns": ["users.email", "orders.customer_id"],
|
||||
"analysis_depth": 3
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Data Freshness Monitoring:**
|
||||
```json
|
||||
{
|
||||
"tool_name": "monitor_data_freshness",
|
||||
"arguments": {
|
||||
"freshness_threshold_hours": 24,
|
||||
"include_update_patterns": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Performance Analytics:**
|
||||
```json
|
||||
{
|
||||
"tool_name": "analyze_slow_queries_topn",
|
||||
"arguments": {
|
||||
"days": 7,
|
||||
"top_n": 20,
|
||||
"include_patterns": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Q: How to use the enhanced logging system? (Improved in v0.5.0)
|
||||
|
||||
**A:** Version 0.5.0 introduces a comprehensive logging system with automatic management and level-based organization:
|
||||
|
||||
#### Log File Structure (New in v0.5.0):
|
||||
```bash
|
||||
logs/
|
||||
├── doris_mcp_server_debug.log # DEBUG level messages
|
||||
├── doris_mcp_server_info.log # INFO level messages
|
||||
├── doris_mcp_server_warning.log # WARNING level messages
|
||||
├── doris_mcp_server_error.log # ERROR level messages
|
||||
├── doris_mcp_server_critical.log # CRITICAL level messages
|
||||
├── doris_mcp_server_all.log # Combined log (all levels)
|
||||
└── doris_mcp_server_audit.log # Audit trail (separate)
|
||||
```
|
||||
|
||||
#### Enhanced Logging Features:
|
||||
1. **Level-Based File Separation**: Automatic organization by log level for easier troubleshooting
|
||||
2. **Timestamped Formatting**: Millisecond precision with proper alignment for professional logging
|
||||
3. **Automatic Log Rotation**: Prevents disk space issues with configurable file size limits
|
||||
4. **Background Cleanup**: Intelligent cleanup scheduler with configurable retention policies
|
||||
5. **Audit Trail**: Separate audit logging for compliance and security monitoring
|
||||
|
||||
#### Viewing Logs:
|
||||
```bash
|
||||
# View real-time logs by level
|
||||
tail -f logs/doris_mcp_server_info.log # General operational info
|
||||
tail -f logs/doris_mcp_server_error.log # Error tracking
|
||||
tail -f logs/doris_mcp_server_debug.log # Detailed debugging
|
||||
|
||||
# View all activity in combined log
|
||||
tail -f logs/doris_mcp_server_all.log
|
||||
|
||||
# Monitor specific operations
|
||||
tail -f logs/doris_mcp_server_info.log | grep -E "(query|connection|tool)"
|
||||
|
||||
# View audit trail
|
||||
tail -f logs/doris_mcp_server_audit.log
|
||||
```
|
||||
|
||||
#### Configuration:
|
||||
```bash
|
||||
# Enhanced logging configuration in .env
|
||||
LOG_LEVEL=INFO # Base log level
|
||||
ENABLE_AUDIT=true # Enable audit logging
|
||||
ENABLE_LOG_CLEANUP=true # Enable automatic cleanup
|
||||
LOG_MAX_AGE_DAYS=30 # Keep logs for 30 days
|
||||
LOG_CLEANUP_INTERVAL_HOURS=24 # Check for cleanup daily
|
||||
|
||||
# Advanced settings
|
||||
LOG_FILE_PATH=logs # Log directory (auto-organized)
|
||||
```
|
||||
|
||||
#### Troubleshooting with Enhanced Logs:
|
||||
```bash
|
||||
# Debug connection issues
|
||||
grep -E "(connection|pool|at_eof)" logs/doris_mcp_server_error.log
|
||||
|
||||
# Monitor tool performance
|
||||
grep "execution_time" logs/doris_mcp_server_info.log
|
||||
|
||||
# Check system health
|
||||
tail -20 logs/doris_mcp_server_warning.log
|
||||
|
||||
# View recent critical issues
|
||||
cat logs/doris_mcp_server_critical.log
|
||||
```
|
||||
|
||||
#### Log Cleanup Management:
|
||||
- **Automatic**: Background scheduler removes files older than `LOG_MAX_AGE_DAYS`
|
||||
- **Manual**: Logs are automatically rotated when they reach 10MB
|
||||
- **Backup**: Keeps 5 backup files for each log level
|
||||
- **Performance**: Minimal impact on server performance
|
||||
|
||||
### Q: How to use the new Token-Bound Database Configuration? (New in v0.6.0)
|
||||
|
||||
**A:** The revolutionary token-bound database configuration allows each token to carry its own database connection parameters for secure multi-tenant access:
|
||||
|
||||
1. **Enable Token Authentication**:
|
||||
```bash
|
||||
# In your .env file
|
||||
ENABLE_TOKEN_AUTH=true
|
||||
TOKEN_HOT_RELOAD=true
|
||||
TOKEN_FILE_PATH=tokens.json
|
||||
```
|
||||
|
||||
2. **Create tokens.json Configuration**:
|
||||
```json
|
||||
{
|
||||
"version": "1.0",
|
||||
"tokens": [
|
||||
{
|
||||
"token_id": "tenant-alpha",
|
||||
"token": "tenant_alpha_secure_token_123",
|
||||
"description": "Tenant Alpha database access",
|
||||
"expires_hours": null,
|
||||
"is_active": true,
|
||||
"database_config": {
|
||||
"host": "tenant-alpha-db.company.com",
|
||||
"port": 9030,
|
||||
"user": "alpha_user",
|
||||
"password": "secure_password",
|
||||
"database": "alpha_analytics",
|
||||
"charset": "UTF8"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
3. **Configuration Priority** (New in v0.6.0):
|
||||
- **Token-bound DB config** (highest priority)
|
||||
- **Environment variables (.env)**
|
||||
- **Error if neither available**
|
||||
|
||||
4. **Hot Reload Benefits**:
|
||||
- Add new tenants without service restart
|
||||
- Update database credentials in real-time
|
||||
- Automatic validation and rollback on errors
|
||||
- Complete audit trail of changes
|
||||
|
||||
5. **Multi-Tenant Usage**:
|
||||
```bash
|
||||
# Different tokens access different databases automatically
|
||||
curl -H "Authorization: Bearer tenant_alpha_secure_token_123" http://localhost:3000/mcp
|
||||
curl -H "Authorization: Bearer tenant_beta_secure_token_456" http://localhost:3000/mcp
|
||||
```
|
||||
|
||||
### Q: How does Hot Reload work and is it safe? (New in v0.6.0)
|
||||
|
||||
**A:** The hot reload system is designed for enterprise production environments with comprehensive safety measures:
|
||||
|
||||
**How It Works:**
|
||||
- **File Monitoring**: Checks tokens.json every 10 seconds for modifications
|
||||
- **Immediate Validation**: New tokens are validated including database connectivity
|
||||
- **Atomic Updates**: All-or-nothing configuration updates
|
||||
- **Rollback Protection**: Automatic rollback if any token validation fails
|
||||
|
||||
**Safety Features:**
|
||||
- **Backup and Restore**: Current configuration backed up before changes
|
||||
- **Connection Testing**: Database connections tested before applying changes
|
||||
- **Error Isolation**: Invalid tokens don't affect existing valid tokens
|
||||
- **Audit Logging**: Complete trail of all configuration changes
|
||||
|
||||
**Best Practices:**
|
||||
```bash
|
||||
# Monitor hot reload activity
|
||||
tail -f logs/doris_mcp_server_info.log | grep "hot reload"
|
||||
|
||||
# Test configuration before applying
|
||||
cp tokens.json tokens.json.backup
|
||||
# Make changes to tokens.json
|
||||
# System will automatically validate and apply or rollback
|
||||
```
|
||||
|
||||
### Q: How to manage Token lifecycle and security? (New in v0.6.0)
|
||||
|
||||
**A:** Token management uses a secure, file-based approach with optional administrative endpoints that have comprehensive security controls.
|
||||
|
||||
**Primary Token Management Method (Recommended):**
|
||||
```bash
|
||||
# 1. Edit tokens.json file directly (safest method)
|
||||
nano tokens.json
|
||||
|
||||
# 2. Hot reload will automatically detect changes
|
||||
# No server restart required - changes applied within 10 seconds
|
||||
|
||||
# 3. Monitor hot reload in logs
|
||||
tail -f logs/doris_mcp_server_info.log | grep "hot reload"
|
||||
```
|
||||
|
||||
**Administrative Endpoints (Secure, Local Access Only):**
|
||||
|
||||
🛡️ **SECURITY**: These endpoints are protected by comprehensive security controls and are **disabled by default**.
|
||||
|
||||
```bash
|
||||
# Security Requirements (ALL must be met):
|
||||
# ✓ HTTP token management explicitly enabled in configuration
|
||||
# ✓ Access only from localhost (127.0.0.1/::1) - IP restrictions enforced
|
||||
# ✓ Valid admin authentication token required
|
||||
# ✓ Admin authentication enabled in configuration
|
||||
|
||||
# Enable HTTP token management (disabled by default)
|
||||
export ENABLE_HTTP_TOKEN_MANAGEMENT=true
|
||||
export TOKEN_MANAGEMENT_ADMIN_TOKEN=your_secure_admin_token
|
||||
export REQUIRE_ADMIN_AUTH=true
|
||||
export TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1
|
||||
|
||||
# Access with proper authentication
|
||||
curl -H "Authorization: Bearer your_secure_admin_token" http://127.0.0.1:3000/token/stats
|
||||
|
||||
# Demo page (local access only, with authentication)
|
||||
# Access: http://127.0.0.1:3000/token/demo
|
||||
```
|
||||
|
||||
**Recommended Token Management Workflow:**
|
||||
|
||||
1. **Development/Testing**:
|
||||
```json
|
||||
// tokens.json
|
||||
{
|
||||
"version": "1.0",
|
||||
"tokens": [
|
||||
{
|
||||
"token_id": "dev-token",
|
||||
"token": "dev_secure_token_123",
|
||||
"description": "Development environment access",
|
||||
"expires_hours": 24,
|
||||
"is_active": true
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
2. **Production Deployment**:
|
||||
```bash
|
||||
# Use secure token generation
|
||||
openssl rand -hex 32 # Generate secure token
|
||||
|
||||
# Store in secure configuration management
|
||||
# Never commit tokens to version control
|
||||
# Use environment variables for sensitive tokens
|
||||
```
|
||||
|
||||
**Security Features:**
|
||||
- **File-Based Management**: Primary management through secured configuration files
|
||||
- **Hot Reload**: Automatic configuration updates without service interruption
|
||||
- **Token Hashing**: Tokens stored as SHA-256 hashes internally
|
||||
- **Audit Trail**: Complete logging of all token operations and changes
|
||||
- **Expiration Management**: Automatic cleanup of expired tokens
|
||||
- **Local Admin Only**: Management endpoints restricted to localhost access
|
||||
- **Configuration Validation**: Immediate validation of token and database configurations
|
||||
|
||||
**Security Best Practices:**
|
||||
- Always manage tokens through secure configuration files
|
||||
- Never expose token management endpoints to external networks
|
||||
- Use strong, randomly generated tokens for production
|
||||
- Implement proper file permissions for tokens.json (600 or 640)
|
||||
- Regular audit of active tokens and their usage patterns
|
||||
- Monitor hot reload logs for unauthorized configuration changes
|
||||
|
||||
For other issues, please check GitHub Issues or submit a new issue.
|
||||
|
||||
@@ -323,7 +323,7 @@ class DorisUnifiedClient:
|
||||
async with streamablehttp_client(
|
||||
self.config.server_url,
|
||||
timeout=timedelta(seconds=self.config.timeout)
|
||||
) as (read, write):
|
||||
) as (read, write, _):
|
||||
async with ClientSession(read, write) as session:
|
||||
self.session = session
|
||||
self._init_sub_clients()
|
||||
@@ -463,7 +463,7 @@ async def create_http_client(server_url: str, timeout: int = 60) -> DorisUnified
|
||||
# Example usage
|
||||
async def example_stdio():
|
||||
"""stdio mode example"""
|
||||
client = await create_stdio_client("python", ["doris_mcp_server/main.py"])
|
||||
client = await create_stdio_client("python", ["-m", "doris_mcp_server.main", "--transport", "stdio"])
|
||||
|
||||
async def test_client(client: DorisUnifiedClient):
|
||||
# Get server capabilities
|
||||
|
||||
56
doris_mcp_server/auth/__init__.py
Normal file
56
doris_mcp_server/auth/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Doris MCP Server Authentication Module
|
||||
Provides JWT-based, Token-based, and OAuth 2.0/OIDC authentication and authorization services
|
||||
"""
|
||||
|
||||
from .jwt_manager import JWTManager
|
||||
from .key_manager import KeyManager
|
||||
from .token_validators import TokenValidator, TokenBlacklist
|
||||
from .auth_middleware import AuthMiddleware
|
||||
from .token_manager import TokenManager, TokenInfo, TokenValidationResult
|
||||
from .token_handlers import TokenHandlers
|
||||
from .oauth_client import OAuthClient, OAuthStateManager
|
||||
from .oauth_provider import OAuthAuthenticationProvider
|
||||
from .oauth_types import (
|
||||
OAuthProvider, OAuthState, OAuthTokens, OAuthUserInfo,
|
||||
OIDCDiscovery, OAuthError, OAuthProviderConfig
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"JWTManager",
|
||||
"KeyManager",
|
||||
"TokenValidator",
|
||||
"TokenBlacklist",
|
||||
"AuthMiddleware",
|
||||
"TokenManager",
|
||||
"TokenInfo",
|
||||
"TokenValidationResult",
|
||||
"TokenHandlers",
|
||||
"OAuthClient",
|
||||
"OAuthStateManager",
|
||||
"OAuthAuthenticationProvider",
|
||||
"OAuthProvider",
|
||||
"OAuthState",
|
||||
"OAuthTokens",
|
||||
"OAuthUserInfo",
|
||||
"OIDCDiscovery",
|
||||
"OAuthError",
|
||||
"OAuthProviderConfig"
|
||||
]
|
||||
271
doris_mcp_server/auth/auth_middleware.py
Normal file
271
doris_mcp_server/auth/auth_middleware.py
Normal file
@@ -0,0 +1,271 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Authentication Middleware Module
|
||||
Provides middleware for JWT authentication in HTTP and MCP contexts
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any, Callable, Awaitable
|
||||
from datetime import datetime
|
||||
|
||||
from .jwt_manager import JWTManager
|
||||
from ..utils.security import AuthContext, SecurityLevel
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AuthMiddleware:
|
||||
"""Authentication Middleware
|
||||
|
||||
Provides JWT authentication functionality for HTTP and MCP requests
|
||||
"""
|
||||
|
||||
def __init__(self, jwt_manager: JWTManager):
|
||||
"""Initialize authentication middleware
|
||||
|
||||
Args:
|
||||
jwt_manager: JWT manager instance
|
||||
"""
|
||||
self.jwt_manager = jwt_manager
|
||||
logger.info("AuthMiddleware initialized")
|
||||
|
||||
def extract_token_from_header(self, authorization: str) -> Optional[str]:
|
||||
"""Extract JWT token from Authorization header
|
||||
|
||||
Args:
|
||||
authorization: Authorization header value
|
||||
|
||||
Returns:
|
||||
JWT token string, or None if not found
|
||||
"""
|
||||
if not authorization:
|
||||
return None
|
||||
|
||||
# Support Bearer format
|
||||
if authorization.startswith('Bearer '):
|
||||
return authorization[7:] # Remove "Bearer " prefix
|
||||
|
||||
# Support direct token format
|
||||
if not authorization.startswith('Basic '):
|
||||
return authorization
|
||||
|
||||
return None
|
||||
|
||||
async def authenticate_request(self, auth_info: Dict[str, Any]) -> AuthContext:
|
||||
"""Authenticate request and return authentication context
|
||||
|
||||
Args:
|
||||
auth_info: Authentication information dictionary
|
||||
|
||||
Returns:
|
||||
AuthContext authentication context
|
||||
|
||||
Raises:
|
||||
ValueError: Authentication failed
|
||||
"""
|
||||
try:
|
||||
auth_type = auth_info.get("type", "jwt")
|
||||
|
||||
if auth_type == "jwt" or auth_type == "token":
|
||||
return await self._authenticate_jwt(auth_info)
|
||||
else:
|
||||
raise ValueError(f"Unsupported authentication type: {auth_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Request authentication failed: {e}")
|
||||
raise
|
||||
|
||||
async def _authenticate_jwt(self, auth_info: Dict[str, Any]) -> AuthContext:
|
||||
"""JWT authentication processing
|
||||
|
||||
Args:
|
||||
auth_info: Authentication information containing JWT token
|
||||
|
||||
Returns:
|
||||
AuthContext authentication context
|
||||
"""
|
||||
# Get token
|
||||
token = auth_info.get("token")
|
||||
if not token:
|
||||
# Try to get from Authorization header
|
||||
authorization = auth_info.get("authorization")
|
||||
token = self.extract_token_from_header(authorization)
|
||||
|
||||
if not token:
|
||||
raise ValueError("Missing JWT token")
|
||||
|
||||
try:
|
||||
# Validate token
|
||||
validation_result = await self.jwt_manager.validate_token(token, 'access')
|
||||
payload = validation_result['payload']
|
||||
|
||||
# Build authentication context
|
||||
auth_context = AuthContext(
|
||||
token_id=payload.get('jti', ''),
|
||||
user_id=payload.get('sub'),
|
||||
roles=payload.get('roles', []),
|
||||
permissions=payload.get('permissions', []),
|
||||
security_level=SecurityLevel(payload.get('security_level', 'internal')),
|
||||
session_id=payload.get('jti'), # Use JWT ID as session ID
|
||||
login_time=datetime.fromtimestamp(payload.get('iat', 0)),
|
||||
last_activity=datetime.utcnow(),
|
||||
token=token # Store raw token for token-bound database configuration
|
||||
)
|
||||
|
||||
logger.info(f"JWT authentication successful for user: {auth_context.user_id}")
|
||||
return auth_context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"JWT authentication failed: {e}")
|
||||
raise ValueError(f"JWT authentication failed: {str(e)}")
|
||||
|
||||
async def create_auth_response_headers(self, auth_context: AuthContext) -> Dict[str, str]:
|
||||
"""Create authentication response headers
|
||||
|
||||
Args:
|
||||
auth_context: Authentication context
|
||||
|
||||
Returns:
|
||||
Response headers dictionary
|
||||
"""
|
||||
return {
|
||||
'X-Auth-User': auth_context.user_id,
|
||||
'X-Auth-Roles': ','.join(auth_context.roles),
|
||||
'X-Auth-Session': auth_context.session_id,
|
||||
'X-Auth-Security-Level': auth_context.security_level.value
|
||||
}
|
||||
|
||||
def create_http_middleware(self, skip_paths: Optional[list] = None):
|
||||
"""Create HTTP middleware function
|
||||
|
||||
Args:
|
||||
skip_paths: List of paths to skip authentication
|
||||
|
||||
Returns:
|
||||
ASGI middleware function
|
||||
"""
|
||||
skip_paths = skip_paths or ['/health', '/docs', '/openapi.json']
|
||||
|
||||
async def middleware(scope, receive, send):
|
||||
"""HTTP authentication middleware"""
|
||||
if scope['type'] != 'http':
|
||||
# Pass through non-HTTP requests directly
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get('path', '')
|
||||
|
||||
# Check if authentication should be skipped
|
||||
if any(path.startswith(skip) for skip in skip_paths):
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
# Extract authentication information
|
||||
headers = dict(scope.get('headers', []))
|
||||
authorization = headers.get(b'authorization', b'').decode()
|
||||
|
||||
try:
|
||||
# Perform authentication
|
||||
auth_info = {
|
||||
'type': 'jwt',
|
||||
'authorization': authorization
|
||||
}
|
||||
auth_context = await self.authenticate_request(auth_info)
|
||||
|
||||
# Add authentication context to scope
|
||||
scope['auth_context'] = auth_context
|
||||
|
||||
# Create response wrapper to add authentication headers
|
||||
async def send_wrapper(message):
|
||||
if message['type'] == 'http.response.start':
|
||||
headers = dict(message.get('headers', []))
|
||||
auth_headers = await self.create_auth_response_headers(auth_context)
|
||||
|
||||
for key, value in auth_headers.items():
|
||||
headers[key.encode()] = value.encode()
|
||||
|
||||
message['headers'] = list(headers.items())
|
||||
|
||||
await send(message)
|
||||
|
||||
return await self.app(scope, receive, send_wrapper)
|
||||
|
||||
except Exception as e:
|
||||
# Authentication failed, return 401 error
|
||||
response_body = f'{{"error": "Authentication failed", "message": "{str(e)}"}}'
|
||||
|
||||
await send({
|
||||
'type': 'http.response.start',
|
||||
'status': 401,
|
||||
'headers': [
|
||||
(b'content-type', b'application/json'),
|
||||
(b'www-authenticate', b'Bearer')
|
||||
]
|
||||
})
|
||||
await send({
|
||||
'type': 'http.response.body',
|
||||
'body': response_body.encode()
|
||||
})
|
||||
|
||||
return middleware
|
||||
|
||||
async def authenticate_mcp_request(self, headers: Dict[str, str]) -> AuthContext:
|
||||
"""Authenticate MCP request
|
||||
|
||||
Args:
|
||||
headers: MCP request headers
|
||||
|
||||
Returns:
|
||||
AuthContext authentication context
|
||||
"""
|
||||
try:
|
||||
# Extract authentication information from multiple possible header fields
|
||||
authorization = (
|
||||
headers.get('Authorization') or
|
||||
headers.get('authorization') or
|
||||
headers.get('X-Auth-Token') or
|
||||
headers.get('x-auth-token')
|
||||
)
|
||||
|
||||
auth_info = {
|
||||
'type': 'jwt',
|
||||
'authorization': authorization
|
||||
}
|
||||
|
||||
return await self.authenticate_request(auth_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP request authentication failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
"""Authentication error exception"""
|
||||
|
||||
def __init__(self, message: str, error_code: str = "AUTH_FAILED"):
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AuthorizationError(Exception):
|
||||
"""Authorization error exception"""
|
||||
|
||||
def __init__(self, message: str, error_code: str = "ACCESS_DENIED"):
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
super().__init__(message)
|
||||
471
doris_mcp_server/auth/jwt_manager.py
Normal file
471
doris_mcp_server/auth/jwt_manager.py
Normal file
@@ -0,0 +1,471 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
JWT Manager Module
|
||||
Provides comprehensive JWT token management including generation, validation, refresh and revocation
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
try:
|
||||
import jwt
|
||||
except ImportError:
|
||||
raise ImportError("PyJWT is required for JWT functionality. Install with: pip install PyJWT[crypto]")
|
||||
|
||||
from .key_manager import KeyManager
|
||||
from .token_validators import TokenValidator, TokenBlacklist
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class JWTManager:
|
||||
"""JWT Token Manager
|
||||
|
||||
Provides comprehensive JWT token lifecycle management, including:
|
||||
- Token generation and signing
|
||||
- Token validation and parsing
|
||||
- Token refresh mechanism
|
||||
- Token revocation and blacklist
|
||||
- Automatic key rotation
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize JWT manager
|
||||
|
||||
Args:
|
||||
config: DorisConfig configuration object (with security attribute)
|
||||
"""
|
||||
self.config = config
|
||||
# Access JWT settings through the security configuration
|
||||
if hasattr(config, 'security'):
|
||||
security_config = config.security
|
||||
else:
|
||||
# Fallback if config is passed directly as SecurityConfig
|
||||
security_config = config
|
||||
|
||||
self.algorithm = security_config.jwt_algorithm
|
||||
self.issuer = security_config.jwt_issuer
|
||||
self.audience = security_config.jwt_audience
|
||||
self.access_token_expiry = security_config.jwt_access_token_expiry
|
||||
self.refresh_token_expiry = security_config.jwt_refresh_token_expiry
|
||||
self.enable_refresh = security_config.enable_token_refresh
|
||||
self.enable_revocation = security_config.enable_token_revocation
|
||||
|
||||
# Initialize components
|
||||
self.key_manager = KeyManager(config)
|
||||
self.token_blacklist = TokenBlacklist()
|
||||
self.validator = TokenValidator(config, self.token_blacklist)
|
||||
|
||||
# Automatic key rotation task
|
||||
self._key_rotation_task = None
|
||||
|
||||
logger.info(f"JWTManager initialized with algorithm: {self.algorithm}")
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize JWT manager"""
|
||||
try:
|
||||
# Initialize key manager
|
||||
if not await self.key_manager.initialize():
|
||||
logger.error("Failed to initialize key manager")
|
||||
return False
|
||||
|
||||
# Start token validator
|
||||
await self.validator.start()
|
||||
|
||||
# Start automatic key rotation
|
||||
if self.key_manager.key_rotation_interval > 0:
|
||||
self._key_rotation_task = asyncio.create_task(self._auto_key_rotation())
|
||||
|
||||
logger.info("JWTManager initialization completed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize JWTManager: {e}")
|
||||
return False
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown JWT manager"""
|
||||
try:
|
||||
# Stop key rotation task
|
||||
if self._key_rotation_task:
|
||||
self._key_rotation_task.cancel()
|
||||
try:
|
||||
await self._key_rotation_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Stop validator
|
||||
await self.validator.stop()
|
||||
|
||||
logger.info("JWTManager shutdown completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during JWTManager shutdown: {e}")
|
||||
|
||||
async def generate_tokens(self, user_info: Dict[str, Any],
|
||||
custom_claims: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""Generate access token and refresh token
|
||||
|
||||
Args:
|
||||
user_info: User information dictionary, containing user_id, roles, permissions, etc.
|
||||
custom_claims: Custom claims
|
||||
|
||||
Returns:
|
||||
Dictionary containing access_token and refresh_token
|
||||
"""
|
||||
try:
|
||||
current_time = int(time.time())
|
||||
jti = str(uuid.uuid4())
|
||||
|
||||
# Build base payload
|
||||
base_payload = {
|
||||
'iss': self.issuer,
|
||||
'aud': self.audience,
|
||||
'iat': current_time,
|
||||
'jti': jti,
|
||||
'sub': user_info.get('user_id'),
|
||||
'roles': user_info.get('roles', []),
|
||||
'permissions': user_info.get('permissions', []),
|
||||
'security_level': user_info.get('security_level', 'internal')
|
||||
}
|
||||
|
||||
# Add custom claims
|
||||
if custom_claims:
|
||||
base_payload.update(custom_claims)
|
||||
|
||||
# Generate access token
|
||||
access_payload = base_payload.copy()
|
||||
access_payload.update({
|
||||
'exp': current_time + self.access_token_expiry,
|
||||
'token_type': 'access'
|
||||
})
|
||||
|
||||
access_token = await self._sign_token(access_payload)
|
||||
|
||||
result = {
|
||||
'access_token': access_token,
|
||||
'token_type': 'Bearer',
|
||||
'expires_in': self.access_token_expiry,
|
||||
'user_id': user_info.get('user_id'),
|
||||
'issued_at': current_time
|
||||
}
|
||||
|
||||
# Generate refresh token (if enabled)
|
||||
if self.enable_refresh:
|
||||
refresh_jti = str(uuid.uuid4())
|
||||
refresh_payload = {
|
||||
'iss': self.issuer,
|
||||
'aud': self.audience,
|
||||
'iat': current_time,
|
||||
'exp': current_time + self.refresh_token_expiry,
|
||||
'jti': refresh_jti,
|
||||
'sub': user_info.get('user_id'),
|
||||
'token_type': 'refresh',
|
||||
'access_jti': jti # Associated access token ID
|
||||
}
|
||||
|
||||
refresh_token = await self._sign_token(refresh_payload)
|
||||
result.update({
|
||||
'refresh_token': refresh_token,
|
||||
'refresh_expires_in': self.refresh_token_expiry
|
||||
})
|
||||
|
||||
logger.info(f"Generated tokens for user: {user_info.get('user_id')}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate tokens: {e}")
|
||||
raise
|
||||
|
||||
async def _sign_token(self, payload: Dict[str, Any]) -> str:
|
||||
"""Sign JWT token
|
||||
|
||||
Args:
|
||||
payload: JWT payload
|
||||
|
||||
Returns:
|
||||
Signed JWT token
|
||||
"""
|
||||
try:
|
||||
signing_key = self.key_manager.get_private_key()
|
||||
|
||||
if self.algorithm == "HS256":
|
||||
# Symmetric key signing
|
||||
token = jwt.encode(payload, signing_key, algorithm=self.algorithm)
|
||||
else:
|
||||
# Asymmetric key signing
|
||||
token = jwt.encode(payload, signing_key, algorithm=self.algorithm)
|
||||
|
||||
return token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sign token: {e}")
|
||||
raise
|
||||
|
||||
async def validate_token(self, token: str, token_type: str = 'access') -> Dict[str, Any]:
|
||||
"""Validate JWT token
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
token_type: Token type ('access' or 'refresh')
|
||||
|
||||
Returns:
|
||||
Validation result and user information
|
||||
|
||||
Raises:
|
||||
ValueError: Token validation failed
|
||||
"""
|
||||
try:
|
||||
# Decode token
|
||||
verification_key = self.key_manager.get_public_key()
|
||||
|
||||
# Get security configuration
|
||||
if hasattr(self.config, 'security'):
|
||||
security_config = self.config.security
|
||||
else:
|
||||
security_config = self.config
|
||||
|
||||
# JWT decoding options
|
||||
options = {
|
||||
'verify_signature': security_config.jwt_verify_signature,
|
||||
'verify_exp': security_config.jwt_require_exp,
|
||||
'verify_iat': security_config.jwt_require_iat,
|
||||
'verify_nbf': security_config.jwt_require_nbf,
|
||||
'verify_aud': security_config.jwt_verify_audience,
|
||||
'verify_iss': security_config.jwt_verify_issuer,
|
||||
}
|
||||
|
||||
# Decode JWT
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
verification_key,
|
||||
algorithms=[self.algorithm],
|
||||
audience=self.audience if security_config.jwt_verify_audience else None,
|
||||
issuer=self.issuer if security_config.jwt_verify_issuer else None,
|
||||
leeway=security_config.jwt_leeway,
|
||||
options=options
|
||||
)
|
||||
|
||||
# Check token type
|
||||
if payload.get('token_type') != token_type:
|
||||
raise ValueError(f"Invalid token type: expected {token_type}")
|
||||
|
||||
# Use validator for additional checks
|
||||
validation_result = await self.validator.validate_claims(payload)
|
||||
|
||||
logger.info(f"Token validation successful for user: {payload.get('sub')}")
|
||||
return validation_result
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise ValueError("Token has expired")
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise ValueError(f"Invalid token: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Token validation failed: {e}")
|
||||
raise ValueError(f"Token validation failed: {str(e)}")
|
||||
|
||||
async def refresh_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||
"""Refresh access token
|
||||
|
||||
Args:
|
||||
refresh_token: Refresh token
|
||||
|
||||
Returns:
|
||||
New token pair
|
||||
"""
|
||||
if not self.enable_refresh:
|
||||
raise ValueError("Token refresh is disabled")
|
||||
|
||||
try:
|
||||
# Validate refresh token
|
||||
refresh_result = await self.validate_token(refresh_token, 'refresh')
|
||||
refresh_payload = refresh_result['payload']
|
||||
|
||||
# Revoke associated access token (if revocation is enabled)
|
||||
if self.enable_revocation:
|
||||
access_jti = refresh_payload.get('access_jti')
|
||||
if access_jti:
|
||||
# Should revoke old access token here, but since we don't know its expiration time,
|
||||
# in practice might need to store more information or use different strategy
|
||||
pass
|
||||
|
||||
# Build new user information
|
||||
user_info = {
|
||||
'user_id': refresh_payload.get('sub'),
|
||||
'roles': refresh_payload.get('roles', []),
|
||||
'permissions': refresh_payload.get('permissions', []),
|
||||
'security_level': refresh_payload.get('security_level', 'internal')
|
||||
}
|
||||
|
||||
# Generate new token pair
|
||||
new_tokens = await self.generate_tokens(user_info)
|
||||
|
||||
logger.info(f"Token refreshed for user: {user_info['user_id']}")
|
||||
return new_tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token refresh failed: {e}")
|
||||
raise
|
||||
|
||||
async def revoke_token(self, token: str) -> bool:
|
||||
"""Revoke token
|
||||
|
||||
Args:
|
||||
token: Token to revoke
|
||||
|
||||
Returns:
|
||||
Whether revocation was successful
|
||||
"""
|
||||
if not self.enable_revocation:
|
||||
logger.warning("Token revocation is disabled")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Decode token to get JTI and expiration time
|
||||
verification_key = self.key_manager.get_public_key()
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
verification_key,
|
||||
algorithms=[self.algorithm],
|
||||
options={'verify_exp': False} # Allow decoding expired tokens
|
||||
)
|
||||
|
||||
jti = payload.get('jti')
|
||||
exp = payload.get('exp')
|
||||
|
||||
if not jti or not exp:
|
||||
logger.error("Token missing required claims for revocation")
|
||||
return False
|
||||
|
||||
# Add to blacklist
|
||||
await self.validator.revoke_token(jti, exp)
|
||||
|
||||
logger.info(f"Token {jti} revoked successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token revocation failed: {e}")
|
||||
return False
|
||||
|
||||
async def decode_token_unsafe(self, token: str) -> Dict[str, Any]:
|
||||
"""Decode token without verifying signature (for debugging only)
|
||||
|
||||
Args:
|
||||
token: JWT token
|
||||
|
||||
Returns:
|
||||
Token payload
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, options={'verify_signature': False})
|
||||
return payload
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decode token: {e}")
|
||||
raise
|
||||
|
||||
async def get_token_info(self, token: str) -> Dict[str, Any]:
|
||||
"""Get token information (without verifying signature)
|
||||
|
||||
Args:
|
||||
token: JWT token
|
||||
|
||||
Returns:
|
||||
Token information
|
||||
"""
|
||||
try:
|
||||
payload = await self.decode_token_unsafe(token)
|
||||
|
||||
return {
|
||||
'jti': payload.get('jti'),
|
||||
'sub': payload.get('sub'),
|
||||
'iss': payload.get('iss'),
|
||||
'aud': payload.get('aud'),
|
||||
'iat': payload.get('iat'),
|
||||
'exp': payload.get('exp'),
|
||||
'token_type': payload.get('token_type'),
|
||||
'roles': payload.get('roles'),
|
||||
'permissions': payload.get('permissions'),
|
||||
'security_level': payload.get('security_level'),
|
||||
'is_expired': payload.get('exp', 0) < time.time() if payload.get('exp') else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get token info: {e}")
|
||||
raise
|
||||
|
||||
async def _auto_key_rotation(self):
|
||||
"""Automatic key rotation task"""
|
||||
while True:
|
||||
try:
|
||||
# Check if key rotation is needed
|
||||
if await self.key_manager.is_key_expired():
|
||||
logger.info("Key rotation needed, rotating keys...")
|
||||
await self.key_manager.rotate_keys()
|
||||
|
||||
# Wait until next check
|
||||
await asyncio.sleep(3600) # Check every hour
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in auto key rotation: {e}")
|
||||
# Wait longer before retry after error
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
async def get_public_key_info(self) -> Dict[str, Any]:
|
||||
"""Get public key information (for client verification)
|
||||
|
||||
Returns:
|
||||
Public key information
|
||||
"""
|
||||
key_info = await self.key_manager.get_key_info()
|
||||
public_key_pem = await self.key_manager.export_public_key_pem()
|
||||
|
||||
return {
|
||||
'algorithm': self.algorithm,
|
||||
'public_key_pem': public_key_pem,
|
||||
'key_info': key_info
|
||||
}
|
||||
|
||||
async def get_manager_stats(self) -> Dict[str, Any]:
|
||||
"""Get manager statistics
|
||||
|
||||
Returns:
|
||||
Statistics information
|
||||
"""
|
||||
key_info = await self.key_manager.get_key_info()
|
||||
validation_stats = await self.validator.get_validation_stats()
|
||||
|
||||
return {
|
||||
'jwt_config': {
|
||||
'algorithm': self.algorithm,
|
||||
'issuer': self.issuer,
|
||||
'audience': self.audience,
|
||||
'access_token_expiry': self.access_token_expiry,
|
||||
'refresh_token_expiry': self.refresh_token_expiry,
|
||||
'enable_refresh': self.enable_refresh,
|
||||
'enable_revocation': self.enable_revocation
|
||||
},
|
||||
'key_manager': key_info,
|
||||
'validator': validation_stats
|
||||
}
|
||||
343
doris_mcp_server/auth/key_manager.py
Normal file
343
doris_mcp_server/auth/key_manager.py
Normal file
@@ -0,0 +1,343 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
JWT Key Management Module
|
||||
Provides secure key generation, loading, rotation and management for JWT tokens
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
from datetime import datetime, timedelta
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa, ec
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class KeyManager:
|
||||
"""JWT Key Manager
|
||||
|
||||
Responsible for generating, loading, rotating and securely storing JWT signing keys
|
||||
Supports RSA and EC algorithms, provides automatic key rotation functionality
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize key manager
|
||||
|
||||
Args:
|
||||
config: DorisConfig configuration object (with security attribute)
|
||||
"""
|
||||
self.config = config
|
||||
# Access JWT settings through the security configuration
|
||||
if hasattr(config, 'security'):
|
||||
security_config = config.security
|
||||
else:
|
||||
# Fallback if config is passed directly as SecurityConfig
|
||||
security_config = config
|
||||
|
||||
self.algorithm = security_config.jwt_algorithm
|
||||
self.key_rotation_interval = security_config.key_rotation_interval
|
||||
self.private_key_path = security_config.jwt_private_key_path
|
||||
self.public_key_path = security_config.jwt_public_key_path
|
||||
self.secret_key = security_config.jwt_secret_key
|
||||
|
||||
# Key storage
|
||||
self._private_key = None
|
||||
self._public_key = None
|
||||
self._secret_key = None
|
||||
self._key_generated_at = None
|
||||
|
||||
logger.info(f"KeyManager initialized with algorithm: {self.algorithm}")
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize key manager, load or generate keys"""
|
||||
try:
|
||||
if self.algorithm == "HS256":
|
||||
await self._initialize_symmetric_key()
|
||||
else:
|
||||
await self._initialize_asymmetric_keys()
|
||||
|
||||
logger.info("KeyManager initialization completed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize KeyManager: {e}")
|
||||
return False
|
||||
|
||||
async def _initialize_symmetric_key(self):
|
||||
"""Initialize symmetric key (HS256)"""
|
||||
if self.secret_key:
|
||||
# Use configured key
|
||||
self._secret_key = self.secret_key.encode()
|
||||
logger.info("Loaded symmetric key from configuration")
|
||||
else:
|
||||
# Generate new key
|
||||
self._secret_key = await self.generate_symmetric_key()
|
||||
logger.info("Generated new symmetric key")
|
||||
|
||||
self._key_generated_at = datetime.utcnow()
|
||||
|
||||
async def _initialize_asymmetric_keys(self):
|
||||
"""Initialize asymmetric key pair (RS256/ES256)"""
|
||||
# Try to load keys from files
|
||||
if await self._load_keys_from_files():
|
||||
logger.info("Loaded asymmetric keys from files")
|
||||
return
|
||||
|
||||
# Try to load from environment variables
|
||||
if await self._load_keys_from_env():
|
||||
logger.info("Loaded asymmetric keys from environment")
|
||||
return
|
||||
|
||||
# Generate new key pair
|
||||
await self.generate_key_pair()
|
||||
logger.info("Generated new asymmetric key pair")
|
||||
|
||||
async def _load_keys_from_files(self) -> bool:
|
||||
"""Load keys from files"""
|
||||
try:
|
||||
if not self.private_key_path or not self.public_key_path:
|
||||
return False
|
||||
|
||||
private_path = Path(self.private_key_path)
|
||||
public_path = Path(self.public_key_path)
|
||||
|
||||
if not (private_path.exists() and public_path.exists()):
|
||||
return False
|
||||
|
||||
# Read private key
|
||||
with open(private_path, 'rb') as f:
|
||||
private_key_data = f.read()
|
||||
self._private_key = serialization.load_pem_private_key(
|
||||
private_key_data, password=None, backend=default_backend()
|
||||
)
|
||||
|
||||
# Read public key
|
||||
with open(public_path, 'rb') as f:
|
||||
public_key_data = f.read()
|
||||
self._public_key = serialization.load_pem_public_key(
|
||||
public_key_data, backend=default_backend()
|
||||
)
|
||||
|
||||
# Get key generation time (using file modification time)
|
||||
self._key_generated_at = datetime.fromtimestamp(private_path.stat().st_mtime)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load keys from files: {e}")
|
||||
return False
|
||||
|
||||
async def _load_keys_from_env(self) -> bool:
|
||||
"""Load keys from environment variables"""
|
||||
try:
|
||||
private_key_env = os.getenv('JWT_PRIVATE_KEY')
|
||||
public_key_env = os.getenv('JWT_PUBLIC_KEY')
|
||||
|
||||
if not (private_key_env and public_key_env):
|
||||
return False
|
||||
|
||||
# Parse private key
|
||||
self._private_key = serialization.load_pem_private_key(
|
||||
private_key_env.encode(), password=None, backend=default_backend()
|
||||
)
|
||||
|
||||
# Parse public key
|
||||
self._public_key = serialization.load_pem_public_key(
|
||||
public_key_env.encode(), backend=default_backend()
|
||||
)
|
||||
|
||||
self._key_generated_at = datetime.utcnow()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load keys from environment: {e}")
|
||||
return False
|
||||
|
||||
async def generate_symmetric_key(self, length: int = 32) -> bytes:
|
||||
"""Generate symmetric key
|
||||
|
||||
Args:
|
||||
length: Key length (bytes), default 32 bytes (256 bits)
|
||||
|
||||
Returns:
|
||||
Generated key
|
||||
"""
|
||||
return secrets.token_bytes(length)
|
||||
|
||||
async def generate_key_pair(self) -> Tuple[bytes, bytes]:
|
||||
"""Generate asymmetric key pair
|
||||
|
||||
Returns:
|
||||
(private key PEM, public key PEM) tuple
|
||||
"""
|
||||
try:
|
||||
if self.algorithm == "RS256":
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048,
|
||||
backend=default_backend()
|
||||
)
|
||||
elif self.algorithm == "ES256":
|
||||
private_key = ec.generate_private_key(
|
||||
ec.SECP256R1(), backend=default_backend()
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm for key generation: {self.algorithm}")
|
||||
|
||||
# Get public key
|
||||
public_key = private_key.public_key()
|
||||
|
||||
# Serialize private key
|
||||
private_pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption()
|
||||
)
|
||||
|
||||
# Serialize public key
|
||||
public_pem = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
)
|
||||
|
||||
# Store keys
|
||||
self._private_key = private_key
|
||||
self._public_key = public_key
|
||||
self._key_generated_at = datetime.utcnow()
|
||||
|
||||
# If file paths are configured, save to files
|
||||
if self.private_key_path and self.public_key_path:
|
||||
await self._save_keys_to_files(private_pem, public_pem)
|
||||
|
||||
logger.info(f"Generated new {self.algorithm} key pair")
|
||||
return private_pem, public_pem
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate key pair: {e}")
|
||||
raise
|
||||
|
||||
async def _save_keys_to_files(self, private_pem: bytes, public_pem: bytes):
|
||||
"""Save keys to files"""
|
||||
try:
|
||||
# Ensure directories exist
|
||||
private_path = Path(self.private_key_path)
|
||||
public_path = Path(self.public_key_path)
|
||||
|
||||
private_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
public_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save private key (set secure permissions)
|
||||
with open(private_path, 'wb') as f:
|
||||
f.write(private_pem)
|
||||
os.chmod(private_path, 0o600) # Only owner can read/write
|
||||
|
||||
# Save public key
|
||||
with open(public_path, 'wb') as f:
|
||||
f.write(public_pem)
|
||||
os.chmod(public_path, 0o644) # Owner read/write, others read only
|
||||
|
||||
logger.info(f"Saved keys to files: {private_path}, {public_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save keys to files: {e}")
|
||||
raise
|
||||
|
||||
def get_private_key(self):
|
||||
"""Get private key for signing"""
|
||||
if self.algorithm == "HS256":
|
||||
return self._secret_key
|
||||
else:
|
||||
return self._private_key
|
||||
|
||||
def get_public_key(self):
|
||||
"""Get public key for verification"""
|
||||
if self.algorithm == "HS256":
|
||||
return self._secret_key
|
||||
else:
|
||||
return self._public_key
|
||||
|
||||
def get_algorithm(self) -> str:
|
||||
"""Get signing algorithm"""
|
||||
return self.algorithm
|
||||
|
||||
async def is_key_expired(self) -> bool:
|
||||
"""Check if key is expired"""
|
||||
if not self._key_generated_at:
|
||||
return True
|
||||
|
||||
expiry_time = self._key_generated_at + timedelta(seconds=self.key_rotation_interval)
|
||||
return datetime.utcnow() > expiry_time
|
||||
|
||||
async def rotate_keys(self) -> bool:
|
||||
"""Rotate keys"""
|
||||
try:
|
||||
logger.info("Starting key rotation")
|
||||
|
||||
if self.algorithm == "HS256":
|
||||
# Generate new symmetric key
|
||||
self._secret_key = await self.generate_symmetric_key()
|
||||
self._key_generated_at = datetime.utcnow()
|
||||
else:
|
||||
# Generate new asymmetric key pair
|
||||
await self.generate_key_pair()
|
||||
|
||||
logger.info("Key rotation completed successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Key rotation failed: {e}")
|
||||
return False
|
||||
|
||||
async def get_key_info(self) -> dict:
|
||||
"""Get key information"""
|
||||
return {
|
||||
"algorithm": self.algorithm,
|
||||
"key_generated_at": self._key_generated_at.isoformat() if self._key_generated_at else None,
|
||||
"key_expires_at": (
|
||||
self._key_generated_at + timedelta(seconds=self.key_rotation_interval)
|
||||
).isoformat() if self._key_generated_at else None,
|
||||
"is_expired": await self.is_key_expired(),
|
||||
"has_private_key": self._private_key is not None or self._secret_key is not None,
|
||||
"has_public_key": self._public_key is not None or self._secret_key is not None
|
||||
}
|
||||
|
||||
async def export_public_key_pem(self) -> Optional[str]:
|
||||
"""Export public key in PEM format"""
|
||||
if self.algorithm == "HS256":
|
||||
return None # Symmetric key not exported
|
||||
|
||||
if not self._public_key:
|
||||
return None
|
||||
|
||||
try:
|
||||
public_pem = self._public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
)
|
||||
return public_pem.decode()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to export public key: {e}")
|
||||
return None
|
||||
536
doris_mcp_server/auth/oauth_client.py
Normal file
536
doris_mcp_server/auth/oauth_client.py
Normal file
@@ -0,0 +1,536 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
OAuth 2.0/OIDC Client Manager
|
||||
Provides OAuth authentication client implementation with PKCE and OIDC support
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional, Any, Tuple
|
||||
from urllib.parse import urlencode, parse_qs, urlparse
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError:
|
||||
raise ImportError("aiohttp is required for OAuth functionality. Install with: pip install aiohttp")
|
||||
|
||||
from .oauth_types import (
|
||||
OAuthProvider, OAuthState, OAuthTokens, OAuthUserInfo,
|
||||
OIDCDiscovery, OAuthError, OAuthProviderConfig, OAUTH_PROVIDERS
|
||||
)
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OAuthStateManager:
|
||||
"""Manages OAuth state parameters for CSRF protection"""
|
||||
|
||||
def __init__(self, state_expiry: int = 600):
|
||||
"""Initialize state manager
|
||||
|
||||
Args:
|
||||
state_expiry: State expiry time in seconds
|
||||
"""
|
||||
self.state_expiry = state_expiry
|
||||
self._states: Dict[str, OAuthState] = {}
|
||||
self._cleanup_task = None
|
||||
|
||||
logger.info("OAuthStateManager initialized")
|
||||
|
||||
async def start(self):
|
||||
"""Start periodic cleanup task"""
|
||||
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||
logger.info("OAuth state manager started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop periodic cleanup task"""
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("OAuth state manager stopped")
|
||||
|
||||
def create_state(self, redirect_uri: str, pkce_enabled: bool = True,
|
||||
nonce_enabled: bool = True) -> OAuthState:
|
||||
"""Create new OAuth state
|
||||
|
||||
Args:
|
||||
redirect_uri: OAuth redirect URI
|
||||
pkce_enabled: Whether to enable PKCE
|
||||
nonce_enabled: Whether to enable nonce (for OIDC)
|
||||
|
||||
Returns:
|
||||
OAuth state object
|
||||
"""
|
||||
state = secrets.token_urlsafe(32)
|
||||
nonce = secrets.token_urlsafe(32) if nonce_enabled else None
|
||||
|
||||
pkce_verifier = None
|
||||
pkce_challenge = None
|
||||
if pkce_enabled:
|
||||
pkce_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=')
|
||||
challenge_bytes = hashlib.sha256(pkce_verifier.encode()).digest()
|
||||
pkce_challenge = base64.urlsafe_b64encode(challenge_bytes).decode('utf-8').rstrip('=')
|
||||
|
||||
oauth_state = OAuthState(
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
pkce_verifier=pkce_verifier,
|
||||
pkce_challenge=pkce_challenge,
|
||||
redirect_uri=redirect_uri,
|
||||
created_at=datetime.utcnow(),
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=self.state_expiry)
|
||||
)
|
||||
|
||||
self._states[state] = oauth_state
|
||||
logger.debug(f"Created OAuth state: {state}")
|
||||
return oauth_state
|
||||
|
||||
def get_state(self, state: str) -> Optional[OAuthState]:
|
||||
"""Get OAuth state by state parameter
|
||||
|
||||
Args:
|
||||
state: State parameter
|
||||
|
||||
Returns:
|
||||
OAuth state object or None if not found/expired
|
||||
"""
|
||||
oauth_state = self._states.get(state)
|
||||
if oauth_state and oauth_state.expires_at > datetime.utcnow():
|
||||
return oauth_state
|
||||
elif oauth_state:
|
||||
# Remove expired state
|
||||
del self._states[state]
|
||||
logger.debug(f"Removed expired OAuth state: {state}")
|
||||
return None
|
||||
|
||||
def consume_state(self, state: str) -> Optional[OAuthState]:
|
||||
"""Get and remove OAuth state
|
||||
|
||||
Args:
|
||||
state: State parameter
|
||||
|
||||
Returns:
|
||||
OAuth state object or None if not found/expired
|
||||
"""
|
||||
oauth_state = self.get_state(state)
|
||||
if oauth_state:
|
||||
del self._states[state]
|
||||
logger.debug(f"Consumed OAuth state: {state}")
|
||||
return oauth_state
|
||||
|
||||
async def _periodic_cleanup(self):
|
||||
"""Periodic cleanup of expired states"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(300) # Clean up every 5 minutes
|
||||
current_time = datetime.utcnow()
|
||||
expired_states = [
|
||||
state for state, oauth_state in self._states.items()
|
||||
if oauth_state.expires_at <= current_time
|
||||
]
|
||||
|
||||
for state in expired_states:
|
||||
del self._states[state]
|
||||
|
||||
if expired_states:
|
||||
logger.info(f"Cleaned up {len(expired_states)} expired OAuth states")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error during OAuth state cleanup: {e}")
|
||||
|
||||
|
||||
class OAuthClient:
|
||||
"""OAuth 2.0/OIDC Client implementation"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize OAuth client
|
||||
|
||||
Args:
|
||||
config: DorisConfig with OAuth configuration
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
# Access OAuth settings through security configuration
|
||||
if hasattr(config, 'security'):
|
||||
security_config = config.security
|
||||
else:
|
||||
security_config = config
|
||||
|
||||
self.enabled = security_config.oauth_enabled
|
||||
if not self.enabled:
|
||||
logger.info("OAuth client disabled by configuration")
|
||||
return
|
||||
|
||||
# Build provider configuration
|
||||
self.provider_config = self._build_provider_config(security_config)
|
||||
self.state_manager = OAuthStateManager(security_config.oauth_state_expiry)
|
||||
|
||||
# HTTP client session
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
# Discovery cache
|
||||
self._discovery_cache: Optional[OIDCDiscovery] = None
|
||||
self._discovery_cache_time: Optional[datetime] = None
|
||||
|
||||
logger.info(f"OAuthClient initialized for provider: {self.provider_config.provider.value}")
|
||||
|
||||
def _build_provider_config(self, security_config) -> OAuthProviderConfig:
|
||||
"""Build OAuth provider configuration
|
||||
|
||||
Args:
|
||||
security_config: Security configuration object
|
||||
|
||||
Returns:
|
||||
OAuth provider configuration
|
||||
"""
|
||||
try:
|
||||
provider = OAuthProvider(security_config.oauth_provider)
|
||||
except ValueError:
|
||||
provider = OAuthProvider.CUSTOM
|
||||
|
||||
# Get default configuration for known providers
|
||||
defaults = OAUTH_PROVIDERS.get(provider, {})
|
||||
|
||||
return OAuthProviderConfig(
|
||||
provider=provider,
|
||||
client_id=security_config.oauth_client_id,
|
||||
client_secret=security_config.oauth_client_secret,
|
||||
redirect_uri=security_config.oauth_redirect_uri,
|
||||
scopes=security_config.oauth_scopes or defaults.get("scopes", ["openid", "email", "profile"]),
|
||||
|
||||
# Endpoints (use configured or defaults)
|
||||
authorization_endpoint=security_config.oauth_authorization_endpoint or defaults.get("authorization_endpoint", ""),
|
||||
token_endpoint=security_config.oauth_token_endpoint or defaults.get("token_endpoint", ""),
|
||||
userinfo_endpoint=security_config.oauth_userinfo_endpoint or defaults.get("userinfo_endpoint"),
|
||||
jwks_uri=security_config.oauth_jwks_uri or defaults.get("jwks_uri"),
|
||||
|
||||
# Discovery
|
||||
discovery_url=security_config.oidc_discovery_url or defaults.get("discovery_url"),
|
||||
|
||||
# Settings
|
||||
pkce_enabled=security_config.oauth_pkce_enabled,
|
||||
nonce_enabled=security_config.oauth_nonce_enabled,
|
||||
|
||||
# User mapping
|
||||
user_id_claim=security_config.oauth_user_id_claim or defaults.get("user_id_claim", "sub"),
|
||||
email_claim=security_config.oauth_email_claim or defaults.get("email_claim", "email"),
|
||||
name_claim=security_config.oauth_name_claim or defaults.get("name_claim", "name"),
|
||||
roles_claim=security_config.oauth_roles_claim,
|
||||
default_roles=security_config.oauth_default_roles
|
||||
)
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize OAuth client
|
||||
|
||||
Returns:
|
||||
True if initialization successful
|
||||
"""
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
try:
|
||||
# Create HTTP session
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# Start state manager
|
||||
await self.state_manager.start()
|
||||
|
||||
# Perform OIDC discovery if configured
|
||||
if self.provider_config.discovery_url:
|
||||
await self._discover_oidc_endpoints()
|
||||
|
||||
logger.info("OAuth client initialization completed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize OAuth client: {e}")
|
||||
return False
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown OAuth client"""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
try:
|
||||
# Stop state manager
|
||||
await self.state_manager.stop()
|
||||
|
||||
# Close HTTP session
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
logger.info("OAuth client shutdown completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during OAuth client shutdown: {e}")
|
||||
|
||||
async def _discover_oidc_endpoints(self):
|
||||
"""Discover OIDC endpoints using discovery URL"""
|
||||
try:
|
||||
# Check cache first
|
||||
if (self._discovery_cache and self._discovery_cache_time and
|
||||
datetime.utcnow() - self._discovery_cache_time < timedelta(hours=1)):
|
||||
return self._discovery_cache
|
||||
|
||||
logger.info(f"Discovering OIDC endpoints: {self.provider_config.discovery_url}")
|
||||
|
||||
async with self._session.get(self.provider_config.discovery_url) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
|
||||
discovery = OIDCDiscovery(
|
||||
issuer=data["issuer"],
|
||||
authorization_endpoint=data["authorization_endpoint"],
|
||||
token_endpoint=data["token_endpoint"],
|
||||
userinfo_endpoint=data.get("userinfo_endpoint"),
|
||||
jwks_uri=data.get("jwks_uri"),
|
||||
scopes_supported=data.get("scopes_supported"),
|
||||
response_types_supported=data.get("response_types_supported"),
|
||||
subject_types_supported=data.get("subject_types_supported"),
|
||||
id_token_signing_alg_values_supported=data.get("id_token_signing_alg_values_supported")
|
||||
)
|
||||
|
||||
# Update provider configuration with discovered endpoints
|
||||
if not self.provider_config.authorization_endpoint:
|
||||
self.provider_config.authorization_endpoint = discovery.authorization_endpoint
|
||||
if not self.provider_config.token_endpoint:
|
||||
self.provider_config.token_endpoint = discovery.token_endpoint
|
||||
if not self.provider_config.userinfo_endpoint:
|
||||
self.provider_config.userinfo_endpoint = discovery.userinfo_endpoint
|
||||
if not self.provider_config.jwks_uri:
|
||||
self.provider_config.jwks_uri = discovery.jwks_uri
|
||||
|
||||
# Cache discovery result
|
||||
self._discovery_cache = discovery
|
||||
self._discovery_cache_time = datetime.utcnow()
|
||||
|
||||
logger.info("OIDC endpoint discovery completed successfully")
|
||||
return discovery
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OIDC endpoint discovery failed: {e}")
|
||||
raise
|
||||
|
||||
def build_authorization_url(self) -> Tuple[str, OAuthState]:
|
||||
"""Build OAuth authorization URL
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, oauth_state)
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth client is not enabled")
|
||||
|
||||
# Create state for CSRF protection
|
||||
oauth_state = self.state_manager.create_state(
|
||||
redirect_uri=self.provider_config.redirect_uri,
|
||||
pkce_enabled=self.provider_config.pkce_enabled,
|
||||
nonce_enabled=self.provider_config.nonce_enabled
|
||||
)
|
||||
|
||||
# Build authorization parameters
|
||||
params = {
|
||||
'response_type': 'code',
|
||||
'client_id': self.provider_config.client_id,
|
||||
'redirect_uri': self.provider_config.redirect_uri,
|
||||
'scope': ' '.join(self.provider_config.scopes),
|
||||
'state': oauth_state.state
|
||||
}
|
||||
|
||||
# Add PKCE challenge
|
||||
if oauth_state.pkce_challenge:
|
||||
params['code_challenge'] = oauth_state.pkce_challenge
|
||||
params['code_challenge_method'] = 'S256'
|
||||
|
||||
# Add nonce for OIDC
|
||||
if oauth_state.nonce:
|
||||
params['nonce'] = oauth_state.nonce
|
||||
|
||||
# Build URL
|
||||
authorization_url = f"{self.provider_config.authorization_endpoint}?{urlencode(params)}"
|
||||
|
||||
logger.info(f"Built OAuth authorization URL for state: {oauth_state.state}")
|
||||
return authorization_url, oauth_state
|
||||
|
||||
async def exchange_code_for_tokens(self, code: str, state: str) -> Tuple[OAuthTokens, OAuthState]:
|
||||
"""Exchange authorization code for tokens
|
||||
|
||||
Args:
|
||||
code: Authorization code
|
||||
state: State parameter
|
||||
|
||||
Returns:
|
||||
Tuple of (OAuth tokens, OAuth state)
|
||||
|
||||
Raises:
|
||||
ValueError: If state is invalid or exchange fails
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth client is not enabled")
|
||||
|
||||
# Validate and consume state
|
||||
oauth_state = self.state_manager.consume_state(state)
|
||||
if not oauth_state:
|
||||
raise ValueError("Invalid or expired state parameter")
|
||||
|
||||
try:
|
||||
# Prepare token request
|
||||
data = {
|
||||
'grant_type': 'authorization_code',
|
||||
'client_id': self.provider_config.client_id,
|
||||
'client_secret': self.provider_config.client_secret,
|
||||
'code': code,
|
||||
'redirect_uri': oauth_state.redirect_uri
|
||||
}
|
||||
|
||||
# Add PKCE verifier
|
||||
if oauth_state.pkce_verifier:
|
||||
data['code_verifier'] = oauth_state.pkce_verifier
|
||||
|
||||
# Make token request
|
||||
async with self._session.post(
|
||||
self.provider_config.token_endpoint,
|
||||
data=data,
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded'}
|
||||
) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
if response.status != 200:
|
||||
error_msg = response_data.get('error_description', response_data.get('error', 'Token exchange failed'))
|
||||
raise ValueError(f"Token exchange failed: {error_msg}")
|
||||
|
||||
tokens = OAuthTokens(
|
||||
access_token=response_data['access_token'],
|
||||
token_type=response_data.get('token_type', 'Bearer'),
|
||||
expires_in=response_data.get('expires_in'),
|
||||
refresh_token=response_data.get('refresh_token'),
|
||||
scope=response_data.get('scope'),
|
||||
id_token=response_data.get('id_token')
|
||||
)
|
||||
|
||||
logger.info("Successfully exchanged authorization code for tokens")
|
||||
return tokens, oauth_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token exchange failed: {e}")
|
||||
raise ValueError(f"Token exchange failed: {str(e)}")
|
||||
|
||||
async def get_user_info(self, tokens: OAuthTokens) -> OAuthUserInfo:
|
||||
"""Get user information from OAuth provider
|
||||
|
||||
Args:
|
||||
tokens: OAuth tokens
|
||||
|
||||
Returns:
|
||||
OAuth user information
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth client is not enabled")
|
||||
|
||||
if not self.provider_config.userinfo_endpoint:
|
||||
raise ValueError("Userinfo endpoint not configured")
|
||||
|
||||
try:
|
||||
# Make userinfo request
|
||||
headers = {'Authorization': f'{tokens.token_type} {tokens.access_token}'}
|
||||
|
||||
async with self._session.get(
|
||||
self.provider_config.userinfo_endpoint,
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
user_data = await response.json()
|
||||
|
||||
# Extract user information using configured claims
|
||||
user_info = OAuthUserInfo(
|
||||
sub=str(user_data.get(self.provider_config.user_id_claim, '')),
|
||||
email=user_data.get(self.provider_config.email_claim),
|
||||
name=user_data.get(self.provider_config.name_claim),
|
||||
given_name=user_data.get('given_name'),
|
||||
family_name=user_data.get('family_name'),
|
||||
picture=user_data.get('picture'),
|
||||
locale=user_data.get('locale'),
|
||||
email_verified=user_data.get('email_verified'),
|
||||
roles=user_data.get(self.provider_config.roles_claim, self.provider_config.default_roles.copy()),
|
||||
raw_claims=user_data
|
||||
)
|
||||
|
||||
logger.info(f"Retrieved user info for user: {user_info.sub}")
|
||||
return user_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get user info: {e}")
|
||||
raise ValueError(f"Failed to get user info: {str(e)}")
|
||||
|
||||
async def refresh_tokens(self, refresh_token: str) -> OAuthTokens:
|
||||
"""Refresh OAuth tokens
|
||||
|
||||
Args:
|
||||
refresh_token: Refresh token
|
||||
|
||||
Returns:
|
||||
New OAuth tokens
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth client is not enabled")
|
||||
|
||||
try:
|
||||
data = {
|
||||
'grant_type': 'refresh_token',
|
||||
'client_id': self.provider_config.client_id,
|
||||
'client_secret': self.provider_config.client_secret,
|
||||
'refresh_token': refresh_token
|
||||
}
|
||||
|
||||
async with self._session.post(
|
||||
self.provider_config.token_endpoint,
|
||||
data=data,
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded'}
|
||||
) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
if response.status != 200:
|
||||
error_msg = response_data.get('error_description', response_data.get('error', 'Token refresh failed'))
|
||||
raise ValueError(f"Token refresh failed: {error_msg}")
|
||||
|
||||
tokens = OAuthTokens(
|
||||
access_token=response_data['access_token'],
|
||||
token_type=response_data.get('token_type', 'Bearer'),
|
||||
expires_in=response_data.get('expires_in'),
|
||||
refresh_token=response_data.get('refresh_token', refresh_token), # Keep old if not provided
|
||||
scope=response_data.get('scope'),
|
||||
id_token=response_data.get('id_token')
|
||||
)
|
||||
|
||||
logger.info("Successfully refreshed OAuth tokens")
|
||||
return tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token refresh failed: {e}")
|
||||
raise ValueError(f"Token refresh failed: {str(e)}")
|
||||
312
doris_mcp_server/auth/oauth_handlers.py
Normal file
312
doris_mcp_server/auth/oauth_handlers.py
Normal file
@@ -0,0 +1,312 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
OAuth HTTP Handlers
|
||||
Provides HTTP endpoints for OAuth authentication flow
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
import json
|
||||
|
||||
from starlette.responses import JSONResponse, RedirectResponse, HTMLResponse
|
||||
from starlette.requests import Request
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OAuthHandlers:
|
||||
"""OAuth HTTP request handlers"""
|
||||
|
||||
def __init__(self, security_manager):
|
||||
"""Initialize OAuth handlers
|
||||
|
||||
Args:
|
||||
security_manager: DorisSecurityManager instance
|
||||
"""
|
||||
self.security_manager = security_manager
|
||||
logger.info("OAuth handlers initialized")
|
||||
|
||||
async def handle_login(self, request: Request) -> JSONResponse:
|
||||
"""Handle OAuth login initiation
|
||||
|
||||
Returns JSON with authorization URL and state
|
||||
"""
|
||||
try:
|
||||
# Check if OAuth is enabled
|
||||
oauth_info = self.security_manager.get_oauth_provider_info()
|
||||
if not oauth_info.get("enabled"):
|
||||
return JSONResponse(
|
||||
{"error": "OAuth authentication is not enabled"},
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# Get authorization URL
|
||||
authorization_url, state = self.security_manager.get_oauth_authorization_url()
|
||||
|
||||
return JSONResponse({
|
||||
"authorization_url": authorization_url,
|
||||
"state": state,
|
||||
"provider": oauth_info.get("provider"),
|
||||
"message": "Navigate to authorization_url to complete OAuth login"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth login initiation failed: {e}")
|
||||
return JSONResponse(
|
||||
{"error": f"OAuth login failed: {str(e)}"},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
async def handle_callback(self, request: Request) -> JSONResponse:
|
||||
"""Handle OAuth callback
|
||||
|
||||
Processes the OAuth callback and returns authentication result
|
||||
"""
|
||||
try:
|
||||
# Get query parameters
|
||||
query_params = dict(request.query_params)
|
||||
|
||||
# Check for error in callback
|
||||
if "error" in query_params:
|
||||
error_description = query_params.get("error_description", "Unknown error")
|
||||
logger.warning(f"OAuth callback error: {query_params['error']} - {error_description}")
|
||||
return JSONResponse(
|
||||
{
|
||||
"error": query_params["error"],
|
||||
"error_description": error_description,
|
||||
"error_uri": query_params.get("error_uri")
|
||||
},
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# Extract required parameters
|
||||
code = query_params.get("code")
|
||||
state = query_params.get("state")
|
||||
|
||||
if not code or not state:
|
||||
return JSONResponse(
|
||||
{"error": "Missing required parameters: code and state"},
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# Handle OAuth callback
|
||||
auth_context = await self.security_manager.handle_oauth_callback(code, state)
|
||||
|
||||
# Return successful authentication response
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"user_id": auth_context.user_id,
|
||||
"roles": auth_context.roles,
|
||||
"permissions": auth_context.permissions,
|
||||
"security_level": auth_context.security_level.value,
|
||||
"session_id": auth_context.session_id,
|
||||
"message": "OAuth authentication successful"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth callback handling failed: {e}")
|
||||
return JSONResponse(
|
||||
{"error": f"OAuth callback failed: {str(e)}"},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
async def handle_provider_info(self, request: Request) -> JSONResponse:
|
||||
"""Handle OAuth provider information request
|
||||
|
||||
Returns information about the configured OAuth provider
|
||||
"""
|
||||
try:
|
||||
provider_info = self.security_manager.get_oauth_provider_info()
|
||||
return JSONResponse(provider_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get OAuth provider info: {e}")
|
||||
return JSONResponse(
|
||||
{"error": f"Failed to get provider info: {str(e)}"},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
async def handle_demo_page(self, request: Request) -> HTMLResponse:
|
||||
"""Handle OAuth demo page
|
||||
|
||||
Returns a simple HTML page for testing OAuth flow
|
||||
"""
|
||||
oauth_info = self.security_manager.get_oauth_provider_info()
|
||||
if not oauth_info.get("enabled"):
|
||||
return HTMLResponse("""
|
||||
<html>
|
||||
<head><title>OAuth Demo</title></head>
|
||||
<body>
|
||||
<h1>OAuth Demo</h1>
|
||||
<p style="color: red;">OAuth authentication is not enabled.</p>
|
||||
<p>Please configure OAuth settings in your security configuration.</p>
|
||||
</body>
|
||||
</html>
|
||||
""")
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Doris MCP Server - OAuth Demo</title>
|
||||
<style>
|
||||
body {{
|
||||
font-family: Arial, sans-serif;
|
||||
max-width: 800px;
|
||||
margin: 0 auto;
|
||||
padding: 20px;
|
||||
}}
|
||||
.info {{
|
||||
background-color: #f0f8ff;
|
||||
padding: 15px;
|
||||
border-left: 4px solid #0066cc;
|
||||
margin: 20px 0;
|
||||
}}
|
||||
.error {{
|
||||
background-color: #ffe6e6;
|
||||
padding: 15px;
|
||||
border-left: 4px solid #cc0000;
|
||||
margin: 20px 0;
|
||||
}}
|
||||
.success {{
|
||||
background-color: #e6ffe6;
|
||||
padding: 15px;
|
||||
border-left: 4px solid #00cc00;
|
||||
margin: 20px 0;
|
||||
}}
|
||||
button {{
|
||||
background-color: #0066cc;
|
||||
color: white;
|
||||
padding: 10px 20px;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 16px;
|
||||
}}
|
||||
button:hover {{
|
||||
background-color: #0052a3;
|
||||
}}
|
||||
pre {{
|
||||
background-color: #f5f5f5;
|
||||
padding: 10px;
|
||||
border-radius: 4px;
|
||||
overflow-x: auto;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Doris MCP Server - OAuth Demo</h1>
|
||||
|
||||
<div class="info">
|
||||
<h3>OAuth Configuration</h3>
|
||||
<p><strong>Provider:</strong> {oauth_info.get('provider', 'N/A')}</p>
|
||||
<p><strong>Client ID:</strong> {oauth_info.get('client_id', 'N/A')}</p>
|
||||
<p><strong>Scopes:</strong> {', '.join(oauth_info.get('scopes', []))}</p>
|
||||
<p><strong>PKCE Enabled:</strong> {oauth_info.get('pkce_enabled', False)}</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<h3>OAuth Authentication Test</h3>
|
||||
<p>Click the button below to start OAuth authentication flow:</p>
|
||||
<button onclick="startOAuthFlow()">Start OAuth Login</button>
|
||||
</div>
|
||||
|
||||
<div id="result" style="margin-top: 20px;"></div>
|
||||
|
||||
<div>
|
||||
<h3>API Endpoints</h3>
|
||||
<ul>
|
||||
<li><code>GET /auth/login</code> - Initiate OAuth login</li>
|
||||
<li><code>GET /auth/callback</code> - OAuth callback handler</li>
|
||||
<li><code>GET /auth/provider</code> - Provider information</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
async function startOAuthFlow() {{
|
||||
const resultDiv = document.getElementById('result');
|
||||
resultDiv.innerHTML = '<div class="info">Initiating OAuth flow...</div>';
|
||||
|
||||
try {{
|
||||
const response = await fetch('/auth/login');
|
||||
const data = await response.json();
|
||||
|
||||
if (response.ok) {{
|
||||
resultDiv.innerHTML = `
|
||||
<div class="success">
|
||||
<h4>OAuth URL Generated Successfully</h4>
|
||||
<p><strong>State:</strong> ${{data.state}}</p>
|
||||
<p><strong>Provider:</strong> ${{data.provider}}</p>
|
||||
<p><a href="${{data.authorization_url}}" target="_blank">Click here to authenticate</a></p>
|
||||
<p><em>Note: After authentication, you will be redirected to the callback URL.</em></p>
|
||||
</div>
|
||||
`;
|
||||
|
||||
// Automatically redirect to OAuth provider
|
||||
// window.open(data.authorization_url, '_blank');
|
||||
}} else {{
|
||||
resultDiv.innerHTML = `
|
||||
<div class="error">
|
||||
<h4>Error</h4>
|
||||
<p>${{data.error}}</p>
|
||||
</div>
|
||||
`;
|
||||
}}
|
||||
}} catch (error) {{
|
||||
resultDiv.innerHTML = `
|
||||
<div class="error">
|
||||
<h4>Network Error</h4>
|
||||
<p>${{error.message}}</p>
|
||||
</div>
|
||||
`;
|
||||
}}
|
||||
}}
|
||||
|
||||
// Handle OAuth callback result if present in URL
|
||||
window.addEventListener('load', function() {{
|
||||
const urlParams = new URLSearchParams(window.location.search);
|
||||
if (urlParams.has('code') && urlParams.has('state')) {{
|
||||
const resultDiv = document.getElementById('result');
|
||||
resultDiv.innerHTML = `
|
||||
<div class="success">
|
||||
<h4>OAuth Callback Received</h4>
|
||||
<p>Code: ${{urlParams.get('code')}}</p>
|
||||
<p>State: ${{urlParams.get('state')}}</p>
|
||||
<p>The authentication was successful!</p>
|
||||
</div>
|
||||
`;
|
||||
}} else if (urlParams.has('error')) {{
|
||||
const resultDiv = document.getElementById('result');
|
||||
resultDiv.innerHTML = `
|
||||
<div class="error">
|
||||
<h4>OAuth Error</h4>
|
||||
<p>Error: ${{urlParams.get('error')}}</p>
|
||||
<p>Description: ${{urlParams.get('error_description') || 'No description'}}</p>
|
||||
</div>
|
||||
`;
|
||||
}}
|
||||
}});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return HTMLResponse(html_content)
|
||||
289
doris_mcp_server/auth/oauth_provider.py
Normal file
289
doris_mcp_server/auth/oauth_provider.py
Normal file
@@ -0,0 +1,289 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
OAuth Authentication Provider
|
||||
Integrates OAuth 2.0/OIDC authentication with the existing authentication framework
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from .oauth_client import OAuthClient
|
||||
from .oauth_types import OAuthTokens, OAuthUserInfo, OAuthState
|
||||
from ..utils.security import AuthContext, SecurityLevel
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OAuthAuthenticationProvider:
|
||||
"""OAuth authentication provider for Doris MCP Server"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize OAuth authentication provider
|
||||
|
||||
Args:
|
||||
config: DorisConfig with OAuth configuration
|
||||
"""
|
||||
self.config = config
|
||||
self.oauth_client = OAuthClient(config)
|
||||
self.enabled = self.oauth_client.enabled
|
||||
|
||||
logger.info(f"OAuthAuthenticationProvider initialized (enabled: {self.enabled})")
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize OAuth authentication provider
|
||||
|
||||
Returns:
|
||||
True if initialization successful
|
||||
"""
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
success = await self.oauth_client.initialize()
|
||||
if success:
|
||||
logger.info("OAuth authentication provider initialized successfully")
|
||||
else:
|
||||
logger.error("Failed to initialize OAuth authentication provider")
|
||||
return success
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown OAuth authentication provider"""
|
||||
if self.enabled:
|
||||
await self.oauth_client.shutdown()
|
||||
logger.info("OAuth authentication provider shutdown completed")
|
||||
|
||||
def get_authorization_url(self) -> Tuple[str, str]:
|
||||
"""Get OAuth authorization URL
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, state)
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth authentication is not enabled")
|
||||
|
||||
authorization_url, oauth_state = self.oauth_client.build_authorization_url()
|
||||
return authorization_url, oauth_state.state
|
||||
|
||||
async def handle_callback(self, code: str, state: str) -> AuthContext:
|
||||
"""Handle OAuth callback and create authentication context
|
||||
|
||||
Args:
|
||||
code: Authorization code from OAuth provider
|
||||
state: State parameter for CSRF protection
|
||||
|
||||
Returns:
|
||||
AuthContext for the authenticated user
|
||||
|
||||
Raises:
|
||||
ValueError: If authentication fails
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth authentication is not enabled")
|
||||
|
||||
try:
|
||||
# Exchange code for tokens
|
||||
tokens, oauth_state = await self.oauth_client.exchange_code_for_tokens(code, state)
|
||||
|
||||
# Get user information
|
||||
user_info = await self.oauth_client.get_user_info(tokens)
|
||||
|
||||
# Create authentication context
|
||||
auth_context = await self._create_auth_context(user_info, tokens)
|
||||
|
||||
logger.info(f"OAuth authentication successful for user: {auth_context.user_id}")
|
||||
return auth_context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth callback handling failed: {e}")
|
||||
raise ValueError(f"OAuth authentication failed: {str(e)}")
|
||||
|
||||
async def authenticate_with_token(self, access_token: str) -> AuthContext:
|
||||
"""Authenticate using OAuth access token
|
||||
|
||||
Args:
|
||||
access_token: OAuth access token
|
||||
|
||||
Returns:
|
||||
AuthContext for the authenticated user
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth authentication is not enabled")
|
||||
|
||||
try:
|
||||
# Create token object
|
||||
tokens = OAuthTokens(access_token=access_token)
|
||||
|
||||
# Get user information
|
||||
user_info = await self.oauth_client.get_user_info(tokens)
|
||||
|
||||
# Create authentication context
|
||||
auth_context = await self._create_auth_context(user_info, tokens)
|
||||
|
||||
logger.info(f"OAuth token authentication successful for user: {auth_context.user_id}")
|
||||
return auth_context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth token authentication failed: {e}")
|
||||
raise ValueError(f"OAuth token authentication failed: {str(e)}")
|
||||
|
||||
async def refresh_authentication(self, refresh_token: str) -> Tuple[AuthContext, str]:
|
||||
"""Refresh OAuth authentication
|
||||
|
||||
Args:
|
||||
refresh_token: OAuth refresh token
|
||||
|
||||
Returns:
|
||||
Tuple of (AuthContext, new_access_token)
|
||||
"""
|
||||
if not self.enabled:
|
||||
raise ValueError("OAuth authentication is not enabled")
|
||||
|
||||
try:
|
||||
# Refresh tokens
|
||||
tokens = await self.oauth_client.refresh_tokens(refresh_token)
|
||||
|
||||
# Get updated user information
|
||||
user_info = await self.oauth_client.get_user_info(tokens)
|
||||
|
||||
# Create authentication context
|
||||
auth_context = await self._create_auth_context(user_info, tokens)
|
||||
|
||||
logger.info(f"OAuth refresh successful for user: {auth_context.user_id}")
|
||||
return auth_context, tokens.access_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth refresh failed: {e}")
|
||||
raise ValueError(f"OAuth refresh failed: {str(e)}")
|
||||
|
||||
async def _create_auth_context(self, user_info: OAuthUserInfo, tokens: OAuthTokens) -> AuthContext:
|
||||
"""Create authentication context from OAuth user info
|
||||
|
||||
Args:
|
||||
user_info: OAuth user information
|
||||
tokens: OAuth tokens
|
||||
|
||||
Returns:
|
||||
AuthContext for the user
|
||||
"""
|
||||
# Determine security level based on roles or email domain
|
||||
security_level = await self._determine_security_level(user_info)
|
||||
|
||||
# Map OAuth roles to application permissions
|
||||
permissions = await self._map_permissions(user_info.roles)
|
||||
|
||||
# Generate session ID
|
||||
session_id = f"oauth_{user_info.sub}_{datetime.utcnow().timestamp()}"
|
||||
|
||||
return AuthContext(
|
||||
token_id=f"oauth_{user_info.sub}",
|
||||
user_id=user_info.sub,
|
||||
roles=user_info.roles,
|
||||
permissions=permissions,
|
||||
security_level=security_level,
|
||||
session_id=session_id,
|
||||
login_time=datetime.utcnow(),
|
||||
last_activity=datetime.utcnow(),
|
||||
token="" # OAuth doesn't have raw token, use empty string
|
||||
)
|
||||
|
||||
async def _determine_security_level(self, user_info: OAuthUserInfo) -> SecurityLevel:
|
||||
"""Determine security level for OAuth user
|
||||
|
||||
Args:
|
||||
user_info: OAuth user information
|
||||
|
||||
Returns:
|
||||
SecurityLevel for the user
|
||||
"""
|
||||
# Check if user has admin roles
|
||||
admin_roles = {"admin", "administrator", "data_admin", "super_admin"}
|
||||
if any(role.lower() in admin_roles for role in user_info.roles):
|
||||
return SecurityLevel.SECRET
|
||||
|
||||
# Check email domain for internal users
|
||||
if user_info.email:
|
||||
# You can configure trusted domains for internal access
|
||||
trusted_domains = ["yourcompany.com", "internal.org"] # Configure as needed
|
||||
email_domain = user_info.email.split("@")[-1].lower()
|
||||
if email_domain in trusted_domains:
|
||||
return SecurityLevel.CONFIDENTIAL
|
||||
|
||||
# Check for special roles
|
||||
elevated_roles = {"data_analyst", "developer", "manager"}
|
||||
if any(role.lower() in elevated_roles for role in user_info.roles):
|
||||
return SecurityLevel.CONFIDENTIAL
|
||||
|
||||
# Default to internal level for OAuth users
|
||||
return SecurityLevel.INTERNAL
|
||||
|
||||
async def _map_permissions(self, roles: list[str]) -> list[str]:
|
||||
"""Map OAuth roles to application permissions
|
||||
|
||||
Args:
|
||||
roles: OAuth user roles
|
||||
|
||||
Returns:
|
||||
List of application permissions
|
||||
"""
|
||||
permissions = set()
|
||||
|
||||
# Role to permission mapping
|
||||
role_permissions = {
|
||||
"admin": ["admin", "read_data", "write_data", "manage_users"],
|
||||
"administrator": ["admin", "read_data", "write_data", "manage_users"],
|
||||
"data_admin": ["admin", "read_data", "write_data"],
|
||||
"super_admin": ["admin", "read_data", "write_data", "manage_users", "system_admin"],
|
||||
"data_analyst": ["read_data", "query_database"],
|
||||
"developer": ["read_data", "query_database", "debug"],
|
||||
"viewer": ["read_data"],
|
||||
"user": ["read_data"],
|
||||
"oauth_user": ["read_data"] # Default OAuth user permission
|
||||
}
|
||||
|
||||
# Map roles to permissions
|
||||
for role in roles:
|
||||
role_lower = role.lower()
|
||||
if role_lower in role_permissions:
|
||||
permissions.update(role_permissions[role_lower])
|
||||
|
||||
# Ensure OAuth users have at least basic permissions
|
||||
if not permissions:
|
||||
permissions.add("read_data")
|
||||
|
||||
return list(permissions)
|
||||
|
||||
def get_provider_info(self) -> Dict[str, Any]:
|
||||
"""Get OAuth provider information
|
||||
|
||||
Returns:
|
||||
Provider information dictionary
|
||||
"""
|
||||
if not self.enabled:
|
||||
return {"enabled": False}
|
||||
|
||||
config = self.oauth_client.provider_config
|
||||
return {
|
||||
"enabled": True,
|
||||
"provider": config.provider.value,
|
||||
"client_id": config.client_id,
|
||||
"scopes": config.scopes,
|
||||
"redirect_uri": config.redirect_uri,
|
||||
"pkce_enabled": config.pkce_enabled,
|
||||
"nonce_enabled": config.nonce_enabled
|
||||
}
|
||||
196
doris_mcp_server/auth/oauth_types.py
Normal file
196
doris_mcp_server/auth/oauth_types.py
Normal file
@@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
OAuth 2.0/OIDC Type Definitions
|
||||
Provides data types and models for OAuth authentication flow
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
|
||||
class OAuthProvider(Enum):
|
||||
"""OAuth provider enumeration"""
|
||||
GOOGLE = "google"
|
||||
MICROSOFT = "microsoft"
|
||||
GITHUB = "github"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class OAuthGrantType(Enum):
|
||||
"""OAuth grant type enumeration"""
|
||||
AUTHORIZATION_CODE = "authorization_code"
|
||||
REFRESH_TOKEN = "refresh_token"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthState:
|
||||
"""OAuth state parameter for CSRF protection"""
|
||||
state: str
|
||||
nonce: Optional[str] = None
|
||||
pkce_verifier: Optional[str] = None
|
||||
pkce_challenge: Optional[str] = None
|
||||
redirect_uri: str = ""
|
||||
created_at: datetime = None
|
||||
expires_at: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthTokens:
|
||||
"""OAuth token response"""
|
||||
access_token: str
|
||||
token_type: str = "Bearer"
|
||||
expires_in: Optional[int] = None
|
||||
refresh_token: Optional[str] = None
|
||||
scope: Optional[str] = None
|
||||
id_token: Optional[str] = None # OIDC ID token
|
||||
created_at: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthUserInfo:
|
||||
"""OAuth/OIDC user information"""
|
||||
sub: str # Subject identifier
|
||||
email: Optional[str] = None
|
||||
email_verified: Optional[bool] = None
|
||||
name: Optional[str] = None
|
||||
given_name: Optional[str] = None
|
||||
family_name: Optional[str] = None
|
||||
picture: Optional[str] = None
|
||||
locale: Optional[str] = None
|
||||
roles: List[str] = None
|
||||
raw_claims: Dict[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.roles is None:
|
||||
self.roles = []
|
||||
if self.raw_claims is None:
|
||||
self.raw_claims = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCDiscovery:
|
||||
"""OIDC Discovery document"""
|
||||
issuer: str
|
||||
authorization_endpoint: str
|
||||
token_endpoint: str
|
||||
userinfo_endpoint: Optional[str] = None
|
||||
jwks_uri: Optional[str] = None
|
||||
scopes_supported: List[str] = None
|
||||
response_types_supported: List[str] = None
|
||||
subject_types_supported: List[str] = None
|
||||
id_token_signing_alg_values_supported: List[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.scopes_supported is None:
|
||||
self.scopes_supported = ["openid"]
|
||||
if self.response_types_supported is None:
|
||||
self.response_types_supported = ["code"]
|
||||
if self.subject_types_supported is None:
|
||||
self.subject_types_supported = ["public"]
|
||||
if self.id_token_signing_alg_values_supported is None:
|
||||
self.id_token_signing_alg_values_supported = ["RS256"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthError:
|
||||
"""OAuth error response"""
|
||||
error: str
|
||||
error_description: Optional[str] = None
|
||||
error_uri: Optional[str] = None
|
||||
state: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthProviderConfig:
|
||||
"""OAuth provider configuration"""
|
||||
provider: OAuthProvider
|
||||
client_id: str
|
||||
client_secret: str
|
||||
redirect_uri: str
|
||||
scopes: List[str]
|
||||
|
||||
# Endpoints
|
||||
authorization_endpoint: str
|
||||
token_endpoint: str
|
||||
userinfo_endpoint: Optional[str] = None
|
||||
jwks_uri: Optional[str] = None
|
||||
|
||||
# Discovery
|
||||
discovery_url: Optional[str] = None
|
||||
|
||||
# Settings
|
||||
pkce_enabled: bool = True
|
||||
nonce_enabled: bool = True
|
||||
|
||||
# User mapping
|
||||
user_id_claim: str = "sub"
|
||||
email_claim: str = "email"
|
||||
name_claim: str = "name"
|
||||
roles_claim: str = "roles"
|
||||
default_roles: List[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.default_roles is None:
|
||||
self.default_roles = ["oauth_user"]
|
||||
|
||||
|
||||
# Pre-defined provider configurations
|
||||
OAUTH_PROVIDERS = {
|
||||
OAuthProvider.GOOGLE: {
|
||||
"authorization_endpoint": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_endpoint": "https://oauth2.googleapis.com/token",
|
||||
"userinfo_endpoint": "https://openidconnect.googleapis.com/v1/userinfo",
|
||||
"jwks_uri": "https://www.googleapis.com/oauth2/v3/certs",
|
||||
"discovery_url": "https://accounts.google.com/.well-known/openid_configuration",
|
||||
"scopes": ["openid", "email", "profile"],
|
||||
"user_id_claim": "sub",
|
||||
"email_claim": "email",
|
||||
"name_claim": "name"
|
||||
},
|
||||
OAuthProvider.MICROSOFT: {
|
||||
"authorization_endpoint": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
|
||||
"token_endpoint": "https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||
"userinfo_endpoint": "https://graph.microsoft.com/v1.0/me",
|
||||
"jwks_uri": "https://login.microsoftonline.com/common/discovery/v2.0/keys",
|
||||
"discovery_url": "https://login.microsoftonline.com/common/v2.0/.well-known/openid_configuration",
|
||||
"scopes": ["openid", "profile", "email", "User.Read"],
|
||||
"user_id_claim": "sub",
|
||||
"email_claim": "email",
|
||||
"name_claim": "name"
|
||||
},
|
||||
OAuthProvider.GITHUB: {
|
||||
"authorization_endpoint": "https://github.com/login/oauth/authorize",
|
||||
"token_endpoint": "https://github.com/login/oauth/access_token",
|
||||
"userinfo_endpoint": "https://api.github.com/user",
|
||||
"scopes": ["user:email", "read:user"],
|
||||
"user_id_claim": "id",
|
||||
"email_claim": "email",
|
||||
"name_claim": "name"
|
||||
}
|
||||
}
|
||||
677
doris_mcp_server/auth/token_handlers.py
Normal file
677
doris_mcp_server/auth/token_handlers.py
Normal file
@@ -0,0 +1,677 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Token Authentication HTTP Handlers
|
||||
|
||||
Provides HTTP endpoints for token management including creation, revocation,
|
||||
listing, and statistics. Used for administrative token management in HTTP mode.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, HTMLResponse
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.security import SecurityLevel
|
||||
from ..utils.config import DatabaseConfig
|
||||
from .token_security_middleware import TokenSecurityMiddleware
|
||||
|
||||
|
||||
class TokenHandlers:
|
||||
"""Token Authentication HTTP Handlers"""
|
||||
|
||||
def __init__(self, security_manager, config=None):
|
||||
self.security_manager = security_manager
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Initialize security middleware if config is provided
|
||||
if config:
|
||||
self.security_middleware = TokenSecurityMiddleware(config)
|
||||
else:
|
||||
self.security_middleware = None
|
||||
self.logger.warning("Token handlers initialized without security middleware - access control disabled")
|
||||
|
||||
async def handle_create_token(self, request: Request) -> JSONResponse:
|
||||
"""Handle token creation request"""
|
||||
# Apply security checks
|
||||
if self.security_middleware:
|
||||
security_response = await self.security_middleware.check_token_management_access(request)
|
||||
if security_response:
|
||||
return security_response
|
||||
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
return JSONResponse({
|
||||
"error": "Token authentication is not enabled"
|
||||
}, status_code=503)
|
||||
|
||||
# Parse request data
|
||||
if request.method == "GET":
|
||||
# GET request with query parameters
|
||||
query_params = dict(request.query_params)
|
||||
token_id = query_params.get("token_id")
|
||||
expires_hours_str = query_params.get("expires_hours")
|
||||
description = query_params.get("description", "")
|
||||
custom_token = query_params.get("custom_token")
|
||||
# Database configuration from query params
|
||||
db_config = None
|
||||
if query_params.get("db_host"):
|
||||
db_config = DatabaseConfig(
|
||||
host=query_params.get("db_host", "localhost"),
|
||||
port=int(query_params.get("db_port", "9030")),
|
||||
user=query_params.get("db_user", "root"),
|
||||
password=query_params.get("db_password", ""),
|
||||
database=query_params.get("db_database", "information_schema"),
|
||||
fe_http_port=int(query_params.get("db_fe_http_port", "8030"))
|
||||
)
|
||||
else:
|
||||
# POST request with JSON body
|
||||
try:
|
||||
body = await request.json()
|
||||
except:
|
||||
return JSONResponse({
|
||||
"error": "Invalid JSON body"
|
||||
}, status_code=400)
|
||||
|
||||
token_id = body.get("token_id")
|
||||
expires_hours_str = body.get("expires_hours")
|
||||
description = body.get("description", "")
|
||||
custom_token = body.get("custom_token")
|
||||
# Database configuration from JSON body
|
||||
db_config = None
|
||||
if body.get("database_config"):
|
||||
db_data = body["database_config"]
|
||||
try:
|
||||
db_config = DatabaseConfig(
|
||||
host=db_data.get("host", "localhost"),
|
||||
port=int(db_data.get("port", 9030)),
|
||||
user=db_data.get("user", "root"),
|
||||
password=db_data.get("password", ""),
|
||||
database=db_data.get("database", "information_schema"),
|
||||
fe_http_port=int(db_data.get("fe_http_port", 8030))
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
return JSONResponse({
|
||||
"error": f"Invalid database configuration: {str(e)}"
|
||||
}, status_code=400)
|
||||
|
||||
# Validate required fields
|
||||
if not token_id:
|
||||
return JSONResponse({
|
||||
"error": "token_id is required"
|
||||
}, status_code=400)
|
||||
|
||||
# Parse expires_hours
|
||||
expires_hours = None
|
||||
if expires_hours_str:
|
||||
try:
|
||||
expires_hours = int(expires_hours_str)
|
||||
except ValueError:
|
||||
return JSONResponse({
|
||||
"error": "expires_hours must be an integer"
|
||||
}, status_code=400)
|
||||
|
||||
# Create token using the actual API
|
||||
try:
|
||||
token = await self.security_manager.create_token(
|
||||
token_id=token_id,
|
||||
expires_hours=expires_hours,
|
||||
description=description,
|
||||
custom_token=custom_token,
|
||||
database_config=db_config
|
||||
)
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"token_id": token_id,
|
||||
"token": token,
|
||||
"expires_hours": expires_hours,
|
||||
"description": description,
|
||||
"message": "Token created successfully"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Token creation failed: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Token creation failed: {str(e)}"
|
||||
}, status_code=400)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_create_token: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
async def handle_revoke_token(self, request: Request) -> JSONResponse:
|
||||
"""Handle token revocation request"""
|
||||
# Apply security checks
|
||||
if self.security_middleware:
|
||||
security_response = await self.security_middleware.check_token_management_access(request)
|
||||
if security_response:
|
||||
return security_response
|
||||
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
return JSONResponse({
|
||||
"error": "Token authentication is not enabled"
|
||||
}, status_code=503)
|
||||
|
||||
# Get token_id from query parameters or path
|
||||
token_id = request.query_params.get("token_id")
|
||||
if not token_id and request.method == "DELETE":
|
||||
# Try to get from path: /token/revoke/{token_id}
|
||||
path_parts = str(request.url.path).split("/")
|
||||
if len(path_parts) >= 4:
|
||||
token_id = path_parts[-1]
|
||||
|
||||
if not token_id:
|
||||
return JSONResponse({
|
||||
"error": "token_id is required"
|
||||
}, status_code=400)
|
||||
|
||||
# Revoke token
|
||||
success = await self.security_manager.revoke_token(token_id)
|
||||
|
||||
if success:
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"token_id": token_id,
|
||||
"message": "Token revoked successfully"
|
||||
})
|
||||
else:
|
||||
return JSONResponse({
|
||||
"success": False,
|
||||
"token_id": token_id,
|
||||
"message": "Token not found or already revoked"
|
||||
}, status_code=404)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_revoke_token: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
async def handle_list_tokens(self, request: Request) -> JSONResponse:
|
||||
"""Handle token listing request"""
|
||||
# Apply security checks
|
||||
if self.security_middleware:
|
||||
security_response = await self.security_middleware.check_token_management_access(request)
|
||||
if security_response:
|
||||
return security_response
|
||||
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
return JSONResponse({
|
||||
"error": "Token authentication is not enabled"
|
||||
}, status_code=503)
|
||||
|
||||
# Get tokens list
|
||||
tokens = await self.security_manager.list_tokens()
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"count": len(tokens),
|
||||
"tokens": tokens
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_list_tokens: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
async def handle_token_stats(self, request: Request) -> JSONResponse:
|
||||
"""Handle token statistics request"""
|
||||
# Apply security checks
|
||||
if self.security_middleware:
|
||||
security_response = await self.security_middleware.check_token_management_access(request)
|
||||
if security_response:
|
||||
return security_response
|
||||
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
return JSONResponse({
|
||||
"error": "Token authentication is not enabled"
|
||||
}, status_code=503)
|
||||
|
||||
# Get token statistics
|
||||
stats = self.security_manager.get_token_stats()
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"stats": stats
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_token_stats: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
async def handle_cleanup_tokens(self, request: Request) -> JSONResponse:
|
||||
"""Handle expired tokens cleanup request"""
|
||||
# Apply security checks
|
||||
if self.security_middleware:
|
||||
security_response = await self.security_middleware.check_token_management_access(request)
|
||||
if security_response:
|
||||
return security_response
|
||||
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
return JSONResponse({
|
||||
"error": "Token authentication is not enabled"
|
||||
}, status_code=503)
|
||||
|
||||
# Cleanup expired tokens
|
||||
cleaned_count = await self.security_manager.cleanup_expired_tokens()
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"cleaned_count": cleaned_count,
|
||||
"message": f"Cleaned up {cleaned_count} expired tokens"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_cleanup_tokens: {e}")
|
||||
return JSONResponse({
|
||||
"error": f"Internal server error: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
async def handle_management_page(self, request: Request) -> HTMLResponse:
|
||||
"""Handle token management demo page"""
|
||||
# Apply security checks
|
||||
if self.security_middleware:
|
||||
security_response = await self.security_middleware.check_token_management_access(request)
|
||||
if security_response:
|
||||
# Convert JSON response to HTML for demo page
|
||||
error_data = security_response.body.decode('utf-8') if hasattr(security_response, 'body') else '{"error": "Access denied"}'
|
||||
try:
|
||||
error_info = json.loads(error_data)
|
||||
except:
|
||||
error_info = {"error": "Access denied"}
|
||||
|
||||
error_html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Access Denied - Token Management</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; margin: 50px; background: #f5f5f5; }}
|
||||
.container {{ max-width: 600px; margin: 0 auto; background: white; padding: 30px; border-radius: 8px; }}
|
||||
.error {{ color: #dc3545; background: #f8d7da; border: 1px solid #f5c6cb; padding: 15px; border-radius: 5px; }}
|
||||
.security-info {{ background: #d1ecf1; border: 1px solid #bee5eb; padding: 15px; border-radius: 5px; margin-top: 20px; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔐 Token Management - Access Denied</h1>
|
||||
<div class="error">
|
||||
<h3>Access Denied</h3>
|
||||
<p><strong>Error:</strong> {error_info.get('error', 'Access denied')}</p>
|
||||
<p><strong>Message:</strong> {error_info.get('message', 'Token management access is restricted')}</p>
|
||||
{'<p><strong>Your IP:</strong> ' + str(error_info.get('client_ip', 'Unknown')) + '</p>' if 'client_ip' in error_info else ''}
|
||||
</div>
|
||||
|
||||
<div class="security-info">
|
||||
<h3>🛡️ Security Information</h3>
|
||||
<p>Token management endpoints are protected by the following security measures:</p>
|
||||
<ul>
|
||||
<li><strong>IP Restrictions:</strong> Only localhost/127.0.0.1 access allowed</li>
|
||||
<li><strong>Admin Authentication:</strong> Valid admin token required</li>
|
||||
<li><strong>Configuration Control:</strong> Must be explicitly enabled</li>
|
||||
</ul>
|
||||
<p>If you need access, please:</p>
|
||||
<ol>
|
||||
<li>Access from the server host (127.0.0.1)</li>
|
||||
<li>Ensure HTTP token management is enabled in configuration</li>
|
||||
<li>Provide valid admin authentication</li>
|
||||
</ol>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return HTMLResponse(error_html, status_code=security_response.status_code)
|
||||
|
||||
try:
|
||||
# Check if token manager is available
|
||||
if not self.security_manager.auth_provider.token_manager:
|
||||
html_content = """
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Token Management - Not Available</title>
|
||||
<style>
|
||||
body { font-family: Arial, sans-serif; margin: 50px; }
|
||||
.error { color: red; font-size: 18px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Token Management</h1>
|
||||
<div class="error">Token authentication is not enabled on this server.</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return HTMLResponse(html_content)
|
||||
|
||||
# Get current stats for demo
|
||||
stats = self.security_manager.get_token_stats()
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Doris MCP Server - Token Management</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; margin: 50px; background: #f5f5f5; }}
|
||||
.container {{ max-width: 1200px; margin: 0 auto; background: white; padding: 30px; border-radius: 8px; }}
|
||||
h1 {{ color: #333; }}
|
||||
.section {{ margin: 30px 0; padding: 20px; border: 1px solid #ddd; border-radius: 5px; }}
|
||||
.stats {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 15px; }}
|
||||
.stat-item {{ padding: 15px; background: #f8f9fa; border-radius: 5px; text-align: center; }}
|
||||
.stat-value {{ font-size: 24px; font-weight: bold; color: #007bff; }}
|
||||
.form-group {{ margin: 15px 0; }}
|
||||
.form-group label {{ display: block; margin-bottom: 5px; font-weight: bold; }}
|
||||
.form-group input, .form-group textarea {{ width: 100%; padding: 8px; border: 1px solid #ddd; border-radius: 4px; }}
|
||||
button {{ padding: 10px 20px; margin: 5px; border: none; border-radius: 4px; cursor: pointer; }}
|
||||
.btn-primary {{ background: #007bff; color: white; }}
|
||||
.btn-danger {{ background: #dc3545; color: white; }}
|
||||
.btn-success {{ background: #28a745; color: white; }}
|
||||
.response {{ margin: 15px 0; padding: 15px; border-radius: 5px; }}
|
||||
.response.success {{ background: #d4edda; border: 1px solid #c3e6cb; }}
|
||||
.response.error {{ background: #f8d7da; border: 1px solid #f5c6cb; }}
|
||||
.token-list {{ margin: 15px 0; }}
|
||||
.token-item {{ padding: 10px; margin: 5px 0; background: #f8f9fa; border-radius: 4px; }}
|
||||
pre {{ background: #f8f9fa; padding: 10px; border-radius: 4px; overflow-x: auto; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔐 Doris MCP Server - Token Management</h1>
|
||||
|
||||
<div class="section">
|
||||
<h2>📊 Token Statistics</h2>
|
||||
<div class="stats">
|
||||
<div class="stat-item">
|
||||
<div class="stat-value">{stats.get('total_tokens', 0)}</div>
|
||||
<div>Total Tokens</div>
|
||||
</div>
|
||||
<div class="stat-item">
|
||||
<div class="stat-value">{stats.get('active_tokens', 0)}</div>
|
||||
<div>Active Tokens</div>
|
||||
</div>
|
||||
<div class="stat-item">
|
||||
<div class="stat-value">{stats.get('expired_tokens', 0)}</div>
|
||||
<div>Expired Tokens</div>
|
||||
</div>
|
||||
</div>
|
||||
<p><strong>Token Expiry:</strong> {'Enabled' if stats.get('expiry_enabled') else 'Disabled'}</p>
|
||||
<p><strong>Default Expiry:</strong> {stats.get('default_expiry_hours', 0)} hours</p>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>➕ Create New Token</h2>
|
||||
<form id="createTokenForm">
|
||||
<div class="form-group">
|
||||
<label for="token_id">Token ID (required):</label>
|
||||
<input type="text" id="token_id" name="token_id" placeholder="e.g., my-app-token" required>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="expires_hours">Expires Hours (optional):</label>
|
||||
<input type="number" id="expires_hours" name="expires_hours" placeholder="e.g., 720 (30 days), leave empty for default">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="description">Description (optional):</label>
|
||||
<textarea id="description" name="description" placeholder="Token description"></textarea>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="custom_token">Custom Token (optional):</label>
|
||||
<input type="text" id="custom_token" name="custom_token" placeholder="Leave empty to auto-generate">
|
||||
<small style="color: #666; display: block; margin-top: 5px;">If not provided, a secure token will be generated automatically</small>
|
||||
</div>
|
||||
|
||||
<div class="section" style="margin: 20px 0; padding: 15px; background: #f8f9fa; border-radius: 5px;">
|
||||
<h3>🗄️ Database Configuration (Optional)</h3>
|
||||
<p style="color: #666; font-size: 14px; margin-bottom: 15px;">Configure database connection for this token. Leave empty to use system defaults.</p>
|
||||
|
||||
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 10px;">
|
||||
<div class="form-group">
|
||||
<label for="db_host">Host:</label>
|
||||
<input type="text" id="db_host" name="db_host" placeholder="localhost">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="db_port">Port:</label>
|
||||
<input type="number" id="db_port" name="db_port" placeholder="9030">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="db_user">User:</label>
|
||||
<input type="text" id="db_user" name="db_user" placeholder="root">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="db_password">Password:</label>
|
||||
<input type="password" id="db_password" name="db_password" placeholder="(optional)">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="db_database">Database:</label>
|
||||
<input type="text" id="db_database" name="db_database" placeholder="information_schema">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="db_fe_http_port">FE HTTP Port:</label>
|
||||
<input type="number" id="db_fe_http_port" name="db_fe_http_port" placeholder="8030">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="btn-primary">Create Token</button>
|
||||
</form>
|
||||
<div id="createTokenResponse"></div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>📋 Token Management</h2>
|
||||
<button id="listTokensBtn" class="btn-success">Refresh Token List</button>
|
||||
<button id="cleanupTokensBtn" class="btn-primary">Cleanup Expired Tokens</button>
|
||||
<div id="tokenListResponse"></div>
|
||||
|
||||
<h3>Revoke Token</h3>
|
||||
<div class="form-group">
|
||||
<input type="text" id="revokeTokenId" placeholder="Enter token ID to revoke">
|
||||
<button id="revokeTokenBtn" class="btn-danger">Revoke Token</button>
|
||||
</div>
|
||||
<div id="revokeTokenResponse"></div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>🔧 API Endpoints</h2>
|
||||
<p>Use these endpoints for programmatic token management:</p>
|
||||
<ul>
|
||||
<li><strong>POST /token/create</strong> - Create new token</li>
|
||||
<li><strong>DELETE /token/revoke?token_id=...</strong> - Revoke token</li>
|
||||
<li><strong>GET /token/list</strong> - List all tokens</li>
|
||||
<li><strong>GET /token/stats</strong> - Get token statistics</li>
|
||||
<li><strong>POST /token/cleanup</strong> - Cleanup expired tokens</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// Get admin token from URL parameters
|
||||
const urlParams = new URLSearchParams(window.location.search);
|
||||
const adminToken = urlParams.get('admin_token');
|
||||
|
||||
// Create request headers with admin token
|
||||
function getAuthHeaders() {{
|
||||
if (adminToken) {{
|
||||
return {{
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${{adminToken}}`
|
||||
}};
|
||||
}} else {{
|
||||
return {{'Content-Type': 'application/json'}};
|
||||
}}
|
||||
}}
|
||||
|
||||
// Create URL with admin token parameter
|
||||
function getAuthURL(baseUrl) {{
|
||||
if (adminToken) {{
|
||||
const separator = baseUrl.includes('?') ? '&' : '?';
|
||||
return `${{baseUrl}}${{separator}}admin_token=${{encodeURIComponent(adminToken)}}`;
|
||||
}}
|
||||
return baseUrl;
|
||||
}}
|
||||
|
||||
function showResponse(elementId, data, isSuccess = true) {{
|
||||
const element = document.getElementById(elementId);
|
||||
element.innerHTML = '<pre>' + JSON.stringify(data, null, 2) + '</pre>';
|
||||
element.className = 'response ' + (isSuccess ? 'success' : 'error');
|
||||
}}
|
||||
|
||||
// Create token form - updated to match actual API
|
||||
document.getElementById('createTokenForm').addEventListener('submit', async (e) => {{
|
||||
e.preventDefault();
|
||||
const formData = new FormData(e.target);
|
||||
const data = Object.fromEntries(formData.entries());
|
||||
|
||||
// Remove empty fields for optional parameters
|
||||
if (!data.expires_hours) delete data.expires_hours;
|
||||
if (!data.description) delete data.description;
|
||||
if (!data.custom_token) delete data.custom_token;
|
||||
|
||||
// Handle database configuration
|
||||
if (data.db_host) {{
|
||||
data.database_config = {{
|
||||
host: data.db_host,
|
||||
port: data.db_port ? parseInt(data.db_port) : 9030,
|
||||
user: data.db_user || 'root',
|
||||
password: data.db_password || '',
|
||||
database: data.db_database || 'information_schema',
|
||||
fe_http_port: data.db_fe_http_port ? parseInt(data.db_fe_http_port) : 8030
|
||||
}};
|
||||
}}
|
||||
|
||||
// Remove individual database fields from data
|
||||
delete data.db_host;
|
||||
delete data.db_port;
|
||||
delete data.db_user;
|
||||
delete data.db_password;
|
||||
delete data.db_database;
|
||||
delete data.db_fe_http_port;
|
||||
|
||||
try {{
|
||||
const response = await fetch(getAuthURL('/token/create'), {{
|
||||
method: 'POST',
|
||||
headers: getAuthHeaders(),
|
||||
body: JSON.stringify(data)
|
||||
}});
|
||||
const result = await response.json();
|
||||
showResponse('createTokenResponse', result, response.ok);
|
||||
|
||||
// Refresh token list if creation was successful
|
||||
if (response.ok) {{
|
||||
document.getElementById('listTokensBtn').click();
|
||||
}}
|
||||
}} catch (error) {{
|
||||
showResponse('createTokenResponse', {{error: error.message}}, false);
|
||||
}}
|
||||
}});
|
||||
|
||||
// List tokens
|
||||
document.getElementById('listTokensBtn').addEventListener('click', async () => {{
|
||||
try {{
|
||||
const response = await fetch(getAuthURL('/token/list'), {{
|
||||
headers: getAuthHeaders()
|
||||
}});
|
||||
const result = await response.json();
|
||||
showResponse('tokenListResponse', result, response.ok);
|
||||
}} catch (error) {{
|
||||
showResponse('tokenListResponse', {{error: error.message}}, false);
|
||||
}}
|
||||
}});
|
||||
|
||||
// Cleanup tokens
|
||||
document.getElementById('cleanupTokensBtn').addEventListener('click', async () => {{
|
||||
try {{
|
||||
const response = await fetch(getAuthURL('/token/cleanup'), {{
|
||||
method: 'POST',
|
||||
headers: getAuthHeaders()
|
||||
}});
|
||||
const result = await response.json();
|
||||
showResponse('tokenListResponse', result, response.ok);
|
||||
}} catch (error) {{
|
||||
showResponse('tokenListResponse', {{error: error.message}}, false);
|
||||
}}
|
||||
}});
|
||||
|
||||
// Revoke token
|
||||
document.getElementById('revokeTokenBtn').addEventListener('click', async () => {{
|
||||
const tokenId = document.getElementById('revokeTokenId').value;
|
||||
if (!tokenId) {{
|
||||
showResponse('revokeTokenResponse', {{error: 'Token ID is required'}}, false);
|
||||
return;
|
||||
}}
|
||||
|
||||
try {{
|
||||
const response = await fetch(getAuthURL(`/token/revoke?token_id=${{encodeURIComponent(tokenId)}}`), {{
|
||||
method: 'DELETE',
|
||||
headers: getAuthHeaders()
|
||||
}});
|
||||
const result = await response.json();
|
||||
showResponse('revokeTokenResponse', result, response.ok);
|
||||
|
||||
// Refresh token list if revocation was successful
|
||||
if (response.ok) {{
|
||||
document.getElementById('listTokensBtn').click();
|
||||
document.getElementById('revokeTokenId').value = '';
|
||||
}}
|
||||
}} catch (error) {{
|
||||
showResponse('revokeTokenResponse', {{error: error.message}}, false);
|
||||
}}
|
||||
}});
|
||||
|
||||
// Load token list on page load
|
||||
document.getElementById('listTokensBtn').click();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return HTMLResponse(html_content)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_demo_page: {e}")
|
||||
error_html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Token Management Error</title>
|
||||
<style>body {{ font-family: Arial, sans-serif; margin: 50px; }}</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Token Management Error</h1>
|
||||
<p>Error loading token management page: {str(e)}</p>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return HTMLResponse(error_html, status_code=500)
|
||||
827
doris_mcp_server/auth/token_manager.py
Normal file
827
doris_mcp_server/auth/token_manager.py
Normal file
@@ -0,0 +1,827 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Token Authentication Management Module
|
||||
|
||||
Provides enterprise-grade token authentication system with configurable tokens,
|
||||
expiration management, role-based access control and secure token storage.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from pathlib import Path
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.security import SecurityLevel
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""Database connection configuration for token binding"""
|
||||
|
||||
host: str
|
||||
port: int = 9030
|
||||
user: str = ""
|
||||
password: str = ""
|
||||
database: str = "information_schema"
|
||||
charset: str = "UTF8"
|
||||
fe_http_port: int = 8030
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenInfo:
|
||||
"""Token information structure with optional database binding"""
|
||||
|
||||
token_id: str # Unique token identifier for audit and management
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
expires_at: Optional[datetime] = None
|
||||
last_used: Optional[datetime] = None
|
||||
description: str = "" # Optional description for token purpose
|
||||
is_active: bool = True
|
||||
database_config: Optional[DatabaseConfig] = None # Optional database binding
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenValidationResult:
|
||||
"""Token validation result"""
|
||||
|
||||
is_valid: bool
|
||||
token_info: Optional[TokenInfo] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class TokenManager:
|
||||
"""Enterprise Token Authentication Manager
|
||||
|
||||
Features:
|
||||
- Configurable token storage (file-based or environment variables)
|
||||
- Token expiration management
|
||||
- Secure token hashing
|
||||
- Role-based access control
|
||||
- Token lifecycle management
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Token storage
|
||||
self._tokens: Dict[str, TokenInfo] = {} # token_hash -> TokenInfo
|
||||
self._token_ids: Dict[str, str] = {} # token_id -> token_hash
|
||||
|
||||
# Configuration
|
||||
self.token_file_path = getattr(config.security, 'token_file_path', 'tokens.json')
|
||||
self.enable_token_expiry = getattr(config.security, 'enable_token_expiry', True)
|
||||
self.default_token_expiry_hours = getattr(config.security, 'default_token_expiry_hours', 24 * 30) # 30 days
|
||||
self.token_hash_algorithm = getattr(config.security, 'token_hash_algorithm', 'sha256')
|
||||
|
||||
# Hot reload configuration
|
||||
self.enable_hot_reload = True
|
||||
self.hot_reload_interval = 10 # Check every 10 seconds
|
||||
self._file_last_modified = 0
|
||||
self._hot_reload_task = None
|
||||
|
||||
# Initialize with default tokens if none exist
|
||||
self._initialize_default_tokens()
|
||||
|
||||
# Load tokens from configuration
|
||||
self._load_tokens()
|
||||
|
||||
# Start hot reload monitoring
|
||||
if self.enable_hot_reload:
|
||||
self._start_hot_reload()
|
||||
|
||||
self.logger.info(f"TokenManager initialized with {len(self._tokens)} tokens, hot reload: {self.enable_hot_reload}")
|
||||
|
||||
def _initialize_default_tokens(self):
|
||||
"""Initialize default tokens for basic authentication (configurable via environment)"""
|
||||
# Default token configurations (can be overridden by environment variables)
|
||||
default_tokens = [
|
||||
{
|
||||
'token_id': 'admin-token',
|
||||
'token': os.getenv('DEFAULT_ADMIN_TOKEN', 'doris_admin_token_123456'),
|
||||
'description': os.getenv('DEFAULT_ADMIN_DESCRIPTION', 'Default admin API access token'),
|
||||
'expires_hours': None # Never expires
|
||||
},
|
||||
{
|
||||
'token_id': 'analyst-token',
|
||||
'token': os.getenv('DEFAULT_ANALYST_TOKEN', 'doris_analyst_token_123456'),
|
||||
'description': os.getenv('DEFAULT_ANALYST_DESCRIPTION', 'Default data analysis API access token'),
|
||||
'expires_hours': None # Never expires
|
||||
},
|
||||
{
|
||||
'token_id': 'readonly-token',
|
||||
'token': os.getenv('DEFAULT_READONLY_TOKEN', 'doris_readonly_token_123456'),
|
||||
'description': os.getenv('DEFAULT_READONLY_DESCRIPTION', 'Default read-only API access token'),
|
||||
'expires_hours': None # Never expires
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# Only add default tokens if no custom tokens are defined via environment variables
|
||||
# Check if any TOKEN_* environment variables exist (excluding system and legacy configs)
|
||||
excluded_prefixes = ('DEFAULT_', 'TOKEN_FILE_PATH', 'TOKEN_HASH_')
|
||||
excluded_vars = {'TOKEN_SECRET', 'TOKEN_EXPIRY'}
|
||||
|
||||
custom_tokens_exist = any(
|
||||
key.startswith('TOKEN_') and
|
||||
not key.startswith(excluded_prefixes) and
|
||||
not key.endswith(('_EXPIRES_HOURS', '_DESCRIPTION')) and
|
||||
key not in excluded_vars
|
||||
for key in os.environ.keys()
|
||||
)
|
||||
|
||||
# Also check if token file exists and has content
|
||||
token_file_exists = False
|
||||
if os.path.exists(self.token_file_path):
|
||||
try:
|
||||
with open(self.token_file_path, 'r') as f:
|
||||
content = f.read().strip()
|
||||
if content and content != '{}':
|
||||
token_file_exists = True
|
||||
except:
|
||||
pass
|
||||
|
||||
# Add default tokens only if no custom configuration exists
|
||||
if not custom_tokens_exist and not token_file_exists:
|
||||
for token_config in default_tokens:
|
||||
self._add_token_from_config(token_config)
|
||||
|
||||
self.logger.info(f"Initialized {len(default_tokens)} default tokens (no custom config found)")
|
||||
else:
|
||||
self.logger.info("Skipped default tokens initialization (custom tokens detected)")
|
||||
|
||||
def _add_token_from_config(self, token_config: Dict[str, Any]):
|
||||
"""Add token from configuration with optional database binding"""
|
||||
try:
|
||||
# Calculate expiration time
|
||||
expires_at = None
|
||||
if self.enable_token_expiry:
|
||||
expires_hours = token_config.get('expires_hours', self.default_token_expiry_hours)
|
||||
if expires_hours is not None:
|
||||
expires_at = datetime.utcnow() + timedelta(hours=expires_hours)
|
||||
|
||||
# Parse database configuration if provided
|
||||
database_config = None
|
||||
if 'database_config' in token_config:
|
||||
db_config = token_config['database_config']
|
||||
database_config = DatabaseConfig(
|
||||
host=db_config.get('host', 'localhost'),
|
||||
port=db_config.get('port', 9030),
|
||||
user=db_config.get('user', 'root'),
|
||||
password=db_config.get('password', ''),
|
||||
database=db_config.get('database', 'information_schema'),
|
||||
charset=db_config.get('charset', 'UTF8'),
|
||||
fe_http_port=db_config.get('fe_http_port', 8030)
|
||||
)
|
||||
|
||||
# Create token info
|
||||
token_info = TokenInfo(
|
||||
token_id=token_config['token_id'],
|
||||
expires_at=expires_at,
|
||||
description=token_config.get('description', ''),
|
||||
is_active=token_config.get('is_active', True),
|
||||
database_config=database_config
|
||||
)
|
||||
|
||||
# Hash the token
|
||||
raw_token = token_config['token']
|
||||
token_hash = self._hash_token(raw_token)
|
||||
|
||||
# Store token
|
||||
self._tokens[token_hash] = token_info
|
||||
self._token_ids[token_info.token_id] = token_hash
|
||||
|
||||
db_info = f" with DB binding ({database_config.host})" if database_config else ""
|
||||
self.logger.debug(f"Added token '{token_info.token_id}'{db_info}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to add token from config: {e}")
|
||||
raise
|
||||
|
||||
def _load_tokens(self):
|
||||
"""Load tokens from configuration sources"""
|
||||
# 1. Load from environment variables
|
||||
self._load_tokens_from_env()
|
||||
|
||||
# 2. Load from token file if exists
|
||||
if os.path.exists(self.token_file_path):
|
||||
self._load_tokens_from_file()
|
||||
|
||||
self.logger.info(f"Token loading completed, total tokens: {len(self._tokens)}")
|
||||
|
||||
def _load_tokens_from_env(self):
|
||||
"""Load tokens from environment variables
|
||||
|
||||
Simplified format:
|
||||
TOKEN_<ID>=<token>
|
||||
TOKEN_<ID>_EXPIRES_HOURS=<hours>
|
||||
TOKEN_<ID>_DESCRIPTION=<description>
|
||||
"""
|
||||
token_prefixes = set()
|
||||
|
||||
# Find all TOKEN_ environment variables (exclude legacy and system variables)
|
||||
excluded_token_vars = {
|
||||
'TOKEN_SECRET', # Legacy token secret
|
||||
'TOKEN_EXPIRY', # Legacy token expiry
|
||||
'TOKEN_FILE_PATH', # System config
|
||||
'TOKEN_HASH_ALGORITHM' # System config
|
||||
}
|
||||
|
||||
for key in os.environ:
|
||||
if (key.startswith('TOKEN_') and
|
||||
not key.endswith(('_EXPIRES_HOURS', '_DESCRIPTION')) and
|
||||
key not in excluded_token_vars):
|
||||
token_id = key[6:] # Remove 'TOKEN_' prefix
|
||||
token_prefixes.add(token_id)
|
||||
|
||||
# Load each token
|
||||
for token_id in token_prefixes:
|
||||
try:
|
||||
token = os.environ.get(f'TOKEN_{token_id}')
|
||||
if not token:
|
||||
continue
|
||||
|
||||
expires_hours_str = os.environ.get(f'TOKEN_{token_id}_EXPIRES_HOURS', str(self.default_token_expiry_hours))
|
||||
description = os.environ.get(f'TOKEN_{token_id}_DESCRIPTION', f'Environment token {token_id}')
|
||||
|
||||
expires_hours = None
|
||||
try:
|
||||
if expires_hours_str and expires_hours_str.lower() != 'none':
|
||||
expires_hours = int(expires_hours_str)
|
||||
except ValueError:
|
||||
expires_hours = self.default_token_expiry_hours
|
||||
|
||||
# Add token
|
||||
token_config = {
|
||||
'token_id': token_id.lower(),
|
||||
'token': token,
|
||||
'expires_hours': expires_hours,
|
||||
'description': description
|
||||
}
|
||||
|
||||
self._add_token_from_config(token_config)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load token {token_id} from environment: {e}")
|
||||
|
||||
def _load_tokens_from_file(self):
|
||||
"""Load tokens from JSON file"""
|
||||
try:
|
||||
with open(self.token_file_path, 'r', encoding='utf-8') as f:
|
||||
tokens_data = json.load(f)
|
||||
|
||||
if isinstance(tokens_data, dict) and 'tokens' in tokens_data:
|
||||
tokens_list = tokens_data['tokens']
|
||||
elif isinstance(tokens_data, list):
|
||||
tokens_list = tokens_data
|
||||
else:
|
||||
self.logger.error(f"Invalid token file format: {self.token_file_path}")
|
||||
return
|
||||
|
||||
for token_config in tokens_list:
|
||||
self._add_token_from_config(token_config)
|
||||
|
||||
self.logger.info(f"Loaded {len(tokens_list)} tokens from file: {self.token_file_path}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load tokens from file {self.token_file_path}: {e}")
|
||||
|
||||
def _hash_token(self, token: str) -> str:
|
||||
"""Hash token for secure storage"""
|
||||
if self.token_hash_algorithm == 'sha256':
|
||||
return hashlib.sha256(token.encode('utf-8')).hexdigest()
|
||||
elif self.token_hash_algorithm == 'sha512':
|
||||
return hashlib.sha512(token.encode('utf-8')).hexdigest()
|
||||
else:
|
||||
# Fallback to sha256
|
||||
return hashlib.sha256(token.encode('utf-8')).hexdigest()
|
||||
|
||||
async def validate_token(self, token: str) -> TokenValidationResult:
|
||||
"""Validate token and return user information"""
|
||||
try:
|
||||
# Hash the provided token
|
||||
token_hash = self._hash_token(token)
|
||||
|
||||
# Find token info
|
||||
token_info = self._tokens.get(token_hash)
|
||||
if not token_info:
|
||||
return TokenValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Invalid token"
|
||||
)
|
||||
|
||||
# Check if token is active
|
||||
if not token_info.is_active:
|
||||
return TokenValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Token is inactive"
|
||||
)
|
||||
|
||||
# Check expiration
|
||||
if token_info.expires_at and datetime.utcnow() > token_info.expires_at:
|
||||
return TokenValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Token has expired"
|
||||
)
|
||||
|
||||
# Update last used time
|
||||
token_info.last_used = datetime.utcnow()
|
||||
|
||||
return TokenValidationResult(
|
||||
is_valid=True,
|
||||
token_info=token_info
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Token validation error: {e}")
|
||||
return TokenValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"Token validation failed: {str(e)}"
|
||||
)
|
||||
|
||||
def generate_token(self, length: int = 32) -> str:
|
||||
"""Generate a cryptographically secure random token"""
|
||||
return secrets.token_urlsafe(length)
|
||||
|
||||
async def create_token(
|
||||
self,
|
||||
token_id: str,
|
||||
expires_hours: Optional[int] = None,
|
||||
description: str = "",
|
||||
custom_token: Optional[str] = None,
|
||||
database_config: Optional[DatabaseConfig] = None
|
||||
) -> str:
|
||||
"""Create a new token"""
|
||||
try:
|
||||
# Check if token_id already exists
|
||||
if token_id in self._token_ids:
|
||||
raise ValueError(f"Token ID '{token_id}' already exists")
|
||||
|
||||
# Generate or use provided token
|
||||
if custom_token:
|
||||
raw_token = custom_token
|
||||
else:
|
||||
raw_token = self.generate_token()
|
||||
|
||||
# Calculate expiration
|
||||
expires_at = None
|
||||
if expires_hours is not None:
|
||||
expires_at = datetime.utcnow() + timedelta(hours=expires_hours)
|
||||
elif self.enable_token_expiry:
|
||||
expires_at = datetime.utcnow() + timedelta(hours=self.default_token_expiry_hours)
|
||||
|
||||
# Create token info
|
||||
token_info = TokenInfo(
|
||||
token_id=token_id,
|
||||
expires_at=expires_at,
|
||||
description=description,
|
||||
database_config=database_config
|
||||
)
|
||||
|
||||
# Hash and store token
|
||||
token_hash = self._hash_token(raw_token)
|
||||
self._tokens[token_hash] = token_info
|
||||
self._token_ids[token_id] = token_hash
|
||||
|
||||
self.logger.info(f"Created new token '{token_id}'")
|
||||
|
||||
# Save token to file
|
||||
self._save_token_to_file(token_id, raw_token, token_info)
|
||||
|
||||
return raw_token
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to create token: {e}")
|
||||
raise
|
||||
|
||||
async def revoke_token(self, token_id: str) -> bool:
|
||||
"""Revoke a token by token ID"""
|
||||
try:
|
||||
if token_id not in self._token_ids:
|
||||
self.logger.warning(f"Token ID '{token_id}' not found")
|
||||
return False
|
||||
|
||||
# Get token hash and remove from storage
|
||||
token_hash = self._token_ids[token_id]
|
||||
if token_hash in self._tokens:
|
||||
del self._tokens[token_hash]
|
||||
del self._token_ids[token_id]
|
||||
|
||||
self.logger.info(f"Revoked token '{token_id}'")
|
||||
|
||||
# Save updated tokens to file
|
||||
self._remove_token_from_file(token_id)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to revoke token '{token_id}': {e}")
|
||||
return False
|
||||
|
||||
def _save_tokens_to_file(self):
|
||||
"""Save current tokens to JSON file"""
|
||||
try:
|
||||
# Convert current tokens to file format
|
||||
tokens_list = []
|
||||
|
||||
for token_hash, token_info in self._tokens.items():
|
||||
# Find the raw token for this token_info
|
||||
raw_token = None
|
||||
for tid, thash in self._token_ids.items():
|
||||
if thash == token_hash and tid == token_info.token_id:
|
||||
# We can't recover the original token from hash,
|
||||
# so we'll create a placeholder for existing tokens
|
||||
raw_token = f"<existing_token_hash_{token_hash[:8]}>"
|
||||
break
|
||||
|
||||
if raw_token is None:
|
||||
continue
|
||||
|
||||
token_config = {
|
||||
"token_id": token_info.token_id,
|
||||
"token": raw_token,
|
||||
"description": token_info.description,
|
||||
"expires_hours": None,
|
||||
"is_active": token_info.is_active
|
||||
}
|
||||
|
||||
# Add expiration info
|
||||
if token_info.expires_at:
|
||||
# Calculate remaining hours from now
|
||||
remaining = token_info.expires_at - datetime.utcnow()
|
||||
if remaining.total_seconds() > 0:
|
||||
token_config["expires_hours"] = int(remaining.total_seconds() / 3600)
|
||||
else:
|
||||
token_config["expires_hours"] = 0
|
||||
|
||||
# Add database config if present
|
||||
if token_info.database_config:
|
||||
token_config["database_config"] = {
|
||||
"host": token_info.database_config.host,
|
||||
"port": token_info.database_config.port,
|
||||
"user": token_info.database_config.user,
|
||||
"password": token_info.database_config.password,
|
||||
"database": token_info.database_config.database,
|
||||
"charset": token_info.database_config.charset,
|
||||
"fe_http_port": token_info.database_config.fe_http_port
|
||||
}
|
||||
|
||||
tokens_list.append(token_config)
|
||||
|
||||
# Create file content
|
||||
file_content = {
|
||||
"version": "1.0",
|
||||
"description": "Doris MCP Server Token configuration file",
|
||||
"created_at": datetime.utcnow().isoformat() + "Z",
|
||||
"tokens": tokens_list,
|
||||
"notes": [
|
||||
"This file is automatically updated when tokens are created or revoked",
|
||||
"Please backup this file before making manual changes",
|
||||
"Tokens with hash placeholders were loaded from previous configuration"
|
||||
]
|
||||
}
|
||||
|
||||
# Save to file
|
||||
with open(self.token_file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(file_content, f, indent=2, ensure_ascii=False)
|
||||
|
||||
self.logger.info(f"Saved {len(tokens_list)} tokens to file: {self.token_file_path}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to save tokens to file {self.token_file_path}: {e}")
|
||||
|
||||
def _save_token_to_file(self, token_id: str, raw_token: str, token_info: TokenInfo):
|
||||
"""Save a single new token to file (for newly created tokens only)"""
|
||||
try:
|
||||
# Load existing file
|
||||
existing_data = {"tokens": []}
|
||||
if os.path.exists(self.token_file_path):
|
||||
try:
|
||||
with open(self.token_file_path, 'r', encoding='utf-8') as f:
|
||||
existing_data = json.load(f)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Could not load existing token file: {e}")
|
||||
|
||||
# Ensure tokens list exists
|
||||
if 'tokens' not in existing_data or not isinstance(existing_data['tokens'], list):
|
||||
existing_data['tokens'] = []
|
||||
|
||||
# Check if token already exists in file
|
||||
token_exists = False
|
||||
for i, token_config in enumerate(existing_data['tokens']):
|
||||
if token_config.get('token_id') == token_id:
|
||||
# Update existing token
|
||||
existing_data['tokens'][i] = self._token_info_to_config(token_id, raw_token, token_info)
|
||||
token_exists = True
|
||||
break
|
||||
|
||||
# Add new token if it doesn't exist
|
||||
if not token_exists:
|
||||
new_token_config = self._token_info_to_config(token_id, raw_token, token_info)
|
||||
existing_data['tokens'].append(new_token_config)
|
||||
|
||||
# Update metadata
|
||||
existing_data.update({
|
||||
"version": "1.0",
|
||||
"description": "Doris MCP Server Token configuration file",
|
||||
"updated_at": datetime.utcnow().isoformat() + "Z"
|
||||
})
|
||||
|
||||
# Save to file
|
||||
with open(self.token_file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
self.logger.info(f"Saved token '{token_id}' to file: {self.token_file_path}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to save token '{token_id}' to file: {e}")
|
||||
|
||||
def _token_info_to_config(self, token_id: str, raw_token: str, token_info: TokenInfo) -> dict:
|
||||
"""Convert TokenInfo to file configuration format"""
|
||||
token_config = {
|
||||
"token_id": token_id,
|
||||
"token": raw_token,
|
||||
"description": token_info.description,
|
||||
"expires_hours": None,
|
||||
"is_active": token_info.is_active
|
||||
}
|
||||
|
||||
# Add expiration info
|
||||
if token_info.expires_at:
|
||||
# Calculate remaining hours from creation time
|
||||
remaining = token_info.expires_at - token_info.created_at
|
||||
token_config["expires_hours"] = int(remaining.total_seconds() / 3600) if remaining.total_seconds() > 0 else None
|
||||
|
||||
# Add database config if present
|
||||
if token_info.database_config:
|
||||
token_config["database_config"] = {
|
||||
"host": token_info.database_config.host,
|
||||
"port": token_info.database_config.port,
|
||||
"user": token_info.database_config.user,
|
||||
"password": token_info.database_config.password,
|
||||
"database": token_info.database_config.database,
|
||||
"charset": token_info.database_config.charset,
|
||||
"fe_http_port": token_info.database_config.fe_http_port
|
||||
}
|
||||
|
||||
return token_config
|
||||
|
||||
def _remove_token_from_file(self, token_id: str):
|
||||
"""Remove a token from the JSON file"""
|
||||
try:
|
||||
if not os.path.exists(self.token_file_path):
|
||||
return
|
||||
|
||||
# Load existing file
|
||||
with open(self.token_file_path, 'r', encoding='utf-8') as f:
|
||||
existing_data = json.load(f)
|
||||
|
||||
if 'tokens' not in existing_data or not isinstance(existing_data['tokens'], list):
|
||||
return
|
||||
|
||||
# Remove the token
|
||||
original_count = len(existing_data['tokens'])
|
||||
existing_data['tokens'] = [
|
||||
token for token in existing_data['tokens']
|
||||
if token.get('token_id') != token_id
|
||||
]
|
||||
|
||||
if len(existing_data['tokens']) < original_count:
|
||||
# Update metadata
|
||||
existing_data.update({
|
||||
"version": "1.0",
|
||||
"description": "Doris MCP Server Token configuration file",
|
||||
"updated_at": datetime.utcnow().isoformat() + "Z"
|
||||
})
|
||||
|
||||
# Save to file
|
||||
with open(self.token_file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
self.logger.info(f"Removed token '{token_id}' from file: {self.token_file_path}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to remove token '{token_id}' from file: {e}")
|
||||
|
||||
async def list_tokens(self) -> List[Dict[str, Any]]:
|
||||
"""List all tokens (without sensitive data)"""
|
||||
tokens = []
|
||||
|
||||
for token_hash, token_info in self._tokens.items():
|
||||
token_data = {
|
||||
'token_id': token_info.token_id,
|
||||
'created_at': token_info.created_at.isoformat(),
|
||||
'expires_at': token_info.expires_at.isoformat() if token_info.expires_at else None,
|
||||
'last_used': token_info.last_used.isoformat() if token_info.last_used else None,
|
||||
'is_active': token_info.is_active,
|
||||
'description': token_info.description,
|
||||
'is_expired': token_info.expires_at and datetime.utcnow() > token_info.expires_at if token_info.expires_at else False
|
||||
}
|
||||
|
||||
# Add database binding info (without sensitive data)
|
||||
if token_info.database_config:
|
||||
token_data['database_binding'] = {
|
||||
'host': token_info.database_config.host,
|
||||
'port': token_info.database_config.port,
|
||||
'user': token_info.database_config.user,
|
||||
'database': token_info.database_config.database,
|
||||
'has_password': bool(token_info.database_config.password)
|
||||
}
|
||||
else:
|
||||
token_data['database_binding'] = None
|
||||
|
||||
tokens.append(token_data)
|
||||
|
||||
# Sort by creation time
|
||||
tokens.sort(key=lambda x: x['created_at'], reverse=True)
|
||||
|
||||
return tokens
|
||||
|
||||
async def cleanup_expired_tokens(self) -> int:
|
||||
"""Remove expired tokens and return count"""
|
||||
if not self.enable_token_expiry:
|
||||
return 0
|
||||
|
||||
now = datetime.utcnow()
|
||||
expired_tokens = []
|
||||
|
||||
# Find expired tokens
|
||||
for token_hash, token_info in self._tokens.items():
|
||||
if token_info.expires_at and now > token_info.expires_at:
|
||||
expired_tokens.append((token_hash, token_info.token_id))
|
||||
|
||||
# Remove expired tokens
|
||||
for token_hash, token_id in expired_tokens:
|
||||
del self._tokens[token_hash]
|
||||
if token_id in self._token_ids:
|
||||
del self._token_ids[token_id]
|
||||
|
||||
if expired_tokens:
|
||||
self.logger.info(f"Cleaned up {len(expired_tokens)} expired tokens")
|
||||
|
||||
return len(expired_tokens)
|
||||
|
||||
async def save_tokens_to_file(self, file_path: Optional[str] = None) -> bool:
|
||||
"""Save current tokens to JSON file"""
|
||||
try:
|
||||
file_path = file_path or self.token_file_path
|
||||
tokens_list = await self.list_tokens()
|
||||
|
||||
tokens_data = {
|
||||
'version': '1.0',
|
||||
'created_at': datetime.utcnow().isoformat(),
|
||||
'tokens': tokens_list
|
||||
}
|
||||
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(tokens_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
self.logger.info(f"Saved {len(tokens_list)} tokens to file: {file_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to save tokens to file: {e}")
|
||||
return False
|
||||
|
||||
def get_database_config_by_token(self, token: str) -> Optional[DatabaseConfig]:
|
||||
"""Get database configuration bound to a token
|
||||
|
||||
Args:
|
||||
token: The raw token string
|
||||
|
||||
Returns:
|
||||
DatabaseConfig if token exists and has database binding, None otherwise
|
||||
"""
|
||||
try:
|
||||
token_hash = self._hash_token(token)
|
||||
token_info = self._tokens.get(token_hash)
|
||||
|
||||
if not token_info or not token_info.is_active:
|
||||
return None
|
||||
|
||||
# Check expiration
|
||||
if token_info.expires_at and datetime.utcnow() > token_info.expires_at:
|
||||
return None
|
||||
|
||||
return token_info.database_config
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get database config for token: {e}")
|
||||
return None
|
||||
|
||||
def get_token_stats(self) -> Dict[str, Any]:
|
||||
"""Get token statistics"""
|
||||
now = datetime.utcnow()
|
||||
total_tokens = len(self._tokens)
|
||||
active_tokens = sum(1 for info in self._tokens.values() if info.is_active)
|
||||
expired_tokens = sum(1 for info in self._tokens.values()
|
||||
if info.expires_at and now > info.expires_at)
|
||||
tokens_with_db = sum(1 for info in self._tokens.values()
|
||||
if info.database_config is not None)
|
||||
|
||||
return {
|
||||
'total_tokens': total_tokens,
|
||||
'active_tokens': active_tokens,
|
||||
'expired_tokens': expired_tokens,
|
||||
'tokens_with_database_binding': tokens_with_db,
|
||||
'expiry_enabled': self.enable_token_expiry,
|
||||
'default_expiry_hours': self.default_token_expiry_hours,
|
||||
'hot_reload_enabled': self.enable_hot_reload,
|
||||
'last_file_check': datetime.fromtimestamp(self._file_last_modified).isoformat() if self._file_last_modified else None
|
||||
}
|
||||
|
||||
def _start_hot_reload(self):
|
||||
"""Start hot reload monitoring task"""
|
||||
if self._hot_reload_task:
|
||||
return # Already running
|
||||
|
||||
# Update initial file modification time
|
||||
self._update_file_modified_time()
|
||||
|
||||
# Start monitoring task
|
||||
self._hot_reload_task = asyncio.create_task(self._hot_reload_monitor())
|
||||
self.logger.info(f"Started hot reload monitoring for {self.token_file_path}")
|
||||
|
||||
def stop_hot_reload(self):
|
||||
"""Stop hot reload monitoring"""
|
||||
if self._hot_reload_task:
|
||||
self._hot_reload_task.cancel()
|
||||
self._hot_reload_task = None
|
||||
self.logger.info("Stopped hot reload monitoring")
|
||||
|
||||
def _update_file_modified_time(self):
|
||||
"""Update the last modified time of tokens file"""
|
||||
try:
|
||||
if os.path.exists(self.token_file_path):
|
||||
self._file_last_modified = os.path.getmtime(self.token_file_path)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Failed to get file modification time: {e}")
|
||||
|
||||
async def _hot_reload_monitor(self):
|
||||
"""Background task to monitor tokens.json file changes"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.hot_reload_interval)
|
||||
|
||||
if not os.path.exists(self.token_file_path):
|
||||
continue
|
||||
|
||||
# Check if file was modified
|
||||
current_mtime = os.path.getmtime(self.token_file_path)
|
||||
if current_mtime > self._file_last_modified:
|
||||
self.logger.info(f"Detected changes in {self.token_file_path}, reloading tokens...")
|
||||
|
||||
try:
|
||||
# Backup current tokens
|
||||
old_tokens = self._tokens.copy()
|
||||
old_token_ids = self._token_ids.copy()
|
||||
|
||||
# Clear and reload
|
||||
self._tokens.clear()
|
||||
self._token_ids.clear()
|
||||
|
||||
# Reinitialize default tokens
|
||||
self._initialize_default_tokens()
|
||||
|
||||
# Load from file
|
||||
self._load_tokens_from_file()
|
||||
|
||||
# Update modification time
|
||||
self._file_last_modified = current_mtime
|
||||
|
||||
self.logger.info(f"Hot reload completed, {len(self._tokens)} tokens loaded")
|
||||
|
||||
except Exception as reload_error:
|
||||
# Restore backup on failure
|
||||
self.logger.error(f"Hot reload failed, restoring previous tokens: {reload_error}")
|
||||
self._tokens = old_tokens
|
||||
self._token_ids = old_token_ids
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self.logger.info("Hot reload monitor stopped")
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in hot reload monitor: {e}")
|
||||
227
doris_mcp_server/auth/token_security_middleware.py
Normal file
227
doris_mcp_server/auth/token_security_middleware.py
Normal file
@@ -0,0 +1,227 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Token Management Security Middleware
|
||||
|
||||
Provides comprehensive security controls for token management endpoints including
|
||||
IP restrictions, admin authentication, and configuration-based access control.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import ipaddress
|
||||
import secrets
|
||||
import time
|
||||
from typing import Optional, List, Dict, Any
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.config import DorisConfig
|
||||
|
||||
|
||||
class TokenSecurityMiddleware:
|
||||
"""Security middleware for token management endpoints"""
|
||||
|
||||
def __init__(self, config: DorisConfig):
|
||||
self.config = config
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Initialize admin token hash if provided
|
||||
self._admin_token_hash = None
|
||||
if config.security.token_management_admin_token:
|
||||
self._admin_token_hash = self._hash_token(config.security.token_management_admin_token)
|
||||
|
||||
# Normalize allowed IPs
|
||||
self._allowed_networks = self._parse_allowed_networks(
|
||||
config.security.token_management_allowed_ips
|
||||
)
|
||||
|
||||
self.logger.info(f"Token management security initialized: "
|
||||
f"HTTP endpoints {'enabled' if config.security.enable_http_token_management else 'disabled'}, "
|
||||
f"Admin auth {'required' if config.security.require_admin_auth else 'optional'}, "
|
||||
f"Allowed networks: {len(self._allowed_networks)}")
|
||||
|
||||
def _hash_token(self, token: str) -> str:
|
||||
"""Hash token using SHA-256"""
|
||||
return hashlib.sha256(token.encode('utf-8')).hexdigest()
|
||||
|
||||
def _parse_allowed_networks(self, allowed_ips: List[str]) -> List[ipaddress.IPv4Network | ipaddress.IPv6Network]:
|
||||
"""Parse allowed IP addresses and networks"""
|
||||
networks = []
|
||||
for ip_str in allowed_ips:
|
||||
try:
|
||||
# Try to parse as network (CIDR notation)
|
||||
if '/' in ip_str:
|
||||
networks.append(ipaddress.ip_network(ip_str, strict=False))
|
||||
else:
|
||||
# Parse as single IP and convert to /32 network
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
if isinstance(ip, ipaddress.IPv4Address):
|
||||
networks.append(ipaddress.IPv4Network(f"{ip}/32"))
|
||||
else:
|
||||
networks.append(ipaddress.IPv6Network(f"{ip}/128"))
|
||||
except ValueError as e:
|
||||
self.logger.warning(f"Invalid IP/network '{ip_str}': {e}")
|
||||
|
||||
return networks
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract client IP from request, considering proxies"""
|
||||
# Check X-Forwarded-For header first (for proxy setups)
|
||||
forwarded_for = request.headers.get('X-Forwarded-For')
|
||||
if forwarded_for:
|
||||
# Take the first IP (original client)
|
||||
client_ip = forwarded_for.split(',')[0].strip()
|
||||
elif request.headers.get('X-Real-IP'):
|
||||
client_ip = request.headers.get('X-Real-IP')
|
||||
else:
|
||||
# Direct connection
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
return client_ip
|
||||
|
||||
def _is_ip_allowed(self, client_ip: str) -> bool:
|
||||
"""Check if client IP is in allowed networks"""
|
||||
try:
|
||||
client_addr = ipaddress.ip_address(client_ip)
|
||||
|
||||
for network in self._allowed_networks:
|
||||
if client_addr in network:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except ValueError:
|
||||
self.logger.warning(f"Invalid client IP format: {client_ip}")
|
||||
return False
|
||||
|
||||
def _extract_admin_token(self, request: Request) -> Optional[str]:
|
||||
"""Extract admin token from request headers"""
|
||||
# Try Authorization header first
|
||||
auth_header = request.headers.get('Authorization', '')
|
||||
if auth_header.startswith('Bearer '):
|
||||
return auth_header[7:]
|
||||
elif auth_header.startswith('Token '):
|
||||
return auth_header[6:]
|
||||
|
||||
# Try X-Admin-Token header
|
||||
admin_token = request.headers.get('X-Admin-Token', '')
|
||||
if admin_token:
|
||||
return admin_token
|
||||
|
||||
# Try query parameter as fallback (not recommended for production)
|
||||
admin_token = request.query_params.get('admin_token', '')
|
||||
if admin_token:
|
||||
self.logger.warning("Admin token passed via query parameter - this is insecure for production")
|
||||
return admin_token
|
||||
|
||||
return None
|
||||
|
||||
def _verify_admin_token(self, provided_token: str) -> bool:
|
||||
"""Verify provided admin token against configured token"""
|
||||
if not self._admin_token_hash:
|
||||
self.logger.warning("No admin token configured for token management")
|
||||
return False
|
||||
|
||||
provided_hash = self._hash_token(provided_token)
|
||||
|
||||
# Use constant-time comparison to prevent timing attacks
|
||||
return hmac.compare_digest(self._admin_token_hash, provided_hash)
|
||||
|
||||
async def check_token_management_access(self, request: Request) -> Optional[JSONResponse]:
|
||||
"""
|
||||
Check if request is authorized for token management operations
|
||||
|
||||
Returns:
|
||||
None if access is granted
|
||||
JSONResponse with error if access is denied
|
||||
"""
|
||||
|
||||
# Check if HTTP token management is enabled
|
||||
if not self.config.security.enable_http_token_management:
|
||||
self.logger.warning(f"Token management endpoint access denied - HTTP management disabled: {request.url.path}")
|
||||
return JSONResponse({
|
||||
"error": "Token management endpoints are disabled for security",
|
||||
"message": "HTTP token management is disabled. Use file-based token management instead.",
|
||||
"suggestion": "Edit tokens.json file directly or enable HTTP management with proper security configuration"
|
||||
}, status_code=403)
|
||||
|
||||
# Extract client IP
|
||||
client_ip = self._get_client_ip(request)
|
||||
|
||||
# Check IP restrictions
|
||||
if not self._is_ip_allowed(client_ip):
|
||||
self.logger.warning(f"Token management access denied for IP {client_ip}: not in allowed list")
|
||||
return JSONResponse({
|
||||
"error": "Access denied - IP not allowed",
|
||||
"client_ip": client_ip,
|
||||
"message": "Token management is restricted to specific IP addresses",
|
||||
"allowed_networks": [str(net) for net in self._allowed_networks]
|
||||
}, status_code=403)
|
||||
|
||||
# Check admin authentication if required
|
||||
if self.config.security.require_admin_auth:
|
||||
admin_token = self._extract_admin_token(request)
|
||||
|
||||
if not admin_token:
|
||||
self.logger.warning(f"Token management access denied for IP {client_ip}: missing admin token")
|
||||
return JSONResponse({
|
||||
"error": "Admin authentication required",
|
||||
"message": "Token management requires admin authentication",
|
||||
"hint": "Provide admin token in Authorization header: 'Bearer <admin_token>' or 'X-Admin-Token: <admin_token>'"
|
||||
}, status_code=401)
|
||||
|
||||
if not self._verify_admin_token(admin_token):
|
||||
self.logger.warning(f"Token management access denied for IP {client_ip}: invalid admin token")
|
||||
return JSONResponse({
|
||||
"error": "Invalid admin token",
|
||||
"message": "The provided admin token is invalid"
|
||||
}, status_code=401)
|
||||
|
||||
# Log successful access
|
||||
self.logger.info(f"Token management access granted for IP {client_ip} to {request.url.path}")
|
||||
|
||||
# Access granted
|
||||
return None
|
||||
|
||||
def get_security_info(self) -> Dict[str, Any]:
|
||||
"""Get current security configuration info (for demo/status pages)"""
|
||||
return {
|
||||
"http_token_management_enabled": self.config.security.enable_http_token_management,
|
||||
"admin_auth_required": self.config.security.require_admin_auth,
|
||||
"admin_token_configured": bool(self._admin_token_hash),
|
||||
"allowed_networks_count": len(self._allowed_networks),
|
||||
"allowed_networks": [str(net) for net in self._allowed_networks],
|
||||
"security_features": [
|
||||
"IP address restrictions",
|
||||
"Admin token authentication" if self.config.security.require_admin_auth else "Optional admin authentication",
|
||||
"Secure token hashing",
|
||||
"Request logging and auditing"
|
||||
]
|
||||
}
|
||||
|
||||
def generate_admin_token(self) -> str:
|
||||
"""Generate a secure admin token"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
# Convenience function for middleware creation
|
||||
def create_token_security_middleware(config: DorisConfig) -> TokenSecurityMiddleware:
|
||||
"""Create token security middleware with configuration"""
|
||||
return TokenSecurityMiddleware(config)
|
||||
365
doris_mcp_server/auth/token_validators.py
Normal file
365
doris_mcp_server/auth/token_validators.py
Normal file
@@ -0,0 +1,365 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
JWT Token Validation Module
|
||||
Provides token validation, blacklist management and security features
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, Set, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TokenBlacklist:
|
||||
"""JWT Token Blacklist Manager
|
||||
|
||||
Manages revoked tokens to prevent revoked tokens from being used again
|
||||
Supports both in-memory and persistent storage
|
||||
"""
|
||||
|
||||
def __init__(self, cleanup_interval: int = 3600):
|
||||
"""Initialize token blacklist
|
||||
|
||||
Args:
|
||||
cleanup_interval: Interval for cleaning up expired tokens (seconds)
|
||||
"""
|
||||
self.cleanup_interval = cleanup_interval
|
||||
# Storage format: {token_jti: expiry_timestamp}
|
||||
self._blacklisted_tokens: Dict[str, float] = {}
|
||||
self._cleanup_task = None
|
||||
|
||||
logger.info("TokenBlacklist initialized")
|
||||
|
||||
async def start(self):
|
||||
"""Start blacklist manager"""
|
||||
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||
logger.info("TokenBlacklist started with periodic cleanup")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop blacklist manager"""
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("TokenBlacklist stopped")
|
||||
|
||||
async def add_token(self, jti: str, exp: float):
|
||||
"""Add token to blacklist
|
||||
|
||||
Args:
|
||||
jti: JWT ID (unique identifier)
|
||||
exp: Token expiration timestamp
|
||||
"""
|
||||
self._blacklisted_tokens[jti] = exp
|
||||
logger.info(f"Token {jti} added to blacklist")
|
||||
|
||||
async def is_blacklisted(self, jti: str) -> bool:
|
||||
"""Check if token is blacklisted
|
||||
|
||||
Args:
|
||||
jti: JWT ID
|
||||
|
||||
Returns:
|
||||
True if blacklisted, False otherwise
|
||||
"""
|
||||
return jti in self._blacklisted_tokens
|
||||
|
||||
async def remove_token(self, jti: str) -> bool:
|
||||
"""Remove token from blacklist
|
||||
|
||||
Args:
|
||||
jti: JWT ID
|
||||
|
||||
Returns:
|
||||
True if removed, False if not found
|
||||
"""
|
||||
if jti in self._blacklisted_tokens:
|
||||
del self._blacklisted_tokens[jti]
|
||||
logger.info(f"Token {jti} removed from blacklist")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""Clean up expired blacklisted tokens
|
||||
|
||||
Returns:
|
||||
Number of tokens cleaned up
|
||||
"""
|
||||
current_time = time.time()
|
||||
expired_tokens = [
|
||||
jti for jti, exp in self._blacklisted_tokens.items()
|
||||
if exp <= current_time
|
||||
]
|
||||
|
||||
for jti in expired_tokens:
|
||||
del self._blacklisted_tokens[jti]
|
||||
|
||||
if expired_tokens:
|
||||
logger.info(f"Cleaned up {len(expired_tokens)} expired tokens from blacklist")
|
||||
|
||||
return len(expired_tokens)
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get blacklist statistics"""
|
||||
current_time = time.time()
|
||||
active_tokens = sum(1 for exp in self._blacklisted_tokens.values() if exp > current_time)
|
||||
|
||||
return {
|
||||
"total_blacklisted": len(self._blacklisted_tokens),
|
||||
"active_blacklisted": active_tokens,
|
||||
"expired_blacklisted": len(self._blacklisted_tokens) - active_tokens,
|
||||
"cleanup_interval": self.cleanup_interval
|
||||
}
|
||||
|
||||
async def _periodic_cleanup(self):
|
||||
"""Periodically clean up expired tokens"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.cleanup_interval)
|
||||
await self.cleanup_expired()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error during periodic cleanup: {e}")
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Token usage rate limiter"""
|
||||
|
||||
def __init__(self, max_requests: int = 100, time_window: int = 3600):
|
||||
"""Initialize rate limiter
|
||||
|
||||
Args:
|
||||
max_requests: Maximum requests within time window
|
||||
time_window: Time window in seconds
|
||||
"""
|
||||
self.max_requests = max_requests
|
||||
self.time_window = time_window
|
||||
# Storage format: {user_id: [timestamp1, timestamp2, ...]}
|
||||
self._request_history: Dict[str, list] = defaultdict(list)
|
||||
|
||||
logger.info(f"RateLimiter initialized: {max_requests} requests per {time_window} seconds")
|
||||
|
||||
async def is_allowed(self, user_id: str) -> bool:
|
||||
"""Check if user is allowed to make request
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
True if allowed, False otherwise
|
||||
"""
|
||||
current_time = time.time()
|
||||
user_requests = self._request_history[user_id]
|
||||
|
||||
# Clean up expired request records
|
||||
cutoff_time = current_time - self.time_window
|
||||
user_requests[:] = [t for t in user_requests if t > cutoff_time]
|
||||
|
||||
# Check if limit exceeded
|
||||
if len(user_requests) >= self.max_requests:
|
||||
logger.warning(f"Rate limit exceeded for user {user_id}")
|
||||
return False
|
||||
|
||||
# Record current request
|
||||
user_requests.append(current_time)
|
||||
return True
|
||||
|
||||
async def get_usage(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get user usage information
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Usage statistics
|
||||
"""
|
||||
current_time = time.time()
|
||||
user_requests = self._request_history[user_id]
|
||||
|
||||
# Clean up expired records
|
||||
cutoff_time = current_time - self.time_window
|
||||
active_requests = [t for t in user_requests if t > cutoff_time]
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"requests_in_window": len(active_requests),
|
||||
"max_requests": self.max_requests,
|
||||
"time_window": self.time_window,
|
||||
"remaining_requests": max(0, self.max_requests - len(active_requests))
|
||||
}
|
||||
|
||||
|
||||
class TokenValidator:
|
||||
"""JWT Token Validator
|
||||
|
||||
Provides comprehensive JWT token validation functionality, including signature verification,
|
||||
claim validation, blacklist checking and rate limiting
|
||||
"""
|
||||
|
||||
def __init__(self, config, blacklist: Optional[TokenBlacklist] = None):
|
||||
"""Initialize token validator
|
||||
|
||||
Args:
|
||||
config: DorisConfig configuration object (with security attribute)
|
||||
blacklist: Token blacklist manager
|
||||
"""
|
||||
self.config = config
|
||||
self.blacklist = blacklist or TokenBlacklist()
|
||||
self.rate_limiter = RateLimiter()
|
||||
|
||||
# Access JWT settings through the security configuration
|
||||
if hasattr(config, 'security'):
|
||||
security_config = config.security
|
||||
else:
|
||||
# Fallback if config is passed directly as SecurityConfig
|
||||
security_config = config
|
||||
|
||||
# Validation options
|
||||
self.verify_signature = security_config.jwt_verify_signature
|
||||
self.verify_audience = security_config.jwt_verify_audience
|
||||
self.verify_issuer = security_config.jwt_verify_issuer
|
||||
self.require_exp = security_config.jwt_require_exp
|
||||
self.require_iat = security_config.jwt_require_iat
|
||||
self.require_nbf = security_config.jwt_require_nbf
|
||||
self.leeway = security_config.jwt_leeway
|
||||
|
||||
# Expected values
|
||||
self.expected_audience = security_config.jwt_audience
|
||||
self.expected_issuer = security_config.jwt_issuer
|
||||
|
||||
logger.info("TokenValidator initialized")
|
||||
|
||||
async def validate_claims(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate JWT claims
|
||||
|
||||
Args:
|
||||
payload: JWT payload
|
||||
|
||||
Returns:
|
||||
Validation result
|
||||
|
||||
Raises:
|
||||
ValueError: Validation failed
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# Validate issuer
|
||||
if self.verify_issuer:
|
||||
if payload.get('iss') != self.expected_issuer:
|
||||
raise ValueError(f"Invalid issuer: expected {self.expected_issuer}")
|
||||
|
||||
# Validate audience
|
||||
if self.verify_audience:
|
||||
aud = payload.get('aud')
|
||||
if isinstance(aud, list):
|
||||
if self.expected_audience not in aud:
|
||||
raise ValueError(f"Invalid audience: {self.expected_audience} not in {aud}")
|
||||
elif aud != self.expected_audience:
|
||||
raise ValueError(f"Invalid audience: expected {self.expected_audience}")
|
||||
|
||||
# Validate expiration time
|
||||
if self.require_exp or 'exp' in payload:
|
||||
exp = payload.get('exp')
|
||||
if not exp:
|
||||
raise ValueError("Missing 'exp' claim")
|
||||
if current_time > exp + self.leeway:
|
||||
raise ValueError("Token has expired")
|
||||
|
||||
# Validate not before time
|
||||
if self.require_nbf or 'nbf' in payload:
|
||||
nbf = payload.get('nbf')
|
||||
if not nbf:
|
||||
raise ValueError("Missing 'nbf' claim")
|
||||
if current_time < nbf - self.leeway:
|
||||
raise ValueError("Token not yet valid")
|
||||
|
||||
# Validate issued at time
|
||||
if self.require_iat or 'iat' in payload:
|
||||
iat = payload.get('iat')
|
||||
if not iat:
|
||||
raise ValueError("Missing 'iat' claim")
|
||||
# Allow some clock skew, but cannot be future time
|
||||
if iat > current_time + self.leeway:
|
||||
raise ValueError("Token issued in the future")
|
||||
|
||||
# Check blacklist
|
||||
jti = payload.get('jti')
|
||||
if jti and await self.blacklist.is_blacklisted(jti):
|
||||
raise ValueError("Token has been revoked")
|
||||
|
||||
# Rate limit check
|
||||
user_id = payload.get('sub')
|
||||
if user_id:
|
||||
if not await self.rate_limiter.is_allowed(user_id):
|
||||
raise ValueError("Rate limit exceeded")
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"user_id": user_id,
|
||||
"payload": payload
|
||||
}
|
||||
|
||||
async def start(self):
|
||||
"""Start validator"""
|
||||
await self.blacklist.start()
|
||||
logger.info("TokenValidator started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop validator"""
|
||||
await self.blacklist.stop()
|
||||
logger.info("TokenValidator stopped")
|
||||
|
||||
async def revoke_token(self, jti: str, exp: float):
|
||||
"""Revoke token
|
||||
|
||||
Args:
|
||||
jti: JWT ID
|
||||
exp: Token expiration time
|
||||
"""
|
||||
await self.blacklist.add_token(jti, exp)
|
||||
logger.info(f"Token {jti} has been revoked")
|
||||
|
||||
async def get_validation_stats(self) -> Dict[str, Any]:
|
||||
"""Get validation statistics"""
|
||||
blacklist_stats = await self.blacklist.get_stats()
|
||||
|
||||
return {
|
||||
"blacklist": blacklist_stats,
|
||||
"validation_config": {
|
||||
"verify_signature": self.verify_signature,
|
||||
"verify_audience": self.verify_audience,
|
||||
"verify_issuer": self.verify_issuer,
|
||||
"require_exp": self.require_exp,
|
||||
"require_iat": self.require_iat,
|
||||
"require_nbf": self.require_nbf,
|
||||
"leeway": self.leeway
|
||||
}
|
||||
}
|
||||
|
||||
async def get_user_rate_limit_info(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get user rate limit information"""
|
||||
return await self.rate_limiter.get_usage(user_id)
|
||||
@@ -28,25 +28,182 @@ import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
# MCP version compatibility check
|
||||
# MCP version compatibility handling
|
||||
MCP_VERSION = 'unknown'
|
||||
Server = None
|
||||
InitializationOptions = None
|
||||
Prompt = None
|
||||
Resource = None
|
||||
TextContent = None
|
||||
Tool = None
|
||||
|
||||
def _import_mcp_with_compatibility():
|
||||
"""Import MCP components with multi-version compatibility"""
|
||||
global MCP_VERSION, Server, InitializationOptions, Prompt, Resource, TextContent, Tool
|
||||
|
||||
try:
|
||||
# Strategy 1: Try direct server-only imports to avoid client-side issues
|
||||
from mcp.server import Server as _Server
|
||||
from mcp.server.models import InitializationOptions as _InitOptions
|
||||
from mcp.types import (
|
||||
Prompt as _Prompt,
|
||||
Resource as _Resource,
|
||||
TextContent as _TextContent,
|
||||
Tool as _Tool,
|
||||
)
|
||||
|
||||
# Assign to globals
|
||||
Server = _Server
|
||||
InitializationOptions = _InitOptions
|
||||
Prompt = _Prompt
|
||||
Resource = _Resource
|
||||
TextContent = _TextContent
|
||||
Tool = _Tool
|
||||
|
||||
# Try to get version safely
|
||||
try:
|
||||
import mcp
|
||||
MCP_VERSION = getattr(mcp, '__version__', 'unknown')
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Using MCP version: {MCP_VERSION}")
|
||||
except Exception as e:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"Could not determine MCP version: {e}")
|
||||
MCP_VERSION = 'unknown'
|
||||
MCP_VERSION = getattr(mcp, '__version__', None)
|
||||
if not MCP_VERSION:
|
||||
# Fallback: try to get version from package metadata
|
||||
try:
|
||||
import importlib.metadata
|
||||
MCP_VERSION = importlib.metadata.version('mcp')
|
||||
except Exception:
|
||||
# Second fallback: try pkg_resources
|
||||
try:
|
||||
import pkg_resources
|
||||
MCP_VERSION = pkg_resources.get_distribution('mcp').version
|
||||
except Exception:
|
||||
MCP_VERSION = 'detected-but-version-unknown'
|
||||
except Exception:
|
||||
# Version detection failed, but imports worked
|
||||
try:
|
||||
import importlib.metadata
|
||||
MCP_VERSION = importlib.metadata.version('mcp')
|
||||
except Exception:
|
||||
try:
|
||||
import pkg_resources
|
||||
MCP_VERSION = pkg_resources.get_distribution('mcp').version
|
||||
except Exception:
|
||||
MCP_VERSION = 'imported-successfully'
|
||||
|
||||
from mcp.server import Server
|
||||
from mcp.server.models import InitializationOptions
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"MCP components imported successfully, version: {MCP_VERSION}")
|
||||
return True
|
||||
|
||||
except Exception as import_error:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Strategy 2: Handle RequestContext compatibility issues in 1.9.x versions
|
||||
error_str = str(import_error).lower()
|
||||
if 'requestcontext' in error_str and 'too few arguments' in error_str:
|
||||
logger.warning(f"Detected MCP RequestContext compatibility issue: {import_error}")
|
||||
logger.info("Attempting comprehensive workaround for MCP 1.9.x RequestContext issue...")
|
||||
|
||||
try:
|
||||
# Comprehensive monkey patch approach
|
||||
import sys
|
||||
import types
|
||||
|
||||
# Create and install mock modules before any MCP imports
|
||||
if 'mcp.shared.context' not in sys.modules:
|
||||
mock_context_module = types.ModuleType('mcp.shared.context')
|
||||
|
||||
class FlexibleRequestContext:
|
||||
"""Flexible RequestContext that accepts variable arguments"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __class_getitem__(cls, params):
|
||||
# Accept any number of parameters and return cls
|
||||
return cls
|
||||
|
||||
# Add other methods that might be called
|
||||
def __getattr__(self, name):
|
||||
return lambda *args, **kwargs: None
|
||||
|
||||
mock_context_module.RequestContext = FlexibleRequestContext
|
||||
sys.modules['mcp.shared.context'] = mock_context_module
|
||||
|
||||
# Also patch the typing system to be more permissive
|
||||
original_check_generic = None
|
||||
try:
|
||||
import typing
|
||||
if hasattr(typing, '_check_generic'):
|
||||
original_check_generic = typing._check_generic
|
||||
def permissive_check_generic(cls, params, elen):
|
||||
# Don't enforce strict parameter count checking
|
||||
return
|
||||
typing._check_generic = permissive_check_generic
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Clear any cached imports that might have failed
|
||||
modules_to_clear = [k for k in sys.modules.keys() if k.startswith('mcp.')]
|
||||
for module in modules_to_clear:
|
||||
if module in sys.modules:
|
||||
del sys.modules[module]
|
||||
|
||||
# Now try importing again with the patches in place
|
||||
from mcp.server import Server as _Server
|
||||
from mcp.server.models import InitializationOptions as _InitOptions
|
||||
from mcp.types import (
|
||||
Prompt,
|
||||
Resource,
|
||||
TextContent,
|
||||
Tool,
|
||||
Prompt as _Prompt,
|
||||
Resource as _Resource,
|
||||
TextContent as _TextContent,
|
||||
Tool as _Tool,
|
||||
)
|
||||
|
||||
# Assign to globals
|
||||
Server = _Server
|
||||
InitializationOptions = _InitOptions
|
||||
Prompt = _Prompt
|
||||
Resource = _Resource
|
||||
TextContent = _TextContent
|
||||
Tool = _Tool
|
||||
|
||||
# Try to detect actual version even in compatibility mode
|
||||
try:
|
||||
import importlib.metadata
|
||||
actual_version = importlib.metadata.version('mcp')
|
||||
MCP_VERSION = f'compatibility-mode-{actual_version}'
|
||||
except Exception:
|
||||
try:
|
||||
import pkg_resources
|
||||
actual_version = pkg_resources.get_distribution('mcp').version
|
||||
MCP_VERSION = f'compatibility-mode-{actual_version}'
|
||||
except Exception:
|
||||
MCP_VERSION = 'compatibility-mode-1.9.x'
|
||||
|
||||
logger.info("MCP 1.9.x compatibility workaround successful!")
|
||||
|
||||
# Restore original typing function if we patched it
|
||||
if original_check_generic:
|
||||
typing._check_generic = original_check_generic
|
||||
|
||||
return True
|
||||
|
||||
except Exception as workaround_error:
|
||||
logger.error(f"MCP compatibility workaround failed: {workaround_error}")
|
||||
|
||||
# Restore original typing function if we patched it
|
||||
if original_check_generic:
|
||||
try:
|
||||
import typing
|
||||
typing._check_generic = original_check_generic
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.error(f"Failed to import MCP components: {import_error}")
|
||||
return False
|
||||
|
||||
# Perform MCP import with compatibility handling
|
||||
if not _import_mcp_with_compatibility():
|
||||
raise ImportError(
|
||||
"Failed to import MCP components. Please ensure MCP is properly installed. "
|
||||
"Supported versions: 1.8.x, 1.9.x"
|
||||
)
|
||||
|
||||
from .tools.tools_manager import DorisToolsManager
|
||||
@@ -57,14 +214,15 @@ from .utils.db import DorisConnectionManager
|
||||
from .utils.security import DorisSecurityManager
|
||||
import os
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
# Configure logging - will be properly initialized later
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create a default config instance for getting default values
|
||||
_default_config = DorisConfig()
|
||||
|
||||
|
||||
|
||||
|
||||
class DorisServer:
|
||||
"""Apache Doris MCP Server main class"""
|
||||
|
||||
@@ -72,20 +230,60 @@ class DorisServer:
|
||||
self.config = config
|
||||
self.server = Server("doris-mcp-server")
|
||||
|
||||
# Initialize security manager
|
||||
# Initialize security manager (without connection_manager initially)
|
||||
self.security_manager = DorisSecurityManager(config)
|
||||
|
||||
# Initialize connection manager, pass in security manager
|
||||
self.connection_manager = DorisConnectionManager(config, self.security_manager)
|
||||
# Initialize connection manager, pass in security manager and token manager for token-bound DB config
|
||||
token_manager = self.security_manager.auth_provider.token_manager if hasattr(self.security_manager, 'auth_provider') and hasattr(self.security_manager.auth_provider, 'token_manager') else None
|
||||
self.connection_manager = DorisConnectionManager(config, self.security_manager, token_manager)
|
||||
|
||||
# Set connection manager reference in security manager for database validation
|
||||
self.security_manager.connection_manager = self.connection_manager
|
||||
|
||||
# Initialize independent managers
|
||||
self.resources_manager = DorisResourcesManager(self.connection_manager)
|
||||
self.tools_manager = DorisToolsManager(self.connection_manager)
|
||||
self.prompts_manager = DorisPromptsManager(self.connection_manager)
|
||||
|
||||
self.logger = logging.getLogger(f"{__name__}.DorisServer")
|
||||
# Import here to avoid circular imports
|
||||
from .utils.logger import get_logger
|
||||
self.logger = get_logger(f"{__name__}.DorisServer")
|
||||
self._setup_handlers()
|
||||
|
||||
async def _extract_auth_info_from_scope(self, scope, headers):
|
||||
"""Extract authentication information from ASGI scope and headers"""
|
||||
auth_info = {}
|
||||
|
||||
# Extract client IP
|
||||
client = scope.get("client")
|
||||
if client:
|
||||
auth_info["client_ip"] = client[0]
|
||||
else:
|
||||
auth_info["client_ip"] = "unknown"
|
||||
|
||||
# Extract token from Authorization header
|
||||
authorization = headers.get(b'authorization', b'').decode('utf-8')
|
||||
if authorization:
|
||||
if authorization.startswith('Bearer '):
|
||||
auth_info["token"] = authorization[7:]
|
||||
auth_info["authorization"] = authorization
|
||||
elif authorization.startswith('Token '):
|
||||
auth_info["token"] = authorization[6:]
|
||||
auth_info["authorization"] = authorization
|
||||
|
||||
# Extract token from query parameters (for compatibility)
|
||||
query_string = scope.get("query_string", b"").decode('utf-8')
|
||||
if query_string and "token=" in query_string:
|
||||
import urllib.parse
|
||||
query_params = urllib.parse.parse_qs(query_string)
|
||||
if "token" in query_params:
|
||||
auth_info["token"] = query_params["token"][0]
|
||||
|
||||
# If no token found, this will be handled by the authentication system
|
||||
# (either return anonymous context if auth disabled, or raise error if auth enabled)
|
||||
|
||||
return auth_info
|
||||
|
||||
def _get_mcp_capabilities(self):
|
||||
"""Get MCP capabilities with version compatibility"""
|
||||
try:
|
||||
@@ -230,12 +428,24 @@ class DorisServer:
|
||||
self.logger.info("Starting Doris MCP Server (stdio mode)")
|
||||
|
||||
try:
|
||||
# Ensure connection manager is initialized
|
||||
await self.connection_manager.initialize()
|
||||
self.logger.info("Connection manager initialization completed")
|
||||
# Initialize security manager first (includes JWT setup if enabled)
|
||||
await self.security_manager.initialize()
|
||||
self.logger.info("Security manager initialization completed")
|
||||
|
||||
# Start stdio server - using simpler approach
|
||||
# For stdio mode, we must establish a working database connection
|
||||
# Use the dedicated stdio mode initialization method
|
||||
await self.connection_manager.initialize_for_stdio_mode()
|
||||
|
||||
# Start stdio server - using compatible import approach
|
||||
try:
|
||||
from mcp.server.stdio import stdio_server
|
||||
except ImportError:
|
||||
# Fallback for different MCP versions
|
||||
try:
|
||||
from mcp.server import stdio_server
|
||||
except ImportError as stdio_import_error:
|
||||
self.logger.error(f"Failed to import stdio_server: {stdio_import_error}")
|
||||
raise RuntimeError("stdio_server module not available in this MCP version")
|
||||
|
||||
self.logger.info("Creating stdio_server transport...")
|
||||
|
||||
@@ -283,13 +493,21 @@ class DorisServer:
|
||||
|
||||
|
||||
|
||||
async def start_http(self, host: str = os.getenv("SERVER_HOST", _default_config.database.host), port: int = os.getenv("SERVER_PORT", _default_config.server_port)):
|
||||
"""Start Streamable HTTP transport mode"""
|
||||
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}")
|
||||
async def start_http(self, host: str = os.getenv("SERVER_HOST", _default_config.database.host), port: int = os.getenv("SERVER_PORT", _default_config.server_port), workers: int = 1):
|
||||
"""Start Streamable HTTP transport mode with workers support"""
|
||||
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}, workers: {workers}")
|
||||
|
||||
try:
|
||||
# Ensure connection manager is initialized
|
||||
await self.connection_manager.initialize()
|
||||
# Initialize security manager first (includes JWT setup if enabled)
|
||||
await self.security_manager.initialize()
|
||||
self.logger.info("Security manager initialization completed")
|
||||
|
||||
# For HTTP mode, try to initialize global connection pool with graceful degradation
|
||||
global_pool_created = await self.connection_manager.initialize_for_http_mode()
|
||||
if global_pool_created:
|
||||
self.logger.info("Global database connection pool available for HTTP mode")
|
||||
else:
|
||||
self.logger.info("HTTP mode running without global database pool, will use token-bound configurations")
|
||||
|
||||
# Use Starlette and StreamableHTTPSessionManager according to official example
|
||||
import uvicorn
|
||||
@@ -314,6 +532,44 @@ class DorisServer:
|
||||
async def health_check(request):
|
||||
return JSONResponse({"status": "healthy", "service": "doris-mcp-server"})
|
||||
|
||||
# OAuth endpoints
|
||||
from .auth.oauth_handlers import OAuthHandlers
|
||||
oauth_handlers = OAuthHandlers(self.security_manager)
|
||||
|
||||
async def oauth_login(request):
|
||||
return await oauth_handlers.handle_login(request)
|
||||
|
||||
async def oauth_callback(request):
|
||||
return await oauth_handlers.handle_callback(request)
|
||||
|
||||
async def oauth_provider_info(request):
|
||||
return await oauth_handlers.handle_provider_info(request)
|
||||
|
||||
async def oauth_demo(request):
|
||||
return await oauth_handlers.handle_demo_page(request)
|
||||
|
||||
# Token management endpoints
|
||||
from .auth.token_handlers import TokenHandlers
|
||||
token_handlers = TokenHandlers(self.security_manager, self.config)
|
||||
|
||||
async def token_create(request):
|
||||
return await token_handlers.handle_create_token(request)
|
||||
|
||||
async def token_revoke(request):
|
||||
return await token_handlers.handle_revoke_token(request)
|
||||
|
||||
async def token_list(request):
|
||||
return await token_handlers.handle_list_tokens(request)
|
||||
|
||||
async def token_stats(request):
|
||||
return await token_handlers.handle_token_stats(request)
|
||||
|
||||
async def token_cleanup(request):
|
||||
return await token_handlers.handle_cleanup_tokens(request)
|
||||
|
||||
async def token_management(request):
|
||||
return await token_handlers.handle_management_page(request)
|
||||
|
||||
# Lifecycle manager - simplified since we manage session_manager externally
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app: Starlette) -> AsyncIterator[None]:
|
||||
@@ -329,6 +585,18 @@ class DorisServer:
|
||||
debug=True,
|
||||
routes=[
|
||||
Route("/health", health_check, methods=["GET"]),
|
||||
# OAuth endpoints
|
||||
Route("/auth/login", oauth_login, methods=["GET"]),
|
||||
Route("/auth/callback", oauth_callback, methods=["GET"]),
|
||||
Route("/auth/provider", oauth_provider_info, methods=["GET"]),
|
||||
Route("/auth/demo", oauth_demo, methods=["GET"]),
|
||||
# Token management endpoints
|
||||
Route("/token/create", token_create, methods=["GET", "POST"]),
|
||||
Route("/token/revoke", token_revoke, methods=["GET", "DELETE"]),
|
||||
Route("/token/list", token_list, methods=["GET"]),
|
||||
Route("/token/stats", token_stats, methods=["GET"]),
|
||||
Route("/token/cleanup", token_cleanup, methods=["GET", "POST"]),
|
||||
Route("/token/management", token_management, methods=["GET"]),
|
||||
],
|
||||
lifespan=lifespan,
|
||||
)
|
||||
@@ -346,8 +614,10 @@ class DorisServer:
|
||||
self.logger.info(f"Received request for path: {path}")
|
||||
|
||||
try:
|
||||
# Handle health check
|
||||
if path.startswith("/health"):
|
||||
# Handle health check, auth, and token management endpoints
|
||||
if (path.startswith("/health") or
|
||||
path.startswith("/auth/") or
|
||||
path.startswith("/token/")):
|
||||
await starlette_app(scope, receive, send)
|
||||
return
|
||||
|
||||
@@ -360,6 +630,39 @@ class DorisServer:
|
||||
self.logger.info(f"MCP Request - Method: {method}")
|
||||
self.logger.info(f"MCP Request - Headers: {headers}")
|
||||
|
||||
# Authentication check for MCP requests
|
||||
try:
|
||||
# Extract authentication information
|
||||
auth_info = await self._extract_auth_info_from_scope(scope, headers)
|
||||
|
||||
# Authenticate the request
|
||||
auth_context = await self.security_manager.authenticate_request(auth_info)
|
||||
self.logger.info(f"MCP request authenticated: token_id={auth_context.token_id}, client_ip={auth_context.client_ip}")
|
||||
|
||||
# Store auth context in scope for potential use by tools/resources
|
||||
scope["auth_context"] = auth_context
|
||||
|
||||
# FIX for Issue #62 Bug 1: Set auth_context in context variable
|
||||
# This allows tools to access token information for token-bound database configuration
|
||||
# CRITICAL: Use the global ContextVar from security.py to ensure same instance is used everywhere
|
||||
try:
|
||||
from .utils.security import mcp_auth_context_var
|
||||
mcp_auth_context_var.set(auth_context)
|
||||
self.logger.debug(f"Set auth_context in context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
|
||||
except Exception as ctx_error:
|
||||
self.logger.warning(f"Failed to set auth_context in context variable: {ctx_error}")
|
||||
|
||||
except Exception as auth_error:
|
||||
self.logger.error(f"MCP authentication failed: {auth_error}")
|
||||
# Return 401 Unauthorized
|
||||
from starlette.responses import JSONResponse
|
||||
response = JSONResponse(
|
||||
{"error": "Authentication required", "message": str(auth_error)},
|
||||
status_code=401
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
# Handle Dify compatibility for GET requests
|
||||
if method == "GET":
|
||||
accept_header = headers.get(b'accept', b'').decode('utf-8')
|
||||
@@ -402,7 +705,23 @@ class DorisServer:
|
||||
self.logger.warning(f"Unsupported scope type: {scope['type']}")
|
||||
return
|
||||
|
||||
# Start uvicorn server with session manager lifecycle
|
||||
# Choose startup method based on worker count
|
||||
if workers > 1:
|
||||
self.logger.info(f"Using multi-process mode with {workers} workers")
|
||||
self.logger.info("Note: Multi-worker mode provides full MCP functionality with independent worker processes")
|
||||
|
||||
# Use the dedicated multiworker app module with full MCP support
|
||||
uvicorn.run(
|
||||
"doris_mcp_server.multiworker_app:app",
|
||||
host=host,
|
||||
port=port,
|
||||
workers=workers,
|
||||
log_level="info"
|
||||
)
|
||||
|
||||
else:
|
||||
self.logger.info("Using single-process mode")
|
||||
# Single worker mode, use original logic with session manager lifecycle
|
||||
config = uvicorn.Config(
|
||||
app=mcp_app,
|
||||
host=host,
|
||||
@@ -429,10 +748,16 @@ class DorisServer:
|
||||
self.logger.error(f" Exception {i+1}: {type(exc).__name__}: {exc}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown server"""
|
||||
self.logger.info("Shutting down Doris MCP Server")
|
||||
try:
|
||||
# Shutdown security manager first (includes JWT cleanup)
|
||||
await self.security_manager.shutdown()
|
||||
self.logger.info("Security manager shutdown completed")
|
||||
|
||||
await self.connection_manager.close()
|
||||
self.logger.info("Doris MCP Server has been shut down")
|
||||
except Exception as e:
|
||||
@@ -452,6 +777,11 @@ Transport Modes:
|
||||
Examples:
|
||||
python -m doris_mcp_server --transport stdio
|
||||
python -m doris_mcp_server --transport http --host 0.0.0.0 --port 3000
|
||||
python -m doris_mcp_server --transport stdio --doris-host localhost --doris-port 9030
|
||||
python -m doris_mcp_server --transport http --doris-user admin --doris-database test_db
|
||||
|
||||
# Backward compatibility: --db-* parameters are also supported
|
||||
python -m doris_mcp_server --transport stdio --db-host localhost --db-port 9030
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -466,35 +796,45 @@ Examples:
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default=os.getenv("SERVER_HOST", _default_config.database.host),
|
||||
help=f"Host address for HTTP mode (default: {_default_config.database.host})",
|
||||
default=os.getenv("SERVER_HOST", _default_config.server_host),
|
||||
help=f"Host address for HTTP mode (default: {_default_config.server_host})",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=os.getenv("SERVER_PORT", _default_config.server_port), help=f"Port number for HTTP mode (default: {_default_config.server_port})"
|
||||
"--port",
|
||||
type=int,
|
||||
default=3000,
|
||||
help="Port number for HTTP mode (default: 3000)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--db-host",
|
||||
"--workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of worker processes for HTTP mode (default: 1, use 0 for auto-detect CPU cores)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--doris-host", "--db-host",
|
||||
type=str,
|
||||
default=os.getenv("DB_HOST", _default_config.database.host),
|
||||
default=os.getenv("DORIS_HOST", _default_config.database.host),
|
||||
help=f"Doris database host address (default: {_default_config.database.host})",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--db-port", type=int, default=os.getenv("DB_PORT", _default_config.database.port), help=f"Doris database port number (default: {_default_config.database.port})"
|
||||
"--doris-port", "--db-port", type=int, default=9030, help="Doris database port number (default: 9030)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--db-user", type=str, default=os.getenv("DB_USER", _default_config.database.user), help=f"Doris database username (default: {_default_config.database.user})"
|
||||
"--doris-user", "--db-user", type=str, default=os.getenv("DORIS_USER", _default_config.database.user), help=f"Doris database username (default: {_default_config.database.user})"
|
||||
)
|
||||
|
||||
parser.add_argument("--db-password", type=str, default="", help="Doris database password")
|
||||
parser.add_argument("--doris-password", "--db-password", type=str, default=os.getenv("DORIS_PASSWORD", ""), help="Doris database password")
|
||||
|
||||
parser.add_argument(
|
||||
"--db-database",
|
||||
"--doris-database", "--db-database",
|
||||
type=str,
|
||||
default=os.getenv("DB_DATABASE", _default_config.database.database),
|
||||
default=os.getenv("DORIS_DATABASE", _default_config.database.database),
|
||||
help=f"Doris database name (default: {_default_config.database.database})",
|
||||
)
|
||||
|
||||
@@ -509,41 +849,91 @@ Examples:
|
||||
return parser
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function"""
|
||||
def update_configuration(config: DorisConfig):
|
||||
"""Update doris configuration object"""
|
||||
# For some arguments, if not specified, environment variables or default configurations will be used as default values
|
||||
parser = create_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set log level
|
||||
logging.getLogger().setLevel(getattr(logging, args.log_level))
|
||||
|
||||
# Create configuration - priority: command line arguments > .env file > default values
|
||||
config = DorisConfig.from_env() # First load from .env file and environment variables
|
||||
|
||||
# Update config values
|
||||
# Command line arguments override configuration (if provided)
|
||||
if args.db_host != _default_config.database.host: # If not default value, use command line argument
|
||||
config.database.host = args.db_host
|
||||
if args.db_port != _default_config.database.port:
|
||||
config.database.port = args.db_port
|
||||
if args.db_user != _default_config.database.user:
|
||||
config.database.user = args.db_user
|
||||
if args.db_password: # Use password if provided
|
||||
config.database.password = args.db_password
|
||||
if args.db_database != _default_config.database.database:
|
||||
config.database.database = args.db_database
|
||||
# basic
|
||||
if args.transport != _default_config.transport:
|
||||
config.transport = args.transport
|
||||
if args.host != _default_config.server_host:
|
||||
config.server_host = args.host
|
||||
if args.port != _default_config.server_port:
|
||||
config.server_port = args.port
|
||||
server_name = os.getenv("SERVER_NAME")
|
||||
if server_name:
|
||||
config.server_name = server_name
|
||||
server_version = os.getenv("SERVER_VERSION")
|
||||
if server_version:
|
||||
config.server_version = server_version
|
||||
|
||||
# database
|
||||
if args.doris_host != _default_config.database.host: # If not default value, use command line argument
|
||||
config.database.host = args.doris_host
|
||||
if args.doris_port != _default_config.database.port:
|
||||
config.database.port = args.doris_port
|
||||
if args.doris_user != _default_config.database.user:
|
||||
config.database.user = args.doris_user
|
||||
if args.doris_password: # Use password if provided
|
||||
config.database.password = args.doris_password
|
||||
if args.doris_database != _default_config.database.database:
|
||||
config.database.database = args.doris_database
|
||||
|
||||
# logging
|
||||
if args.log_level != _default_config.logging.level:
|
||||
config.logging.level = args.log_level
|
||||
|
||||
# workers (add to config for HTTP mode)
|
||||
if hasattr(args, 'workers'):
|
||||
config.workers = args.workers
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function"""
|
||||
# Create configuration - priority: command line arguments > env variables > .env file > default values
|
||||
# First load from .env file and environment variables
|
||||
config = DorisConfig.from_env()
|
||||
|
||||
# Then parse the command line arguments, and update the config object.
|
||||
update_configuration(config)
|
||||
|
||||
# Initialize enhanced logging system
|
||||
from .utils.config import ConfigManager
|
||||
config_manager = ConfigManager(config)
|
||||
config_manager.setup_logging()
|
||||
|
||||
# Get logger with proper configuration
|
||||
from .utils.logger import get_logger, log_system_info
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Log system information for debugging
|
||||
log_system_info()
|
||||
|
||||
logger.info("Starting Doris MCP Server...")
|
||||
logger.info(f"Transport: {config.transport}")
|
||||
logger.info(f"Log Level: {config.logging.level}")
|
||||
|
||||
# Create server instance
|
||||
server = DorisServer(config)
|
||||
|
||||
try:
|
||||
if args.transport == "stdio":
|
||||
if config.transport == "stdio":
|
||||
await server.start_stdio()
|
||||
elif args.transport == "http":
|
||||
await server.start_http(args.host, args.port)
|
||||
elif config.transport == "http":
|
||||
# Get workers configuration with auto-detection support
|
||||
workers = getattr(config, 'workers', 1)
|
||||
if workers == 0:
|
||||
import multiprocessing
|
||||
workers = multiprocessing.cpu_count()
|
||||
logger.info(f"Auto-detected {workers} CPU cores for worker processes")
|
||||
|
||||
await server.start_http(config.server_host, config.server_port, workers)
|
||||
else:
|
||||
logger.error(f"Unsupported transport protocol: {args.transport}")
|
||||
logger.error(f"Unsupported transport protocol: {config.transport}")
|
||||
await server.shutdown()
|
||||
return 1
|
||||
|
||||
@@ -564,6 +954,10 @@ async def main():
|
||||
except Exception as shutdown_error:
|
||||
logger.error(f"Error occurred while shutting down server: {shutdown_error}")
|
||||
|
||||
# Shutdown logging system
|
||||
from .utils.logger import shutdown_logging
|
||||
shutdown_logging()
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
627
doris_mcp_server/multiworker_app.py
Normal file
627
doris_mcp_server/multiworker_app.py
Normal file
@@ -0,0 +1,627 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Multi-worker application module for doris-mcp-server
|
||||
|
||||
This module provides full MCP functionality with multi-worker support.
|
||||
Each worker process creates its own MCP server and session manager using the same
|
||||
robust architecture as the single-worker mode.
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
# Import MCP components with compatibility handling
|
||||
# Use the same import strategy as main.py for consistency
|
||||
MCP_VERSION = 'unknown'
|
||||
Server = None
|
||||
InitializationOptions = None
|
||||
Prompt = None
|
||||
Resource = None
|
||||
TextContent = None
|
||||
Tool = None
|
||||
|
||||
def _import_mcp_with_compatibility():
|
||||
"""Import MCP components with multi-version compatibility"""
|
||||
global MCP_VERSION, Server, InitializationOptions, Prompt, Resource, TextContent, Tool
|
||||
|
||||
try:
|
||||
# Strategy 1: Try direct server-only imports to avoid client-side issues
|
||||
from mcp.server import Server as _Server
|
||||
from mcp.server.models import InitializationOptions as _InitOptions
|
||||
from mcp.types import (
|
||||
Prompt as _Prompt,
|
||||
Resource as _Resource,
|
||||
TextContent as _TextContent,
|
||||
Tool as _Tool,
|
||||
)
|
||||
|
||||
# Assign to globals
|
||||
Server = _Server
|
||||
InitializationOptions = _InitOptions
|
||||
Prompt = _Prompt
|
||||
Resource = _Resource
|
||||
TextContent = _TextContent
|
||||
Tool = _Tool
|
||||
|
||||
# Try to get version safely
|
||||
try:
|
||||
import mcp
|
||||
MCP_VERSION = getattr(mcp, '__version__', None)
|
||||
if not MCP_VERSION:
|
||||
# Fallback: try to get version from package metadata
|
||||
try:
|
||||
import importlib.metadata
|
||||
MCP_VERSION = importlib.metadata.version('mcp')
|
||||
except Exception:
|
||||
# Second fallback: try pkg_resources
|
||||
try:
|
||||
import pkg_resources
|
||||
MCP_VERSION = pkg_resources.get_distribution('mcp').version
|
||||
except Exception:
|
||||
MCP_VERSION = 'detected-but-version-unknown'
|
||||
except Exception:
|
||||
# Version detection failed, but imports worked
|
||||
try:
|
||||
import importlib.metadata
|
||||
MCP_VERSION = importlib.metadata.version('mcp')
|
||||
except Exception:
|
||||
try:
|
||||
import pkg_resources
|
||||
MCP_VERSION = pkg_resources.get_distribution('mcp').version
|
||||
except Exception:
|
||||
MCP_VERSION = 'imported-successfully'
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"MCP components imported successfully in multiworker, version: {MCP_VERSION}")
|
||||
return True
|
||||
|
||||
except Exception as import_error:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Strategy 2: Handle RequestContext compatibility issues in 1.9.x versions
|
||||
error_str = str(import_error).lower()
|
||||
if 'requestcontext' in error_str and 'too few arguments' in error_str:
|
||||
logger.warning(f"Detected MCP RequestContext compatibility issue: {import_error}")
|
||||
logger.info("Attempting comprehensive workaround for MCP 1.9.x RequestContext issue...")
|
||||
|
||||
try:
|
||||
# Comprehensive monkey patch approach
|
||||
import sys
|
||||
import types
|
||||
|
||||
# Create and install mock modules before any MCP imports
|
||||
if 'mcp.shared.context' not in sys.modules:
|
||||
mock_context_module = types.ModuleType('mcp.shared.context')
|
||||
|
||||
class FlexibleRequestContext:
|
||||
"""Flexible RequestContext that accepts variable arguments"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __class_getitem__(cls, params):
|
||||
# Accept any number of parameters and return cls
|
||||
return cls
|
||||
|
||||
# Add other methods that might be called
|
||||
def __getattr__(self, name):
|
||||
return lambda *args, **kwargs: None
|
||||
|
||||
mock_context_module.RequestContext = FlexibleRequestContext
|
||||
sys.modules['mcp.shared.context'] = mock_context_module
|
||||
|
||||
# Also patch the typing system to be more permissive
|
||||
original_check_generic = None
|
||||
try:
|
||||
import typing
|
||||
if hasattr(typing, '_check_generic'):
|
||||
original_check_generic = typing._check_generic
|
||||
def permissive_check_generic(cls, params, elen):
|
||||
# Don't enforce strict parameter count checking
|
||||
return
|
||||
typing._check_generic = permissive_check_generic
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Clear any cached imports that might have failed
|
||||
modules_to_clear = [k for k in sys.modules.keys() if k.startswith('mcp.')]
|
||||
for module in modules_to_clear:
|
||||
if module in sys.modules:
|
||||
del sys.modules[module]
|
||||
|
||||
# Now try importing again with the patches in place
|
||||
from mcp.server import Server as _Server
|
||||
from mcp.server.models import InitializationOptions as _InitOptions
|
||||
from mcp.types import (
|
||||
Prompt as _Prompt,
|
||||
Resource as _Resource,
|
||||
TextContent as _TextContent,
|
||||
Tool as _Tool,
|
||||
)
|
||||
|
||||
# Assign to globals
|
||||
Server = _Server
|
||||
InitializationOptions = _InitOptions
|
||||
Prompt = _Prompt
|
||||
Resource = _Resource
|
||||
TextContent = _TextContent
|
||||
Tool = _Tool
|
||||
|
||||
# Try to detect actual version even in compatibility mode
|
||||
try:
|
||||
import importlib.metadata
|
||||
actual_version = importlib.metadata.version('mcp')
|
||||
MCP_VERSION = f'compatibility-mode-{actual_version}'
|
||||
except Exception:
|
||||
try:
|
||||
import pkg_resources
|
||||
actual_version = pkg_resources.get_distribution('mcp').version
|
||||
MCP_VERSION = f'compatibility-mode-{actual_version}'
|
||||
except Exception:
|
||||
MCP_VERSION = 'compatibility-mode-1.9.x'
|
||||
|
||||
logger.info("MCP 1.9.x compatibility workaround successful in multiworker!")
|
||||
|
||||
# Restore original typing function if we patched it
|
||||
if original_check_generic:
|
||||
typing._check_generic = original_check_generic
|
||||
|
||||
return True
|
||||
|
||||
except Exception as workaround_error:
|
||||
logger.error(f"MCP compatibility workaround failed in multiworker: {workaround_error}")
|
||||
|
||||
# Restore original typing function if we patched it
|
||||
if original_check_generic:
|
||||
try:
|
||||
import typing
|
||||
typing._check_generic = original_check_generic
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.error(f"Failed to import MCP components in multiworker: {import_error}")
|
||||
return False
|
||||
|
||||
# Perform MCP import with compatibility handling
|
||||
if not _import_mcp_with_compatibility():
|
||||
raise ImportError(
|
||||
"Failed to import MCP components in multiworker. Please ensure MCP is properly installed. "
|
||||
"Supported versions: 1.8.x, 1.9.x"
|
||||
)
|
||||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Route
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
# Import Doris MCP components
|
||||
from .tools.tools_manager import DorisToolsManager
|
||||
from .tools.prompts_manager import DorisPromptsManager
|
||||
from .tools.resources_manager import DorisResourcesManager
|
||||
from .utils.config import DorisConfig
|
||||
from .utils.db import DorisConnectionManager
|
||||
from .utils.security import DorisSecurityManager
|
||||
|
||||
# Global variables for worker-specific instances
|
||||
_worker_server = None
|
||||
_worker_session_manager = None
|
||||
_worker_connection_manager = None
|
||||
_worker_security_manager = None
|
||||
_worker_session_manager_context = None
|
||||
_worker_initialized = False
|
||||
|
||||
def get_mcp_capabilities():
|
||||
"""Get MCP capabilities for worker - use the same logic as main.py"""
|
||||
try:
|
||||
# For MCP 1.9.x and newer
|
||||
from mcp.server.lowlevel.server import NotificationOptions
|
||||
|
||||
capabilities = {
|
||||
"resources": {},
|
||||
"tools": {},
|
||||
"prompts": {},
|
||||
"notification_options": {
|
||||
"prompts_changed": True,
|
||||
"resources_changed": True,
|
||||
"tools_changed": True
|
||||
}
|
||||
}
|
||||
return capabilities
|
||||
except Exception as e:
|
||||
# Import logger properly
|
||||
from .utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
logger.warning(f"Failed to get full capabilities in multiworker: {e}")
|
||||
return {
|
||||
"resources": {},
|
||||
"tools": {},
|
||||
"prompts": {}
|
||||
}
|
||||
|
||||
async def initialize_worker():
|
||||
"""Initialize MCP server and managers for this worker process"""
|
||||
global _worker_server, _worker_session_manager, _worker_connection_manager, _worker_security_manager, _worker_session_manager_context, _worker_initialized, _oauth_handlers, _token_handlers
|
||||
|
||||
if _worker_initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
# Import logger properly
|
||||
from .utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.info(f"Initializing MCP worker process {os.getpid()}")
|
||||
|
||||
# Create configuration
|
||||
config = DorisConfig.from_env()
|
||||
|
||||
# Initialize enhanced logging system
|
||||
from .utils.config import ConfigManager
|
||||
config_manager = ConfigManager(config)
|
||||
config_manager.setup_logging()
|
||||
|
||||
# Create security manager
|
||||
_worker_security_manager = DorisSecurityManager(config)
|
||||
|
||||
# Initialize security manager first (includes JWT setup if enabled)
|
||||
await _worker_security_manager.initialize()
|
||||
logger.info(f"Worker {os.getpid()} security manager initialization completed")
|
||||
|
||||
# Create connection manager with token manager for token-bound DB config
|
||||
token_manager = _worker_security_manager.auth_provider.token_manager if hasattr(_worker_security_manager, 'auth_provider') and hasattr(_worker_security_manager.auth_provider, 'token_manager') else None
|
||||
_worker_connection_manager = DorisConnectionManager(config, _worker_security_manager, token_manager)
|
||||
|
||||
# Set connection manager reference in security manager for database validation
|
||||
_worker_security_manager.connection_manager = _worker_connection_manager
|
||||
|
||||
await _worker_connection_manager.initialize()
|
||||
|
||||
# Create MCP server
|
||||
_worker_server = Server("doris-mcp-server")
|
||||
|
||||
# Create managers
|
||||
resources_manager = DorisResourcesManager(_worker_connection_manager)
|
||||
tools_manager = DorisToolsManager(_worker_connection_manager)
|
||||
prompts_manager = DorisPromptsManager(_worker_connection_manager)
|
||||
|
||||
# Setup MCP handlers
|
||||
@_worker_server.list_resources()
|
||||
async def handle_list_resources() -> list[Resource]:
|
||||
"""Handle resource list request"""
|
||||
try:
|
||||
logger.info("Handling resource list request in worker")
|
||||
resources = await resources_manager.list_resources()
|
||||
logger.info(f"Returning {len(resources)} resources from worker")
|
||||
return resources
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle resource list request in worker: {e}")
|
||||
return []
|
||||
|
||||
@_worker_server.read_resource()
|
||||
async def handle_read_resource(uri: str) -> str:
|
||||
"""Handle resource read request"""
|
||||
try:
|
||||
logger.info(f"Handling resource read request in worker: {uri}")
|
||||
content = await resources_manager.read_resource(uri)
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle resource read request in worker: {e}")
|
||||
return json.dumps(
|
||||
{"error": f"Failed to read resource: {str(e)}", "uri": uri},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
@_worker_server.list_tools()
|
||||
async def handle_list_tools() -> list[Tool]:
|
||||
"""Handle tool list request"""
|
||||
try:
|
||||
logger.info("Handling tool list request in worker")
|
||||
tools = await tools_manager.list_tools()
|
||||
logger.info(f"Returning {len(tools)} tools from worker")
|
||||
return tools
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle tool list request in worker: {e}")
|
||||
return []
|
||||
|
||||
@_worker_server.call_tool()
|
||||
async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
|
||||
"""Handle tool call request"""
|
||||
try:
|
||||
logger.info(f"Handling tool call request in worker: {name}")
|
||||
result = await tools_manager.call_tool(name, arguments)
|
||||
return [TextContent(type="text", text=result)]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle tool call request in worker: {e}")
|
||||
error_result = json.dumps(
|
||||
{
|
||||
"error": f"Tool call failed: {str(e)}",
|
||||
"tool_name": name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
return [TextContent(type="text", text=error_result)]
|
||||
|
||||
@_worker_server.list_prompts()
|
||||
async def handle_list_prompts() -> list[Prompt]:
|
||||
"""Handle prompt list request"""
|
||||
try:
|
||||
logger.info("Handling prompt list request in worker")
|
||||
prompts = await prompts_manager.list_prompts()
|
||||
logger.info(f"Returning {len(prompts)} prompts from worker")
|
||||
return prompts
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle prompt list request in worker: {e}")
|
||||
return []
|
||||
|
||||
@_worker_server.get_prompt()
|
||||
async def handle_get_prompt(name: str, arguments: dict[str, Any]) -> str:
|
||||
"""Handle prompt get request"""
|
||||
try:
|
||||
logger.info(f"Handling prompt get request in worker: {name}")
|
||||
result = await prompts_manager.get_prompt(name, arguments)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle prompt get request in worker: {e}")
|
||||
error_result = json.dumps(
|
||||
{
|
||||
"error": f"Failed to get prompt: {str(e)}",
|
||||
"prompt_name": name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
return error_result
|
||||
|
||||
# Create session manager for this worker
|
||||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||
|
||||
_worker_session_manager = StreamableHTTPSessionManager(
|
||||
app=_worker_server,
|
||||
json_response=True,
|
||||
stateless=True # Use stateless mode for multi-worker compatibility
|
||||
)
|
||||
|
||||
# Start the session manager context
|
||||
_worker_session_manager_context = _worker_session_manager.run()
|
||||
await _worker_session_manager_context.__aenter__()
|
||||
|
||||
# Initialize OAuth and Token handlers
|
||||
from .auth.oauth_handlers import OAuthHandlers
|
||||
from .auth.token_handlers import TokenHandlers
|
||||
_oauth_handlers = OAuthHandlers(_worker_security_manager)
|
||||
_token_handlers = TokenHandlers(_worker_security_manager, config)
|
||||
|
||||
_worker_initialized = True
|
||||
logger.info(f"Worker {os.getpid()} MCP initialization completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
from .utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
logger.error(f"Failed to initialize worker {os.getpid()}: {e}")
|
||||
import traceback
|
||||
logger.error("Complete error stack:")
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
async def health_check(request):
|
||||
"""Health check endpoint that shows worker PID"""
|
||||
return JSONResponse({
|
||||
"status": "healthy",
|
||||
"service": "doris-mcp-server",
|
||||
"worker_pid": os.getpid(),
|
||||
"worker_mode": "multi-process-full-mcp",
|
||||
"mcp_initialized": _worker_initialized,
|
||||
"mcp_version": MCP_VERSION
|
||||
})
|
||||
|
||||
# OAuth and Token handlers (initialize after worker setup)
|
||||
_oauth_handlers = None
|
||||
_token_handlers = None
|
||||
|
||||
async def oauth_login(request):
|
||||
"""OAuth login endpoint"""
|
||||
if not _oauth_handlers:
|
||||
return JSONResponse({"error": "OAuth not initialized"}, status_code=503)
|
||||
return await _oauth_handlers.handle_login(request)
|
||||
|
||||
async def oauth_callback(request):
|
||||
"""OAuth callback endpoint"""
|
||||
if not _oauth_handlers:
|
||||
return JSONResponse({"error": "OAuth not initialized"}, status_code=503)
|
||||
return await _oauth_handlers.handle_callback(request)
|
||||
|
||||
async def oauth_provider_info(request):
|
||||
"""OAuth provider info endpoint"""
|
||||
if not _oauth_handlers:
|
||||
return JSONResponse({"error": "OAuth not initialized"}, status_code=503)
|
||||
return await _oauth_handlers.handle_provider_info(request)
|
||||
|
||||
async def oauth_demo(request):
|
||||
"""OAuth demo page endpoint"""
|
||||
if not _oauth_handlers:
|
||||
from starlette.responses import HTMLResponse
|
||||
return HTMLResponse("<h1>OAuth not initialized</h1>")
|
||||
return await _oauth_handlers.handle_demo_page(request)
|
||||
|
||||
# Token management endpoints
|
||||
async def token_create(request):
|
||||
"""Token creation endpoint"""
|
||||
if not _token_handlers:
|
||||
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
|
||||
return await _token_handlers.handle_create_token(request)
|
||||
|
||||
async def token_revoke(request):
|
||||
"""Token revocation endpoint"""
|
||||
if not _token_handlers:
|
||||
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
|
||||
return await _token_handlers.handle_revoke_token(request)
|
||||
|
||||
async def token_list(request):
|
||||
"""Token listing endpoint"""
|
||||
if not _token_handlers:
|
||||
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
|
||||
return await _token_handlers.handle_list_tokens(request)
|
||||
|
||||
async def token_stats(request):
|
||||
"""Token statistics endpoint"""
|
||||
if not _token_handlers:
|
||||
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
|
||||
return await _token_handlers.handle_token_stats(request)
|
||||
|
||||
async def token_cleanup(request):
|
||||
"""Token cleanup endpoint"""
|
||||
if not _token_handlers:
|
||||
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
|
||||
return await _token_handlers.handle_cleanup_tokens(request)
|
||||
|
||||
async def token_management(request):
|
||||
"""Token management page endpoint"""
|
||||
if not _token_handlers:
|
||||
from starlette.responses import HTMLResponse
|
||||
return HTMLResponse("<h1>Token handlers not initialized</h1>")
|
||||
return await _token_handlers.handle_management_page(request)
|
||||
|
||||
async def root_info(request):
|
||||
"""Root endpoint"""
|
||||
return JSONResponse({
|
||||
"service": "doris-mcp-server",
|
||||
"mode": "multi-worker-full-mcp",
|
||||
"worker_pid": os.getpid(),
|
||||
"mcp_initialized": _worker_initialized,
|
||||
"mcp_version": MCP_VERSION,
|
||||
"endpoints": {
|
||||
"health": "/health",
|
||||
"mcp": "/mcp"
|
||||
}
|
||||
})
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app):
|
||||
"""Application lifespan manager"""
|
||||
# Startup
|
||||
try:
|
||||
await initialize_worker()
|
||||
# Import logger properly
|
||||
from .utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
logger.info(f"Worker {os.getpid()} startup completed")
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
# Shutdown
|
||||
from .utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Close session manager context
|
||||
if _worker_session_manager_context:
|
||||
try:
|
||||
await _worker_session_manager_context.__aexit__(None, None, None)
|
||||
logger.info(f"Worker {os.getpid()} session manager context closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing worker session manager context: {e}")
|
||||
|
||||
if _worker_connection_manager:
|
||||
try:
|
||||
await _worker_connection_manager.close()
|
||||
logger.info(f"Worker {os.getpid()} connection manager closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing worker connection manager: {e}")
|
||||
|
||||
if _worker_security_manager:
|
||||
try:
|
||||
await _worker_security_manager.shutdown()
|
||||
logger.info(f"Worker {os.getpid()} security manager shutdown completed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error shutting down worker security manager: {e}")
|
||||
|
||||
# Shutdown logging system
|
||||
try:
|
||||
from .utils.logger import shutdown_logging
|
||||
shutdown_logging()
|
||||
except Exception as e:
|
||||
logger.error(f"Error shutting down logging system: {e}")
|
||||
|
||||
async def mcp_asgi_app(scope, receive, send):
|
||||
"""ASGI app that handles MCP requests"""
|
||||
if not _worker_initialized:
|
||||
# Send error response if worker not initialized
|
||||
await send({
|
||||
'type': 'http.response.start',
|
||||
'status': 503,
|
||||
'headers': [(b'content-type', b'application/json')]
|
||||
})
|
||||
await send({
|
||||
'type': 'http.response.body',
|
||||
'body': b'{"error": "Worker not initialized"}'
|
||||
})
|
||||
return
|
||||
|
||||
# Import logger properly
|
||||
from .utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Get request path for logging
|
||||
path = scope.get('path', '')
|
||||
method = scope.get('method', 'UNKNOWN')
|
||||
logger.debug(f"Worker {os.getpid()} handling MCP request: {method} {path}")
|
||||
|
||||
# Handle the request directly without nested run context
|
||||
await _worker_session_manager.handle_request(scope, receive, send)
|
||||
|
||||
# Create Starlette app with basic routes
|
||||
basic_app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route("/", root_info, methods=["GET"]),
|
||||
Route("/health", health_check, methods=["GET"]),
|
||||
# OAuth endpoints
|
||||
Route("/auth/login", oauth_login, methods=["GET"]),
|
||||
Route("/auth/callback", oauth_callback, methods=["GET"]),
|
||||
Route("/auth/provider", oauth_provider_info, methods=["GET"]),
|
||||
Route("/auth/demo", oauth_demo, methods=["GET"]),
|
||||
# Token management endpoints
|
||||
Route("/token/create", token_create, methods=["GET", "POST"]),
|
||||
Route("/token/revoke", token_revoke, methods=["GET", "DELETE"]),
|
||||
Route("/token/list", token_list, methods=["GET"]),
|
||||
Route("/token/stats", token_stats, methods=["GET"]),
|
||||
Route("/token/cleanup", token_cleanup, methods=["GET", "POST"]),
|
||||
Route("/token/management", token_management, methods=["GET"]),
|
||||
],
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Create main ASGI app that routes between basic app and MCP
|
||||
async def app(scope, receive, send):
|
||||
"""Main ASGI app that routes requests"""
|
||||
path = scope.get('path', '/')
|
||||
|
||||
if path == "/mcp" or path.startswith('/mcp/'):
|
||||
# Handle MCP requests with session manager
|
||||
await mcp_asgi_app(scope, receive, send)
|
||||
else:
|
||||
# Handle other requests with basic Starlette app (includes auth endpoints)
|
||||
await basic_app(scope, receive, send)
|
||||
@@ -31,6 +31,7 @@ from mcp.types import (
|
||||
)
|
||||
|
||||
from ..utils.db import DorisConnectionManager
|
||||
from ..utils.sql_security_utils import get_auth_context
|
||||
|
||||
|
||||
class PromptTemplate:
|
||||
@@ -422,7 +423,8 @@ Please generate accurate and efficient SQL queries based on the above requiremen
|
||||
AND table_type = 'BASE TABLE'
|
||||
"""
|
||||
|
||||
db_result = await connection.execute(db_info_sql)
|
||||
auth_context = get_auth_context()
|
||||
db_result = await connection.execute(db_info_sql, auth_context=auth_context)
|
||||
db_info = db_result.data[0] if db_result.data else {}
|
||||
|
||||
# Get main table list
|
||||
@@ -438,7 +440,7 @@ Please generate accurate and efficient SQL queries based on the above requiremen
|
||||
LIMIT 10
|
||||
"""
|
||||
|
||||
tables_result = await connection.execute(tables_sql)
|
||||
tables_result = await connection.execute(tables_sql, auth_context=auth_context)
|
||||
|
||||
context = f"""Current database statistics:
|
||||
- Total number of tables: {db_info.get("table_count", 0)}
|
||||
|
||||
@@ -26,6 +26,7 @@ from typing import Any
|
||||
from mcp.types import Resource
|
||||
|
||||
from ..utils.db import DorisConnectionManager
|
||||
from ..utils.sql_security_utils import get_auth_context
|
||||
|
||||
|
||||
class TableMetadata:
|
||||
@@ -169,7 +170,8 @@ class DorisResourcesManager:
|
||||
ORDER BY table_name
|
||||
"""
|
||||
|
||||
result = await connection.execute(tables_query)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(tables_query, auth_context=auth_context)
|
||||
tables = []
|
||||
|
||||
for row in result.data:
|
||||
@@ -204,7 +206,8 @@ class DorisResourcesManager:
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
result = await connection.execute(columns_query, (table_name,))
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(columns_query, params=(table_name,), auth_context=auth_context)
|
||||
return [dict(row) for row in result.data]
|
||||
|
||||
async def _get_view_metadata(self) -> list[ViewMetadata]:
|
||||
@@ -226,7 +229,8 @@ class DorisResourcesManager:
|
||||
ORDER BY table_name
|
||||
"""
|
||||
|
||||
result = await connection.execute(views_query)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(views_query, auth_context=auth_context)
|
||||
views = []
|
||||
|
||||
for row in result.data:
|
||||
@@ -257,7 +261,8 @@ class DorisResourcesManager:
|
||||
AND table_name = %s
|
||||
"""
|
||||
|
||||
table_result = await connection.execute(table_info_query, (table_name,))
|
||||
auth_context = get_auth_context()
|
||||
table_result = await connection.execute(table_info_query, params=(table_name,), auth_context=auth_context)
|
||||
if not table_result.data:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
|
||||
@@ -295,7 +300,8 @@ class DorisResourcesManager:
|
||||
ORDER BY index_name, seq_in_index
|
||||
"""
|
||||
|
||||
result = await connection.execute(indexes_query, (table_name,))
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(indexes_query, params=(table_name,), auth_context=auth_context)
|
||||
return [dict(row) for row in result.data]
|
||||
|
||||
async def _get_view_definition(self, view_name: str) -> str:
|
||||
@@ -312,7 +318,8 @@ class DorisResourcesManager:
|
||||
AND table_name = %s
|
||||
"""
|
||||
|
||||
result = await connection.execute(view_query, (view_name,))
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(view_query, params=(view_name,), auth_context=auth_context)
|
||||
if not result.data:
|
||||
raise ValueError(f"View {view_name} does not exist")
|
||||
|
||||
@@ -340,7 +347,8 @@ class DorisResourcesManager:
|
||||
AND table_type = 'BASE TABLE'
|
||||
"""
|
||||
|
||||
table_result = await connection.execute(table_stats_query)
|
||||
auth_context = get_auth_context()
|
||||
table_result = await connection.execute(table_stats_query, auth_context=auth_context)
|
||||
table_stats = table_result.data[0] if table_result.data else {}
|
||||
|
||||
# Get view statistics
|
||||
@@ -350,7 +358,7 @@ class DorisResourcesManager:
|
||||
WHERE table_schema = DATABASE()
|
||||
"""
|
||||
|
||||
view_result = await connection.execute(view_stats_query)
|
||||
view_result = await connection.execute(view_stats_query, auth_context=auth_context)
|
||||
view_stats = view_result.data[0] if view_result.data else {}
|
||||
|
||||
stats_info = {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
542
doris_mcp_server/utils/adbc_query_tools.py
Normal file
542
doris_mcp_server/utils/adbc_query_tools.py
Normal file
@@ -0,0 +1,542 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Apache Doris ADBC Query Tools
|
||||
High-performance data querying using Apache Arrow Flight SQL protocol
|
||||
"""
|
||||
|
||||
import os
|
||||
import socket
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.db import DorisConnectionManager
|
||||
from ..utils.sql_security_utils import get_auth_context
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _convert_numpy_types(obj):
|
||||
"""Convert numpy types to native Python types for JSON serialization"""
|
||||
try:
|
||||
# Import numpy only when needed
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
if isinstance(obj, np.integer):
|
||||
return int(obj)
|
||||
elif isinstance(obj, np.floating):
|
||||
return float(obj)
|
||||
elif isinstance(obj, np.bool_):
|
||||
return bool(obj)
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
elif isinstance(obj, (pd.Timestamp, pd.NaT.__class__)):
|
||||
return str(obj)
|
||||
elif pd.isna(obj):
|
||||
return None
|
||||
else:
|
||||
return obj
|
||||
except ImportError:
|
||||
# If numpy/pandas not available, return as-is
|
||||
return obj
|
||||
|
||||
|
||||
def _convert_dataframe_to_json_serializable(df):
|
||||
"""Convert DataFrame to JSON serializable format"""
|
||||
try:
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
# Convert DataFrame to records
|
||||
records = df.to_dict('records')
|
||||
|
||||
# Convert each record's values
|
||||
converted_records = []
|
||||
for record in records:
|
||||
converted_record = {}
|
||||
for key, value in record.items():
|
||||
converted_record[key] = _convert_numpy_types(value)
|
||||
converted_records.append(converted_record)
|
||||
|
||||
return converted_records
|
||||
except ImportError:
|
||||
# Fallback to basic dict conversion
|
||||
return df.to_dict('records')
|
||||
|
||||
|
||||
class DorisADBCQueryTools:
|
||||
"""ADBC Query Tools for high-performance data transfer using Arrow Flight SQL"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
self.adbc_client = None
|
||||
self.flight_sql_module = None
|
||||
self.adbc_manager_module = None
|
||||
|
||||
async def exec_adbc_query(
|
||||
self,
|
||||
sql: str,
|
||||
max_rows: int | None = None,
|
||||
timeout: int | None = None,
|
||||
return_format: str | None = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute SQL query using ADBC (Arrow Flight SQL) protocol
|
||||
|
||||
Args:
|
||||
sql: SQL statement to execute
|
||||
max_rows: Maximum number of rows to return (uses config default if None)
|
||||
timeout: Query timeout in seconds (uses config default if None)
|
||||
return_format: Format for returned data ("arrow", "pandas", "dict", uses config default if None)
|
||||
|
||||
Returns:
|
||||
Query results in specified format with metadata
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Use configuration defaults if parameters not specified
|
||||
adbc_config = self.connection_manager.config.adbc
|
||||
max_rows = max_rows if max_rows is not None else adbc_config.default_max_rows
|
||||
timeout = timeout if timeout is not None else adbc_config.default_timeout
|
||||
return_format = return_format if return_format is not None else adbc_config.default_return_format
|
||||
|
||||
# Step 1: Check environment variables and port availability
|
||||
port_check_result = await self._check_arrow_flight_ports()
|
||||
if not port_check_result["success"]:
|
||||
return port_check_result
|
||||
|
||||
# Step 2: Import required ADBC modules
|
||||
import_result = await self._import_adbc_modules()
|
||||
if not import_result["success"]:
|
||||
return import_result
|
||||
|
||||
# Step 3: Create ADBC connection
|
||||
connection_result = await self._create_adbc_connection()
|
||||
if not connection_result["success"]:
|
||||
return connection_result
|
||||
|
||||
# Step 4: Execute query using ADBC
|
||||
query_result = await self._execute_query_with_adbc(
|
||||
sql, max_rows, timeout, return_format
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
if query_result["success"]:
|
||||
query_result["execution_time"] = round(execution_time, 3)
|
||||
query_result["protocol"] = "ADBC_Arrow_Flight_SQL"
|
||||
query_result["timestamp"] = datetime.now().isoformat()
|
||||
|
||||
return query_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ADBC query execution failed: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"ADBC query execution failed: {str(e)}",
|
||||
"error_type": "execution_error",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
async def _check_arrow_flight_ports(self) -> Dict[str, Any]:
|
||||
"""Check Arrow Flight SQL port configuration and availability"""
|
||||
try:
|
||||
# Check environment variables
|
||||
fe_port = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
|
||||
be_port = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
|
||||
|
||||
if not fe_port:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Missing environment variable FE_ARROW_FLIGHT_SQL_PORT, please configure Arrow Flight SQL FE port in .env file",
|
||||
"error_type": "missing_fe_port_config"
|
||||
}
|
||||
|
||||
if not be_port:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Missing environment variable BE_ARROW_FLIGHT_SQL_PORT, please configure Arrow Flight SQL BE port in .env file",
|
||||
"error_type": "missing_be_port_config"
|
||||
}
|
||||
|
||||
# Convert to integer and validate
|
||||
try:
|
||||
fe_port = int(fe_port)
|
||||
be_port = int(be_port)
|
||||
except ValueError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Invalid Arrow Flight SQL port configuration, please ensure FE_ARROW_FLIGHT_SQL_PORT and BE_ARROW_FLIGHT_SQL_PORT are valid numbers",
|
||||
"error_type": "invalid_port_format"
|
||||
}
|
||||
|
||||
# Get host address
|
||||
db_config = self.connection_manager.config.database
|
||||
fe_host = db_config.host
|
||||
|
||||
# Check FE Arrow Flight SQL port availability
|
||||
fe_available = self._check_port_connectivity(fe_host, fe_port)
|
||||
if not fe_available:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Cannot connect to FE Arrow Flight SQL port {fe_host}:{fe_port}, please check if service is running",
|
||||
"error_type": "fe_port_unavailable",
|
||||
"fe_host": fe_host,
|
||||
"fe_port": fe_port
|
||||
}
|
||||
|
||||
# Get BE host list
|
||||
be_hosts = await self._get_be_hosts()
|
||||
if not be_hosts:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Cannot get BE node information, please check cluster status",
|
||||
"error_type": "no_be_hosts"
|
||||
}
|
||||
|
||||
# Check at least one BE Arrow Flight SQL port availability
|
||||
be_available_count = 0
|
||||
be_check_results = []
|
||||
|
||||
for be_host in be_hosts[:3]: # Check first 3 BE nodes
|
||||
be_available = self._check_port_connectivity(be_host, be_port)
|
||||
be_check_results.append({
|
||||
"host": be_host,
|
||||
"port": be_port,
|
||||
"available": be_available
|
||||
})
|
||||
if be_available:
|
||||
be_available_count += 1
|
||||
|
||||
if be_available_count == 0:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Cannot connect to any BE Arrow Flight SQL port (port: {be_port}), please check if BE services are running",
|
||||
"error_type": "no_be_ports_available",
|
||||
"be_check_results": be_check_results
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"fe_host": fe_host,
|
||||
"fe_port": fe_port,
|
||||
"be_port": be_port,
|
||||
"be_hosts": be_hosts,
|
||||
"be_available_count": be_available_count,
|
||||
"be_check_results": be_check_results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Arrow Flight port check failed: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Arrow Flight port check failed: {str(e)}",
|
||||
"error_type": "port_check_error"
|
||||
}
|
||||
|
||||
def _check_port_connectivity(self, host: str, port: int, timeout: int | None = None) -> bool:
|
||||
"""Check port connectivity"""
|
||||
try:
|
||||
# Use config timeout if not specified
|
||||
if timeout is None:
|
||||
timeout = self.connection_manager.config.adbc.connection_timeout
|
||||
|
||||
with socket.create_connection((host, port), timeout=timeout):
|
||||
return True
|
||||
except (socket.timeout, socket.error, OSError):
|
||||
return False
|
||||
|
||||
async def _get_be_hosts(self) -> List[str]:
|
||||
"""Get BE host list"""
|
||||
try:
|
||||
db_config = self.connection_manager.config.database
|
||||
|
||||
# Use configured BE hosts first
|
||||
if db_config.be_hosts:
|
||||
logger.info(f"Using configured BE hosts: {db_config.be_hosts}")
|
||||
return db_config.be_hosts
|
||||
|
||||
# Get BE nodes via SHOW BACKENDS
|
||||
logger.info("No BE hosts configured, getting BE node information via SHOW BACKENDS")
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute("SHOW BACKENDS", auth_context=auth_context)
|
||||
|
||||
be_hosts = []
|
||||
for row in result.data:
|
||||
host = row.get("Host")
|
||||
alive = row.get("Alive", "").lower()
|
||||
if host and alive == "true":
|
||||
be_hosts.append(host)
|
||||
|
||||
logger.info(f"Got {len(be_hosts)} active BE nodes from SHOW BACKENDS")
|
||||
return be_hosts
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get BE hosts: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _import_adbc_modules(self) -> Dict[str, Any]:
|
||||
"""Import ADBC related modules"""
|
||||
try:
|
||||
# Import ADBC Driver Manager
|
||||
try:
|
||||
import adbc_driver_manager
|
||||
self.adbc_manager_module = adbc_driver_manager
|
||||
except ImportError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Missing adbc_driver_manager module, please install: pip install adbc_driver_manager",
|
||||
"error_type": "missing_adbc_manager"
|
||||
}
|
||||
|
||||
# Import ADBC Flight SQL Driver
|
||||
try:
|
||||
import adbc_driver_flightsql.dbapi as flight_sql
|
||||
self.flight_sql_module = flight_sql
|
||||
except ImportError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Missing adbc_driver_flightsql module, please install: pip install adbc_driver_flightsql",
|
||||
"error_type": "missing_flight_sql_driver"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"adbc_manager_version": getattr(adbc_driver_manager, '__version__', 'unknown'),
|
||||
"flight_sql_version": getattr(flight_sql, '__version__', 'unknown')
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ADBC module import failed: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"ADBC module import failed: {str(e)}",
|
||||
"error_type": "import_error"
|
||||
}
|
||||
|
||||
async def _create_adbc_connection(self) -> Dict[str, Any]:
|
||||
"""Create ADBC connection"""
|
||||
try:
|
||||
db_config = self.connection_manager.config.database
|
||||
fe_port = int(os.getenv("FE_ARROW_FLIGHT_SQL_PORT"))
|
||||
|
||||
# Build connection URI
|
||||
uri = f"grpc://{db_config.host}:{fe_port}"
|
||||
|
||||
# Create database connection parameters
|
||||
db_kwargs = {
|
||||
self.adbc_manager_module.DatabaseOptions.USERNAME.value: db_config.user,
|
||||
self.adbc_manager_module.DatabaseOptions.PASSWORD.value: db_config.password,
|
||||
}
|
||||
|
||||
# Create connection
|
||||
self.adbc_client = self.flight_sql_module.connect(
|
||||
uri=uri,
|
||||
db_kwargs=db_kwargs
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"uri": uri,
|
||||
"connection_established": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create ADBC connection: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed to create ADBC connection: {str(e)}",
|
||||
"error_type": "connection_error"
|
||||
}
|
||||
|
||||
async def _execute_query_with_adbc(
|
||||
self,
|
||||
sql: str,
|
||||
max_rows: int,
|
||||
timeout: int,
|
||||
return_format: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute query using ADBC"""
|
||||
try:
|
||||
if not self.adbc_client:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "ADBC connection not established",
|
||||
"error_type": "no_connection"
|
||||
}
|
||||
|
||||
# SECURITY FIX: Perform SQL security validation before executing
|
||||
auth_context = get_auth_context()
|
||||
if self.connection_manager.security_manager:
|
||||
# Always perform security validation, even without auth_context
|
||||
# Use a default context for basic SQL security checks
|
||||
validation_result = await self.connection_manager.security_manager.validate_sql_security(sql, auth_context)
|
||||
if not validation_result.is_valid:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"SQL security validation failed: {validation_result.error_message}",
|
||||
"error_type": "security_violation",
|
||||
"risk_level": validation_result.risk_level
|
||||
}
|
||||
|
||||
cursor = self.adbc_client.cursor()
|
||||
start_time = time.time()
|
||||
|
||||
# Execute query
|
||||
cursor.execute(sql)
|
||||
|
||||
# Get results based on return format
|
||||
if return_format == "arrow":
|
||||
# Return Arrow format
|
||||
arrow_data = cursor.fetchallarrow()
|
||||
|
||||
# Limit rows
|
||||
if len(arrow_data) > max_rows:
|
||||
arrow_data = arrow_data.slice(0, max_rows)
|
||||
|
||||
# Convert Arrow data to serializable format
|
||||
preview_df = arrow_data.to_pandas().head(10) if len(arrow_data) > 0 else None
|
||||
result_data = {
|
||||
"format": "arrow",
|
||||
"num_rows": len(arrow_data),
|
||||
"num_columns": len(arrow_data.schema),
|
||||
"column_names": arrow_data.schema.names,
|
||||
"column_types": [str(field.type) for field in arrow_data.schema],
|
||||
"data_preview": _convert_dataframe_to_json_serializable(preview_df) if preview_df is not None else [],
|
||||
"total_bytes": arrow_data.nbytes if hasattr(arrow_data, 'nbytes') else 0
|
||||
}
|
||||
|
||||
elif return_format == "pandas":
|
||||
# Return Pandas DataFrame
|
||||
df = cursor.fetch_df()
|
||||
|
||||
# Limit rows
|
||||
if len(df) > max_rows:
|
||||
df = df.head(max_rows)
|
||||
|
||||
result_data = {
|
||||
"format": "pandas",
|
||||
"num_rows": len(df),
|
||||
"num_columns": len(df.columns),
|
||||
"column_names": df.columns.tolist(),
|
||||
"column_types": df.dtypes.astype(str).tolist(),
|
||||
"data": _convert_dataframe_to_json_serializable(df),
|
||||
"memory_usage": int(df.memory_usage(deep=True).sum())
|
||||
}
|
||||
|
||||
else: # return_format == "dict"
|
||||
# Return dictionary format
|
||||
arrow_data = cursor.fetchallarrow()
|
||||
df = arrow_data.to_pandas()
|
||||
|
||||
# Limit rows
|
||||
if len(df) > max_rows:
|
||||
df = df.head(max_rows)
|
||||
|
||||
result_data = {
|
||||
"format": "dict",
|
||||
"num_rows": len(df),
|
||||
"num_columns": len(df.columns),
|
||||
"column_names": df.columns.tolist(),
|
||||
"column_types": df.dtypes.astype(str).tolist(),
|
||||
"data": _convert_dataframe_to_json_serializable(df)
|
||||
}
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
cursor.close()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"result": result_data,
|
||||
"execution_time": round(execution_time, 3),
|
||||
"sql": sql,
|
||||
"max_rows_applied": len(result_data.get("data", [])) >= max_rows
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ADBC query execution failed: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"ADBC query execution failed: {str(e)}",
|
||||
"error_type": "query_execution_error",
|
||||
"sql": sql
|
||||
}
|
||||
|
||||
async def get_adbc_connection_info(self) -> Dict[str, Any]:
|
||||
"""Get ADBC connection information and status"""
|
||||
try:
|
||||
# Check port status
|
||||
port_status = await self._check_arrow_flight_ports()
|
||||
|
||||
# Check module status
|
||||
module_status = await self._import_adbc_modules()
|
||||
|
||||
# Get configuration information
|
||||
db_config = self.connection_manager.config.database
|
||||
fe_port = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
|
||||
be_port = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
|
||||
|
||||
connection_info = {
|
||||
"adbc_available": module_status["success"],
|
||||
"ports_available": port_status["success"],
|
||||
"configuration": {
|
||||
"fe_host": db_config.host,
|
||||
"fe_arrow_flight_port": fe_port,
|
||||
"be_arrow_flight_port": be_port,
|
||||
"user": db_config.user
|
||||
},
|
||||
"port_status": port_status,
|
||||
"module_status": module_status,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
if port_status["success"] and module_status["success"]:
|
||||
connection_info["status"] = "ready"
|
||||
connection_info["message"] = "ADBC Arrow Flight SQL connection ready"
|
||||
else:
|
||||
connection_info["status"] = "not_ready"
|
||||
errors = []
|
||||
if not port_status["success"]:
|
||||
errors.append(port_status["error"])
|
||||
if not module_status["success"]:
|
||||
errors.append(module_status["error"])
|
||||
connection_info["message"] = "; ".join(errors)
|
||||
|
||||
return connection_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get ADBC connection information: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": f"Failed to get ADBC connection information: {str(e)}",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup resources"""
|
||||
try:
|
||||
if self.adbc_client:
|
||||
self.adbc_client.close()
|
||||
except:
|
||||
pass
|
||||
@@ -29,6 +29,13 @@ from pathlib import Path
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import (
|
||||
SQLSecurityError,
|
||||
validate_identifier,
|
||||
quote_identifier,
|
||||
build_table_reference,
|
||||
get_auth_context
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -46,10 +53,17 @@ class TableAnalyzer:
|
||||
sample_size: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""Get table summary information"""
|
||||
# SECURITY FIX: Validate table_name and get auth_context
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
except SQLSecurityError as e:
|
||||
raise ValueError(f"Invalid table name: {e}")
|
||||
|
||||
auth_context = get_auth_context()
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
|
||||
# Get table basic information
|
||||
table_info_sql = f"""
|
||||
# Get table basic information using parameterized query
|
||||
table_info_sql = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment,
|
||||
@@ -58,17 +72,17 @@ class TableAnalyzer:
|
||||
engine
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '{table_name}'
|
||||
AND table_name = %s
|
||||
"""
|
||||
|
||||
table_info_result = await connection.execute(table_info_sql)
|
||||
table_info_result = await connection.execute(table_info_sql, params=(table_name,), auth_context=auth_context)
|
||||
if not table_info_result.data:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
|
||||
table_info = table_info_result.data[0]
|
||||
|
||||
# Get column information
|
||||
columns_sql = f"""
|
||||
# Get column information using parameterized query
|
||||
columns_sql = """
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
@@ -76,11 +90,11 @@ class TableAnalyzer:
|
||||
column_comment
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '{table_name}'
|
||||
AND table_name = %s
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
columns_result = await connection.execute(columns_sql)
|
||||
columns_result = await connection.execute(columns_sql, params=(table_name,), auth_context=auth_context)
|
||||
|
||||
summary = {
|
||||
"table_name": table_info["table_name"],
|
||||
@@ -92,10 +106,11 @@ class TableAnalyzer:
|
||||
"columns": columns_result.data,
|
||||
}
|
||||
|
||||
# Get sample data
|
||||
# Get sample data using quoted identifier
|
||||
if include_sample and sample_size > 0:
|
||||
sample_sql = f"SELECT * FROM {table_name} LIMIT {sample_size}"
|
||||
sample_result = await connection.execute(sample_sql)
|
||||
quoted_table = quote_identifier(table_name, "table name")
|
||||
sample_sql = f"SELECT * FROM {quoted_table} LIMIT {sample_size}"
|
||||
sample_result = await connection.execute(sample_sql, auth_context=auth_context)
|
||||
summary["sample_data"] = sample_result.data
|
||||
|
||||
return summary
|
||||
@@ -120,7 +135,8 @@ class TableAnalyzer:
|
||||
FROM {table_name}
|
||||
"""
|
||||
|
||||
basic_result = await connection.execute(basic_stats_sql)
|
||||
auth_context = get_auth_context()
|
||||
basic_result = await connection.execute(basic_stats_sql, auth_context=auth_context)
|
||||
if not basic_result.data:
|
||||
return {
|
||||
"success": False,
|
||||
@@ -144,7 +160,7 @@ class TableAnalyzer:
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
distribution_result = await connection.execute(distribution_sql)
|
||||
distribution_result = await connection.execute(distribution_sql, auth_context=auth_context)
|
||||
analysis["value_distribution"] = distribution_result.data
|
||||
|
||||
if analysis_type == "detailed":
|
||||
@@ -159,7 +175,7 @@ class TableAnalyzer:
|
||||
WHERE {column_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
numeric_result = await connection.execute(numeric_stats_sql)
|
||||
numeric_result = await connection.execute(numeric_stats_sql, auth_context=auth_context)
|
||||
if numeric_result.data:
|
||||
analysis.update(numeric_result.data[0])
|
||||
except Exception:
|
||||
@@ -196,7 +212,8 @@ class TableAnalyzer:
|
||||
AND table_name = '{table_name}'
|
||||
"""
|
||||
|
||||
table_result = await connection.execute(table_info_sql)
|
||||
auth_context = get_auth_context()
|
||||
table_result = await connection.execute(table_info_sql, auth_context=auth_context)
|
||||
if not table_result.data:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
|
||||
@@ -211,7 +228,7 @@ class TableAnalyzer:
|
||||
AND table_name != %s
|
||||
"""
|
||||
|
||||
all_tables_result = await connection.execute(all_tables_sql, (table_name,))
|
||||
all_tables_result = await connection.execute(all_tables_sql, params=(table_name,), auth_context=auth_context)
|
||||
|
||||
return {
|
||||
"center_table": table_result.data[0],
|
||||
@@ -291,7 +308,8 @@ class PerformanceMonitor:
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
tables_result = await connection.execute(tables_sql)
|
||||
auth_context = get_auth_context()
|
||||
tables_result = await connection.execute(tables_sql, auth_context=auth_context)
|
||||
stats = {
|
||||
"metric_type": "tables",
|
||||
"time_range": time_range,
|
||||
@@ -379,9 +397,23 @@ class SQLAnalyzer:
|
||||
|
||||
logger.info(f"Generating SQL explain for query ID: {query_id}")
|
||||
|
||||
# 🔧 FIX: Get auth_context for token-bound database configuration
|
||||
auth_context = None
|
||||
try:
|
||||
from .security import mcp_auth_context_var
|
||||
auth_context = mcp_auth_context_var.get()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Switch database if specified
|
||||
# SECURITY FIX: Validate and quote db_name
|
||||
if db_name:
|
||||
await self.connection_manager.execute_query("explain_session", f"USE {db_name}")
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return {"success": False, "error": f"Invalid database name: {e}"}
|
||||
safe_db = quote_identifier(db_name, "database name")
|
||||
await self.connection_manager.execute_query("explain_session", f"USE {safe_db}", None, auth_context)
|
||||
|
||||
# Construct EXPLAIN query
|
||||
explain_type = "EXPLAIN VERBOSE" if verbose else "EXPLAIN"
|
||||
@@ -390,7 +422,7 @@ class SQLAnalyzer:
|
||||
logger.info(f"Executing explain query: {explain_sql}")
|
||||
|
||||
# Execute explain query
|
||||
result = await self.connection_manager.execute_query("explain_session", explain_sql)
|
||||
result = await self.connection_manager.execute_query("explain_session", explain_sql, None, auth_context)
|
||||
|
||||
# Format explain output
|
||||
explain_content = []
|
||||
@@ -515,24 +547,36 @@ class SQLAnalyzer:
|
||||
|
||||
try:
|
||||
# Switch to specified database/catalog if provided
|
||||
# SECURITY FIX: Validate identifiers before using in SQL
|
||||
if catalog_name:
|
||||
await connection.execute(f"USE `{catalog_name}`")
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
return {"success": False, "error": f"Invalid catalog name: {e}"}
|
||||
safe_catalog = quote_identifier(catalog_name, "catalog name")
|
||||
auth_context = get_auth_context()
|
||||
await connection.execute(f"SWITCH {safe_catalog}", auth_context=auth_context)
|
||||
if db_name:
|
||||
await connection.execute(f"USE `{db_name}`")
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return {"success": False, "error": f"Invalid database name: {e}"}
|
||||
safe_db = quote_identifier(db_name, "database name")
|
||||
await connection.execute(f"USE {safe_db}", auth_context=auth_context)
|
||||
|
||||
# Set trace ID for the session using session variable
|
||||
# According to official docs: set session_context="trace_id:your_trace_id"
|
||||
await connection.execute(f'set session_context="trace_id:{trace_id}"')
|
||||
await connection.execute(f'set session_context="trace_id:{trace_id}"', auth_context=auth_context)
|
||||
logger.info(f"Set trace ID: {trace_id}")
|
||||
|
||||
# Enable profile
|
||||
await connection.execute(f'set enable_profile=true')
|
||||
await connection.execute(f'set enable_profile=true', auth_context=auth_context)
|
||||
logger.info(f"Enabled profile")
|
||||
|
||||
# Execute the SQL statement
|
||||
logger.info(f"Executing SQL with trace ID: {sql}")
|
||||
start_time = time.time()
|
||||
sql_result = await connection.execute(sql)
|
||||
sql_result = await connection.execute(sql, auth_context=auth_context)
|
||||
execution_time = time.time() - start_time
|
||||
logger.info(f"SQL execution completed in {execution_time:.3f}s")
|
||||
|
||||
|
||||
@@ -32,6 +32,8 @@ try:
|
||||
except ImportError:
|
||||
load_dotenv = None
|
||||
|
||||
from .logger import get_logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
@@ -52,23 +54,73 @@ class DatabaseConfig:
|
||||
be_hosts: list[str] = field(default_factory=list)
|
||||
be_webserver_port: int = 8040
|
||||
|
||||
# Arrow Flight SQL Configuration (Required for ADBC tools)
|
||||
fe_arrow_flight_sql_port: int | None = None
|
||||
be_arrow_flight_sql_port: int | None = None
|
||||
|
||||
# Connection pool configuration
|
||||
min_connections: int = 5
|
||||
# Note: min_connections is fixed at 0 to avoid at_eof connection issues
|
||||
# This prevents pre-creation of connections which can cause state problems
|
||||
_min_connections: int = field(default=0, init=False) # Internal use only, always 0
|
||||
max_connections: int = 20
|
||||
connection_timeout: int = 30
|
||||
health_check_interval: int = 60
|
||||
max_connection_age: int = 3600
|
||||
|
||||
@property
|
||||
def min_connections(self) -> int:
|
||||
"""Minimum connections is always 0 to prevent at_eof issues"""
|
||||
return self._min_connections
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityConfig:
|
||||
"""Security configuration"""
|
||||
|
||||
# Authentication configuration
|
||||
auth_type: str = "token" # token, basic, oauth
|
||||
token_secret: str = "default_secret"
|
||||
# Independent authentication switches - any one enabled allows that method
|
||||
enable_token_auth: bool = False # Enable token-based authentication (default: disabled)
|
||||
enable_jwt_auth: bool = False # Enable JWT authentication (default: disabled)
|
||||
enable_oauth_auth: bool = False # Enable OAuth 2.0/OIDC authentication (default: disabled)
|
||||
|
||||
# Legacy configuration (kept for backward compatibility)
|
||||
auth_type: str = "token" # jwt, token, basic, oauth (deprecated: use individual switches)
|
||||
token_secret: str = "default_secret" # Legacy token secret for backward compatibility
|
||||
token_expiry: int = 3600
|
||||
|
||||
# Enhanced Token Authentication Configuration
|
||||
token_file_path: str = "tokens.json" # Path to token configuration file
|
||||
enable_token_expiry: bool = True # Enable token expiration
|
||||
default_token_expiry_hours: int = 24 * 30 # Default expiry: 30 days
|
||||
token_hash_algorithm: str = "sha256" # Token hashing algorithm: sha256, sha512
|
||||
|
||||
# Token Management Security (New in v0.6.0)
|
||||
enable_http_token_management: bool = False # Enable HTTP token management endpoints (default: disabled for security)
|
||||
token_management_admin_token: str = "" # Admin token for token management endpoints (required if HTTP management enabled)
|
||||
token_management_allowed_ips: list[str] = field(default_factory=lambda: ["127.0.0.1", "::1", "localhost"]) # Allowed IPs for token management
|
||||
require_admin_auth: bool = True # Require admin authentication for token management (default: true)
|
||||
|
||||
# JWT Configuration
|
||||
jwt_algorithm: str = "RS256" # RS256, ES256, HS256
|
||||
jwt_issuer: str = "doris-mcp-server"
|
||||
jwt_audience: str = "doris-mcp-client"
|
||||
jwt_private_key_path: str = ""
|
||||
jwt_public_key_path: str = ""
|
||||
jwt_secret_key: str = "" # Only used for HS256 algorithm
|
||||
jwt_access_token_expiry: int = 3600 # 1 hour
|
||||
jwt_refresh_token_expiry: int = 86400 # 24 hours
|
||||
enable_token_refresh: bool = True
|
||||
enable_token_revocation: bool = True
|
||||
key_rotation_interval: int = 30 * 24 * 3600 # 30 days in seconds
|
||||
|
||||
# JWT Security Features
|
||||
jwt_require_iat: bool = True # Require "issued at" claim
|
||||
jwt_require_exp: bool = True # Require "expires at" claim
|
||||
jwt_require_nbf: bool = False # Require "not before" claim
|
||||
jwt_leeway: int = 10 # Clock skew tolerance in seconds
|
||||
jwt_verify_signature: bool = True # Verify JWT signature
|
||||
jwt_verify_audience: bool = True # Verify audience claim
|
||||
jwt_verify_issuer: bool = True # Verify issuer claim
|
||||
|
||||
# SQL security configuration
|
||||
enable_security_check: bool = True # Main switch: whether to enable SQL security check
|
||||
blocked_keywords: list[str] = field(
|
||||
@@ -102,6 +154,45 @@ class SecurityConfig:
|
||||
enable_masking: bool = True
|
||||
masking_rules: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
# OAuth 2.0/OIDC Configuration
|
||||
oauth_enabled: bool = False
|
||||
oauth_provider: str = "" # 'google', 'microsoft', 'github', 'custom'
|
||||
oauth_client_id: str = ""
|
||||
oauth_client_secret: str = ""
|
||||
oauth_redirect_uri: str = "http://localhost:3000/auth/callback"
|
||||
|
||||
# OIDC Discovery
|
||||
oidc_discovery_url: str = "" # e.g., https://accounts.google.com/.well-known/openid_configuration
|
||||
oauth_authorization_endpoint: str = ""
|
||||
oauth_token_endpoint: str = ""
|
||||
oauth_userinfo_endpoint: str = ""
|
||||
oauth_jwks_uri: str = ""
|
||||
|
||||
# OAuth Scopes and Settings
|
||||
oauth_scopes: list[str] = field(default_factory=list)
|
||||
oauth_state_expiry: int = 600 # State parameter expiry in seconds (10 minutes)
|
||||
oauth_pkce_enabled: bool = True # Enable PKCE for better security
|
||||
oauth_nonce_enabled: bool = True # Enable nonce for OIDC
|
||||
|
||||
# User Mapping Configuration
|
||||
oauth_user_id_claim: str = "sub" # JWT claim for user ID
|
||||
oauth_email_claim: str = "email"
|
||||
oauth_name_claim: str = "name"
|
||||
oauth_roles_claim: str = "roles" # Custom claim for roles
|
||||
oauth_default_roles: list[str] = field(default_factory=lambda: ["oauth_user"])
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize default OAuth scopes based on provider"""
|
||||
if not self.oauth_scopes and self.oauth_provider:
|
||||
if self.oauth_provider == "google":
|
||||
self.oauth_scopes = ["openid", "email", "profile"]
|
||||
elif self.oauth_provider == "microsoft":
|
||||
self.oauth_scopes = ["openid", "profile", "email", "User.Read"]
|
||||
elif self.oauth_provider == "github":
|
||||
self.oauth_scopes = ["user:email", "read:user"]
|
||||
else:
|
||||
self.oauth_scopes = ["openid", "email", "profile"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerformanceConfig:
|
||||
@@ -124,6 +215,49 @@ class PerformanceConfig:
|
||||
max_response_content_size: int = 4096
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataQualityConfig:
|
||||
"""Data quality analysis configuration"""
|
||||
|
||||
# Column analysis configuration
|
||||
max_columns_per_batch: int = 20 # Maximum columns to analyze in a single batch
|
||||
default_sample_size: int = 100000 # Default sample size for analysis
|
||||
|
||||
# Sampling strategy configuration
|
||||
small_table_threshold: int = 100000 # Tables smaller than this use full table analysis
|
||||
medium_table_threshold: int = 1000000 # Tables smaller than this use simple LIMIT sampling
|
||||
# Tables larger than medium_table_threshold use systematic sampling
|
||||
|
||||
# Performance optimization
|
||||
enable_batch_analysis: bool = True # Enable batch analysis for multiple columns
|
||||
batch_timeout: int = 300 # Timeout for batch analysis in seconds
|
||||
|
||||
# Accuracy vs Performance trade-off
|
||||
enable_fast_mode: bool = False # Use approximate algorithms for faster results
|
||||
fast_mode_sample_size: int = 10000 # Sample size for fast mode
|
||||
|
||||
# Statistical analysis configuration
|
||||
enable_distribution_analysis: bool = True # Enable distribution analysis
|
||||
histogram_bins: int = 20 # Number of bins for histogram analysis
|
||||
percentile_levels: list[float] = field(default_factory=lambda: [0.25, 0.5, 0.75, 0.95, 0.99]) # Percentile levels to calculate
|
||||
|
||||
|
||||
@dataclass
|
||||
class ADBCConfig:
|
||||
"""ADBC (Arrow Flight SQL) configuration"""
|
||||
|
||||
# Default query parameters
|
||||
default_max_rows: int = 100000
|
||||
default_timeout: int = 60
|
||||
default_return_format: str = "arrow" # "arrow", "pandas", "dict"
|
||||
|
||||
# Connection timeout for ADBC
|
||||
connection_timeout: int = 30
|
||||
|
||||
# Whether to enable ADBC tools
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoggingConfig:
|
||||
"""Logging configuration"""
|
||||
@@ -138,6 +272,11 @@ class LoggingConfig:
|
||||
enable_audit: bool = True
|
||||
audit_file_path: str | None = None
|
||||
|
||||
# Log cleanup configuration
|
||||
enable_cleanup: bool = True
|
||||
max_age_days: int = 30
|
||||
cleanup_interval_hours: int = 24
|
||||
|
||||
|
||||
@dataclass
|
||||
class MonitoringConfig:
|
||||
@@ -164,6 +303,7 @@ class DorisConfig:
|
||||
# Basic configuration
|
||||
server_name: str = "doris-mcp-server"
|
||||
server_version: str = "0.4.1"
|
||||
server_host: str = "localhost"
|
||||
server_port: int = 3000
|
||||
transport: str = "stdio"
|
||||
|
||||
@@ -174,8 +314,10 @@ class DorisConfig:
|
||||
database: DatabaseConfig = field(default_factory=DatabaseConfig)
|
||||
security: SecurityConfig = field(default_factory=SecurityConfig)
|
||||
performance: PerformanceConfig = field(default_factory=PerformanceConfig)
|
||||
data_quality: DataQualityConfig = field(default_factory=DataQualityConfig)
|
||||
logging: LoggingConfig = field(default_factory=LoggingConfig)
|
||||
monitoring: MonitoringConfig = field(default_factory=MonitoringConfig)
|
||||
adbc: ADBCConfig = field(default_factory=ADBCConfig)
|
||||
|
||||
# Custom configuration
|
||||
custom_config: dict[str, Any] = field(default_factory=dict)
|
||||
@@ -205,6 +347,9 @@ class DorisConfig:
|
||||
def from_env(cls, env_file: str | None = None) -> "DorisConfig":
|
||||
"""Load configuration from environment variables
|
||||
|
||||
The kv pairs in the. env file will be loaded as environment variables,
|
||||
but the existing environment variables will not be overridden.
|
||||
|
||||
Args:
|
||||
env_file: .env file path, if None, search in the following order:
|
||||
.env, .env.local, .env.production, .env.development
|
||||
@@ -223,7 +368,7 @@ class DorisConfig:
|
||||
env_files = [".env", ".env.local", ".env.production", ".env.development"]
|
||||
for env_path in env_files:
|
||||
if Path(env_path).exists():
|
||||
load_dotenv(env_path)
|
||||
load_dotenv(env_path, override=False)
|
||||
logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_path}")
|
||||
break
|
||||
else:
|
||||
@@ -233,24 +378,45 @@ class DorisConfig:
|
||||
|
||||
config = cls()
|
||||
|
||||
# Database configuration
|
||||
config.database.host = os.getenv("DORIS_HOST", config.database.host)
|
||||
config.database.port = int(os.getenv("DORIS_PORT", str(config.database.port)))
|
||||
config.database.user = os.getenv("DORIS_USER", config.database.user)
|
||||
config.database.password = os.getenv("DORIS_PASSWORD", config.database.password)
|
||||
config.database.database = os.getenv("DORIS_DATABASE", config.database.database)
|
||||
config.database.fe_http_port = int(os.getenv("DORIS_FE_HTTP_PORT", str(config.database.fe_http_port)))
|
||||
# Database configuration - handle empty strings properly
|
||||
doris_host = os.getenv("DORIS_HOST", "").strip()
|
||||
config.database.host = doris_host if doris_host else config.database.host
|
||||
|
||||
doris_port = os.getenv("DORIS_PORT", "").strip()
|
||||
if doris_port and doris_port.isdigit():
|
||||
config.database.port = int(doris_port)
|
||||
|
||||
doris_user = os.getenv("DORIS_USER", "").strip()
|
||||
config.database.user = doris_user if doris_user else config.database.user
|
||||
|
||||
doris_password = os.getenv("DORIS_PASSWORD", "")
|
||||
config.database.password = doris_password if doris_password else config.database.password
|
||||
|
||||
doris_database = os.getenv("DORIS_DATABASE", "").strip()
|
||||
config.database.database = doris_database if doris_database else config.database.database
|
||||
|
||||
doris_fe_http_port = os.getenv("DORIS_FE_HTTP_PORT", "").strip()
|
||||
if doris_fe_http_port and doris_fe_http_port.isdigit():
|
||||
config.database.fe_http_port = int(doris_fe_http_port)
|
||||
|
||||
# BE nodes configuration
|
||||
be_hosts_env = os.getenv("DORIS_BE_HOSTS", "")
|
||||
if be_hosts_env:
|
||||
config.database.be_hosts = [host.strip() for host in be_hosts_env.split(",") if host.strip()]
|
||||
config.database.be_webserver_port = int(os.getenv("DORIS_BE_WEBSERVER_PORT", str(config.database.be_webserver_port)))
|
||||
be_webserver_port = os.getenv("DORIS_BE_WEBSERVER_PORT", "").strip()
|
||||
if be_webserver_port and be_webserver_port.isdigit():
|
||||
config.database.be_webserver_port = int(be_webserver_port)
|
||||
|
||||
# Arrow Flight SQL Configuration
|
||||
fe_arrow_port_env = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
|
||||
if fe_arrow_port_env:
|
||||
config.database.fe_arrow_flight_sql_port = int(fe_arrow_port_env)
|
||||
|
||||
be_arrow_port_env = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
|
||||
if be_arrow_port_env:
|
||||
config.database.be_arrow_flight_sql_port = int(be_arrow_port_env)
|
||||
|
||||
# Connection pool configuration
|
||||
config.database.min_connections = int(
|
||||
os.getenv("DORIS_MIN_CONNECTIONS", str(config.database.min_connections))
|
||||
)
|
||||
config.database.max_connections = int(
|
||||
os.getenv("DORIS_MAX_CONNECTIONS", str(config.database.max_connections))
|
||||
)
|
||||
@@ -265,6 +431,10 @@ class DorisConfig:
|
||||
)
|
||||
|
||||
# Security configuration
|
||||
# Independent authentication switches
|
||||
config.security.enable_token_auth = os.getenv("ENABLE_TOKEN_AUTH", str(config.security.enable_token_auth)).lower() == "true"
|
||||
config.security.enable_jwt_auth = os.getenv("ENABLE_JWT_AUTH", str(config.security.enable_jwt_auth)).lower() == "true"
|
||||
config.security.enable_oauth_auth = os.getenv("ENABLE_OAUTH_AUTH", str(config.security.enable_oauth_auth)).lower() == "true"
|
||||
config.security.auth_type = os.getenv("AUTH_TYPE", config.security.auth_type)
|
||||
config.security.token_secret = os.getenv("TOKEN_SECRET", config.security.token_secret)
|
||||
config.security.token_expiry = int(
|
||||
@@ -296,6 +466,31 @@ class DorisConfig:
|
||||
os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true"
|
||||
)
|
||||
|
||||
# Enhanced Token Authentication configuration
|
||||
config.security.token_file_path = os.getenv("TOKEN_FILE_PATH", config.security.token_file_path)
|
||||
config.security.enable_token_expiry = (
|
||||
os.getenv("ENABLE_TOKEN_EXPIRY", str(config.security.enable_token_expiry).lower()).lower() == "true"
|
||||
)
|
||||
config.security.default_token_expiry_hours = int(
|
||||
os.getenv("DEFAULT_TOKEN_EXPIRY_HOURS", str(config.security.default_token_expiry_hours))
|
||||
)
|
||||
config.security.token_hash_algorithm = os.getenv("TOKEN_HASH_ALGORITHM", config.security.token_hash_algorithm)
|
||||
|
||||
# Token Management Security Configuration (New in v0.6.0)
|
||||
config.security.enable_http_token_management = (
|
||||
os.getenv("ENABLE_HTTP_TOKEN_MANAGEMENT", str(config.security.enable_http_token_management).lower()).lower() == "true"
|
||||
)
|
||||
config.security.token_management_admin_token = os.getenv("TOKEN_MANAGEMENT_ADMIN_TOKEN", config.security.token_management_admin_token)
|
||||
|
||||
# Parse allowed IPs from comma-separated string
|
||||
allowed_ips_str = os.getenv("TOKEN_MANAGEMENT_ALLOWED_IPS", "")
|
||||
if allowed_ips_str:
|
||||
config.security.token_management_allowed_ips = [ip.strip() for ip in allowed_ips_str.split(",") if ip.strip()]
|
||||
|
||||
config.security.require_admin_auth = (
|
||||
os.getenv("REQUIRE_ADMIN_AUTH", str(config.security.require_admin_auth).lower()).lower() == "true"
|
||||
)
|
||||
|
||||
# Performance configuration
|
||||
config.performance.enable_query_cache = (
|
||||
os.getenv("ENABLE_QUERY_CACHE", "true").lower() == "true"
|
||||
@@ -323,6 +518,15 @@ class DorisConfig:
|
||||
os.getenv("ENABLE_AUDIT", str(config.logging.enable_audit).lower()).lower() == "true"
|
||||
)
|
||||
config.logging.audit_file_path = os.getenv("AUDIT_FILE_PATH", config.logging.audit_file_path)
|
||||
config.logging.enable_cleanup = (
|
||||
os.getenv("ENABLE_LOG_CLEANUP", str(config.logging.enable_cleanup).lower()).lower() == "true"
|
||||
)
|
||||
config.logging.max_age_days = int(
|
||||
os.getenv("LOG_MAX_AGE_DAYS", str(config.logging.max_age_days))
|
||||
)
|
||||
config.logging.cleanup_interval_hours = int(
|
||||
os.getenv("LOG_CLEANUP_INTERVAL_HOURS", str(config.logging.cleanup_interval_hours))
|
||||
)
|
||||
|
||||
# Monitoring configuration
|
||||
config.monitoring.enable_metrics = (
|
||||
@@ -339,10 +543,59 @@ class DorisConfig:
|
||||
)
|
||||
config.monitoring.alert_webhook_url = os.getenv("ALERT_WEBHOOK_URL", config.monitoring.alert_webhook_url)
|
||||
|
||||
# ADBC configuration
|
||||
config.adbc.default_max_rows = int(
|
||||
os.getenv("ADBC_DEFAULT_MAX_ROWS", str(config.adbc.default_max_rows))
|
||||
)
|
||||
config.adbc.default_timeout = int(
|
||||
os.getenv("ADBC_DEFAULT_TIMEOUT", str(config.adbc.default_timeout))
|
||||
)
|
||||
config.adbc.default_return_format = os.getenv("ADBC_DEFAULT_RETURN_FORMAT", config.adbc.default_return_format)
|
||||
config.adbc.connection_timeout = int(
|
||||
os.getenv("ADBC_CONNECTION_TIMEOUT", str(config.adbc.connection_timeout))
|
||||
)
|
||||
config.adbc.enabled = (
|
||||
os.getenv("ADBC_ENABLED", str(config.adbc.enabled).lower()).lower() == "true"
|
||||
)
|
||||
|
||||
# Data quality configuration
|
||||
config.data_quality.max_columns_per_batch = int(
|
||||
os.getenv("DATA_QUALITY_MAX_COLUMNS_PER_BATCH", str(config.data_quality.max_columns_per_batch))
|
||||
)
|
||||
config.data_quality.default_sample_size = int(
|
||||
os.getenv("DATA_QUALITY_DEFAULT_SAMPLE_SIZE", str(config.data_quality.default_sample_size))
|
||||
)
|
||||
config.data_quality.small_table_threshold = int(
|
||||
os.getenv("DATA_QUALITY_SMALL_TABLE_THRESHOLD", str(config.data_quality.small_table_threshold))
|
||||
)
|
||||
config.data_quality.medium_table_threshold = int(
|
||||
os.getenv("DATA_QUALITY_MEDIUM_TABLE_THRESHOLD", str(config.data_quality.medium_table_threshold))
|
||||
)
|
||||
config.data_quality.enable_batch_analysis = (
|
||||
os.getenv("DATA_QUALITY_ENABLE_BATCH_ANALYSIS", str(config.data_quality.enable_batch_analysis).lower()).lower() == "true"
|
||||
)
|
||||
config.data_quality.batch_timeout = int(
|
||||
os.getenv("DATA_QUALITY_BATCH_TIMEOUT", str(config.data_quality.batch_timeout))
|
||||
)
|
||||
config.data_quality.enable_fast_mode = (
|
||||
os.getenv("DATA_QUALITY_ENABLE_FAST_MODE", str(config.data_quality.enable_fast_mode).lower()).lower() == "true"
|
||||
)
|
||||
config.data_quality.fast_mode_sample_size = int(
|
||||
os.getenv("DATA_QUALITY_FAST_MODE_SAMPLE_SIZE", str(config.data_quality.fast_mode_sample_size))
|
||||
)
|
||||
config.data_quality.enable_distribution_analysis = (
|
||||
os.getenv("DATA_QUALITY_ENABLE_DISTRIBUTION_ANALYSIS", str(config.data_quality.enable_distribution_analysis).lower()).lower() == "true"
|
||||
)
|
||||
config.data_quality.histogram_bins = int(
|
||||
os.getenv("DATA_QUALITY_HISTOGRAM_BINS", str(config.data_quality.histogram_bins))
|
||||
)
|
||||
|
||||
# Server configuration
|
||||
config.server_name = os.getenv("SERVER_NAME", config.server_name)
|
||||
config.server_version = os.getenv("SERVER_VERSION", config.server_version)
|
||||
config.server_port = int(os.getenv("SERVER_PORT", str(config.server_port)))
|
||||
server_port = os.getenv("SERVER_PORT", "").strip()
|
||||
if server_port and server_port.isdigit():
|
||||
config.server_port = int(server_port)
|
||||
config.temp_files_dir = os.getenv("TEMP_FILES_DIR", config.temp_files_dir)
|
||||
|
||||
return config
|
||||
@@ -378,6 +631,13 @@ class DorisConfig:
|
||||
if hasattr(config.performance, key):
|
||||
setattr(config.performance, key, value)
|
||||
|
||||
# Update data quality configuration
|
||||
if "data_quality" in config_data:
|
||||
dq_config = config_data["data_quality"]
|
||||
for key, value in dq_config.items():
|
||||
if hasattr(config.data_quality, key):
|
||||
setattr(config.data_quality, key, value)
|
||||
|
||||
# Update logging configuration
|
||||
if "logging" in config_data:
|
||||
log_config = config_data["logging"]
|
||||
@@ -392,6 +652,13 @@ class DorisConfig:
|
||||
if hasattr(config.monitoring, key):
|
||||
setattr(config.monitoring, key, value)
|
||||
|
||||
# Update ADBC configuration
|
||||
if "adbc" in config_data:
|
||||
adbc_config = config_data["adbc"]
|
||||
for key, value in adbc_config.items():
|
||||
if hasattr(config.adbc, key):
|
||||
setattr(config.adbc, key, value)
|
||||
|
||||
# Custom configuration
|
||||
config.custom_config = config_data.get("custom", {})
|
||||
|
||||
@@ -414,7 +681,9 @@ class DorisConfig:
|
||||
"fe_http_port": self.database.fe_http_port,
|
||||
"be_hosts": self.database.be_hosts,
|
||||
"be_webserver_port": self.database.be_webserver_port,
|
||||
"min_connections": self.database.min_connections,
|
||||
"fe_arrow_flight_sql_port": self.database.fe_arrow_flight_sql_port,
|
||||
"be_arrow_flight_sql_port": self.database.be_arrow_flight_sql_port,
|
||||
"min_connections": self.database.min_connections, # Always 0, shown for reference
|
||||
"max_connections": self.database.max_connections,
|
||||
"connection_timeout": self.database.connection_timeout,
|
||||
"health_check_interval": self.database.health_check_interval,
|
||||
@@ -442,6 +711,19 @@ class DorisConfig:
|
||||
"idle_timeout": self.performance.idle_timeout,
|
||||
"max_response_content_size": self.performance.max_response_content_size,
|
||||
},
|
||||
"data_quality": {
|
||||
"max_columns_per_batch": self.data_quality.max_columns_per_batch,
|
||||
"default_sample_size": self.data_quality.default_sample_size,
|
||||
"small_table_threshold": self.data_quality.small_table_threshold,
|
||||
"medium_table_threshold": self.data_quality.medium_table_threshold,
|
||||
"enable_batch_analysis": self.data_quality.enable_batch_analysis,
|
||||
"batch_timeout": self.data_quality.batch_timeout,
|
||||
"enable_fast_mode": self.data_quality.enable_fast_mode,
|
||||
"fast_mode_sample_size": self.data_quality.fast_mode_sample_size,
|
||||
"enable_distribution_analysis": self.data_quality.enable_distribution_analysis,
|
||||
"histogram_bins": self.data_quality.histogram_bins,
|
||||
"percentile_levels": self.data_quality.percentile_levels,
|
||||
},
|
||||
"logging": {
|
||||
"level": self.logging.level,
|
||||
"format": self.logging.format,
|
||||
@@ -450,6 +732,9 @@ class DorisConfig:
|
||||
"backup_count": self.logging.backup_count,
|
||||
"enable_audit": self.logging.enable_audit,
|
||||
"audit_file_path": self.logging.audit_file_path,
|
||||
"enable_cleanup": self.logging.enable_cleanup,
|
||||
"max_age_days": self.logging.max_age_days,
|
||||
"cleanup_interval_hours": self.logging.cleanup_interval_hours,
|
||||
},
|
||||
"monitoring": {
|
||||
"enable_metrics": self.monitoring.enable_metrics,
|
||||
@@ -460,6 +745,13 @@ class DorisConfig:
|
||||
"enable_alerts": self.monitoring.enable_alerts,
|
||||
"alert_webhook_url": self.monitoring.alert_webhook_url,
|
||||
},
|
||||
"adbc": {
|
||||
"default_max_rows": self.adbc.default_max_rows,
|
||||
"default_timeout": self.adbc.default_timeout,
|
||||
"default_return_format": self.adbc.default_return_format,
|
||||
"connection_timeout": self.adbc.connection_timeout,
|
||||
"enabled": self.adbc.enabled,
|
||||
},
|
||||
"custom": self.custom_config,
|
||||
}
|
||||
|
||||
@@ -492,11 +784,8 @@ class DorisConfig:
|
||||
if not self.database.user:
|
||||
errors.append("Database username cannot be empty")
|
||||
|
||||
if self.database.min_connections <= 0:
|
||||
errors.append("Minimum connections must be greater than 0")
|
||||
|
||||
if self.database.max_connections <= self.database.min_connections:
|
||||
errors.append("Maximum connections must be greater than minimum connections")
|
||||
if self.database.max_connections <= 0:
|
||||
errors.append("Maximum connections must be greater than 0")
|
||||
|
||||
# Validate security configuration
|
||||
if self.security.auth_type not in ["token", "basic", "oauth"]:
|
||||
@@ -521,6 +810,31 @@ class DorisConfig:
|
||||
if self.performance.query_timeout <= 0:
|
||||
errors.append("Query timeout must be greater than 0")
|
||||
|
||||
# Validate data quality configuration
|
||||
if self.data_quality.max_columns_per_batch <= 0:
|
||||
errors.append("Max columns per batch must be greater than 0")
|
||||
|
||||
if self.data_quality.default_sample_size <= 0:
|
||||
errors.append("Default sample size must be greater than 0")
|
||||
|
||||
if self.data_quality.small_table_threshold <= 0:
|
||||
errors.append("Small table threshold must be greater than 0")
|
||||
|
||||
if self.data_quality.medium_table_threshold <= 0:
|
||||
errors.append("Medium table threshold must be greater than 0")
|
||||
|
||||
if self.data_quality.small_table_threshold >= self.data_quality.medium_table_threshold:
|
||||
errors.append("Small table threshold must be less than medium table threshold")
|
||||
|
||||
if self.data_quality.batch_timeout <= 0:
|
||||
errors.append("Batch timeout must be greater than 0")
|
||||
|
||||
if self.data_quality.fast_mode_sample_size <= 0:
|
||||
errors.append("Fast mode sample size must be greater than 0")
|
||||
|
||||
if self.data_quality.histogram_bins <= 0:
|
||||
errors.append("Histogram bins must be greater than 0")
|
||||
|
||||
# Validate logging configuration
|
||||
if self.logging.level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
|
||||
errors.append("Log level must be one of DEBUG, INFO, WARNING, ERROR, or CRITICAL")
|
||||
@@ -531,6 +845,12 @@ class DorisConfig:
|
||||
if self.logging.backup_count < 0:
|
||||
errors.append("Log backup count cannot be negative")
|
||||
|
||||
if self.logging.max_age_days <= 0:
|
||||
errors.append("Log max age days must be greater than 0")
|
||||
|
||||
if self.logging.cleanup_interval_hours <= 0:
|
||||
errors.append("Log cleanup interval hours must be greater than 0")
|
||||
|
||||
# Validate monitoring configuration
|
||||
if not (1 <= self.monitoring.metrics_port <= 65535):
|
||||
errors.append("Monitoring port must be in the range 1-65535")
|
||||
@@ -538,6 +858,19 @@ class DorisConfig:
|
||||
if not (1 <= self.monitoring.health_check_port <= 65535):
|
||||
errors.append("Health check port must be in the range 1-65535")
|
||||
|
||||
# Validate ADBC configuration
|
||||
if self.adbc.default_max_rows <= 0:
|
||||
errors.append("ADBC default max rows must be greater than 0")
|
||||
|
||||
if self.adbc.default_timeout <= 0:
|
||||
errors.append("ADBC default timeout must be greater than 0")
|
||||
|
||||
if self.adbc.default_return_format not in ["arrow", "pandas", "dict"]:
|
||||
errors.append("ADBC default return format must be one of arrow, pandas, or dict")
|
||||
|
||||
if self.adbc.connection_timeout <= 0:
|
||||
errors.append("ADBC connection timeout must be greater than 0")
|
||||
|
||||
return errors
|
||||
|
||||
def get_connection_string(self) -> str:
|
||||
@@ -549,7 +882,7 @@ class DorisConfig:
|
||||
return {
|
||||
"server": f"{self.server_name} v{self.server_version}",
|
||||
"database": f"{self.database.host}:{self.database.port}/{self.database.database}",
|
||||
"connection_pool": f"{self.database.min_connections}-{self.database.max_connections}",
|
||||
"connection_pool": f"0-{self.database.max_connections} (min fixed at 0 for stability)",
|
||||
"security": {
|
||||
"auth_type": self.security.auth_type,
|
||||
"masking_enabled": self.security.enable_masking,
|
||||
@@ -575,56 +908,50 @@ class ConfigManager:
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def setup_logging(self):
|
||||
"""Setup logging configuration"""
|
||||
# Configure root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(getattr(logging, self.config.logging.level.upper()))
|
||||
"""Setup logging configuration using enhanced logger"""
|
||||
from .logger import setup_logging, get_logger
|
||||
import sys
|
||||
|
||||
# Clear existing handlers
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter(self.config.logging.format)
|
||||
|
||||
# Console handler
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# File handler (if configured)
|
||||
# Determine log directory
|
||||
log_dir = "logs"
|
||||
if self.config.logging.file_path:
|
||||
try:
|
||||
from logging.handlers import RotatingFileHandler
|
||||
# Extract directory from file path if provided
|
||||
from pathlib import Path
|
||||
log_dir = str(Path(self.config.logging.file_path).parent)
|
||||
|
||||
file_handler = RotatingFileHandler(
|
||||
self.config.logging.file_path,
|
||||
maxBytes=self.config.logging.max_file_size,
|
||||
backupCount=self.config.logging.backup_count,
|
||||
encoding="utf-8",
|
||||
# Detect if we're in stdio mode by checking if this is likely MCP stdio communication
|
||||
# In stdio mode, we shouldn't output to console as it interferes with JSON protocol
|
||||
is_stdio_mode = (
|
||||
self.config.transport == "stdio" or
|
||||
"--transport" in sys.argv and "stdio" in sys.argv or
|
||||
not sys.stdout.isatty() # Not a terminal (likely piped/redirected)
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(file_handler)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to setup file logging: {e}")
|
||||
|
||||
# Audit log handler (if configured)
|
||||
if self.config.logging.enable_audit and self.config.logging.audit_file_path:
|
||||
try:
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
audit_logger = logging.getLogger("audit")
|
||||
audit_handler = RotatingFileHandler(
|
||||
self.config.logging.audit_file_path,
|
||||
maxBytes=self.config.logging.max_file_size,
|
||||
backupCount=self.config.logging.backup_count,
|
||||
encoding="utf-8",
|
||||
# Setup enhanced logging with cleanup functionality
|
||||
setup_logging(
|
||||
level=self.config.logging.level,
|
||||
log_dir=log_dir,
|
||||
enable_console=not is_stdio_mode, # Disable console logging in stdio mode
|
||||
enable_file=True,
|
||||
enable_audit=self.config.logging.enable_audit,
|
||||
audit_file=self.config.logging.audit_file_path,
|
||||
max_file_size=self.config.logging.max_file_size,
|
||||
backup_count=self.config.logging.backup_count,
|
||||
enable_cleanup=self.config.logging.enable_cleanup,
|
||||
max_age_days=self.config.logging.max_age_days,
|
||||
cleanup_interval_hours=self.config.logging.cleanup_interval_hours
|
||||
)
|
||||
audit_handler.setFormatter(formatter)
|
||||
audit_logger.addHandler(audit_handler)
|
||||
audit_logger.setLevel(logging.INFO)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to setup audit logging: {e}")
|
||||
|
||||
# Update logger to use new system
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
self.logger.info("Enhanced logging system with cleanup initialized successfully")
|
||||
self.logger.info(f"Log directory: {log_dir}")
|
||||
self.logger.info(f"Log level: {self.config.logging.level}")
|
||||
self.logger.info(f"Audit logging: {'Enabled' if self.config.logging.enable_audit else 'Disabled'}")
|
||||
self.logger.info(f"Log cleanup: {'Enabled' if self.config.logging.enable_cleanup else 'Disabled'}")
|
||||
if self.config.logging.enable_cleanup:
|
||||
self.logger.info(f"Cleanup config: Max age {self.config.logging.max_age_days} days, interval {self.config.logging.cleanup_interval_hours}h")
|
||||
|
||||
def validate_config(self) -> bool:
|
||||
"""Validate configuration"""
|
||||
|
||||
771
doris_mcp_server/utils/data_exploration_tools.py
Normal file
771
doris_mcp_server/utils/data_exploration_tools.py
Normal file
@@ -0,0 +1,771 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Data Exploration Tools Module
|
||||
Provides table data distribution analysis and exploration capabilities
|
||||
"""
|
||||
|
||||
import time
|
||||
import math
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import (
|
||||
SQLSecurityError,
|
||||
validate_identifier,
|
||||
quote_identifier,
|
||||
build_table_reference,
|
||||
get_auth_context
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DataExplorationTools:
|
||||
"""Data exploration tools for table distribution analysis"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
logger.info("DataExplorationTools initialized")
|
||||
|
||||
|
||||
|
||||
# ==================== Private Helper Methods ====================
|
||||
|
||||
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
|
||||
"""Build full table name with catalog and database using three-part naming convention"""
|
||||
# SECURITY FIX: Use build_table_reference for safe identifier handling
|
||||
effective_catalog = catalog_name if catalog_name else "internal"
|
||||
|
||||
if db_name:
|
||||
return build_table_reference(table_name, db_name, effective_catalog)
|
||||
else:
|
||||
return build_table_reference(table_name, catalog_name=effective_catalog)
|
||||
|
||||
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
|
||||
"""Get basic table information including row count"""
|
||||
try:
|
||||
# SECURITY FIX: Get auth_context for security validation
|
||||
# table_name should already be validated by _build_full_table_name
|
||||
auth_context = get_auth_context()
|
||||
|
||||
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
|
||||
result = await connection.execute(count_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
return {"row_count": result.data[0]["row_count"]}
|
||||
return None
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed for table {table_name}: {str(e)}")
|
||||
return {"row_count": 0}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
|
||||
return {"row_count": 0}
|
||||
|
||||
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
|
||||
"""Get detailed column information"""
|
||||
try:
|
||||
# SECURITY FIX: Validate identifiers and use parameterized query
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if db_name:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
# Build parameterized query
|
||||
params = [table_name]
|
||||
where_conditions = ["table_name = %s"]
|
||||
|
||||
if db_name:
|
||||
where_conditions.append("table_schema = %s")
|
||||
params.append(db_name)
|
||||
else:
|
||||
where_conditions.append("table_schema = DATABASE()")
|
||||
|
||||
columns_sql = f"""
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
is_nullable,
|
||||
column_comment,
|
||||
ordinal_position
|
||||
FROM information_schema.columns
|
||||
WHERE {' AND '.join(where_conditions)}
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
result = await connection.execute(columns_sql, params=tuple(params), auth_context=auth_context)
|
||||
return result.data if result.data else []
|
||||
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed: {str(e)}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _determine_sampling_strategy(self, connection, table_name: str, total_rows: int, sample_size: int) -> Dict[str, Any]:
|
||||
"""Determine optimal sampling strategy based on table size"""
|
||||
if total_rows <= sample_size:
|
||||
# Use all data if table is small enough
|
||||
return {
|
||||
"total_rows": total_rows,
|
||||
"sample_size": total_rows,
|
||||
"sampling_method": "full_scan",
|
||||
"sampling_ratio": 1.0,
|
||||
"use_sampling": False,
|
||||
"sample_table_expression": table_name
|
||||
}
|
||||
else:
|
||||
# Use random sampling for large tables
|
||||
sampling_ratio = sample_size / total_rows
|
||||
return {
|
||||
"total_rows": total_rows,
|
||||
"sample_size": sample_size,
|
||||
"sampling_method": "random_sample",
|
||||
"sampling_ratio": round(sampling_ratio, 4),
|
||||
"use_sampling": True,
|
||||
"sample_table_expression": f"(SELECT * FROM {table_name} ORDER BY RAND() LIMIT {sample_size}) as sample_table"
|
||||
}
|
||||
|
||||
def _select_analysis_columns(self, columns_info: List[Dict], include_all: bool) -> List[Dict]:
|
||||
"""Select columns for analysis based on strategy"""
|
||||
if include_all:
|
||||
return columns_info
|
||||
|
||||
# If not analyzing all columns, prioritize key columns
|
||||
priority_keywords = ['id', 'key', 'code', 'status', 'type', 'amount', 'count', 'date', 'time']
|
||||
|
||||
priority_columns = []
|
||||
other_columns = []
|
||||
|
||||
for col in columns_info:
|
||||
col_name_lower = col["column_name"].lower()
|
||||
if any(keyword in col_name_lower for keyword in priority_keywords):
|
||||
priority_columns.append(col)
|
||||
else:
|
||||
other_columns.append(col)
|
||||
|
||||
# Return priority columns plus first 10 other columns
|
||||
return priority_columns + other_columns[:10]
|
||||
|
||||
def _is_numeric_type(self, data_type: str) -> bool:
|
||||
"""Check if column type is numeric"""
|
||||
numeric_types = [
|
||||
'tinyint', 'smallint', 'int', 'bigint', 'largeint',
|
||||
'float', 'double', 'decimal', 'numeric'
|
||||
]
|
||||
return any(num_type in data_type.lower() for num_type in numeric_types)
|
||||
|
||||
def _is_categorical_type(self, data_type: str) -> bool:
|
||||
"""Check if column type is categorical"""
|
||||
categorical_types = ['varchar', 'char', 'string', 'text', 'enum']
|
||||
return any(cat_type in data_type.lower() for cat_type in categorical_types)
|
||||
|
||||
def _is_temporal_type(self, data_type: str) -> bool:
|
||||
"""Check if column type is temporal"""
|
||||
temporal_types = ['date', 'datetime', 'timestamp', 'time']
|
||||
return any(temp_type in data_type.lower() for temp_type in temporal_types)
|
||||
|
||||
async def _analyze_numeric_distributions(self, connection, table_name: str, numeric_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Analyze distribution patterns for numeric columns"""
|
||||
numeric_analysis = {}
|
||||
|
||||
for column in numeric_columns:
|
||||
col_name = column["column_name"]
|
||||
try:
|
||||
# Basic statistics
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
stats_sql = f"""
|
||||
SELECT
|
||||
COUNT({col_name}) as count,
|
||||
MIN({col_name}) as min_value,
|
||||
MAX({col_name}) as max_value,
|
||||
AVG({col_name}) as mean_value,
|
||||
STDDEV({col_name}) as std_dev
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
auth_context = get_auth_context()
|
||||
stats_result = await connection.execute(stats_sql, auth_context=auth_context)
|
||||
|
||||
if stats_result.data and stats_result.data[0]["count"] > 0:
|
||||
stats = stats_result.data[0]
|
||||
|
||||
# Percentiles calculation
|
||||
percentiles = await self._calculate_percentiles(connection, table_name, col_name, sampling_info)
|
||||
|
||||
# Outlier detection
|
||||
outliers = await self._detect_numeric_outliers(connection, table_name, col_name, percentiles, sampling_info)
|
||||
|
||||
# Distribution shape analysis
|
||||
distribution_shape = await self._analyze_distribution_shape(
|
||||
connection, table_name, col_name, stats, percentiles, sampling_info
|
||||
)
|
||||
|
||||
numeric_analysis[col_name] = {
|
||||
"data_type": column["data_type"],
|
||||
"statistics": {
|
||||
"count": stats["count"],
|
||||
"mean": round(float(stats["mean_value"]), 4) if stats["mean_value"] else None,
|
||||
"std": round(float(stats["std_dev"]), 4) if stats["std_dev"] else None,
|
||||
"min": float(stats["min_value"]) if stats["min_value"] else None,
|
||||
"max": float(stats["max_value"]) if stats["max_value"] else None,
|
||||
**percentiles
|
||||
},
|
||||
"distribution_shape": distribution_shape,
|
||||
"outliers": outliers
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze numeric column {col_name}: {str(e)}")
|
||||
numeric_analysis[col_name] = {"error": str(e)}
|
||||
|
||||
return numeric_analysis
|
||||
|
||||
async def _calculate_percentiles(self, connection, table_name: str, col_name: str, sampling_info: Dict) -> Dict[str, float]:
|
||||
"""Calculate percentiles for numeric column"""
|
||||
try:
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
percentile_sql = f"""
|
||||
SELECT
|
||||
PERCENTILE({col_name}, 0.25) as p25,
|
||||
PERCENTILE({col_name}, 0.50) as p50,
|
||||
PERCENTILE({col_name}, 0.75) as p75,
|
||||
PERCENTILE({col_name}, 0.90) as p90,
|
||||
PERCENTILE({col_name}, 0.95) as p95,
|
||||
PERCENTILE({col_name}, 0.99) as p99
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(percentile_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
data = result.data[0]
|
||||
return {
|
||||
"25%": round(float(data["p25"]), 4) if data["p25"] else None,
|
||||
"50%": round(float(data["p50"]), 4) if data["p50"] else None,
|
||||
"75%": round(float(data["p75"]), 4) if data["p75"] else None,
|
||||
"90%": round(float(data["p90"]), 4) if data["p90"] else None,
|
||||
"95%": round(float(data["p95"]), 4) if data["p95"] else None,
|
||||
"99%": round(float(data["p99"]), 4) if data["p99"] else None
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to calculate percentiles for {col_name}: {str(e)}")
|
||||
|
||||
return {}
|
||||
|
||||
async def _detect_numeric_outliers(self, connection, table_name: str, col_name: str, percentiles: Dict, sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Detect outliers using IQR method"""
|
||||
try:
|
||||
if "25%" not in percentiles or "75%" not in percentiles:
|
||||
return {"outlier_count": 0, "outlier_rate": 0.0}
|
||||
|
||||
q1 = percentiles["25%"]
|
||||
q3 = percentiles["75%"]
|
||||
iqr = q3 - q1
|
||||
|
||||
lower_bound = q1 - 1.5 * iqr
|
||||
upper_bound = q3 + 1.5 * iqr
|
||||
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
outlier_sql = f"""
|
||||
SELECT
|
||||
COUNT(*) as total_count,
|
||||
SUM(CASE WHEN {col_name} < {lower_bound} OR {col_name} > {upper_bound} THEN 1 ELSE 0 END) as outlier_count
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(outlier_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
data = result.data[0]
|
||||
total_count = data["total_count"]
|
||||
outlier_count = data["outlier_count"]
|
||||
outlier_rate = outlier_count / total_count if total_count > 0 else 0
|
||||
|
||||
return {
|
||||
"outlier_count": outlier_count,
|
||||
"outlier_rate": round(outlier_rate, 4),
|
||||
"outlier_threshold_lower": round(lower_bound, 4),
|
||||
"outlier_threshold_upper": round(upper_bound, 4),
|
||||
"iqr": round(iqr, 4)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to detect outliers for {col_name}: {str(e)}")
|
||||
|
||||
return {"outlier_count": 0, "outlier_rate": 0.0}
|
||||
|
||||
async def _analyze_distribution_shape(self, connection, table_name: str, col_name: str, stats: Dict, percentiles: Dict, sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Analyze the shape of data distribution"""
|
||||
try:
|
||||
mean = stats.get("mean_value", 0)
|
||||
median = percentiles.get("50%", 0)
|
||||
|
||||
if mean is None or median is None:
|
||||
return {"distribution_type": "unknown"}
|
||||
|
||||
# Calculate skewness indicator
|
||||
if abs(mean - median) < 0.01:
|
||||
skew_indicator = "symmetric"
|
||||
elif mean > median:
|
||||
skew_indicator = "right_skewed"
|
||||
else:
|
||||
skew_indicator = "left_skewed"
|
||||
|
||||
# Estimate kurtosis based on percentile spread
|
||||
if "25%" in percentiles and "75%" in percentiles:
|
||||
iqr = percentiles["75%"] - percentiles["25%"]
|
||||
range_90 = percentiles.get("90%", percentiles["75%"]) - percentiles.get("10%", percentiles["25%"])
|
||||
|
||||
if iqr > 0:
|
||||
kurtosis_indicator = "normal" if 2.5 <= range_90/iqr <= 3.5 else ("heavy_tailed" if range_90/iqr > 3.5 else "light_tailed")
|
||||
else:
|
||||
kurtosis_indicator = "unknown"
|
||||
else:
|
||||
kurtosis_indicator = "unknown"
|
||||
|
||||
return {
|
||||
"skewness_indicator": skew_indicator,
|
||||
"kurtosis_indicator": kurtosis_indicator,
|
||||
"distribution_type": self._classify_distribution_type(skew_indicator, kurtosis_indicator),
|
||||
"mean_median_ratio": round(mean / median, 4) if median != 0 else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze distribution shape for {col_name}: {str(e)}")
|
||||
return {"distribution_type": "unknown"}
|
||||
|
||||
def _classify_distribution_type(self, skew: str, kurtosis: str) -> str:
|
||||
"""Classify distribution type based on skewness and kurtosis"""
|
||||
if skew == "symmetric" and kurtosis == "normal":
|
||||
return "approximately_normal"
|
||||
elif skew == "right_skewed":
|
||||
return "right_skewed"
|
||||
elif skew == "left_skewed":
|
||||
return "left_skewed"
|
||||
elif kurtosis == "heavy_tailed":
|
||||
return "heavy_tailed"
|
||||
else:
|
||||
return "non_normal"
|
||||
|
||||
async def _analyze_categorical_distributions(self, connection, table_name: str, categorical_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Analyze distribution patterns for categorical columns"""
|
||||
categorical_analysis = {}
|
||||
|
||||
for column in categorical_columns:
|
||||
col_name = column["column_name"]
|
||||
try:
|
||||
# Basic cardinality and distribution
|
||||
cardinality_sql = f"""
|
||||
SELECT
|
||||
COUNT(DISTINCT {col_name}) as cardinality,
|
||||
COUNT({col_name}) as non_null_count
|
||||
FROM {table_name}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
{sampling_info.get('sample_query_suffix', '')}
|
||||
"""
|
||||
|
||||
auth_context = get_auth_context()
|
||||
cardinality_result = await connection.execute(cardinality_sql, auth_context=auth_context)
|
||||
|
||||
if cardinality_result.data:
|
||||
cardinality_data = cardinality_result.data[0]
|
||||
cardinality = cardinality_data["cardinality"]
|
||||
non_null_count = cardinality_data["non_null_count"]
|
||||
|
||||
# Value distribution (top values)
|
||||
value_distribution = await self._get_categorical_value_distribution(
|
||||
connection, table_name, col_name, sampling_info, non_null_count
|
||||
)
|
||||
|
||||
# Calculate entropy and concentration
|
||||
entropy = self._calculate_entropy(value_distribution)
|
||||
concentration_ratio = value_distribution[0]["percentage"] if value_distribution else 0
|
||||
|
||||
categorical_analysis[col_name] = {
|
||||
"data_type": column["data_type"],
|
||||
"cardinality": cardinality,
|
||||
"non_null_count": non_null_count,
|
||||
"value_distribution": value_distribution,
|
||||
"entropy": round(entropy, 3),
|
||||
"concentration_ratio": round(concentration_ratio, 4),
|
||||
"diversity_score": round(cardinality / non_null_count, 4) if non_null_count > 0 else 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze categorical column {col_name}: {str(e)}")
|
||||
categorical_analysis[col_name] = {"error": str(e)}
|
||||
|
||||
return categorical_analysis
|
||||
|
||||
async def _get_categorical_value_distribution(self, connection, table_name: str, col_name: str, sampling_info: Dict, total_count: int) -> List[Dict]:
|
||||
"""Get value distribution for categorical column"""
|
||||
try:
|
||||
# Use sample table expression if sampling is enabled
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
|
||||
distribution_sql = f"""
|
||||
SELECT
|
||||
{col_name} as value,
|
||||
COUNT(*) as count
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
GROUP BY {col_name}
|
||||
ORDER BY COUNT(*) DESC
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(distribution_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
distribution = []
|
||||
for row in result.data:
|
||||
count = row["count"]
|
||||
percentage = count / total_count if total_count > 0 else 0
|
||||
distribution.append({
|
||||
"value": str(row["value"]),
|
||||
"count": count,
|
||||
"percentage": round(percentage, 4)
|
||||
})
|
||||
return distribution
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get value distribution for {col_name}: {str(e)}")
|
||||
|
||||
return []
|
||||
|
||||
def _calculate_entropy(self, value_distribution: List[Dict]) -> float:
|
||||
"""Calculate Shannon entropy for categorical distribution"""
|
||||
if not value_distribution:
|
||||
return 0.0
|
||||
|
||||
entropy = 0.0
|
||||
for item in value_distribution:
|
||||
p = item["percentage"]
|
||||
if p > 0:
|
||||
entropy -= p * math.log2(p)
|
||||
|
||||
return entropy
|
||||
|
||||
async def _analyze_temporal_distributions(self, connection, table_name: str, temporal_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Analyze distribution patterns for temporal columns"""
|
||||
temporal_analysis = {}
|
||||
|
||||
for column in temporal_columns:
|
||||
col_name = column["column_name"]
|
||||
try:
|
||||
# Date range analysis
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
range_sql = f"""
|
||||
SELECT
|
||||
MIN({col_name}) as earliest,
|
||||
MAX({col_name}) as latest,
|
||||
COUNT({col_name}) as non_null_count
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
auth_context = get_auth_context()
|
||||
range_result = await connection.execute(range_sql, auth_context=auth_context)
|
||||
|
||||
if range_result.data and range_result.data[0]["non_null_count"] > 0:
|
||||
range_data = range_result.data[0]
|
||||
earliest = range_data["earliest"]
|
||||
latest = range_data["latest"]
|
||||
|
||||
# Calculate span
|
||||
date_span_info = self._calculate_date_span(earliest, latest)
|
||||
|
||||
# Temporal patterns analysis
|
||||
temporal_patterns = await self._analyze_temporal_patterns(
|
||||
connection, table_name, col_name, sampling_info
|
||||
)
|
||||
|
||||
temporal_analysis[col_name] = {
|
||||
"data_type": column["data_type"],
|
||||
"non_null_count": range_data["non_null_count"],
|
||||
"date_range": {
|
||||
"earliest": str(earliest),
|
||||
"latest": str(latest),
|
||||
**date_span_info
|
||||
},
|
||||
"temporal_patterns": temporal_patterns
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze temporal column {col_name}: {str(e)}")
|
||||
temporal_analysis[col_name] = {"error": str(e)}
|
||||
|
||||
return temporal_analysis
|
||||
|
||||
def _calculate_date_span(self, earliest, latest) -> Dict[str, Any]:
|
||||
"""Calculate date span information"""
|
||||
try:
|
||||
if isinstance(earliest, str):
|
||||
earliest = datetime.fromisoformat(earliest.replace('Z', '+00:00'))
|
||||
if isinstance(latest, str):
|
||||
latest = datetime.fromisoformat(latest.replace('Z', '+00:00'))
|
||||
|
||||
span = latest - earliest
|
||||
span_days = span.days
|
||||
|
||||
return {
|
||||
"span_days": span_days,
|
||||
"span_years": round(span_days / 365.25, 2),
|
||||
"span_description": self._describe_time_span(span_days)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to calculate date span: {str(e)}")
|
||||
return {"span_days": 0}
|
||||
|
||||
def _describe_time_span(self, days: int) -> str:
|
||||
"""Describe time span in human readable format"""
|
||||
if days < 1:
|
||||
return "less_than_day"
|
||||
elif days < 7:
|
||||
return "days"
|
||||
elif days < 30:
|
||||
return "weeks"
|
||||
elif days < 365:
|
||||
return "months"
|
||||
else:
|
||||
return "years"
|
||||
|
||||
async def _analyze_temporal_patterns(self, connection, table_name: str, col_name: str, sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Analyze temporal patterns like seasonality and trends"""
|
||||
try:
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
# Weekly pattern analysis
|
||||
weekly_pattern_sql = f"""
|
||||
SELECT
|
||||
DAYOFWEEK({col_name}) as day_of_week,
|
||||
COUNT(*) as count
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
GROUP BY DAYOFWEEK({col_name})
|
||||
ORDER BY day_of_week
|
||||
"""
|
||||
|
||||
auth_context = get_auth_context()
|
||||
weekly_result = await connection.execute(weekly_pattern_sql, auth_context=auth_context)
|
||||
|
||||
weekly_pattern = []
|
||||
if weekly_result.data:
|
||||
total_records = sum(row["count"] for row in weekly_result.data)
|
||||
for row in weekly_result.data:
|
||||
percentage = row["count"] / total_records if total_records > 0 else 0
|
||||
weekly_pattern.append(round(percentage, 3))
|
||||
|
||||
# Monthly trend analysis (simplified)
|
||||
monthly_trend_sql = f"""
|
||||
SELECT
|
||||
YEAR({col_name}) as year,
|
||||
MONTH({col_name}) as month,
|
||||
COUNT(*) as count
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
GROUP BY YEAR({col_name}), MONTH({col_name})
|
||||
ORDER BY year, month
|
||||
LIMIT 12
|
||||
"""
|
||||
|
||||
monthly_result = await connection.execute(monthly_trend_sql, auth_context=auth_context)
|
||||
monthly_trend = "stable" # Simplified trend analysis
|
||||
|
||||
if monthly_result.data and len(monthly_result.data) > 3:
|
||||
counts = [row["count"] for row in monthly_result.data]
|
||||
if len(counts) > 1:
|
||||
trend_direction = "increasing" if counts[-1] > counts[0] else "decreasing"
|
||||
monthly_trend = trend_direction
|
||||
|
||||
return {
|
||||
"weekly_pattern": weekly_pattern,
|
||||
"monthly_trend": monthly_trend,
|
||||
"seasonal_component": self._estimate_seasonality(weekly_pattern)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze temporal patterns for {col_name}: {str(e)}")
|
||||
return {"weekly_pattern": [], "monthly_trend": "unknown"}
|
||||
|
||||
def _estimate_seasonality(self, weekly_pattern: List[float]) -> float:
|
||||
"""Estimate seasonality strength based on weekly pattern variance"""
|
||||
if len(weekly_pattern) < 7:
|
||||
return 0.0
|
||||
|
||||
mean_percentage = sum(weekly_pattern) / len(weekly_pattern)
|
||||
variance = sum((x - mean_percentage) ** 2 for x in weekly_pattern) / len(weekly_pattern)
|
||||
|
||||
# Normalize variance to 0-1 scale as seasonality indicator
|
||||
seasonality = min(variance * 10, 1.0) # Scaling factor
|
||||
return round(seasonality, 3)
|
||||
|
||||
async def _generate_data_quality_insights(self, connection, table_name: str, columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Generate overall data quality insights"""
|
||||
try:
|
||||
total_columns = len(columns)
|
||||
|
||||
# Calculate null rates across all columns
|
||||
null_analysis = await self._analyze_overall_null_rates(connection, table_name, columns, sampling_info)
|
||||
|
||||
# Identify potential data quality issues
|
||||
quality_issues = []
|
||||
|
||||
# High null rate columns
|
||||
high_null_columns = [col for col, rate in null_analysis["column_null_rates"].items() if rate > 0.2]
|
||||
if high_null_columns:
|
||||
quality_issues.append({
|
||||
"issue_type": "high_null_rates",
|
||||
"severity": "medium",
|
||||
"affected_columns": high_null_columns,
|
||||
"description": f"{len(high_null_columns)} columns have null rates > 20%"
|
||||
})
|
||||
|
||||
# Calculate overall data quality score
|
||||
avg_null_rate = sum(null_analysis["column_null_rates"].values()) / len(null_analysis["column_null_rates"]) if null_analysis["column_null_rates"] else 0
|
||||
data_quality_score = max(0, 1 - avg_null_rate)
|
||||
|
||||
return {
|
||||
"total_columns_analyzed": total_columns,
|
||||
"null_analysis": null_analysis,
|
||||
"data_quality_score": round(data_quality_score, 3),
|
||||
"quality_issues": quality_issues,
|
||||
"recommendations": self._generate_quality_recommendations(quality_issues, null_analysis)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate data quality insights: {str(e)}")
|
||||
return {"data_quality_score": 0.0, "error": str(e)}
|
||||
|
||||
async def _analyze_overall_null_rates(self, connection, table_name: str, columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Analyze null rates across all columns"""
|
||||
column_null_rates = {}
|
||||
total_null_count = 0
|
||||
total_cell_count = 0
|
||||
|
||||
for column in columns:
|
||||
col_name = column["column_name"]
|
||||
try:
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
null_sql = f"""
|
||||
SELECT
|
||||
COUNT(*) as total_count,
|
||||
COUNT({col_name}) as non_null_count
|
||||
FROM {table_expr}
|
||||
"""
|
||||
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(null_sql, auth_context=auth_context)
|
||||
if result.data:
|
||||
data = result.data[0]
|
||||
total_count = data["total_count"]
|
||||
non_null_count = data["non_null_count"]
|
||||
null_count = total_count - non_null_count
|
||||
null_rate = null_count / total_count if total_count > 0 else 0
|
||||
|
||||
column_null_rates[col_name] = round(null_rate, 4)
|
||||
total_null_count += null_count
|
||||
total_cell_count += total_count
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze null rate for column {col_name}: {str(e)}")
|
||||
column_null_rates[col_name] = 0.0
|
||||
|
||||
overall_null_rate = total_null_count / total_cell_count if total_cell_count > 0 else 0
|
||||
|
||||
return {
|
||||
"column_null_rates": column_null_rates,
|
||||
"overall_null_rate": round(overall_null_rate, 4),
|
||||
"columns_with_nulls": len([rate for rate in column_null_rates.values() if rate > 0])
|
||||
}
|
||||
|
||||
def _generate_quality_recommendations(self, quality_issues: List[Dict], null_analysis: Dict) -> List[Dict]:
|
||||
"""Generate data quality improvement recommendations"""
|
||||
recommendations = []
|
||||
|
||||
# Recommendations based on null analysis
|
||||
overall_null_rate = null_analysis.get("overall_null_rate", 0)
|
||||
if overall_null_rate > 0.1:
|
||||
recommendations.append({
|
||||
"type": "data_completeness",
|
||||
"priority": "high" if overall_null_rate > 0.3 else "medium",
|
||||
"description": f"Overall null rate is {overall_null_rate:.1%}",
|
||||
"action": "Review data collection and validation processes"
|
||||
})
|
||||
|
||||
# Recommendations based on quality issues
|
||||
for issue in quality_issues:
|
||||
if issue["issue_type"] == "high_null_rates":
|
||||
recommendations.append({
|
||||
"type": "column_completeness",
|
||||
"priority": issue["severity"],
|
||||
"description": issue["description"],
|
||||
"action": f"Focus on improving data completeness for: {', '.join(issue['affected_columns'][:3])}"
|
||||
})
|
||||
|
||||
return recommendations
|
||||
|
||||
def _generate_analysis_summary(self, distribution_analysis: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Generate high-level summary of distribution analysis"""
|
||||
summary = {
|
||||
"numeric_columns_count": len(distribution_analysis.get("numeric_columns", {})),
|
||||
"categorical_columns_count": len(distribution_analysis.get("categorical_columns", {})),
|
||||
"temporal_columns_count": len(distribution_analysis.get("temporal_columns", {}))
|
||||
}
|
||||
|
||||
# Identify interesting patterns
|
||||
patterns = []
|
||||
|
||||
# Check for highly skewed numeric columns
|
||||
numeric_cols = distribution_analysis.get("numeric_columns", {})
|
||||
skewed_cols = [
|
||||
col for col, info in numeric_cols.items()
|
||||
if isinstance(info, dict) and
|
||||
info.get("distribution_shape", {}).get("skewness_indicator") in ["right_skewed", "left_skewed"]
|
||||
]
|
||||
|
||||
if skewed_cols:
|
||||
patterns.append(f"Found {len(skewed_cols)} skewed numeric columns")
|
||||
|
||||
# Check for high cardinality categorical columns
|
||||
categorical_cols = distribution_analysis.get("categorical_columns", {})
|
||||
high_cardinality_cols = [
|
||||
col for col, info in categorical_cols.items()
|
||||
if isinstance(info, dict) and info.get("cardinality", 0) > 1000
|
||||
]
|
||||
|
||||
if high_cardinality_cols:
|
||||
patterns.append(f"Found {len(high_cardinality_cols)} high cardinality categorical columns")
|
||||
|
||||
summary["notable_patterns"] = patterns
|
||||
|
||||
return summary
|
||||
1022
doris_mcp_server/utils/data_governance_tools.py
Normal file
1022
doris_mcp_server/utils/data_governance_tools.py
Normal file
File diff suppressed because it is too large
Load Diff
1173
doris_mcp_server/utils/data_quality_tools.py
Normal file
1173
doris_mcp_server/utils/data_quality_tools.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
1025
doris_mcp_server/utils/dependency_analysis_tools.py
Normal file
1025
doris_mcp_server/utils/dependency_analysis_tools.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -15,77 +15,573 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Logging configuration for Doris MCP Server.
|
||||
Enhanced Logging configuration for Doris MCP Server.
|
||||
Features:
|
||||
- Log level-based file separation
|
||||
- Timestamped log entries
|
||||
- Automatic log rotation
|
||||
- Comprehensive logging coverage
|
||||
"""
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
import logging.handlers
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import threading
|
||||
|
||||
|
||||
def setup_logging(
|
||||
level: str = "INFO",
|
||||
log_file: str | None = None,
|
||||
log_format: str | None = None,
|
||||
) -> None:
|
||||
class TimestampedFormatter(logging.Formatter):
|
||||
"""Custom formatter with enhanced timestamp and structured format"""
|
||||
|
||||
def __init__(self, fmt=None, datefmt=None, style='%'):
|
||||
if fmt is None:
|
||||
fmt = "%(asctime)s.%(msecs)03d %(level_aligned)s %(name)s:%(lineno)d - %(message)s"
|
||||
if datefmt is None:
|
||||
datefmt = "%Y-%m-%d %H:%M:%S"
|
||||
super().__init__(fmt, datefmt, style)
|
||||
|
||||
def format(self, record):
|
||||
"""Format log record with enhanced information and proper alignment"""
|
||||
# Add process info if available
|
||||
if hasattr(record, 'process') and record.process:
|
||||
record.process_info = f"[PID:{record.process}]"
|
||||
else:
|
||||
record.process_info = ""
|
||||
|
||||
# Add thread info if available
|
||||
if hasattr(record, 'thread') and record.thread:
|
||||
record.thread_info = f"[TID:{record.thread}]"
|
||||
else:
|
||||
record.thread_info = ""
|
||||
|
||||
# Format with proper alignment after the level name
|
||||
# Calculate padding needed for alignment
|
||||
level_name = record.levelname
|
||||
max_level_length = 8 # Length of "CRITICAL"
|
||||
padding = max_level_length - len(level_name)
|
||||
record.level_aligned = f"[{level_name}]{' ' * padding}"
|
||||
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class LevelBasedFileHandler(logging.Handler):
|
||||
"""Custom handler that writes different log levels to different files"""
|
||||
|
||||
def __init__(self, log_dir: str, base_name: str = "doris_mcp_server",
|
||||
max_bytes: int = 10*1024*1024, backup_count: int = 5):
|
||||
super().__init__()
|
||||
self.log_dir = Path(log_dir)
|
||||
self.base_name = base_name
|
||||
self.max_bytes = max_bytes
|
||||
self.backup_count = backup_count
|
||||
|
||||
# Ensure log directory exists
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create handlers for different log levels
|
||||
self.handlers = {}
|
||||
self._setup_level_handlers()
|
||||
|
||||
def _setup_level_handlers(self):
|
||||
"""Setup rotating file handlers for different log levels"""
|
||||
level_files = {
|
||||
'DEBUG': 'debug.log',
|
||||
'INFO': 'info.log',
|
||||
'WARNING': 'warning.log',
|
||||
'ERROR': 'error.log',
|
||||
'CRITICAL': 'critical.log'
|
||||
}
|
||||
|
||||
formatter = TimestampedFormatter()
|
||||
|
||||
for level, filename in level_files.items():
|
||||
file_path = self.log_dir / f"{self.base_name}_{filename}"
|
||||
handler = logging.handlers.RotatingFileHandler(
|
||||
file_path,
|
||||
maxBytes=self.max_bytes,
|
||||
backupCount=self.backup_count,
|
||||
encoding='utf-8'
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
handler.setLevel(getattr(logging, level))
|
||||
self.handlers[level] = handler
|
||||
|
||||
def emit(self, record):
|
||||
"""Emit log record to appropriate level-based file"""
|
||||
level_name = record.levelname
|
||||
if level_name in self.handlers:
|
||||
try:
|
||||
self.handlers[level_name].emit(record)
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
def close(self):
|
||||
"""Close all handlers"""
|
||||
for handler in self.handlers.values():
|
||||
handler.close()
|
||||
super().close()
|
||||
|
||||
|
||||
class LogCleanupManager:
|
||||
"""Log file cleanup manager for automatic maintenance"""
|
||||
|
||||
def __init__(self, log_dir: str, max_age_days: int = 30, cleanup_interval_hours: int = 24):
|
||||
"""
|
||||
Setup logging configuration.
|
||||
Initialize log cleanup manager.
|
||||
|
||||
Args:
|
||||
level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
log_file: Optional log file path
|
||||
log_format: Optional custom log format
|
||||
log_dir: Directory containing log files
|
||||
max_age_days: Maximum age of log files in days (default: 30 days)
|
||||
cleanup_interval_hours: Cleanup interval in hours (default: 24 hours)
|
||||
"""
|
||||
if log_format is None:
|
||||
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
self.log_dir = Path(log_dir)
|
||||
self.max_age_days = max_age_days
|
||||
self.cleanup_interval_hours = cleanup_interval_hours
|
||||
self.cleanup_thread = None
|
||||
self.stop_event = threading.Event()
|
||||
self.logger = None
|
||||
|
||||
# Base configuration
|
||||
config: dict[str, Any] = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"default": {"format": log_format, "datefmt": "%Y-%m-%d %H:%M:%S"}
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": level,
|
||||
"formatter": "default",
|
||||
"stream": sys.stdout,
|
||||
}
|
||||
},
|
||||
"root": {"level": level, "handlers": ["console"]},
|
||||
"loggers": {
|
||||
"doris_mcp_server": {
|
||||
"level": level,
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
}
|
||||
},
|
||||
def start_cleanup_scheduler(self):
|
||||
"""Start the cleanup scheduler in a background thread"""
|
||||
if self.cleanup_thread and self.cleanup_thread.is_alive():
|
||||
return
|
||||
|
||||
self.stop_event.clear()
|
||||
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
|
||||
self.cleanup_thread.start()
|
||||
|
||||
# Get logger for this class
|
||||
if not self.logger:
|
||||
self.logger = logging.getLogger("doris_mcp_server.log_cleanup")
|
||||
|
||||
self.logger.info(f"Log cleanup scheduler started - cleanup every {self.cleanup_interval_hours}h, max age {self.max_age_days} days")
|
||||
|
||||
def stop_cleanup_scheduler(self):
|
||||
"""Stop the cleanup scheduler"""
|
||||
if self.cleanup_thread and self.cleanup_thread.is_alive():
|
||||
self.stop_event.set()
|
||||
self.cleanup_thread.join(timeout=5)
|
||||
if self.logger:
|
||||
self.logger.info("Log cleanup scheduler stopped")
|
||||
|
||||
def _cleanup_loop(self):
|
||||
"""Background loop for periodic cleanup"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
self.cleanup_old_logs()
|
||||
# Sleep for the specified interval, but check stop event every 60 seconds
|
||||
for _ in range(self.cleanup_interval_hours * 60): # Convert hours to minutes
|
||||
if self.stop_event.wait(60): # Wait 60 seconds or until stop event
|
||||
break
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"Error in log cleanup loop: {e}")
|
||||
# Sleep for 5 minutes before retrying
|
||||
self.stop_event.wait(300)
|
||||
|
||||
def cleanup_old_logs(self):
|
||||
"""Clean up old log files based on age"""
|
||||
if not self.log_dir.exists():
|
||||
return
|
||||
|
||||
current_time = datetime.now()
|
||||
cutoff_time = current_time - timedelta(days=self.max_age_days)
|
||||
|
||||
cleaned_files = []
|
||||
cleaned_size = 0
|
||||
|
||||
# Pattern for log files (including backup files)
|
||||
log_patterns = [
|
||||
"doris_mcp_server_*.log",
|
||||
"doris_mcp_server_*.log.*" # Backup files
|
||||
]
|
||||
|
||||
for pattern in log_patterns:
|
||||
for log_file in self.log_dir.glob(pattern):
|
||||
try:
|
||||
# Get file modification time
|
||||
file_mtime = datetime.fromtimestamp(log_file.stat().st_mtime)
|
||||
|
||||
if file_mtime < cutoff_time:
|
||||
file_size = log_file.stat().st_size
|
||||
log_file.unlink() # Delete the file
|
||||
cleaned_files.append(log_file.name)
|
||||
cleaned_size += file_size
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.warning(f"Failed to cleanup log file {log_file}: {e}")
|
||||
|
||||
if cleaned_files and self.logger:
|
||||
size_mb = cleaned_size / (1024 * 1024)
|
||||
self.logger.info(f"Cleaned up {len(cleaned_files)} old log files, freed {size_mb:.2f} MB")
|
||||
self.logger.debug(f"Cleaned files: {', '.join(cleaned_files)}")
|
||||
|
||||
def get_cleanup_stats(self) -> dict:
|
||||
"""Get statistics about log files and cleanup status"""
|
||||
if not self.log_dir.exists():
|
||||
return {"error": "Log directory does not exist"}
|
||||
|
||||
stats = {
|
||||
"log_directory": str(self.log_dir.absolute()),
|
||||
"max_age_days": self.max_age_days,
|
||||
"cleanup_interval_hours": self.cleanup_interval_hours,
|
||||
"scheduler_running": self.cleanup_thread and self.cleanup_thread.is_alive(),
|
||||
"total_files": 0,
|
||||
"total_size_mb": 0,
|
||||
"files_by_age": {"recent": 0, "old": 0},
|
||||
"oldest_file": None,
|
||||
"newest_file": None
|
||||
}
|
||||
|
||||
# Add file handler if log_file is specified
|
||||
if log_file:
|
||||
# Ensure log directory exists
|
||||
log_path = Path(log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
current_time = datetime.now()
|
||||
cutoff_time = current_time - timedelta(days=self.max_age_days)
|
||||
oldest_time = None
|
||||
newest_time = None
|
||||
|
||||
config["handlers"]["file"] = {
|
||||
"class": "logging.handlers.RotatingFileHandler",
|
||||
"level": level,
|
||||
"formatter": "default",
|
||||
"filename": log_file,
|
||||
"maxBytes": 10485760, # 10MB
|
||||
"backupCount": 5,
|
||||
}
|
||||
log_patterns = ["doris_mcp_server_*.log", "doris_mcp_server_*.log.*"]
|
||||
|
||||
# Add file handler to root and package loggers
|
||||
config["root"]["handlers"].append("file")
|
||||
config["loggers"]["doris_mcp_server"]["handlers"].append("file")
|
||||
for pattern in log_patterns:
|
||||
for log_file in self.log_dir.glob(pattern):
|
||||
try:
|
||||
file_stat = log_file.stat()
|
||||
file_mtime = datetime.fromtimestamp(file_stat.st_mtime)
|
||||
|
||||
logging.config.dictConfig(config)
|
||||
stats["total_files"] += 1
|
||||
stats["total_size_mb"] += file_stat.st_size / (1024 * 1024)
|
||||
|
||||
if file_mtime < cutoff_time:
|
||||
stats["files_by_age"]["old"] += 1
|
||||
else:
|
||||
stats["files_by_age"]["recent"] += 1
|
||||
|
||||
if oldest_time is None or file_mtime < oldest_time:
|
||||
oldest_time = file_mtime
|
||||
stats["oldest_file"] = {"name": log_file.name, "age_days": (current_time - file_mtime).days}
|
||||
|
||||
if newest_time is None or file_mtime > newest_time:
|
||||
newest_time = file_mtime
|
||||
stats["newest_file"] = {"name": log_file.name, "age_days": (current_time - file_mtime).days}
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
stats["total_size_mb"] = round(stats["total_size_mb"], 2)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
class DorisLoggerManager:
|
||||
"""Centralized logger manager for Doris MCP Server"""
|
||||
|
||||
def __init__(self):
|
||||
self.is_initialized = False
|
||||
self.log_dir = None
|
||||
self.config = None
|
||||
self.loggers = {}
|
||||
self.cleanup_manager = None
|
||||
|
||||
def setup_logging(self,
|
||||
level: str = "INFO",
|
||||
log_dir: str = "logs",
|
||||
enable_console: bool = True,
|
||||
enable_file: bool = True,
|
||||
enable_audit: bool = True,
|
||||
audit_file: Optional[str] = None,
|
||||
max_file_size: int = 10*1024*1024,
|
||||
backup_count: int = 5,
|
||||
enable_cleanup: bool = True,
|
||||
max_age_days: int = 30,
|
||||
cleanup_interval_hours: int = 24) -> None:
|
||||
"""
|
||||
Setup comprehensive logging configuration.
|
||||
|
||||
Args:
|
||||
level: Base logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
log_dir: Directory for log files
|
||||
enable_console: Enable console output
|
||||
enable_file: Enable file logging
|
||||
enable_audit: Enable audit logging
|
||||
audit_file: Custom audit log file path
|
||||
max_file_size: Maximum size per log file (bytes)
|
||||
backup_count: Number of backup files to keep
|
||||
enable_cleanup: Enable automatic log cleanup
|
||||
max_age_days: Maximum age of log files in days (default: 30)
|
||||
cleanup_interval_hours: Cleanup interval in hours (default: 24)
|
||||
"""
|
||||
if self.is_initialized:
|
||||
return
|
||||
|
||||
self.log_dir = Path(log_dir)
|
||||
log_dir_writable = True # Initialize the variable
|
||||
|
||||
# Try to create log directory, fallback to console-only if fails
|
||||
try:
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
except (OSError, PermissionError) as e:
|
||||
# If we can't create log directory (e.g., read-only filesystem in stdio mode),
|
||||
# fall back to console-only logging
|
||||
log_dir_writable = False
|
||||
enable_file = False
|
||||
enable_audit = False
|
||||
enable_cleanup = False
|
||||
# Don't use print() in stdio mode as it interferes with MCP JSON protocol
|
||||
# Log the warning through the logging system instead, which will be handled after setup
|
||||
|
||||
# Clear existing handlers
|
||||
root_logger = logging.getLogger()
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
|
||||
# Set root logger level
|
||||
root_logger.setLevel(logging.DEBUG) # Allow all levels, handlers will filter
|
||||
|
||||
handlers = []
|
||||
|
||||
# Console handler
|
||||
if enable_console:
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(getattr(logging, level.upper()))
|
||||
console_formatter = TimestampedFormatter(
|
||||
fmt="%(asctime)s.%(msecs)03d %(level_aligned)s %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
handlers.append(console_handler)
|
||||
|
||||
# Level-based file handlers
|
||||
if enable_file:
|
||||
level_handler = LevelBasedFileHandler(
|
||||
log_dir=str(self.log_dir),
|
||||
base_name="doris_mcp_server",
|
||||
max_bytes=max_file_size,
|
||||
backup_count=backup_count
|
||||
)
|
||||
level_handler.setLevel(logging.DEBUG) # Accept all levels
|
||||
handlers.append(level_handler)
|
||||
|
||||
# Combined application log (all levels in one file)
|
||||
if enable_file:
|
||||
app_log_file = self.log_dir / "doris_mcp_server_all.log"
|
||||
app_handler = logging.handlers.RotatingFileHandler(
|
||||
app_log_file,
|
||||
maxBytes=max_file_size,
|
||||
backupCount=backup_count,
|
||||
encoding='utf-8'
|
||||
)
|
||||
app_handler.setLevel(getattr(logging, level.upper()))
|
||||
app_formatter = TimestampedFormatter()
|
||||
app_handler.setFormatter(app_formatter)
|
||||
handlers.append(app_handler)
|
||||
|
||||
# Audit logger (separate from main logging)
|
||||
if enable_audit:
|
||||
audit_file_path = audit_file or str(self.log_dir / "doris_mcp_server_audit.log")
|
||||
audit_logger = logging.getLogger("audit")
|
||||
audit_logger.setLevel(logging.INFO)
|
||||
|
||||
# Clear existing audit handlers
|
||||
for handler in audit_logger.handlers[:]:
|
||||
audit_logger.removeHandler(handler)
|
||||
|
||||
audit_handler = logging.handlers.RotatingFileHandler(
|
||||
audit_file_path,
|
||||
maxBytes=max_file_size,
|
||||
backupCount=backup_count,
|
||||
encoding='utf-8'
|
||||
)
|
||||
audit_formatter = TimestampedFormatter(
|
||||
fmt="%(asctime)s.%(msecs)03d [AUDIT] %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
audit_handler.setFormatter(audit_formatter)
|
||||
audit_logger.addHandler(audit_handler)
|
||||
audit_logger.propagate = False # Don't propagate to root logger
|
||||
|
||||
# Add all handlers to root logger
|
||||
for handler in handlers:
|
||||
root_logger.addHandler(handler)
|
||||
|
||||
# Setup package-specific loggers
|
||||
self._setup_package_loggers(level)
|
||||
|
||||
# Setup log cleanup manager
|
||||
if enable_cleanup and enable_file:
|
||||
self.cleanup_manager = LogCleanupManager(
|
||||
log_dir=str(self.log_dir),
|
||||
max_age_days=max_age_days,
|
||||
cleanup_interval_hours=cleanup_interval_hours
|
||||
)
|
||||
self.cleanup_manager.start_cleanup_scheduler()
|
||||
|
||||
self.is_initialized = True
|
||||
|
||||
# Log initialization message
|
||||
logger = self.get_logger("doris_mcp_server.logger")
|
||||
logger.info("=" * 80)
|
||||
logger.info("Doris MCP Server Logging System Initialized")
|
||||
logger.info(f"Log Level: {level}")
|
||||
if log_dir_writable:
|
||||
logger.info(f"Log Directory: {self.log_dir.absolute()}")
|
||||
else:
|
||||
logger.info("Log Directory: Not available (console-only mode)")
|
||||
logger.info(f"Console Logging: {'Enabled' if enable_console else 'Disabled'}")
|
||||
logger.info(f"File Logging: {'Enabled' if enable_file else 'Disabled (fallback mode)'}")
|
||||
logger.info(f"Audit Logging: {'Enabled' if enable_audit else 'Disabled (fallback mode)'}")
|
||||
logger.info(f"Log Cleanup: {'Enabled' if enable_cleanup and enable_file else 'Disabled (fallback mode)'}")
|
||||
if enable_cleanup and enable_file:
|
||||
logger.info(f"Cleanup Settings: Max age {max_age_days} days, interval {cleanup_interval_hours}h")
|
||||
if not log_dir_writable:
|
||||
logger.warning("Running in console-only logging mode due to filesystem permissions")
|
||||
logger.warning(f"Could not create log directory '{log_dir}' - stdio mode fallback enabled")
|
||||
logger.info("=" * 80)
|
||||
|
||||
def _setup_package_loggers(self, level: str):
|
||||
"""Setup specific loggers for different modules"""
|
||||
package_loggers = [
|
||||
"doris_mcp_server",
|
||||
"doris_mcp_server.main",
|
||||
"doris_mcp_server.utils",
|
||||
"doris_mcp_server.tools",
|
||||
"doris_mcp_client"
|
||||
]
|
||||
|
||||
for logger_name in package_loggers:
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(getattr(logging, level.upper()))
|
||||
# Don't add handlers here - they inherit from root logger
|
||||
|
||||
def get_logger(self, name: str) -> logging.Logger:
|
||||
"""
|
||||
Get a logger instance with proper configuration.
|
||||
|
||||
Args:
|
||||
name: Logger name (usually __name__)
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
if name not in self.loggers:
|
||||
logger = logging.getLogger(name)
|
||||
self.loggers[name] = logger
|
||||
|
||||
return self.loggers[name]
|
||||
|
||||
def get_audit_logger(self) -> logging.Logger:
|
||||
"""Get the audit logger"""
|
||||
return logging.getLogger("audit")
|
||||
|
||||
def log_system_info(self):
|
||||
"""Log system information for debugging"""
|
||||
logger = self.get_logger("doris_mcp_server.system")
|
||||
logger.info("System Information:")
|
||||
logger.info(f"Python Version: {sys.version}")
|
||||
logger.info(f"Platform: {sys.platform}")
|
||||
logger.info(f"Working Directory: {os.getcwd()}")
|
||||
logger.info(f"Process ID: {os.getpid()}")
|
||||
|
||||
# Log environment variables (filtered)
|
||||
env_vars = ["LOG_LEVEL", "LOG_FILE_PATH", "ENABLE_AUDIT", "AUDIT_FILE_PATH"]
|
||||
for var in env_vars:
|
||||
value = os.getenv(var, "Not Set")
|
||||
logger.info(f"Environment {var}: {value}")
|
||||
|
||||
def get_cleanup_stats(self) -> dict:
|
||||
"""Get log cleanup statistics"""
|
||||
if self.cleanup_manager:
|
||||
return self.cleanup_manager.get_cleanup_stats()
|
||||
else:
|
||||
return {"error": "Log cleanup is not enabled"}
|
||||
|
||||
def manual_cleanup(self) -> dict:
|
||||
"""Manually trigger log cleanup and return statistics"""
|
||||
if self.cleanup_manager:
|
||||
self.cleanup_manager.cleanup_old_logs()
|
||||
return self.cleanup_manager.get_cleanup_stats()
|
||||
else:
|
||||
return {"error": "Log cleanup is not enabled"}
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown logging system"""
|
||||
if not self.is_initialized:
|
||||
return
|
||||
|
||||
logger = self.get_logger("doris_mcp_server.logger")
|
||||
logger.info("Shutting down logging system...")
|
||||
|
||||
# Stop cleanup manager
|
||||
if self.cleanup_manager:
|
||||
self.cleanup_manager.stop_cleanup_scheduler()
|
||||
|
||||
# Close all handlers
|
||||
root_logger = logging.getLogger()
|
||||
for handler in root_logger.handlers[:]:
|
||||
try:
|
||||
handler.close()
|
||||
except Exception as e:
|
||||
print(f"Error closing handler: {e}")
|
||||
|
||||
# Close audit logger handlers
|
||||
audit_logger = logging.getLogger("audit")
|
||||
for handler in audit_logger.handlers[:]:
|
||||
try:
|
||||
handler.close()
|
||||
except Exception as e:
|
||||
print(f"Error closing audit handler: {e}")
|
||||
|
||||
self.is_initialized = False
|
||||
|
||||
|
||||
# Global logger manager instance
|
||||
_logger_manager = DorisLoggerManager()
|
||||
|
||||
|
||||
def setup_logging(level: str = "INFO",
|
||||
log_dir: str = "logs",
|
||||
enable_console: bool = True,
|
||||
enable_file: bool = True,
|
||||
enable_audit: bool = True,
|
||||
audit_file: Optional[str] = None,
|
||||
max_file_size: int = 10*1024*1024,
|
||||
backup_count: int = 5,
|
||||
enable_cleanup: bool = True,
|
||||
max_age_days: int = 30,
|
||||
cleanup_interval_hours: int = 24) -> None:
|
||||
"""
|
||||
Setup logging configuration (convenience function).
|
||||
|
||||
Args:
|
||||
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
log_dir: Directory for log files
|
||||
enable_console: Enable console output
|
||||
enable_file: Enable file logging
|
||||
enable_audit: Enable audit logging
|
||||
audit_file: Custom audit log file path
|
||||
max_file_size: Maximum size per log file (bytes)
|
||||
backup_count: Number of backup files to keep
|
||||
enable_cleanup: Enable automatic log cleanup
|
||||
max_age_days: Maximum age of log files in days (default: 30)
|
||||
cleanup_interval_hours: Cleanup interval in hours (default: 24)
|
||||
"""
|
||||
_logger_manager.setup_logging(
|
||||
level=level,
|
||||
log_dir=log_dir,
|
||||
enable_console=enable_console,
|
||||
enable_file=enable_file,
|
||||
enable_audit=enable_audit,
|
||||
audit_file=audit_file,
|
||||
max_file_size=max_file_size,
|
||||
backup_count=backup_count,
|
||||
enable_cleanup=enable_cleanup,
|
||||
max_age_days=max_age_days,
|
||||
cleanup_interval_hours=cleanup_interval_hours
|
||||
)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
@@ -93,9 +589,60 @@ def get_logger(name: str) -> logging.Logger:
|
||||
Get a logger instance.
|
||||
|
||||
Args:
|
||||
name: Logger name
|
||||
name: Logger name (usually __name__)
|
||||
|
||||
Returns:
|
||||
Logger instance
|
||||
Configured logger instance
|
||||
"""
|
||||
return logging.getLogger(name)
|
||||
return _logger_manager.get_logger(name)
|
||||
|
||||
|
||||
def get_audit_logger() -> logging.Logger:
|
||||
"""Get the audit logger"""
|
||||
return _logger_manager.get_audit_logger()
|
||||
|
||||
|
||||
def log_system_info():
|
||||
"""Log system information for debugging"""
|
||||
_logger_manager.log_system_info()
|
||||
|
||||
|
||||
def get_cleanup_stats() -> dict:
|
||||
"""Get log cleanup statistics"""
|
||||
return _logger_manager.get_cleanup_stats()
|
||||
|
||||
|
||||
def manual_cleanup() -> dict:
|
||||
"""Manually trigger log cleanup and return statistics"""
|
||||
return _logger_manager.manual_cleanup()
|
||||
|
||||
|
||||
def shutdown_logging():
|
||||
"""Shutdown logging system"""
|
||||
_logger_manager.shutdown()
|
||||
|
||||
|
||||
# Compatibility function for existing code
|
||||
def setup_logging_old(level: str = "INFO",
|
||||
log_file: str | None = None,
|
||||
log_format: str | None = None) -> None:
|
||||
"""
|
||||
Legacy setup function for backward compatibility.
|
||||
|
||||
Args:
|
||||
level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
log_file: Optional log file path (deprecated - use log_dir instead)
|
||||
log_format: Optional custom log format (deprecated)
|
||||
"""
|
||||
# Extract directory from log_file if provided
|
||||
log_dir = "logs"
|
||||
if log_file:
|
||||
log_dir = str(Path(log_file).parent)
|
||||
|
||||
setup_logging(
|
||||
level=level,
|
||||
log_dir=log_dir,
|
||||
enable_console=True,
|
||||
enable_file=True,
|
||||
enable_audit=True
|
||||
)
|
||||
|
||||
@@ -28,6 +28,7 @@ from datetime import datetime
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import get_auth_context
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -713,7 +714,8 @@ class DorisMonitoringTools:
|
||||
# Fallback to SHOW BACKENDS if no BE hosts configured
|
||||
logger.info("No BE hosts configured, using SHOW BACKENDS to discover BE nodes")
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
result = await connection.execute("SHOW BACKENDS")
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute("SHOW BACKENDS", auth_context=auth_context)
|
||||
|
||||
be_nodes = []
|
||||
for row in result.data:
|
||||
|
||||
1810
doris_mcp_server/utils/performance_analytics_tools.py
Normal file
1810
doris_mcp_server/utils/performance_analytics_tools.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -33,7 +33,11 @@ from datetime import datetime, timedelta, date
|
||||
from typing import Any, Dict
|
||||
from decimal import Decimal
|
||||
|
||||
import sqlparse
|
||||
|
||||
from .db import DorisConnectionManager, QueryResult
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import get_auth_context
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -92,7 +96,7 @@ class QueryCache:
|
||||
self.max_size = max_size
|
||||
self.default_ttl = default_ttl
|
||||
self.cache: dict[str, CachedQuery] = {}
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def _generate_cache_key(
|
||||
self, sql: str, parameters: dict[str, Any] | None = None
|
||||
@@ -194,7 +198,7 @@ class QueryOptimizer:
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger = get_logger(__name__)
|
||||
self.optimization_rules = self._load_optimization_rules()
|
||||
|
||||
def _load_optimization_rules(self) -> list[dict[str, Any]]:
|
||||
@@ -318,7 +322,7 @@ class DorisQueryExecutor:
|
||||
def __init__(self, connection_manager: DorisConnectionManager, config=None):
|
||||
self.connection_manager = connection_manager
|
||||
self.config = config or self._create_default_config()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Initialize components
|
||||
cache_config = getattr(self.config, 'performance', None)
|
||||
@@ -425,27 +429,27 @@ class DorisQueryExecutor:
|
||||
self, query_request: QueryRequest, auth_context
|
||||
) -> QueryResult:
|
||||
"""Internal query execution"""
|
||||
|
||||
# Database configuration should already be handled during authentication
|
||||
# No need to configure again during query execution
|
||||
|
||||
# Optimize query
|
||||
optimized_sql = await self.query_optimizer.optimize_query(
|
||||
query_request.sql, {"user_roles": getattr(auth_context, 'roles', [])}
|
||||
)
|
||||
|
||||
# Execute query
|
||||
connection = await self.connection_manager.get_connection(
|
||||
query_request.session_id
|
||||
)
|
||||
|
||||
# Set timeout if specified
|
||||
if query_request.timeout:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
connection.execute(optimized_sql, query_request.parameters, auth_context),
|
||||
self.connection_manager.execute_query(query_request.session_id, optimized_sql, query_request.parameters, auth_context),
|
||||
timeout=query_request.timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(f"Query timeout after {query_request.timeout} seconds")
|
||||
else:
|
||||
result = await connection.execute(optimized_sql, query_request.parameters, auth_context)
|
||||
result = await self.connection_manager.execute_query(query_request.session_id, optimized_sql, query_request.parameters, auth_context)
|
||||
|
||||
return result
|
||||
|
||||
@@ -466,6 +470,51 @@ class DorisQueryExecutor:
|
||||
f"Slow query detected: {execution_time:.2f}s (threshold: {self.slow_query_threshold}s)"
|
||||
)
|
||||
|
||||
async def execute_batch_sqls_for_mcp(
|
||||
self, sqls: list[str],
|
||||
timeout: int = 30,
|
||||
session_id: str = "mcp_session",
|
||||
user_id: str = "mcp_user",
|
||||
auth_context=None
|
||||
) -> dict[str, Any]:
|
||||
"""Execute multiple sqls in batch"""
|
||||
if not sqls:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "SQL query is required",
|
||||
"data": None
|
||||
}
|
||||
query_requests = [
|
||||
QueryRequest(
|
||||
sql=sql,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
timeout=timeout,
|
||||
cache_enabled=False
|
||||
)
|
||||
for sql in sqls
|
||||
]
|
||||
query_results = await self.execute_batch_queries(query_requests, auth_context)
|
||||
# Serialize data for JSON response
|
||||
results = [
|
||||
{
|
||||
"data": [self._serialize_row_data(data) for data in result.data],
|
||||
"row_count": result.row_count,
|
||||
"execution_time": result.execution_time,
|
||||
"metadata": {
|
||||
"columns": result.metadata.get("columns", []),
|
||||
"query": result.sql
|
||||
}
|
||||
}
|
||||
for result in query_results
|
||||
]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"multiple_results": True,
|
||||
"results": results
|
||||
}
|
||||
|
||||
async def execute_batch_queries(
|
||||
self, query_requests: list[QueryRequest], auth_context=None
|
||||
) -> list[QueryResult]:
|
||||
@@ -483,20 +532,24 @@ class DorisQueryExecutor:
|
||||
self.execute_query(request, auth_context) for request in query_requests
|
||||
]
|
||||
|
||||
try:
|
||||
query_results = []
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Batch query execution failed: {e}")
|
||||
raise
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
self.logger.error(f"Batch query execution failed: {result}")
|
||||
raise result
|
||||
else:
|
||||
query_results.append(result)
|
||||
|
||||
return results
|
||||
return query_results
|
||||
|
||||
async def explain_query(self, sql: str, session_id: str) -> dict[str, Any]:
|
||||
"""Get query execution plan"""
|
||||
explain_sql = f"EXPLAIN {sql}"
|
||||
|
||||
connection = await self.connection_manager.get_connection(session_id)
|
||||
result = await connection.execute(explain_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(explain_sql, auth_context=auth_context)
|
||||
|
||||
return {
|
||||
"query": sql,
|
||||
@@ -545,9 +598,13 @@ class DorisQueryExecutor:
|
||||
limit: int = 1000,
|
||||
timeout: int = 30,
|
||||
session_id: str = "mcp_session",
|
||||
user_id: str = "mcp_user"
|
||||
user_id: str = "mcp_user",
|
||||
auth_context = None # FIX for Issue #62 Bug 1: Accept auth_context with token
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute SQL query for MCP interface - unified method"""
|
||||
"""Execute SQL query for MCP interface - unified method
|
||||
|
||||
FIX for Issue #62 Bug 1: Now accepts auth_context parameter to support token-bound database configuration
|
||||
"""
|
||||
max_retries = 2
|
||||
retry_count = 0
|
||||
|
||||
@@ -560,23 +617,89 @@ class DorisQueryExecutor:
|
||||
"data": None
|
||||
}
|
||||
|
||||
# Import required security modules
|
||||
from .security import DorisSecurityManager, AuthContext, SecurityLevel
|
||||
|
||||
# FIX: Use provided auth_context if available (contains token for DB config)
|
||||
# Otherwise create default auth context for backward compatibility
|
||||
if auth_context is None:
|
||||
auth_context = AuthContext(
|
||||
user_id=user_id,
|
||||
roles=["read_only_user"], # Restrictive role for MCP interface
|
||||
permissions=["read_data"], # Only read permissions
|
||||
session_id=session_id,
|
||||
security_level=SecurityLevel.INTERNAL,
|
||||
token="" # No token in default context
|
||||
)
|
||||
else:
|
||||
# Use provided auth_context (may contain token for database configuration)
|
||||
self.logger.debug(f"Using provided auth_context with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
|
||||
|
||||
# Perform SQL security validation if enabled
|
||||
if hasattr(self.connection_manager, 'config') and hasattr(self.connection_manager.config, 'security'):
|
||||
if self.connection_manager.config.security.enable_security_check:
|
||||
try:
|
||||
# 🔧 FIX: Use existing security_manager to avoid creating multiple TokenManager instances
|
||||
# Creating new DorisSecurityManager each time causes multiple hot reload monitors
|
||||
security_manager = getattr(self.connection_manager, 'security_manager', None)
|
||||
if not security_manager:
|
||||
# Fallback: create new one only if not available (should rarely happen)
|
||||
self.logger.warning("No existing security_manager, creating new instance")
|
||||
security_manager = DorisSecurityManager(self.connection_manager.config)
|
||||
validation_result = await security_manager.validate_sql_security(sql, auth_context)
|
||||
|
||||
if not validation_result.is_valid:
|
||||
self.logger.warning(f"SQL security validation failed for query: {sql[:100]}...")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"SQL security validation failed: {validation_result.error_message}",
|
||||
"error_type": "security_violation",
|
||||
"blocked_operations": validation_result.blocked_operations,
|
||||
"risk_level": validation_result.risk_level,
|
||||
"data": None,
|
||||
"metadata": {
|
||||
"query": sql,
|
||||
"validation_details": {
|
||||
"blocked_operations": validation_result.blocked_operations,
|
||||
"risk_level": validation_result.risk_level
|
||||
}
|
||||
}
|
||||
}
|
||||
else:
|
||||
self.logger.debug(f"SQL security validation passed for query: {sql[:100]}...")
|
||||
except Exception as security_error:
|
||||
self.logger.error(f"Security validation error: {str(security_error)}")
|
||||
# In case of security validation error, fail safe
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Security validation system error: {str(security_error)}",
|
||||
"error_type": "security_system_error",
|
||||
"data": None,
|
||||
"metadata": {
|
||||
"query": sql,
|
||||
"security_error": str(security_error)
|
||||
}
|
||||
}
|
||||
else:
|
||||
self.logger.info("SQL security check is disabled in configuration")
|
||||
else:
|
||||
self.logger.warning("Security configuration not found, proceeding without validation")
|
||||
|
||||
# Add LIMIT if not present and it's a SELECT query
|
||||
if sql.upper().startswith("SELECT") and "LIMIT" not in sql.upper():
|
||||
if sql.endswith(";"):
|
||||
sql = sql[:-1]
|
||||
sql = f"{sql} LIMIT {limit}"
|
||||
|
||||
# Create auth context for MCP calls
|
||||
class MockAuthContext:
|
||||
def __init__(self):
|
||||
self.user_id = user_id
|
||||
self.roles = ["data_analyst"]
|
||||
self.permissions = ["read_data", "execute_query"]
|
||||
self.session_id = session_id
|
||||
self.security_level = "internal"
|
||||
|
||||
auth_context = MockAuthContext()
|
||||
|
||||
all_statements = [
|
||||
s.strip()
|
||||
for s in sqlparse.split(sql)
|
||||
if s.strip()
|
||||
]
|
||||
if len(all_statements) > 1:
|
||||
return await self.execute_batch_sqls_for_mcp(sqls=all_statements, timeout=timeout,
|
||||
session_id=session_id, user_id=user_id,
|
||||
auth_context=auth_context)
|
||||
# Create query request
|
||||
query_request = QueryRequest(
|
||||
sql=sql,
|
||||
@@ -587,7 +710,6 @@ class DorisQueryExecutor:
|
||||
)
|
||||
|
||||
# Execute query with retry logic
|
||||
try:
|
||||
result = await self.execute_query(query_request, auth_context)
|
||||
|
||||
# Serialize data for JSON response
|
||||
@@ -606,9 +728,11 @@ class DorisQueryExecutor:
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as query_error:
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
error_str = error_msg.lower()
|
||||
|
||||
# Check if it's a connection-related error that we should retry
|
||||
error_str = str(query_error).lower()
|
||||
connection_errors = [
|
||||
"at_eof", "connection", "closed", "nonetype",
|
||||
"transport", "reader", "broken pipe", "connection reset"
|
||||
@@ -618,7 +742,7 @@ class DorisQueryExecutor:
|
||||
|
||||
if is_connection_error and retry_count < max_retries:
|
||||
retry_count += 1
|
||||
self.logger.warning(f"Connection error detected, retrying ({retry_count}/{max_retries}): {query_error}")
|
||||
self.logger.warning(f"Connection error detected, retrying ({retry_count}/{max_retries}): {e}")
|
||||
|
||||
# Release the problematic connection
|
||||
try:
|
||||
@@ -630,14 +754,7 @@ class DorisQueryExecutor:
|
||||
await asyncio.sleep(0.5 * retry_count)
|
||||
continue
|
||||
else:
|
||||
# Re-raise if not a connection error or max retries exceeded
|
||||
raise query_error
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
# If we've exhausted retries or it's not a connection error, return error
|
||||
if retry_count >= max_retries or "at_eof" not in error_msg.lower():
|
||||
error_analysis = self._analyze_error(error_msg)
|
||||
|
||||
return {
|
||||
@@ -651,21 +768,14 @@ class DorisQueryExecutor:
|
||||
"retry_count": retry_count
|
||||
}
|
||||
}
|
||||
else:
|
||||
# Try one more time for connection errors
|
||||
retry_count += 1
|
||||
if retry_count <= max_retries:
|
||||
self.logger.warning(f"Retrying query due to connection error ({retry_count}/{max_retries}): {e}")
|
||||
await asyncio.sleep(0.5 * retry_count)
|
||||
continue
|
||||
else:
|
||||
|
||||
# This should never be reached, but just in case
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Query failed after {max_retries} retries: {error_msg}",
|
||||
"error": "Maximum retries exceeded",
|
||||
"data": None,
|
||||
"metadata": {
|
||||
"query": sql,
|
||||
"error_details": error_msg,
|
||||
"retry_count": retry_count
|
||||
}
|
||||
}
|
||||
@@ -759,7 +869,7 @@ class QueryPerformanceMonitor:
|
||||
|
||||
def __init__(self, query_executor: DorisQueryExecutor):
|
||||
self.query_executor = query_executor
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger = get_logger(__name__)
|
||||
self.performance_records = []
|
||||
|
||||
async def record_query_performance(
|
||||
@@ -843,32 +953,51 @@ class QueryPerformanceMonitor:
|
||||
|
||||
# Unified convenience function for MCP integration
|
||||
async def execute_sql_query(sql: str, connection_manager: DorisConnectionManager, **kwargs) -> Dict[str, Any]:
|
||||
"""Execute SQL query - unified convenience function for MCP tools"""
|
||||
"""Execute SQL query - unified convenience function for MCP tools
|
||||
|
||||
This function now includes security validation to ensure safe query execution.
|
||||
All queries are validated against the configured security policies before execution.
|
||||
|
||||
FIX for Issue #62 Bug 1: Now supports auth_context parameter for token-bound database configuration
|
||||
FIX for Issue #58 Problem 2: Removed executor.close() to prevent ClosedResourceError in multi-worker mode
|
||||
"""
|
||||
try:
|
||||
# Create query executor
|
||||
# Create query executor with the connection manager's configuration
|
||||
executor = DorisQueryExecutor(connection_manager)
|
||||
|
||||
try:
|
||||
# Extract parameters from kwargs or use defaults
|
||||
limit = kwargs.get("limit", 1000)
|
||||
timeout = kwargs.get("timeout", 30)
|
||||
session_id = kwargs.get("session_id", "mcp_session")
|
||||
user_id = kwargs.get("user_id", "mcp_user")
|
||||
auth_context = kwargs.get("auth_context", None) # FIX: Extract auth_context
|
||||
|
||||
# The execute_sql_for_mcp method now includes security validation
|
||||
result = await executor.execute_sql_for_mcp(
|
||||
sql=sql,
|
||||
limit=limit,
|
||||
timeout=timeout,
|
||||
session_id=session_id,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
auth_context=auth_context # FIX: Pass auth_context with token
|
||||
)
|
||||
|
||||
# FIX for Issue #58 Problem 2: Do NOT close executor here
|
||||
# In multi-worker mode, closing here causes ClosedResourceError
|
||||
# The executor's resources (cache, background tasks) will be managed
|
||||
# by the connection_manager lifecycle and Python's garbage collection
|
||||
# This prevents premature cleanup while MCP session manager is still processing
|
||||
|
||||
return result
|
||||
finally:
|
||||
await executor.close()
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Query execution failed: {str(e)}",
|
||||
"data": None
|
||||
"error_type": "execution_error",
|
||||
"data": None,
|
||||
"metadata": {
|
||||
"query": sql,
|
||||
"execution_error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,14 +31,16 @@ from dotenv import load_dotenv
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Import unified logging configuration
|
||||
from doris_mcp_server.utils.logger import get_logger
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import (
|
||||
SQLSecurityError,
|
||||
validate_identifier,
|
||||
quote_identifier
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
|
||||
METADATA_DB_NAME="information_schema"
|
||||
ENABLE_MULTI_DATABASE=os.getenv("ENABLE_MULTI_DATABASE",True)
|
||||
MULTI_DATABASE_NAMES=os.getenv("MULTI_DATABASE_NAMES","")
|
||||
@@ -416,7 +418,7 @@ class MetadataExtractor:
|
||||
|
||||
return matches
|
||||
|
||||
def get_table_schema(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
async def get_table_schema(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the schema information for a table
|
||||
|
||||
@@ -434,12 +436,22 @@ class MetadataExtractor:
|
||||
logger.warning("Database name not specified")
|
||||
return {}
|
||||
|
||||
# SECURITY FIX: Validate identifiers to prevent SQL injection
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
validate_identifier(db_name, "database name")
|
||||
if effective_catalog:
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected in get_table_schema: {e}")
|
||||
return {}
|
||||
|
||||
cache_key = f"schema_{effective_catalog or 'default'}_{db_name}_{table_name}"
|
||||
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
|
||||
return self.metadata_cache[cache_key]
|
||||
|
||||
try:
|
||||
# Use information_schema.columns table to get table schema
|
||||
# Use information_schema.columns table to get table schema (async)
|
||||
query = f"""
|
||||
SELECT
|
||||
COLUMN_NAME,
|
||||
@@ -459,7 +471,7 @@ class MetadataExtractor:
|
||||
ORDINAL_POSITION
|
||||
"""
|
||||
|
||||
result = self._execute_query_with_catalog(query, db_name, effective_catalog)
|
||||
result = await self._execute_query_with_catalog_async(query, db_name, effective_catalog)
|
||||
|
||||
if not result:
|
||||
logger.warning(f"Table {effective_catalog or 'default'}.{db_name}.{table_name} does not exist or has no columns")
|
||||
@@ -468,7 +480,6 @@ class MetadataExtractor:
|
||||
# Create structured table schema information
|
||||
columns = []
|
||||
for col in result:
|
||||
# Ensure using actual column values, not column names
|
||||
column_info = {
|
||||
"name": col.get("COLUMN_NAME", ""),
|
||||
"type": col.get("DATA_TYPE", ""),
|
||||
@@ -481,8 +492,8 @@ class MetadataExtractor:
|
||||
}
|
||||
columns.append(column_info)
|
||||
|
||||
# Get table comment
|
||||
table_comment = self.get_table_comment(table_name, db_name, effective_catalog)
|
||||
# Get table comment (async)
|
||||
table_comment = await self.get_table_comment_async(table_name, db_name, effective_catalog)
|
||||
|
||||
# Build complete structure
|
||||
schema = {
|
||||
@@ -493,7 +504,7 @@ class MetadataExtractor:
|
||||
"create_time": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Get table type information
|
||||
# Get table type information (async)
|
||||
try:
|
||||
table_type_query = f"""
|
||||
SELECT
|
||||
@@ -505,7 +516,7 @@ class MetadataExtractor:
|
||||
TABLE_SCHEMA = '{db_name}'
|
||||
AND TABLE_NAME = '{table_name}'
|
||||
"""
|
||||
table_type_result = self._execute_query(table_type_query)
|
||||
table_type_result = await self._execute_query_async(table_type_query)
|
||||
if table_type_result:
|
||||
schema["table_type"] = table_type_result[0].get("TABLE_TYPE", "")
|
||||
schema["engine"] = table_type_result[0].get("ENGINE", "")
|
||||
@@ -521,6 +532,7 @@ class MetadataExtractor:
|
||||
logger.error(f"Error getting table schema: {str(e)}")
|
||||
return {}
|
||||
|
||||
# Deprecated: sync method (kept for compatibility, will be removed)
|
||||
def get_table_comment(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> str:
|
||||
"""
|
||||
Get the comment for a table
|
||||
@@ -539,6 +551,16 @@ class MetadataExtractor:
|
||||
logger.warning("Database name not specified")
|
||||
return ""
|
||||
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
validate_identifier(db_name, "database name")
|
||||
if effective_catalog:
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return ""
|
||||
|
||||
cache_key = f"table_comment_{effective_catalog or 'default'}_{db_name}_{table_name}"
|
||||
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
|
||||
return self.metadata_cache[cache_key]
|
||||
@@ -571,6 +593,7 @@ class MetadataExtractor:
|
||||
logger.error(f"Error getting table comment: {str(e)}")
|
||||
return ""
|
||||
|
||||
# Deprecated: sync method (kept for compatibility, will be removed)
|
||||
def get_column_comments(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> Dict[str, str]:
|
||||
"""
|
||||
Get comments for all columns in a table
|
||||
@@ -589,6 +612,16 @@ class MetadataExtractor:
|
||||
logger.warning("Database name not specified")
|
||||
return {}
|
||||
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
validate_identifier(db_name, "database name")
|
||||
if effective_catalog:
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return {}
|
||||
|
||||
cache_key = f"column_comments_{effective_catalog or 'default'}_{db_name}_{table_name}"
|
||||
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
|
||||
return self.metadata_cache[cache_key]
|
||||
@@ -626,6 +659,7 @@ class MetadataExtractor:
|
||||
logger.error(f"Error getting column comments: {str(e)}")
|
||||
return {}
|
||||
|
||||
# Deprecated: sync method (kept for compatibility, will be removed)
|
||||
def get_table_indexes(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get the index information for a table
|
||||
@@ -644,64 +678,62 @@ class MetadataExtractor:
|
||||
logger.error("Database name not specified")
|
||||
return []
|
||||
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
validate_identifier(db_name, "database name")
|
||||
if effective_catalog:
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
cache_key = f"indexes_{effective_catalog or 'default'}_{db_name}_{table_name}"
|
||||
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
|
||||
return self.metadata_cache[cache_key]
|
||||
|
||||
try:
|
||||
# Build query with catalog prefix if specified
|
||||
# Build query with catalog prefix if specified (identifiers already validated)
|
||||
safe_table = quote_identifier(table_name, "table name")
|
||||
safe_db = quote_identifier(db_name, "database name")
|
||||
if effective_catalog:
|
||||
query = f"SHOW INDEX FROM `{effective_catalog}`.`{db_name}`.`{table_name}`"
|
||||
safe_catalog = quote_identifier(effective_catalog, "catalog name")
|
||||
query = f"SHOW INDEX FROM {safe_catalog}.{safe_db}.{safe_table}"
|
||||
logger.info(f"Using three-part naming for index query: {query}")
|
||||
else:
|
||||
query = f"SHOW INDEX FROM `{db_name}`.`{table_name}`"
|
||||
query = f"SHOW INDEX FROM {safe_db}.{safe_table}"
|
||||
|
||||
try:
|
||||
df = self._execute_query(query, return_dataframe=True)
|
||||
|
||||
# Process results
|
||||
# NOTE: Deprecated sync path retained for compatibility; use async variant instead.
|
||||
# Deprecated sync path removed; return empty indexes on failure
|
||||
result = []
|
||||
indexes = []
|
||||
current_index = None
|
||||
|
||||
if not df.empty:
|
||||
for _, row in df.iterrows():
|
||||
if result:
|
||||
for r in result:
|
||||
try:
|
||||
index_name = row['Key_name']
|
||||
column_name = row['Column_name']
|
||||
|
||||
if current_index is None or current_index['name'] != index_name:
|
||||
index_name = r.get('Key_name')
|
||||
column_name = r.get('Column_name')
|
||||
if current_index is None or current_index.get('name') != index_name:
|
||||
if current_index is not None:
|
||||
indexes.append(current_index)
|
||||
|
||||
current_index = {
|
||||
'name': index_name,
|
||||
'columns': [column_name],
|
||||
'unique': row['Non_unique'] == 0,
|
||||
'type': row['Index_type']
|
||||
'columns': [column_name] if column_name else [],
|
||||
'unique': r.get('Non_unique', 1) == 0,
|
||||
'type': r.get('Index_type', '')
|
||||
}
|
||||
else:
|
||||
if column_name:
|
||||
current_index['columns'].append(column_name)
|
||||
except Exception as row_error:
|
||||
logger.warning(f"Failed to process index row data: {row_error}")
|
||||
continue
|
||||
|
||||
if current_index is not None:
|
||||
indexes.append(current_index)
|
||||
except Exception as df_error:
|
||||
logger.warning(f"DataFrame processing failed, trying regular query: {df_error}")
|
||||
# Fall back to regular query
|
||||
result = self._execute_query(query, return_dataframe=False)
|
||||
logger.warning(f"Sync index query (deprecated) failed: {df_error}")
|
||||
indexes = []
|
||||
if result:
|
||||
# Simple processing, no complex index grouping
|
||||
for row in result:
|
||||
if isinstance(row, dict):
|
||||
indexes.append({
|
||||
'name': row.get('Key_name', ''),
|
||||
'columns': [row.get('Column_name', '')],
|
||||
'unique': row.get('Non_unique', 1) == 0,
|
||||
'type': row.get('Index_type', '')
|
||||
})
|
||||
|
||||
# Update cache
|
||||
self.metadata_cache[cache_key] = indexes
|
||||
@@ -712,7 +744,7 @@ class MetadataExtractor:
|
||||
logger.error(f"Error getting index information: {str(e)}")
|
||||
return []
|
||||
|
||||
def get_table_relationships(self) -> List[Dict[str, Any]]:
|
||||
async def get_table_relationships(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Infer table relationships from table comments and naming patterns
|
||||
|
||||
@@ -725,13 +757,13 @@ class MetadataExtractor:
|
||||
|
||||
try:
|
||||
# Get all tables
|
||||
tables = self.get_database_tables(self.db_name)
|
||||
tables = await self.get_database_tables_async(self.db_name)
|
||||
relationships = []
|
||||
|
||||
# Simple foreign key naming convention detection
|
||||
# Example: If a table has a column named xxx_id and another table named xxx exists, it might be a foreign key relationship
|
||||
for table_name in tables:
|
||||
schema = self.get_table_schema(table_name, self.db_name)
|
||||
schema = await self.get_table_schema(table_name, self.db_name)
|
||||
columns = schema.get("columns", [])
|
||||
|
||||
for column in columns:
|
||||
@@ -743,7 +775,7 @@ class MetadataExtractor:
|
||||
# Check if the possible table exists
|
||||
if ref_table_name in tables:
|
||||
# Find possible primary key column
|
||||
ref_schema = self.get_table_schema(ref_table_name, self.db_name)
|
||||
ref_schema = await self.get_table_schema(ref_table_name, self.db_name)
|
||||
ref_columns = ref_schema.get("columns", [])
|
||||
|
||||
# Assume primary key column name is id
|
||||
@@ -766,6 +798,7 @@ class MetadataExtractor:
|
||||
logger.error(f"Error inferring table relationships: {str(e)}")
|
||||
return []
|
||||
|
||||
# Deprecated: sync method (kept for compatibility, will be removed)
|
||||
def get_recent_audit_logs(self, days: int = 7, limit: int = 100) -> pd.DataFrame:
|
||||
"""
|
||||
Get recent audit logs
|
||||
@@ -792,13 +825,14 @@ class MetadataExtractor:
|
||||
ORDER BY time DESC
|
||||
LIMIT {limit}
|
||||
"""
|
||||
df = self._execute_query(query, return_dataframe=True)
|
||||
# Deprecated sync path removed; this method is deprecated overall
|
||||
df = pd.DataFrame()
|
||||
return df
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting audit logs: {str(e)}")
|
||||
return pd.DataFrame()
|
||||
|
||||
def get_catalog_list(self) -> List[Dict[str, Any]]:
|
||||
async def get_catalog_list(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get a list of all catalogs in Doris with detailed information
|
||||
|
||||
@@ -812,7 +846,7 @@ class MetadataExtractor:
|
||||
try:
|
||||
# Use SHOW CATALOGS command to get catalog list
|
||||
query = "SHOW CATALOGS"
|
||||
result = self._execute_query(query)
|
||||
result = await self._execute_query_async(query)
|
||||
|
||||
if not result:
|
||||
catalogs = []
|
||||
@@ -1101,7 +1135,8 @@ class MetadataExtractor:
|
||||
AND TABLE_NAME = '{table_name}'
|
||||
"""
|
||||
|
||||
partitions = self._execute_query(query)
|
||||
# Deprecated sync path removed
|
||||
partitions = []
|
||||
|
||||
if not partitions:
|
||||
return {}
|
||||
@@ -1124,31 +1159,25 @@ class MetadataExtractor:
|
||||
logger.error(f"Error getting partition information for table {db_name}.{table_name}: {str(e)}")
|
||||
return {}
|
||||
|
||||
def _execute_query_with_catalog(self, query: str, db_name: str = None, catalog_name: str = None):
|
||||
# Removed sync _execute_query_with_catalog; use async variant instead
|
||||
|
||||
async def _execute_query_with_catalog_async(self, query: str, db_name: str = None, catalog_name: str = None):
|
||||
"""
|
||||
Execute query with catalog-aware metadata operations using three-part naming
|
||||
Async version of _execute_query_with_catalog to avoid cross-event-loop issues.
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
db_name: Database name to use
|
||||
catalog_name: Catalog name for three-part naming
|
||||
|
||||
Returns:
|
||||
Query result
|
||||
When catalog_name is provided and the SQL targets information_schema, we rewrite
|
||||
the SQL to use three-part naming: `{catalog}.information_schema` and execute it
|
||||
via the same running event loop.
|
||||
"""
|
||||
try:
|
||||
# If catalog_name is specified, modify the query to use three-part naming
|
||||
# for information_schema queries
|
||||
if catalog_name and 'information_schema' in query.lower():
|
||||
# Replace 'information_schema' with 'catalog_name.information_schema'
|
||||
modified_query = query.replace('information_schema', f'{catalog_name}.information_schema')
|
||||
logger.info(f"Modified query for catalog {catalog_name}: {modified_query}")
|
||||
return self._execute_query(modified_query, db_name)
|
||||
return await self._execute_query_async(modified_query, db_name)
|
||||
else:
|
||||
# Execute the original query
|
||||
return self._execute_query(query, db_name)
|
||||
return await self._execute_query_async(query, db_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing query with catalog: {str(e)}")
|
||||
logger.error(f"Error executing async query with catalog: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _execute_query_async(self, query: str, db_name: str = None, return_dataframe: bool = False):
|
||||
@@ -1165,8 +1194,17 @@ class MetadataExtractor:
|
||||
"""
|
||||
try:
|
||||
if self.connection_manager:
|
||||
# FIX: Get auth_context from global ContextVar for token-bound database configuration
|
||||
# This ensures all query methods use the correct user's connection pool
|
||||
auth_context = None
|
||||
try:
|
||||
from .security import mcp_auth_context_var
|
||||
auth_context = mcp_auth_context_var.get()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Use the injected connection manager directly (async)
|
||||
result = await self.connection_manager.execute_query(self._session_id, query, None)
|
||||
result = await self.connection_manager.execute_query(self._session_id, query, None, auth_context)
|
||||
|
||||
# Extract data from QueryResult
|
||||
if hasattr(result, 'data'):
|
||||
@@ -1200,76 +1238,35 @@ class MetadataExtractor:
|
||||
else:
|
||||
return []
|
||||
|
||||
def _execute_query(self, query: str, db_name: str = None, return_dataframe: bool = False):
|
||||
"""
|
||||
Execute database query with proper session management (sync wrapper)
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
db_name: Database name to use (optional)
|
||||
return_dataframe: Whether to return a pandas DataFrame instead of list
|
||||
|
||||
Returns:
|
||||
Query result data (list of dictionaries or pandas DataFrame)
|
||||
"""
|
||||
try:
|
||||
if self.connection_manager:
|
||||
import asyncio
|
||||
|
||||
# Try to run the async query
|
||||
try:
|
||||
# Check if there's a running event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we're in an async context, we need to run in a separate thread
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_new_loop():
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(
|
||||
self._execute_query_async(query, db_name, return_dataframe)
|
||||
)
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
return future.result(timeout=30)
|
||||
|
||||
except RuntimeError:
|
||||
# No running loop, we can safely create one
|
||||
return asyncio.run(
|
||||
self._execute_query_async(query, db_name, return_dataframe)
|
||||
)
|
||||
else:
|
||||
# Fallback: Return empty result
|
||||
logger.warning("No connection manager provided, returning empty result")
|
||||
if return_dataframe:
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing query: {str(e)}")
|
||||
# Return empty result instead of raising exception to prevent cascade failures
|
||||
if return_dataframe:
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
return []
|
||||
# Removed sync _execute_query; use async methods exclusively
|
||||
|
||||
async def get_table_schema_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> List[Dict[str, Any]]:
|
||||
"""Asynchronously get table schema information"""
|
||||
try:
|
||||
# Use async query method
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
effective_db = db_name or self.db_name
|
||||
|
||||
# Build query statement
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if effective_db:
|
||||
validate_identifier(effective_db, "database name")
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
query = f"DESCRIBE `{effective_catalog}`.`{db_name or self.db_name}`.`{table_name}`"
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
# Build query statement using safe identifiers
|
||||
safe_table = quote_identifier(table_name, "table name")
|
||||
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
|
||||
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
safe_catalog = quote_identifier(effective_catalog, "catalog name")
|
||||
query = f"DESCRIBE {safe_catalog}.{safe_db}.{safe_table}"
|
||||
else:
|
||||
query = f"DESCRIBE `{db_name or self.db_name}`.`{table_name}`"
|
||||
query = f"DESCRIBE {safe_db}.{safe_table}"
|
||||
|
||||
# Execute async query
|
||||
result = await self._execute_query_async(query, db_name)
|
||||
@@ -1302,8 +1299,15 @@ class MetadataExtractor:
|
||||
try:
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
# SECURITY FIX: Validate catalog name if provided
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
query = f"SHOW DATABASES FROM `{effective_catalog}`"
|
||||
try:
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid catalog name rejected: {e}")
|
||||
return []
|
||||
safe_catalog = quote_identifier(effective_catalog, "catalog name")
|
||||
query = f"SHOW DATABASES FROM {safe_catalog}"
|
||||
else:
|
||||
query = "SHOW DATABASES"
|
||||
|
||||
@@ -1333,10 +1337,23 @@ class MetadataExtractor:
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
effective_db = db_name or self.db_name
|
||||
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
if effective_db:
|
||||
validate_identifier(effective_db, "database name")
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
query = f"SHOW TABLES FROM `{effective_catalog}`.`{effective_db}`"
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
|
||||
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
safe_catalog = quote_identifier(effective_catalog, "catalog name")
|
||||
query = f"SHOW TABLES FROM {safe_catalog}.{safe_db}"
|
||||
else:
|
||||
query = f"SHOW TABLES FROM `{effective_db}`"
|
||||
query = f"SHOW TABLES FROM {safe_db}"
|
||||
|
||||
result = await self._execute_query_async(query, effective_db)
|
||||
|
||||
@@ -1389,6 +1406,162 @@ class MetadataExtractor:
|
||||
logger.error(f"Failed to get catalog list: {e}")
|
||||
return []
|
||||
|
||||
async def get_table_comment_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> str:
|
||||
"""Async version: get the comment for a table."""
|
||||
try:
|
||||
effective_db = db_name or self.db_name
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if effective_db:
|
||||
validate_identifier(effective_db, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return ""
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
TABLE_COMMENT
|
||||
FROM
|
||||
information_schema.tables
|
||||
WHERE
|
||||
TABLE_SCHEMA = '{effective_db}'
|
||||
AND TABLE_NAME = '{table_name}'
|
||||
"""
|
||||
|
||||
result = await self._execute_query_with_catalog_async(query, effective_db, effective_catalog)
|
||||
if not result or not result[0]:
|
||||
return ""
|
||||
return result[0].get("TABLE_COMMENT", "") or ""
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table comment asynchronously: {e}")
|
||||
return ""
|
||||
|
||||
async def get_column_comments_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, str]:
|
||||
"""Async version: get comments for all columns in a table."""
|
||||
try:
|
||||
effective_db = db_name or self.db_name
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if effective_db:
|
||||
validate_identifier(effective_db, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return {}
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
COLUMN_NAME,
|
||||
COLUMN_COMMENT
|
||||
FROM
|
||||
information_schema.columns
|
||||
WHERE
|
||||
TABLE_SCHEMA = '{effective_db}'
|
||||
AND TABLE_NAME = '{table_name}'
|
||||
ORDER BY
|
||||
ORDINAL_POSITION
|
||||
"""
|
||||
|
||||
rows = await self._execute_query_with_catalog_async(query, effective_db, effective_catalog)
|
||||
comments: Dict[str, str] = {}
|
||||
for col in rows or []:
|
||||
name = col.get("COLUMN_NAME", "")
|
||||
if name:
|
||||
comments[name] = col.get("COLUMN_COMMENT", "") or ""
|
||||
return comments
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get column comments asynchronously: {e}")
|
||||
return {}
|
||||
|
||||
async def get_table_indexes_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> List[Dict[str, Any]]:
|
||||
"""Async version: get index information for a table."""
|
||||
try:
|
||||
effective_db = db_name or self.db_name
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if effective_db:
|
||||
validate_identifier(effective_db, "database name")
|
||||
if effective_catalog:
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
# Build query with catalog prefix if specified (using safe identifiers)
|
||||
safe_table = quote_identifier(table_name, "table name")
|
||||
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
|
||||
|
||||
if effective_catalog:
|
||||
safe_catalog = quote_identifier(effective_catalog, "catalog name")
|
||||
query = f"SHOW INDEX FROM {safe_catalog}.{safe_db}.{safe_table}"
|
||||
logger.info(f"Using three-part naming for async index query: {query}")
|
||||
else:
|
||||
query = f"SHOW INDEX FROM {safe_db}.{safe_table}"
|
||||
|
||||
rows = await self._execute_query_async(query, effective_db)
|
||||
indexes: List[Dict[str, Any]] = []
|
||||
if rows:
|
||||
# Group by Key_name
|
||||
current_index: Dict[str, Any] | None = None
|
||||
for r in rows:
|
||||
try:
|
||||
index_name = r.get('Key_name')
|
||||
column_name = r.get('Column_name')
|
||||
if current_index is None or current_index.get('name') != index_name:
|
||||
if current_index is not None:
|
||||
indexes.append(current_index)
|
||||
current_index = {
|
||||
'name': index_name,
|
||||
'columns': [column_name] if column_name else [],
|
||||
'unique': r.get('Non_unique', 1) == 0,
|
||||
'type': r.get('Index_type', '')
|
||||
}
|
||||
else:
|
||||
if column_name:
|
||||
current_index['columns'].append(column_name)
|
||||
except Exception as row_error:
|
||||
logger.warning(f"Failed to process async index row data: {row_error}")
|
||||
continue
|
||||
if current_index is not None:
|
||||
indexes.append(current_index)
|
||||
|
||||
return indexes
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting index information asynchronously: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_recent_audit_logs_async(self, days: int = 7, limit: int = 100):
|
||||
"""Async version: get recent audit logs and return a pandas DataFrame."""
|
||||
try:
|
||||
start_date = (datetime.now() - timedelta(days=days)).strftime('%Y-%m-%d')
|
||||
query = f"""
|
||||
SELECT client_ip, user, db, time, stmt_id, stmt, state, error_code
|
||||
FROM `__internal_schema`.`audit_log`
|
||||
WHERE `time` >= '{start_date}'
|
||||
AND state = 'EOF' AND error_code = 0
|
||||
AND `stmt` NOT LIKE 'SHOW%'
|
||||
AND `stmt` NOT LIKE 'DESC%'
|
||||
AND `stmt` NOT LIKE 'EXPLAIN%'
|
||||
AND `stmt` NOT LIKE 'SELECT 1%'
|
||||
ORDER BY time DESC
|
||||
LIMIT {limit}
|
||||
"""
|
||||
rows = await self._execute_query_async(query)
|
||||
import pandas as pd
|
||||
return pd.DataFrame(rows or [])
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting audit logs asynchronously: {str(e)}")
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
|
||||
# ==================== Business layer methods (original metadata_tools.py functionality) ====================
|
||||
|
||||
def _format_response(self, success: bool, result: Any = None, error: str = None, message: str = "") -> Dict[str, Any]:
|
||||
@@ -1417,6 +1590,9 @@ class MetadataExtractor:
|
||||
"""
|
||||
Execute SQL query and return results, supports catalog federation queries
|
||||
Unified interface for MCP tools
|
||||
|
||||
FIX for Issue #62 Bug 1: Now retrieves auth_context from context variable to support token-bound database configuration
|
||||
FIX for Issue #62 Bug 3: Now uses db_name and catalog_name parameters to switch database context
|
||||
"""
|
||||
logger.info(f"Executing SQL query: {sql}, DB: {db_name}, Catalog: {catalog_name}, MaxRows: {max_rows}, Timeout: {timeout}")
|
||||
|
||||
@@ -1424,15 +1600,86 @@ class MetadataExtractor:
|
||||
if not sql:
|
||||
return self._format_response(success=False, error="No SQL statement provided", message="Please provide SQL statement to execute")
|
||||
|
||||
# FIX for Issue #62 Bug 3: Build context switching SQL if db_name or catalog_name is specified
|
||||
# SECURITY FIX: Validate catalog_name and db_name to prevent SQL injection
|
||||
final_sql = sql
|
||||
if catalog_name or db_name:
|
||||
context_statements = []
|
||||
|
||||
# Validate and sanitize catalog_name
|
||||
if catalog_name:
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid catalog name rejected: {e}")
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid catalog name: {catalog_name}",
|
||||
message="Catalog name contains invalid characters"
|
||||
)
|
||||
# Use quote_identifier to safely escape the catalog name
|
||||
safe_catalog = quote_identifier(catalog_name, "catalog name")
|
||||
context_statements.append(f"USE CATALOG {safe_catalog}")
|
||||
logger.debug(f"Switching to catalog: {catalog_name}")
|
||||
|
||||
# Validate and sanitize db_name
|
||||
if db_name:
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid database name rejected: {e}")
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid database name: {db_name}",
|
||||
message="Database name contains invalid characters"
|
||||
)
|
||||
# Use quote_identifier to safely escape the database name
|
||||
safe_db = quote_identifier(db_name, "database name")
|
||||
if catalog_name:
|
||||
safe_catalog = quote_identifier(catalog_name, "catalog name")
|
||||
context_statements.append(f"USE {safe_catalog}.{safe_db}")
|
||||
else:
|
||||
context_statements.append(f"USE {safe_db}")
|
||||
logger.debug(f"Switching to database: {db_name}")
|
||||
|
||||
# Combine context switching with original SQL
|
||||
if context_statements:
|
||||
# Remove trailing semicolon from context statements if present
|
||||
context_sql = "; ".join(context_statements)
|
||||
# Ensure original SQL doesn't start with semicolon
|
||||
sql_clean = sql.lstrip(";").strip()
|
||||
final_sql = f"{context_sql}; {sql_clean}"
|
||||
logger.debug(f"Modified SQL with context switching: {final_sql[:200]}...")
|
||||
|
||||
# FIX: Try to get auth_context from context variable (set by HTTP middleware)
|
||||
# This allows token-bound database configuration to work
|
||||
# CRITICAL: Use the global ContextVar from security.py to ensure same instance is used everywhere
|
||||
auth_context = None
|
||||
try:
|
||||
from .security import mcp_auth_context_var
|
||||
|
||||
# Get auth_context from the global context variable
|
||||
# This will be set by the HTTP request handler in main.py
|
||||
auth_context = mcp_auth_context_var.get()
|
||||
|
||||
if auth_context:
|
||||
logger.debug(f"Retrieved auth_context from context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
|
||||
else:
|
||||
logger.debug("No auth_context found in context variable, using default")
|
||||
except Exception as ctx_error:
|
||||
logger.debug(f"Could not retrieve auth_context from context variable: {ctx_error}")
|
||||
auth_context = None
|
||||
|
||||
# Import query executor
|
||||
from .query_executor import execute_sql_query
|
||||
|
||||
# Call execute_sql_query to execute query
|
||||
# Call execute_sql_query to execute query with auth_context
|
||||
exec_result = await execute_sql_query(
|
||||
sql=sql,
|
||||
sql=final_sql, # Use modified SQL with context switching
|
||||
connection_manager=self.connection_manager,
|
||||
limit=max_rows,
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
auth_context=auth_context # FIX: Pass auth_context with token
|
||||
)
|
||||
|
||||
return exec_result
|
||||
@@ -1453,6 +1700,36 @@ class MetadataExtractor:
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
# SECURITY: Validate identifiers before processing
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid table name: {table_name}",
|
||||
message="Table name contains invalid characters"
|
||||
)
|
||||
|
||||
if db_name:
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid database name: {db_name}",
|
||||
message="Database name contains invalid characters"
|
||||
)
|
||||
|
||||
if catalog_name and catalog_name != "internal":
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid catalog name: {catalog_name}",
|
||||
message="Catalog name contains invalid characters"
|
||||
)
|
||||
|
||||
try:
|
||||
schema = await self.get_table_schema_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
@@ -1476,6 +1753,27 @@ class MetadataExtractor:
|
||||
"""Get list of all table names in specified database - MCP interface"""
|
||||
logger.info(f"Getting database table list: DB: {db_name}, Catalog: {catalog_name}")
|
||||
|
||||
# SECURITY: Validate identifiers
|
||||
if db_name:
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid database name: {db_name}",
|
||||
message="Database name contains invalid characters"
|
||||
)
|
||||
|
||||
if catalog_name and catalog_name != "internal":
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid catalog name: {catalog_name}",
|
||||
message="Catalog name contains invalid characters"
|
||||
)
|
||||
|
||||
try:
|
||||
tables = await self.get_database_tables_async(db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=tables)
|
||||
@@ -1506,8 +1804,38 @@ class MetadataExtractor:
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
# SECURITY: Validate identifiers
|
||||
try:
|
||||
comment = self.get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
validate_identifier(table_name, "table name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid table name: {table_name}",
|
||||
message="Table name contains invalid characters"
|
||||
)
|
||||
|
||||
if db_name:
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid database name: {db_name}",
|
||||
message="Database name contains invalid characters"
|
||||
)
|
||||
|
||||
if catalog_name and catalog_name != "internal":
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid catalog name: {catalog_name}",
|
||||
message="Catalog name contains invalid characters"
|
||||
)
|
||||
|
||||
try:
|
||||
comment = await self.get_table_comment_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=comment)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table comment: {str(e)}", exc_info=True)
|
||||
@@ -1525,8 +1853,38 @@ class MetadataExtractor:
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
# SECURITY: Validate identifiers
|
||||
try:
|
||||
comments = self.get_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
validate_identifier(table_name, "table name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid table name: {table_name}",
|
||||
message="Table name contains invalid characters"
|
||||
)
|
||||
|
||||
if db_name:
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid database name: {db_name}",
|
||||
message="Database name contains invalid characters"
|
||||
)
|
||||
|
||||
if catalog_name and catalog_name != "internal":
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid catalog name: {catalog_name}",
|
||||
message="Catalog name contains invalid characters"
|
||||
)
|
||||
|
||||
try:
|
||||
comments = await self.get_column_comments_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=comments)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table column comments: {str(e)}", exc_info=True)
|
||||
@@ -1544,8 +1902,38 @@ class MetadataExtractor:
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
# SECURITY: Validate identifiers
|
||||
try:
|
||||
indexes = self.get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
validate_identifier(table_name, "table name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid table name: {table_name}",
|
||||
message="Table name contains invalid characters"
|
||||
)
|
||||
|
||||
if db_name:
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid database name: {db_name}",
|
||||
message="Database name contains invalid characters"
|
||||
)
|
||||
|
||||
if catalog_name and catalog_name != "internal":
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid catalog name: {catalog_name}",
|
||||
message="Catalog name contains invalid characters"
|
||||
)
|
||||
|
||||
try:
|
||||
indexes = await self.get_table_indexes_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=indexes)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table indexes: {str(e)}", exc_info=True)
|
||||
@@ -1569,7 +1957,7 @@ class MetadataExtractor:
|
||||
logger.info(f"Getting audit logs: Days: {days}, Limit: {limit}")
|
||||
|
||||
try:
|
||||
logs_df = self.get_recent_audit_logs(days=days, limit=limit)
|
||||
logs_df = await self.get_recent_audit_logs_async(days=days, limit=limit)
|
||||
|
||||
# Convert DataFrame to JSON format
|
||||
if hasattr(logs_df, 'to_dict'):
|
||||
|
||||
@@ -22,15 +22,23 @@ Implements enterprise-level authentication, authorization, SQL security validati
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import sqlparse
|
||||
from sqlparse.sql import Statement
|
||||
from sqlparse.tokens import Keyword, Name
|
||||
|
||||
from .logger import get_logger
|
||||
from .config import DatabaseConfig
|
||||
|
||||
# Global ContextVar for auth_context - must be a single instance shared across all modules
|
||||
# This allows token-bound database configuration to work correctly in concurrent requests
|
||||
mcp_auth_context_var: ContextVar['AuthContext'] = ContextVar('mcp_auth_context', default=None)
|
||||
|
||||
|
||||
class SecurityLevel(Enum):
|
||||
"""Security level enumeration"""
|
||||
@@ -43,15 +51,18 @@ class SecurityLevel(Enum):
|
||||
|
||||
@dataclass
|
||||
class AuthContext:
|
||||
"""Authentication context"""
|
||||
"""Authentication context for audit and session tracking"""
|
||||
|
||||
user_id: str
|
||||
roles: list[str]
|
||||
permissions: list[str]
|
||||
session_id: str
|
||||
login_time: datetime | None = None
|
||||
token_id: str = "" # Token identifier for audit logging
|
||||
user_id: str = "" # User identifier
|
||||
roles: list[str] = field(default_factory=list) # User roles
|
||||
permissions: list[str] = field(default_factory=list) # User permissions
|
||||
security_level: 'SecurityLevel' = field(default_factory=lambda: SecurityLevel.INTERNAL) # Security level
|
||||
client_ip: str = "unknown" # Client IP address
|
||||
session_id: str = "" # Session identifier
|
||||
login_time: datetime = field(default_factory=datetime.utcnow)
|
||||
last_activity: datetime | None = None
|
||||
security_level: SecurityLevel = SecurityLevel.INTERNAL
|
||||
token: str = "" # Raw token for token-bound database configuration
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -84,12 +95,13 @@ class DorisSecurityManager:
|
||||
Provides complete security control functionality, including authentication, authorization, SQL security validation and data masking
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, connection_manager=None):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger = get_logger(__name__)
|
||||
self.connection_manager = connection_manager
|
||||
|
||||
# Initialize security components
|
||||
self.auth_provider = AuthenticationProvider(config)
|
||||
self.auth_provider = AuthenticationProvider(config, self)
|
||||
self.authz_provider = AuthorizationProvider(config)
|
||||
self.sql_validator = SQLSecurityValidator(config)
|
||||
self.masking_processor = DataMaskingProcessor(config)
|
||||
@@ -99,6 +111,36 @@ class DorisSecurityManager:
|
||||
self.sensitive_tables = self._load_sensitive_tables()
|
||||
self.masking_rules = self._load_masking_rules()
|
||||
|
||||
# Track initialization state
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize security manager components"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
# Initialize authentication provider (for JWT setup)
|
||||
await self.auth_provider.initialize()
|
||||
|
||||
self._initialized = True
|
||||
self.logger.info("DorisSecurityManager initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize DorisSecurityManager: {e}")
|
||||
raise
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown security manager components"""
|
||||
try:
|
||||
await self.auth_provider.shutdown()
|
||||
self._initialized = False
|
||||
self.logger.info("DorisSecurityManager shutdown completed")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during DorisSecurityManager shutdown: {e}")
|
||||
raise
|
||||
|
||||
def _load_blocked_keywords(self) -> set[str]:
|
||||
"""Load blocked SQL keywords from configuration"""
|
||||
# Load keywords from configuration, unified source of truth
|
||||
@@ -182,8 +224,59 @@ class DorisSecurityManager:
|
||||
return default_rules
|
||||
|
||||
async def authenticate_request(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Validate request authentication information"""
|
||||
return await self.auth_provider.authenticate(auth_info)
|
||||
"""Validate request authentication information
|
||||
|
||||
Tries authentication methods in order: Token -> JWT -> OAuth
|
||||
Any one method succeeding allows access
|
||||
If all methods are disabled, returns anonymous context
|
||||
"""
|
||||
# Check if any authentication method is enabled
|
||||
if not (self.config.security.enable_token_auth or
|
||||
self.config.security.enable_jwt_auth or
|
||||
self.config.security.enable_oauth_auth):
|
||||
self.logger.debug("All authentication methods are disabled")
|
||||
# Return anonymous context when no authentication is enabled
|
||||
return AuthContext(
|
||||
token_id="anonymous",
|
||||
user_id="anonymous",
|
||||
roles=["anonymous"],
|
||||
permissions=["read"],
|
||||
security_level=SecurityLevel.PUBLIC,
|
||||
client_ip=auth_info.get("client_ip", "unknown"),
|
||||
session_id="anonymous_session"
|
||||
)
|
||||
|
||||
# Try authentication methods in order of preference
|
||||
last_error = None
|
||||
|
||||
# 1. Try Token authentication first (most common)
|
||||
if self.config.security.enable_token_auth:
|
||||
try:
|
||||
return await self.auth_provider.authenticate_token(auth_info)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Token authentication failed: {e}")
|
||||
last_error = e
|
||||
|
||||
# 2. Try JWT authentication
|
||||
if self.config.security.enable_jwt_auth:
|
||||
try:
|
||||
return await self.auth_provider.authenticate_jwt(auth_info)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"JWT authentication failed: {e}")
|
||||
last_error = e
|
||||
|
||||
# 3. Try OAuth authentication
|
||||
if self.config.security.enable_oauth_auth:
|
||||
try:
|
||||
return await self.auth_provider.authenticate_oauth(auth_info)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"OAuth authentication failed: {e}")
|
||||
last_error = e
|
||||
|
||||
# All enabled authentication methods failed
|
||||
error_message = f"Authentication failed: {str(last_error)}" if last_error else "No authentication method succeeded"
|
||||
self.logger.warning(f"Authentication failed for client {auth_info.get('client_ip', 'unknown')}: {error_message}")
|
||||
raise ValueError(error_message)
|
||||
|
||||
async def authorize_resource_access(
|
||||
self, auth_context: AuthContext, resource_uri: str
|
||||
@@ -205,44 +298,363 @@ class DorisSecurityManager:
|
||||
"""Apply data masking processing"""
|
||||
return await self.masking_processor.process(data, auth_context)
|
||||
|
||||
# OAuth-specific methods
|
||||
def get_oauth_authorization_url(self) -> tuple[str, str]:
|
||||
"""Get OAuth authorization URL
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, state)
|
||||
"""
|
||||
if not self.auth_provider.oauth_provider:
|
||||
raise ValueError("OAuth is not enabled")
|
||||
return self.auth_provider.oauth_provider.get_authorization_url()
|
||||
|
||||
async def handle_oauth_callback(self, code: str, state: str) -> AuthContext:
|
||||
"""Handle OAuth callback
|
||||
|
||||
Args:
|
||||
code: Authorization code from OAuth provider
|
||||
state: State parameter for CSRF protection
|
||||
|
||||
Returns:
|
||||
AuthContext for authenticated user
|
||||
"""
|
||||
if not self.auth_provider.oauth_provider:
|
||||
raise ValueError("OAuth is not enabled")
|
||||
return await self.auth_provider.oauth_provider.handle_callback(code, state)
|
||||
|
||||
def get_oauth_provider_info(self) -> dict[str, Any]:
|
||||
"""Get OAuth provider information
|
||||
|
||||
Returns:
|
||||
OAuth provider information
|
||||
"""
|
||||
if not self.auth_provider.oauth_provider:
|
||||
return {"enabled": False}
|
||||
return self.auth_provider.oauth_provider.get_provider_info()
|
||||
|
||||
# Token management methods
|
||||
async def create_token(
|
||||
self,
|
||||
token_id: str,
|
||||
expires_hours: Optional[int] = None,
|
||||
description: str = "",
|
||||
custom_token: Optional[str] = None,
|
||||
database_config: Optional[DatabaseConfig] = None
|
||||
) -> str:
|
||||
"""Create a new API access token
|
||||
|
||||
Args:
|
||||
token_id: Unique token identifier for audit and management
|
||||
expires_hours: Token expiration in hours (None for no expiration)
|
||||
description: Token description for management purposes
|
||||
custom_token: Custom token string (if None, generates random token)
|
||||
database_config: Optional database configuration for this token
|
||||
|
||||
Returns:
|
||||
Generated token string
|
||||
"""
|
||||
if not self.auth_provider.token_manager:
|
||||
raise ValueError("Token manager not initialized")
|
||||
|
||||
return await self.auth_provider.token_manager.create_token(
|
||||
token_id=token_id,
|
||||
expires_hours=expires_hours,
|
||||
description=description,
|
||||
custom_token=custom_token,
|
||||
database_config=database_config
|
||||
)
|
||||
|
||||
async def revoke_token(self, token_id: str) -> bool:
|
||||
"""Revoke a token by token ID
|
||||
|
||||
Args:
|
||||
token_id: Token ID to revoke
|
||||
|
||||
Returns:
|
||||
True if token was revoked successfully
|
||||
"""
|
||||
if not self.auth_provider.token_manager:
|
||||
raise ValueError("Token manager not initialized")
|
||||
|
||||
return await self.auth_provider.token_manager.revoke_token(token_id)
|
||||
|
||||
async def list_tokens(self) -> list[dict[str, Any]]:
|
||||
"""List all tokens (without sensitive data)
|
||||
|
||||
Returns:
|
||||
List of token information
|
||||
"""
|
||||
if not self.auth_provider.token_manager:
|
||||
raise ValueError("Token manager not initialized")
|
||||
|
||||
return await self.auth_provider.token_manager.list_tokens()
|
||||
|
||||
async def cleanup_expired_tokens(self) -> int:
|
||||
"""Remove expired tokens and return count
|
||||
|
||||
Returns:
|
||||
Number of expired tokens removed
|
||||
"""
|
||||
if not self.auth_provider.token_manager:
|
||||
return 0
|
||||
|
||||
return await self.auth_provider.token_manager.cleanup_expired_tokens()
|
||||
|
||||
def get_token_stats(self) -> dict[str, Any]:
|
||||
"""Get token statistics
|
||||
|
||||
Returns:
|
||||
Token statistics dictionary
|
||||
"""
|
||||
if not self.auth_provider.token_manager:
|
||||
return {"error": "Token manager not initialized"}
|
||||
|
||||
return self.auth_provider.token_manager.get_token_stats()
|
||||
|
||||
async def _validate_token_database_config(self, token: str, token_info) -> None:
|
||||
"""Validate database configuration for token immediately during authentication
|
||||
|
||||
This ensures database connectivity issues are caught at authentication time,
|
||||
not during query execution, providing better user experience.
|
||||
|
||||
Args:
|
||||
token: Raw authentication token
|
||||
token_info: TokenInfo object from token validation
|
||||
|
||||
Raises:
|
||||
ValueError: If database configuration is invalid or connection fails
|
||||
"""
|
||||
try:
|
||||
if not self.connection_manager:
|
||||
self.logger.warning("Connection manager not available for immediate database validation")
|
||||
return
|
||||
|
||||
# Configure and test database connection for this token
|
||||
success, config_source = await self.connection_manager.configure_for_token(token)
|
||||
|
||||
if success:
|
||||
self.logger.info(f"Database configuration validated successfully for token {token_info.token_id} (source: {config_source})")
|
||||
else:
|
||||
raise ValueError("Database configuration validation failed")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Database configuration validation failed for token {token_info.token_id}: {str(e)}"
|
||||
self.logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
class AuthenticationProvider:
|
||||
"""Authentication provider"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, security_manager=None):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger = get_logger(__name__)
|
||||
self.session_cache = {}
|
||||
self.jwt_manager = None
|
||||
self.oauth_provider = None
|
||||
self.token_manager = None
|
||||
self.security_manager = security_manager
|
||||
|
||||
async def authenticate(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Perform identity authentication"""
|
||||
auth_type = auth_info.get("type", "token")
|
||||
# Initialize authentication providers based on individual switches
|
||||
auth_methods_enabled = []
|
||||
|
||||
if auth_type == "token":
|
||||
return await self._authenticate_token(auth_info)
|
||||
elif auth_type == "basic":
|
||||
return await self._authenticate_basic(auth_info)
|
||||
# Initialize Token manager if enabled
|
||||
if config.security.enable_token_auth:
|
||||
self._initialize_token_manager()
|
||||
auth_methods_enabled.append("Token")
|
||||
|
||||
# Initialize JWT manager if enabled
|
||||
if config.security.enable_jwt_auth:
|
||||
self._initialize_jwt_manager()
|
||||
auth_methods_enabled.append("JWT")
|
||||
|
||||
# Initialize OAuth provider if enabled
|
||||
if config.security.enable_oauth_auth or (hasattr(config.security, 'oauth_enabled') and config.security.oauth_enabled):
|
||||
self._initialize_oauth_provider()
|
||||
auth_methods_enabled.append("OAuth")
|
||||
|
||||
if auth_methods_enabled:
|
||||
self.logger.info(f"Authentication enabled with methods: {', '.join(auth_methods_enabled)}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported authentication type: {auth_type}")
|
||||
self.logger.info("All authentication methods are disabled - anonymous access allowed")
|
||||
|
||||
def _initialize_jwt_manager(self):
|
||||
"""Initialize JWT manager"""
|
||||
try:
|
||||
from ..auth.jwt_manager import JWTManager
|
||||
self.jwt_manager = JWTManager(self.config)
|
||||
self.logger.info("JWT manager initialized")
|
||||
except ImportError as e:
|
||||
self.logger.error(f"Failed to import JWT manager: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize JWT manager: {e}")
|
||||
raise
|
||||
|
||||
def _initialize_token_manager(self):
|
||||
"""Initialize Token manager"""
|
||||
try:
|
||||
from ..auth.token_manager import TokenManager
|
||||
self.token_manager = TokenManager(self.config)
|
||||
self.logger.info("Token manager initialized")
|
||||
except ImportError as e:
|
||||
self.logger.error(f"Failed to import Token manager: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize Token manager: {e}")
|
||||
raise
|
||||
|
||||
def _initialize_oauth_provider(self):
|
||||
"""Initialize OAuth provider"""
|
||||
try:
|
||||
from ..auth.oauth_provider import OAuthAuthenticationProvider
|
||||
self.oauth_provider = OAuthAuthenticationProvider(self.config)
|
||||
self.logger.info("OAuth provider initialized")
|
||||
except ImportError as e:
|
||||
self.logger.error(f"Failed to import OAuth provider: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize OAuth provider: {e}")
|
||||
raise
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize authentication provider asynchronously"""
|
||||
if self.jwt_manager:
|
||||
success = await self.jwt_manager.initialize()
|
||||
if not success:
|
||||
raise RuntimeError("Failed to initialize JWT manager")
|
||||
self.logger.info("JWT authentication provider initialized successfully")
|
||||
|
||||
if self.token_manager:
|
||||
# Token manager doesn't need async initialization, just log success
|
||||
self.logger.info("Token authentication provider initialized successfully")
|
||||
|
||||
if self.oauth_provider:
|
||||
success = await self.oauth_provider.initialize()
|
||||
if not success:
|
||||
raise RuntimeError("Failed to initialize OAuth provider")
|
||||
self.logger.info("OAuth authentication provider initialized successfully")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown authentication provider"""
|
||||
if self.jwt_manager:
|
||||
await self.jwt_manager.shutdown()
|
||||
self.logger.info("JWT authentication provider shutdown completed")
|
||||
|
||||
if self.token_manager:
|
||||
# Token manager doesn't need async shutdown, just log
|
||||
self.logger.info("Token authentication provider shutdown completed")
|
||||
|
||||
if self.oauth_provider:
|
||||
await self.oauth_provider.shutdown()
|
||||
self.logger.info("OAuth authentication provider shutdown completed")
|
||||
|
||||
async def authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Perform token authentication"""
|
||||
if not self.config.security.enable_token_auth:
|
||||
raise ValueError("Token authentication is not enabled")
|
||||
return await self._authenticate_token(auth_info)
|
||||
|
||||
async def authenticate_jwt(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Perform JWT authentication"""
|
||||
if not self.config.security.enable_jwt_auth:
|
||||
raise ValueError("JWT authentication is not enabled")
|
||||
return await self._authenticate_jwt(auth_info)
|
||||
|
||||
async def authenticate_oauth(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Perform OAuth authentication"""
|
||||
if not self.config.security.enable_oauth_auth:
|
||||
raise ValueError("OAuth authentication is not enabled")
|
||||
return await self._authenticate_oauth(auth_info)
|
||||
|
||||
async def _authenticate_jwt(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""JWT authentication"""
|
||||
if not self.jwt_manager:
|
||||
raise ValueError("JWT manager not initialized")
|
||||
|
||||
token = auth_info.get("token")
|
||||
if not token:
|
||||
# Try to extract from Authorization header
|
||||
authorization = auth_info.get("authorization")
|
||||
if authorization and authorization.startswith('Bearer '):
|
||||
token = authorization[7:]
|
||||
|
||||
if not token:
|
||||
raise ValueError("Missing JWT token")
|
||||
|
||||
try:
|
||||
# Use JWT middleware for authentication
|
||||
from ..auth.auth_middleware import AuthMiddleware
|
||||
middleware = AuthMiddleware(self.jwt_manager)
|
||||
return await middleware.authenticate_request(auth_info)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"JWT authentication failed: {e}")
|
||||
raise ValueError(f"JWT authentication failed: {str(e)}")
|
||||
|
||||
async def _authenticate_oauth(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""OAuth authentication"""
|
||||
if not self.oauth_provider:
|
||||
raise ValueError("OAuth provider not initialized")
|
||||
|
||||
# Handle different OAuth authentication scenarios
|
||||
if "access_token" in auth_info:
|
||||
# Direct OAuth access token authentication
|
||||
return await self.oauth_provider.authenticate_with_token(auth_info["access_token"])
|
||||
elif "code" in auth_info and "state" in auth_info:
|
||||
# OAuth callback authentication
|
||||
return await self.oauth_provider.handle_callback(auth_info["code"], auth_info["state"])
|
||||
else:
|
||||
raise ValueError("OAuth authentication requires either access_token or code+state")
|
||||
|
||||
async def _authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Token authentication"""
|
||||
if not self.token_manager:
|
||||
raise ValueError("Token manager not initialized")
|
||||
|
||||
token = auth_info.get("token")
|
||||
if not token:
|
||||
# Try to extract from Authorization header
|
||||
authorization = auth_info.get("authorization")
|
||||
if authorization and authorization.startswith('Bearer '):
|
||||
token = authorization[7:]
|
||||
elif authorization and authorization.startswith('Token '):
|
||||
token = authorization[6:]
|
||||
|
||||
if not token:
|
||||
raise ValueError("Missing authentication token")
|
||||
|
||||
# Validate token (simplified implementation, should validate JWT or query authentication service in practice)
|
||||
user_info = await self._validate_token(token)
|
||||
try:
|
||||
# Validate token using TokenManager
|
||||
validation_result = await self.token_manager.validate_token(token)
|
||||
|
||||
if not validation_result.is_valid:
|
||||
raise ValueError(f"Token validation failed: {validation_result.error_message}")
|
||||
|
||||
token_info = validation_result.token_info
|
||||
|
||||
# Immediately validate database configuration for this token
|
||||
if self.security_manager:
|
||||
await self.security_manager._validate_token_database_config(token, token_info)
|
||||
|
||||
return AuthContext(
|
||||
user_id=user_info["user_id"],
|
||||
roles=user_info["roles"],
|
||||
permissions=user_info["permissions"],
|
||||
session_id=auth_info.get("session_id", "default"),
|
||||
token_id=token_info.token_id,
|
||||
user_id=token_info.token_id, # Use token_id as user_id for token auth
|
||||
roles=["token_user"], # Default role for token users
|
||||
permissions=["read", "write"], # Default permissions for token users
|
||||
security_level=SecurityLevel.INTERNAL,
|
||||
client_ip=auth_info.get("client_ip", "unknown"),
|
||||
session_id=auth_info.get("session_id", f"session_{token_info.token_id}"),
|
||||
login_time=datetime.utcnow(),
|
||||
security_level=SecurityLevel(user_info.get("security_level", "internal")),
|
||||
last_activity=token_info.last_used,
|
||||
token=token # Store raw token for token-bound database configuration
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Token authentication failed: {e}")
|
||||
raise ValueError(f"Token authentication failed: {str(e)}")
|
||||
|
||||
async def _authenticate_basic(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Basic authentication (username password)"""
|
||||
username = auth_info.get("username")
|
||||
@@ -321,7 +733,7 @@ class AuthorizationProvider:
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger = get_logger(__name__)
|
||||
self.permission_cache = {}
|
||||
|
||||
# Load sensitive tables configuration
|
||||
@@ -464,7 +876,7 @@ class SQLSecurityValidator:
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Handle DorisConfig object or dictionary configuration
|
||||
if hasattr(config, 'get'):
|
||||
@@ -496,27 +908,47 @@ class SQLSecurityValidator:
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
try:
|
||||
# Parse SQL statement
|
||||
parsed = sqlparse.parse(sql)[0]
|
||||
# SECURITY FIX: Parse ALL SQL statements, not just the first one
|
||||
# This prevents bypassing security checks by injecting additional statements
|
||||
all_statements = sqlparse.parse(sql)
|
||||
|
||||
if not all_statements:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Empty or invalid SQL statement",
|
||||
risk_level="medium"
|
||||
)
|
||||
|
||||
# SECURITY FIX: Validate each statement individually
|
||||
for idx, parsed in enumerate(all_statements):
|
||||
# Skip empty statements (e.g., from trailing semicolons)
|
||||
if not parsed.tokens or str(parsed).strip() == '':
|
||||
continue
|
||||
|
||||
self.logger.debug(f"Validating SQL statement {idx + 1}/{len(all_statements)}: {str(parsed)[:100]}...")
|
||||
|
||||
# Check blocked operations first (more specific)
|
||||
keyword_result = await self._check_blocked_keywords(parsed)
|
||||
if not keyword_result.is_valid:
|
||||
keyword_result.error_message = f"Statement {idx + 1}: {keyword_result.error_message}"
|
||||
return keyword_result
|
||||
|
||||
# Check SQL injection risks
|
||||
injection_result = await self._check_sql_injection(sql, parsed)
|
||||
if not injection_result.is_valid:
|
||||
injection_result.error_message = f"Statement {idx + 1}: {injection_result.error_message}"
|
||||
return injection_result
|
||||
|
||||
# Check query complexity
|
||||
complexity_result = await self._check_query_complexity(parsed)
|
||||
if not complexity_result.is_valid:
|
||||
complexity_result.error_message = f"Statement {idx + 1}: {complexity_result.error_message}"
|
||||
return complexity_result
|
||||
|
||||
# Check table access permissions
|
||||
table_result = await self._check_table_access(parsed, auth_context)
|
||||
if not table_result.is_valid:
|
||||
table_result.error_message = f"Statement {idx + 1}: {table_result.error_message}"
|
||||
return table_result
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
@@ -532,28 +964,69 @@ class SQLSecurityValidator:
|
||||
async def _check_sql_injection(
|
||||
self, sql: str, parsed: Statement
|
||||
) -> ValidationResult:
|
||||
"""Check SQL injection risks"""
|
||||
# Check common SQL injection patterns
|
||||
"""Check SQL injection risks with improved pattern detection
|
||||
|
||||
FIX for Issue #62 Bug 2: Improved patterns to reduce false positives
|
||||
Now better distinguishes between legitimate SQL (like BETWEEN...AND) and injection attempts
|
||||
"""
|
||||
# Improved injection patterns that are more specific and less prone to false positives
|
||||
injection_patterns = [
|
||||
r"(\s|^)(union|select|insert|update|delete|drop|create|alter)\s+.*\s+(union|select|insert|update|delete|drop|create|alter)",
|
||||
r"(\s|^)(or|and)\s+\d+\s*=\s*\d+",
|
||||
r"(\s|^)(or|and)\s+['\"].*['\"]",
|
||||
r";\s*(drop|delete|truncate|alter|create)",
|
||||
r"(exec|execute|sp_|xp_)",
|
||||
r"(script|javascript|vbscript)",
|
||||
r"(char|ascii|substring|concat)\s*\(",
|
||||
# Stacked queries with dangerous operations (true injection risk)
|
||||
r";\s*(DROP|DELETE|TRUNCATE|ALTER|CREATE|INSERT|UPDATE)\s+",
|
||||
|
||||
# UNION-based injection (but allow legitimate UNION queries)
|
||||
# Only flag if UNION is followed by suspicious patterns like SELECT with WHERE 1=1
|
||||
r"UNION\s+(ALL\s+)?SELECT\s+.*\s+(WHERE|AND|OR)\s+\d+\s*=\s*\d+",
|
||||
|
||||
# Boolean-based blind injection with comments (true injection pattern)
|
||||
r"(WHERE|AND|OR)\s+\d+\s*=\s*\d+\s*(--|#|/\*)",
|
||||
|
||||
# Quote-based injection attempts (but not in legitimate strings)
|
||||
r"(WHERE|AND|OR)\s+(['\"])[^\2]*\2\s*=\s*\2[^\2]*\2",
|
||||
|
||||
# Time-based blind injection
|
||||
r"(SLEEP|WAITFOR|BENCHMARK)\s*\(",
|
||||
|
||||
# System stored procedure injection
|
||||
r"(EXEC|EXECUTE|SP_|XP_)\s*\(",
|
||||
|
||||
# Script injection attempts
|
||||
r"<\s*(SCRIPT|JAVASCRIPT|VBSCRIPT)",
|
||||
]
|
||||
|
||||
sql_lower = sql.lower()
|
||||
# FIX: Don't flag legitimate SQL functions and keywords
|
||||
# These patterns are too broad and cause false positives:
|
||||
# - REMOVED: r"(char|ascii|substring|concat)\s*\(" - These are legitimate SQL functions
|
||||
# - REMOVED: r"(\s|^)(or|and)\s+\d+\s*=\s*\d+" - This flags BETWEEN...AND constructs
|
||||
# - REMOVED: r"(\s|^)(or|and)\s+['\"].*['\"]" - This is too broad
|
||||
|
||||
sql_upper = sql.upper()
|
||||
|
||||
# Special case: Allow BETWEEN...AND which is legitimate SQL
|
||||
# This prevents false positives like "WHERE dt BETWEEN '2025-01-01' AND '2025-01-31'"
|
||||
if "BETWEEN" in sql_upper and "AND" in sql_upper:
|
||||
# This is likely a BETWEEN clause, not injection
|
||||
# Check if AND appears in a BETWEEN context
|
||||
between_pattern = r"BETWEEN\s+[^\s]+\s+AND\s+[^\s]+"
|
||||
if re.search(between_pattern, sql_upper, re.IGNORECASE):
|
||||
# Remove BETWEEN clauses before checking other patterns
|
||||
sql_cleaned = re.sub(between_pattern, "BETWEEN_CLAUSE", sql_upper, flags=re.IGNORECASE)
|
||||
sql_to_check = sql_cleaned
|
||||
else:
|
||||
sql_to_check = sql_upper
|
||||
else:
|
||||
sql_to_check = sql_upper
|
||||
|
||||
for pattern in injection_patterns:
|
||||
if re.search(pattern, sql_lower, re.IGNORECASE):
|
||||
if re.search(pattern, sql_to_check, re.IGNORECASE):
|
||||
self.logger.warning(f"Potential SQL injection pattern detected: {pattern}")
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Potential SQL injection risk detected",
|
||||
risk_level="high",
|
||||
)
|
||||
|
||||
# Check suspicious quotes and comments
|
||||
# Check suspicious quotes and comments (with improved detection)
|
||||
if self._has_suspicious_quotes_or_comments(sql):
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
@@ -564,20 +1037,68 @@ class SQLSecurityValidator:
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
def _has_suspicious_quotes_or_comments(self, sql: str) -> bool:
|
||||
"""Check suspicious quote and comment patterns"""
|
||||
# Check unmatched quotes
|
||||
single_quotes = sql.count("'")
|
||||
double_quotes = sql.count('"')
|
||||
"""Check suspicious quote and comment patterns with improved detection
|
||||
|
||||
FIX for Issue #62 Bug 2: Improved detection to reduce false positives
|
||||
Now distinguishes between legitimate comments/strings and injection attempts
|
||||
"""
|
||||
try:
|
||||
# Use sqlparse to parse the SQL and distinguish between code and comments/strings
|
||||
import sqlparse
|
||||
from sqlparse.tokens import Comment, String
|
||||
|
||||
# Parse the SQL
|
||||
parsed = sqlparse.parse(sql)
|
||||
if not parsed:
|
||||
# If parsing fails, be conservative
|
||||
return True
|
||||
|
||||
statement = parsed[0]
|
||||
|
||||
# Check for unmatched quotes ONLY in non-string tokens
|
||||
# This prevents false positives from legitimate string content
|
||||
non_string_content = []
|
||||
has_string_tokens = False
|
||||
|
||||
for token in statement.flatten():
|
||||
if token.ttype in (String.Single, String.Double):
|
||||
has_string_tokens = True
|
||||
# Skip string content - quotes inside strings are legitimate
|
||||
continue
|
||||
elif token.ttype in (Comment.Single, Comment.Multi):
|
||||
# Comments are generally OK, but check for suspicious injection patterns
|
||||
comment_value = str(token).lower()
|
||||
# Check if comment contains dangerous SQL keywords
|
||||
dangerous_in_comments = ['drop', 'delete', 'insert', 'update', 'union', 'exec', 'execute']
|
||||
if any(keyword in comment_value for keyword in dangerous_in_comments):
|
||||
self.logger.warning(f"Suspicious SQL keyword in comment: {token}")
|
||||
return True
|
||||
# Normal comments are OK
|
||||
continue
|
||||
else:
|
||||
# Accumulate non-string, non-comment content
|
||||
non_string_content.append(str(token))
|
||||
|
||||
# Check for unmatched quotes in non-string content
|
||||
non_string_text = ''.join(non_string_content)
|
||||
single_quotes = non_string_text.count("'")
|
||||
double_quotes = non_string_text.count('"')
|
||||
|
||||
# Only flag if there are unmatched quotes in actual SQL code (not in strings)
|
||||
if single_quotes % 2 != 0 or double_quotes % 2 != 0:
|
||||
return True
|
||||
|
||||
# Check SQL comments
|
||||
if "--" in sql or "/*" in sql:
|
||||
return True
|
||||
# FIX: Don't flag legitimate SQL comments
|
||||
# Comments are OK as long as they don't contain dangerous patterns (already checked above)
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.debug(f"SQL parsing error in quote/comment check: {e}")
|
||||
# On parsing error, fall back to conservative check
|
||||
# But be more lenient than before
|
||||
return False # Don't flag on parse errors to reduce false positives
|
||||
|
||||
async def _check_blocked_keywords(self, parsed: Statement) -> ValidationResult:
|
||||
"""Check blocked keywords"""
|
||||
blocked_operations = []
|
||||
@@ -638,6 +1159,10 @@ class SQLSecurityValidator:
|
||||
self, parsed: Statement, auth_context: AuthContext
|
||||
) -> ValidationResult:
|
||||
"""Check table access permissions"""
|
||||
# If no auth_context, skip table access checks (rely on other security checks)
|
||||
if auth_context is None:
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
# Extract table names from query
|
||||
tables = self._extract_table_names(parsed)
|
||||
|
||||
@@ -686,7 +1211,7 @@ class DataMaskingProcessor:
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger = get_logger(__name__)
|
||||
self.masking_algorithms = self._init_masking_algorithms()
|
||||
self.masking_rules = self._load_masking_rules()
|
||||
|
||||
|
||||
788
doris_mcp_server/utils/security_analytics_tools.py
Normal file
788
doris_mcp_server/utils/security_analytics_tools.py
Normal file
@@ -0,0 +1,788 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Security Analytics Tools Module
|
||||
Provides data access analysis, user behavior monitoring, and security insights
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import get_auth_context
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SecurityAnalyticsTools:
|
||||
"""Security analytics tools for access pattern analysis and user monitoring"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
logger.info("SecurityAnalyticsTools initialized")
|
||||
|
||||
async def analyze_data_access_patterns(
|
||||
self,
|
||||
days: int = 7,
|
||||
include_system_users: bool = False,
|
||||
min_query_threshold: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyze data access patterns for users and roles
|
||||
|
||||
Args:
|
||||
days: Number of days to analyze
|
||||
include_system_users: Whether to include system/service users
|
||||
min_query_threshold: Minimum queries for a user to be included in analysis
|
||||
|
||||
Returns:
|
||||
Comprehensive access pattern analysis
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# 🚀 PROGRESS: Initialize security analysis
|
||||
logger.info("=" * 70)
|
||||
logger.info(f"🔒 Starting Data Access Pattern Analysis")
|
||||
logger.info(f"📅 Analysis period: {days} days")
|
||||
logger.info(f"👥 Include system users: {include_system_users}")
|
||||
logger.info(f"🎯 Min query threshold: {min_query_threshold}")
|
||||
logger.info("=" * 70)
|
||||
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
|
||||
# Define analysis period
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=days)
|
||||
|
||||
logger.info(f"📊 Period: {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')}")
|
||||
|
||||
# 🚀 PROGRESS: Step 1 - Get audit log data
|
||||
logger.info("📋 Step 1/5: Retrieving audit log data...")
|
||||
audit_start = time.time()
|
||||
audit_data = await self._get_audit_log_data(connection, start_date, end_date, include_system_users)
|
||||
audit_time = time.time() - audit_start
|
||||
|
||||
if not audit_data:
|
||||
logger.warning("⚠️ No audit data available for the specified period")
|
||||
return {
|
||||
"error": "No audit data available for the specified period",
|
||||
"analysis_period": {
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": end_date.isoformat(),
|
||||
"days": days
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(f"✅ Retrieved {len(audit_data)} audit records in {audit_time:.2f}s")
|
||||
|
||||
# 🚀 PROGRESS: Step 2 - Analyze user access patterns
|
||||
logger.info("👤 Step 2/5: Analyzing user access patterns...")
|
||||
user_start = time.time()
|
||||
user_access_analysis = await self._analyze_user_access_patterns(
|
||||
audit_data, min_query_threshold
|
||||
)
|
||||
user_time = time.time() - user_start
|
||||
logger.info(f"✅ Analyzed {len(user_access_analysis)} users in {user_time:.2f}s")
|
||||
|
||||
# 🚀 PROGRESS: Step 3 - Analyze role-based access
|
||||
logger.info("🎭 Step 3/5: Analyzing role-based access patterns...")
|
||||
role_start = time.time()
|
||||
role_access_analysis = await self._analyze_role_access_patterns(
|
||||
connection, user_access_analysis
|
||||
)
|
||||
role_time = time.time() - role_start
|
||||
logger.info(f"✅ Role analysis completed in {role_time:.2f}s")
|
||||
|
||||
# 🚀 PROGRESS: Step 4 - Detect security anomalies
|
||||
logger.info("🚨 Step 4/5: Detecting security anomalies...")
|
||||
anomaly_start = time.time()
|
||||
security_alerts = await self._detect_security_anomalies(
|
||||
audit_data, user_access_analysis
|
||||
)
|
||||
anomaly_time = time.time() - anomaly_start
|
||||
logger.info(f"✅ Found {len(security_alerts)} security alerts in {anomaly_time:.2f}s")
|
||||
|
||||
# Log alert summary
|
||||
if security_alerts:
|
||||
high_alerts = sum(1 for alert in security_alerts if alert.get("severity") == "high")
|
||||
medium_alerts = sum(1 for alert in security_alerts if alert.get("severity") == "medium")
|
||||
logger.info(f"🚨 Alert breakdown: {high_alerts} high, {medium_alerts} medium")
|
||||
|
||||
# 🚀 PROGRESS: Step 5 - Generate access insights
|
||||
logger.info("💡 Step 5/5: Generating access insights...")
|
||||
insights_start = time.time()
|
||||
access_insights = await self._generate_access_insights(
|
||||
user_access_analysis, role_access_analysis
|
||||
)
|
||||
insights_time = time.time() - insights_start
|
||||
logger.info(f"✅ Access insights generated in {insights_time:.2f}s")
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
"analysis_period": {
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": end_date.isoformat(),
|
||||
"days": days
|
||||
},
|
||||
"analysis_timestamp": datetime.now().isoformat(),
|
||||
"execution_time_seconds": round(execution_time, 3),
|
||||
"user_access_summary": self._generate_user_access_summary(user_access_analysis),
|
||||
"user_access_details": user_access_analysis,
|
||||
"role_analysis": role_access_analysis,
|
||||
"security_alerts": security_alerts,
|
||||
"access_insights": access_insights,
|
||||
"recommendations": self._generate_security_recommendations(security_alerts, access_insights)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Data access pattern analysis failed: {str(e)}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"analysis_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# ==================== Private Helper Methods ====================
|
||||
|
||||
async def _get_audit_log_data(self, connection, start_date: datetime, end_date: datetime, include_system_users: bool) -> List[Dict]:
|
||||
"""Retrieve audit log data for the specified period"""
|
||||
try:
|
||||
# System users filter
|
||||
system_user_filter = ""
|
||||
if not include_system_users:
|
||||
system_users = ['root', 'admin', 'system', 'doris', 'information_schema']
|
||||
user_list = ','.join([f'"{user}"' for user in system_users])
|
||||
system_user_filter = f"AND `user` NOT IN ({user_list})"
|
||||
|
||||
audit_sql = f"""
|
||||
SELECT
|
||||
`user` as user_name,
|
||||
`client_ip` as host,
|
||||
`time` as query_time,
|
||||
`stmt` as sql_statement,
|
||||
`state` as query_status,
|
||||
`scan_bytes` as scan_bytes,
|
||||
`scan_rows` as scan_rows,
|
||||
`return_rows` as return_rows,
|
||||
`query_time` as execution_time_ms
|
||||
FROM internal.__internal_schema.audit_log
|
||||
WHERE `time` >= '{start_date.strftime('%Y-%m-%d %H:%M:%S')}'
|
||||
AND `time` <= '{end_date.strftime('%Y-%m-%d %H:%M:%S')}'
|
||||
AND `stmt` IS NOT NULL
|
||||
AND `stmt` != ''
|
||||
{system_user_filter}
|
||||
ORDER BY `time` DESC
|
||||
LIMIT 10000
|
||||
"""
|
||||
|
||||
# SECURITY FIX: Pass auth_context to execute
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(audit_sql, auth_context=auth_context)
|
||||
return result.data if result.data else []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get audit log data: {str(e)}")
|
||||
# Try alternative method without detailed metrics
|
||||
try:
|
||||
simple_audit_sql = f"""
|
||||
SELECT
|
||||
`user` as user_name,
|
||||
`client_ip` as host,
|
||||
`time` as query_time,
|
||||
`stmt` as sql_statement,
|
||||
`state` as query_status
|
||||
FROM internal.__internal_schema.audit_log
|
||||
WHERE `time` >= '{start_date.strftime('%Y-%m-%d %H:%M:%S')}'
|
||||
AND `time` <= '{end_date.strftime('%Y-%m-%d %H:%M:%S')}'
|
||||
AND `stmt` IS NOT NULL
|
||||
{system_user_filter}
|
||||
ORDER BY `time` DESC
|
||||
LIMIT 10000
|
||||
"""
|
||||
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(simple_audit_sql, auth_context=auth_context)
|
||||
return result.data if result.data else []
|
||||
|
||||
except Exception as e2:
|
||||
logger.error(f"Failed to get simplified audit log data: {str(e2)}")
|
||||
return []
|
||||
|
||||
async def _analyze_user_access_patterns(self, audit_data: List[Dict], min_query_threshold: int) -> List[Dict]:
|
||||
"""Analyze access patterns for individual users"""
|
||||
user_stats = defaultdict(lambda: {
|
||||
"total_queries": 0,
|
||||
"unique_tables_accessed": set(),
|
||||
"hosts": set(),
|
||||
"query_types": Counter(),
|
||||
"query_times": [],
|
||||
"failed_queries": 0,
|
||||
"data_volume_read_bytes": 0,
|
||||
"data_volume_read_rows": 0,
|
||||
"hourly_pattern": [0] * 24,
|
||||
"daily_pattern": [0] * 7,
|
||||
"query_statements": []
|
||||
})
|
||||
|
||||
# Process audit data
|
||||
for entry in audit_data:
|
||||
user_name = entry.get("user_name", "unknown")
|
||||
query_time = entry.get("query_time")
|
||||
sql_statement = entry.get("sql_statement", "")
|
||||
query_status = entry.get("query_status", "")
|
||||
|
||||
stats = user_stats[user_name]
|
||||
stats["total_queries"] += 1
|
||||
|
||||
# Extract table names from SQL
|
||||
tables = self._extract_table_names_from_sql(sql_statement)
|
||||
stats["unique_tables_accessed"].update(tables)
|
||||
|
||||
# Host tracking
|
||||
if entry.get("host"):
|
||||
stats["hosts"].add(entry["host"])
|
||||
|
||||
# Query type analysis
|
||||
query_type = self._classify_query_type(sql_statement)
|
||||
stats["query_types"][query_type] += 1
|
||||
|
||||
# Query time patterns
|
||||
if query_time:
|
||||
try:
|
||||
if isinstance(query_time, str):
|
||||
query_dt = datetime.fromisoformat(query_time.replace('Z', '+00:00'))
|
||||
else:
|
||||
query_dt = query_time
|
||||
|
||||
stats["query_times"].append(query_dt)
|
||||
stats["hourly_pattern"][query_dt.hour] += 1
|
||||
stats["daily_pattern"][query_dt.weekday()] += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Error tracking
|
||||
if query_status and "error" in query_status.lower():
|
||||
stats["failed_queries"] += 1
|
||||
|
||||
# Data volume tracking
|
||||
if entry.get("scan_bytes"):
|
||||
try:
|
||||
stats["data_volume_read_bytes"] += int(entry["scan_bytes"])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if entry.get("scan_rows"):
|
||||
try:
|
||||
stats["data_volume_read_rows"] += int(entry["scan_rows"])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Store sample queries
|
||||
if len(stats["query_statements"]) < 10:
|
||||
stats["query_statements"].append({
|
||||
"sql": sql_statement[:200] + "..." if len(sql_statement) > 200 else sql_statement,
|
||||
"timestamp": str(query_time),
|
||||
"type": query_type
|
||||
})
|
||||
|
||||
# Convert to analysis results
|
||||
user_analysis = []
|
||||
for user_name, stats in user_stats.items():
|
||||
if stats["total_queries"] >= min_query_threshold:
|
||||
# Calculate patterns and insights
|
||||
access_pattern = self._classify_access_pattern(stats["hourly_pattern"])
|
||||
table_access_frequency = dict(Counter(
|
||||
table for entry in audit_data
|
||||
if entry.get("user_name") == user_name
|
||||
for table in self._extract_table_names_from_sql(entry.get("sql_statement", ""))
|
||||
).most_common(10))
|
||||
|
||||
user_analysis.append({
|
||||
"user_name": user_name,
|
||||
"access_stats": {
|
||||
"total_queries": stats["total_queries"],
|
||||
"unique_tables_accessed": len(stats["unique_tables_accessed"]),
|
||||
"unique_hosts": len(stats["hosts"]),
|
||||
"data_volume_read_gb": round(stats["data_volume_read_bytes"] / (1024**3), 3),
|
||||
"data_volume_read_rows": stats["data_volume_read_rows"],
|
||||
"failed_queries": stats["failed_queries"],
|
||||
"success_rate": round((stats["total_queries"] - stats["failed_queries"]) / stats["total_queries"], 3) if stats["total_queries"] > 0 else 0,
|
||||
"peak_access_hour": stats["hourly_pattern"].index(max(stats["hourly_pattern"])) if max(stats["hourly_pattern"]) > 0 else None,
|
||||
"access_pattern": access_pattern
|
||||
},
|
||||
"query_type_distribution": dict(stats["query_types"]),
|
||||
"table_access_frequency": table_access_frequency,
|
||||
"hosts_used": list(stats["hosts"]),
|
||||
"sample_queries": stats["query_statements"],
|
||||
"temporal_patterns": {
|
||||
"hourly_distribution": stats["hourly_pattern"],
|
||||
"daily_distribution": stats["daily_pattern"]
|
||||
}
|
||||
})
|
||||
|
||||
return sorted(user_analysis, key=lambda x: x["access_stats"]["total_queries"], reverse=True)
|
||||
|
||||
def _extract_table_names_from_sql(self, sql: str) -> List[str]:
|
||||
"""Extract table names from SQL statement (simplified implementation)"""
|
||||
if not sql:
|
||||
return []
|
||||
|
||||
import re
|
||||
|
||||
# Simple regex patterns to match table names
|
||||
patterns = [
|
||||
r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
|
||||
r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
|
||||
r'\bINTO\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
|
||||
r'\bUPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
|
||||
r'\bDELETE\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)'
|
||||
]
|
||||
|
||||
tables = []
|
||||
for pattern in patterns:
|
||||
matches = re.findall(pattern, sql, re.IGNORECASE)
|
||||
tables.extend(matches)
|
||||
|
||||
# Clean up table names (remove quotes, aliases, etc.)
|
||||
cleaned_tables = []
|
||||
for table in tables:
|
||||
# Remove backticks, quotes, and get just the table name
|
||||
clean_table = table.strip('`"\'').split(' ')[0]
|
||||
if clean_table and not clean_table.upper() in ['SELECT', 'WHERE', 'AND', 'OR']:
|
||||
cleaned_tables.append(clean_table)
|
||||
|
||||
return list(set(cleaned_tables))
|
||||
|
||||
def _classify_query_type(self, sql: str) -> str:
|
||||
"""Classify SQL query type"""
|
||||
if not sql:
|
||||
return "unknown"
|
||||
|
||||
sql_upper = sql.upper().strip()
|
||||
|
||||
if sql_upper.startswith('SELECT'):
|
||||
return "SELECT"
|
||||
elif sql_upper.startswith('INSERT'):
|
||||
return "INSERT"
|
||||
elif sql_upper.startswith('UPDATE'):
|
||||
return "UPDATE"
|
||||
elif sql_upper.startswith('DELETE'):
|
||||
return "DELETE"
|
||||
elif sql_upper.startswith('CREATE'):
|
||||
return "CREATE"
|
||||
elif sql_upper.startswith('ALTER'):
|
||||
return "ALTER"
|
||||
elif sql_upper.startswith('DROP'):
|
||||
return "DROP"
|
||||
elif sql_upper.startswith('SHOW'):
|
||||
return "SHOW"
|
||||
elif sql_upper.startswith('DESCRIBE') or sql_upper.startswith('DESC'):
|
||||
return "DESCRIBE"
|
||||
else:
|
||||
return "OTHER"
|
||||
|
||||
def _classify_access_pattern(self, hourly_pattern: List[int]) -> str:
|
||||
"""Classify user access pattern based on hourly distribution"""
|
||||
if not hourly_pattern or max(hourly_pattern) == 0:
|
||||
return "no_pattern"
|
||||
|
||||
# Find peak hours
|
||||
max_queries = max(hourly_pattern)
|
||||
peak_hours = [i for i, count in enumerate(hourly_pattern) if count == max_queries]
|
||||
|
||||
# Business hours: 9-17
|
||||
business_hours = set(range(9, 18))
|
||||
peak_in_business_hours = any(hour in business_hours for hour in peak_hours)
|
||||
|
||||
# Night hours: 22-6
|
||||
night_hours = set(list(range(22, 24)) + list(range(0, 7)))
|
||||
peak_in_night_hours = any(hour in night_hours for hour in peak_hours)
|
||||
|
||||
if peak_in_business_hours and not peak_in_night_hours:
|
||||
return "regular_business_hours"
|
||||
elif peak_in_night_hours:
|
||||
return "night_shift_or_batch"
|
||||
elif len(peak_hours) > 6: # Distributed throughout day
|
||||
return "distributed_access"
|
||||
else:
|
||||
return "irregular_pattern"
|
||||
|
||||
async def _analyze_role_access_patterns(self, connection, user_access_analysis: List[Dict]) -> Dict[str, Any]:
|
||||
"""Analyze access patterns by role"""
|
||||
try:
|
||||
# Get user roles information
|
||||
user_roles = await self._get_user_roles(connection)
|
||||
|
||||
# Group users by roles
|
||||
role_stats = defaultdict(lambda: {
|
||||
"user_count": 0,
|
||||
"total_queries": 0,
|
||||
"unique_tables": set(),
|
||||
"query_types": Counter(),
|
||||
"avg_queries_per_user": 0,
|
||||
"users": []
|
||||
})
|
||||
|
||||
# Process user access data
|
||||
for user_data in user_access_analysis:
|
||||
user_name = user_data["user_name"]
|
||||
user_stats = user_data["access_stats"]
|
||||
query_types = user_data["query_type_distribution"]
|
||||
|
||||
# Get user roles (default to 'unknown' if not found)
|
||||
roles = user_roles.get(user_name, ["unknown"])
|
||||
|
||||
for role in roles:
|
||||
stats = role_stats[role]
|
||||
stats["user_count"] += 1
|
||||
stats["total_queries"] += user_stats["total_queries"]
|
||||
stats["users"].append(user_name)
|
||||
|
||||
# Aggregate query types
|
||||
for query_type, count in query_types.items():
|
||||
stats["query_types"][query_type] += count
|
||||
|
||||
# Calculate role analysis
|
||||
role_analysis = {}
|
||||
for role, stats in role_stats.items():
|
||||
if stats["user_count"] > 0:
|
||||
avg_queries = stats["total_queries"] / stats["user_count"]
|
||||
|
||||
# Calculate privilege usage (simplified)
|
||||
total_role_queries = sum(stats["query_types"].values())
|
||||
privilege_usage = {}
|
||||
if total_role_queries > 0:
|
||||
privilege_usage = {
|
||||
query_type: round(count / total_role_queries, 3)
|
||||
for query_type, count in stats["query_types"].items()
|
||||
}
|
||||
|
||||
role_analysis[role] = {
|
||||
"user_count": stats["user_count"],
|
||||
"users": stats["users"],
|
||||
"total_queries": stats["total_queries"],
|
||||
"avg_queries_per_user": round(avg_queries, 1),
|
||||
"query_type_distribution": dict(stats["query_types"]),
|
||||
"privilege_usage": privilege_usage,
|
||||
"activity_level": self._classify_role_activity_level(avg_queries)
|
||||
}
|
||||
|
||||
return role_analysis
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze role access patterns: {str(e)}")
|
||||
return {}
|
||||
|
||||
async def _get_user_roles(self, connection) -> Dict[str, List[str]]:
|
||||
"""Get user roles mapping"""
|
||||
try:
|
||||
# Try to get user role information
|
||||
roles_sql = """
|
||||
SELECT
|
||||
User as user_name,
|
||||
COALESCE(Default_role, 'default') as role_name
|
||||
FROM mysql.user
|
||||
"""
|
||||
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(roles_sql, auth_context=auth_context)
|
||||
|
||||
user_roles = defaultdict(list)
|
||||
if result.data:
|
||||
for row in result.data:
|
||||
user_name = row.get("user_name", "")
|
||||
role_name = row.get("role_name", "default")
|
||||
if user_name:
|
||||
user_roles[user_name].append(role_name)
|
||||
|
||||
return dict(user_roles)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get user roles: {str(e)}")
|
||||
return {}
|
||||
|
||||
def _classify_role_activity_level(self, avg_queries: float) -> str:
|
||||
"""Classify role activity level based on average queries"""
|
||||
if avg_queries > 100:
|
||||
return "high"
|
||||
elif avg_queries > 20:
|
||||
return "medium"
|
||||
elif avg_queries > 5:
|
||||
return "low"
|
||||
else:
|
||||
return "minimal"
|
||||
|
||||
async def _detect_security_anomalies(self, audit_data: List[Dict], user_access_analysis: List[Dict]) -> List[Dict]:
|
||||
"""Detect potential security anomalies"""
|
||||
alerts = []
|
||||
|
||||
# 1. Detect unusual access times
|
||||
for user_data in user_access_analysis:
|
||||
user_name = user_data["user_name"]
|
||||
hourly_pattern = user_data["temporal_patterns"]["hourly_distribution"]
|
||||
|
||||
# Check for significant night-time activity
|
||||
night_queries = sum(hourly_pattern[22:24]) + sum(hourly_pattern[0:6])
|
||||
total_queries = sum(hourly_pattern)
|
||||
|
||||
if total_queries > 0 and night_queries / total_queries > 0.3: # >30% night activity
|
||||
alerts.append({
|
||||
"alert_type": "unusual_access_time",
|
||||
"severity": "medium",
|
||||
"user": user_name,
|
||||
"description": f"User {user_name} has {night_queries/total_queries:.1%} of queries during night hours",
|
||||
"night_query_percentage": round(night_queries/total_queries, 3),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 2. Detect users with high failure rates
|
||||
for user_data in user_access_analysis:
|
||||
user_name = user_data["user_name"]
|
||||
success_rate = user_data["access_stats"]["success_rate"]
|
||||
total_queries = user_data["access_stats"]["total_queries"]
|
||||
|
||||
if total_queries > 10 and success_rate < 0.8: # <80% success rate
|
||||
alerts.append({
|
||||
"alert_type": "high_failure_rate",
|
||||
"severity": "medium",
|
||||
"user": user_name,
|
||||
"description": f"User {user_name} has low query success rate ({success_rate:.1%})",
|
||||
"success_rate": success_rate,
|
||||
"total_queries": total_queries,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 3. Detect unusual data volume access
|
||||
data_volumes = [user["access_stats"]["data_volume_read_gb"] for user in user_access_analysis]
|
||||
if data_volumes:
|
||||
avg_volume = sum(data_volumes) / len(data_volumes)
|
||||
std_dev = (sum((x - avg_volume) ** 2 for x in data_volumes) / len(data_volumes)) ** 0.5
|
||||
threshold = avg_volume + 2 * std_dev # 2 standard deviations above mean
|
||||
|
||||
for user_data in user_access_analysis:
|
||||
user_name = user_data["user_name"]
|
||||
volume = user_data["access_stats"]["data_volume_read_gb"]
|
||||
|
||||
if volume > threshold and volume > 1.0: # >1GB and above threshold
|
||||
alerts.append({
|
||||
"alert_type": "unusual_data_volume",
|
||||
"severity": "high" if volume > threshold * 2 else "medium",
|
||||
"user": user_name,
|
||||
"description": f"User {user_name} read {volume:.2f}GB (threshold: {threshold:.2f}GB)",
|
||||
"data_volume_gb": volume,
|
||||
"threshold_gb": round(threshold, 2),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 4. Detect users accessing many different tables
|
||||
for user_data in user_access_analysis:
|
||||
user_name = user_data["user_name"]
|
||||
unique_tables = user_data["access_stats"]["unique_tables_accessed"]
|
||||
total_queries = user_data["access_stats"]["total_queries"]
|
||||
|
||||
# High table diversity might indicate privilege escalation or data mining
|
||||
if unique_tables > 20 and total_queries > 50:
|
||||
alerts.append({
|
||||
"alert_type": "broad_table_access",
|
||||
"severity": "medium",
|
||||
"user": user_name,
|
||||
"description": f"User {user_name} accessed {unique_tables} different tables",
|
||||
"unique_tables_count": unique_tables,
|
||||
"total_queries": total_queries,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
return sorted(alerts, key=lambda x: {"high": 3, "medium": 2, "low": 1}.get(x["severity"], 0), reverse=True)
|
||||
|
||||
async def _generate_access_insights(self, user_access_analysis: List[Dict], role_analysis: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Generate access insights and patterns"""
|
||||
insights = {
|
||||
"user_behavior_patterns": {},
|
||||
"role_effectiveness": {},
|
||||
"security_posture": {}
|
||||
}
|
||||
|
||||
# User behavior patterns
|
||||
if user_access_analysis:
|
||||
total_users = len(user_access_analysis)
|
||||
active_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 10])
|
||||
power_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 100])
|
||||
|
||||
# Access pattern distribution
|
||||
pattern_distribution = Counter(
|
||||
user["access_stats"]["access_pattern"] for user in user_access_analysis
|
||||
)
|
||||
|
||||
insights["user_behavior_patterns"] = {
|
||||
"total_users_analyzed": total_users,
|
||||
"active_users": active_users,
|
||||
"power_users": power_users,
|
||||
"access_pattern_distribution": dict(pattern_distribution),
|
||||
"avg_queries_per_user": round(
|
||||
sum(u["access_stats"]["total_queries"] for u in user_access_analysis) / total_users, 1
|
||||
) if total_users > 0 else 0
|
||||
}
|
||||
|
||||
# Role effectiveness
|
||||
if role_analysis:
|
||||
most_active_role = max(role_analysis.items(), key=lambda x: x[1]["total_queries"])
|
||||
least_active_role = min(role_analysis.items(), key=lambda x: x[1]["total_queries"])
|
||||
|
||||
insights["role_effectiveness"] = {
|
||||
"total_roles": len(role_analysis),
|
||||
"most_active_role": {
|
||||
"role": most_active_role[0],
|
||||
"total_queries": most_active_role[1]["total_queries"],
|
||||
"user_count": most_active_role[1]["user_count"]
|
||||
},
|
||||
"least_active_role": {
|
||||
"role": least_active_role[0],
|
||||
"total_queries": least_active_role[1]["total_queries"],
|
||||
"user_count": least_active_role[1]["user_count"]
|
||||
},
|
||||
"avg_users_per_role": round(
|
||||
sum(role_info["user_count"] for role_info in role_analysis.values()) / len(role_analysis), 1
|
||||
)
|
||||
}
|
||||
|
||||
# Security posture assessment
|
||||
if user_access_analysis:
|
||||
users_with_failures = len([u for u in user_access_analysis if u["access_stats"]["failed_queries"] > 0])
|
||||
users_night_access = len([
|
||||
u for u in user_access_analysis
|
||||
if any(u["temporal_patterns"]["hourly_distribution"][hour] > 0 for hour in list(range(22, 24)) + list(range(0, 6)))
|
||||
])
|
||||
|
||||
insights["security_posture"] = {
|
||||
"users_with_query_failures": users_with_failures,
|
||||
"users_with_night_access": users_night_access,
|
||||
"security_score": self._calculate_security_score(user_access_analysis),
|
||||
"risk_level": self._assess_overall_risk_level(user_access_analysis)
|
||||
}
|
||||
|
||||
return insights
|
||||
|
||||
def _calculate_security_score(self, user_access_analysis: List[Dict]) -> float:
|
||||
"""Calculate overall security score (0-1, higher is better)"""
|
||||
if not user_access_analysis:
|
||||
return 0.0
|
||||
|
||||
total_users = len(user_access_analysis)
|
||||
|
||||
# Factors that contribute to security score
|
||||
users_with_high_success_rate = len([u for u in user_access_analysis if u["access_stats"]["success_rate"] > 0.9])
|
||||
users_with_normal_patterns = len([u for u in user_access_analysis if u["access_stats"]["access_pattern"] == "regular_business_hours"])
|
||||
|
||||
success_rate_score = users_with_high_success_rate / total_users
|
||||
pattern_score = users_with_normal_patterns / total_users
|
||||
|
||||
# Combined score
|
||||
overall_score = (success_rate_score * 0.6 + pattern_score * 0.4)
|
||||
return round(overall_score, 3)
|
||||
|
||||
def _assess_overall_risk_level(self, user_access_analysis: List[Dict]) -> str:
|
||||
"""Assess overall security risk level"""
|
||||
security_score = self._calculate_security_score(user_access_analysis)
|
||||
|
||||
if security_score > 0.8:
|
||||
return "low"
|
||||
elif security_score > 0.6:
|
||||
return "medium"
|
||||
else:
|
||||
return "high"
|
||||
|
||||
def _generate_user_access_summary(self, user_access_analysis: List[Dict]) -> Dict[str, Any]:
|
||||
"""Generate summary statistics for user access"""
|
||||
if not user_access_analysis:
|
||||
return {
|
||||
"total_users": 0,
|
||||
"active_users": 0,
|
||||
"high_activity_users": 0,
|
||||
"dormant_users": 0
|
||||
}
|
||||
|
||||
total_users = len(user_access_analysis)
|
||||
active_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 10])
|
||||
high_activity_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 100])
|
||||
dormant_users = total_users - active_users
|
||||
|
||||
return {
|
||||
"total_users": total_users,
|
||||
"active_users": active_users,
|
||||
"high_activity_users": high_activity_users,
|
||||
"dormant_users": dormant_users,
|
||||
"activity_distribution": {
|
||||
"high": high_activity_users,
|
||||
"medium": active_users - high_activity_users,
|
||||
"low": dormant_users
|
||||
}
|
||||
}
|
||||
|
||||
def _generate_security_recommendations(self, security_alerts: List[Dict], access_insights: Dict[str, Any]) -> List[Dict]:
|
||||
"""Generate security recommendations based on analysis"""
|
||||
recommendations = []
|
||||
|
||||
# Recommendations based on alerts
|
||||
if security_alerts:
|
||||
high_severity_alerts = [alert for alert in security_alerts if alert["severity"] == "high"]
|
||||
if high_severity_alerts:
|
||||
recommendations.append({
|
||||
"type": "urgent_security_review",
|
||||
"priority": "high",
|
||||
"description": f"Found {len(high_severity_alerts)} high-severity security alerts",
|
||||
"action": "Immediate review of flagged users and access patterns required",
|
||||
"affected_users": list(set(alert["user"] for alert in high_severity_alerts if "user" in alert))
|
||||
})
|
||||
|
||||
# Night access recommendations
|
||||
night_access_alerts = [alert for alert in security_alerts if alert["alert_type"] == "unusual_access_time"]
|
||||
if night_access_alerts:
|
||||
recommendations.append({
|
||||
"type": "access_time_policy",
|
||||
"priority": "medium",
|
||||
"description": f"{len(night_access_alerts)} users have significant night-time access",
|
||||
"action": "Review access time policies and consider time-based restrictions",
|
||||
"affected_users": [alert["user"] for alert in night_access_alerts]
|
||||
})
|
||||
|
||||
# Recommendations based on insights
|
||||
security_posture = access_insights.get("security_posture", {})
|
||||
risk_level = security_posture.get("risk_level", "unknown")
|
||||
|
||||
if risk_level == "high":
|
||||
recommendations.append({
|
||||
"type": "overall_security_improvement",
|
||||
"priority": "high",
|
||||
"description": "Overall security posture indicates high risk",
|
||||
"action": "Comprehensive security audit and policy review recommended"
|
||||
})
|
||||
|
||||
# Role-based recommendations
|
||||
role_effectiveness = access_insights.get("role_effectiveness", {})
|
||||
if role_effectiveness and role_effectiveness.get("total_roles", 0) < 3:
|
||||
recommendations.append({
|
||||
"type": "role_management",
|
||||
"priority": "medium",
|
||||
"description": "Limited role diversity detected",
|
||||
"action": "Consider implementing more granular role-based access control"
|
||||
})
|
||||
|
||||
return recommendations
|
||||
301
doris_mcp_server/utils/sql_security_utils.py
Normal file
301
doris_mcp_server/utils/sql_security_utils.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
SQL Security Utilities Module
|
||||
|
||||
Provides SQL identifier validation, escaping, and safe query building utilities
|
||||
to prevent SQL injection attacks.
|
||||
"""
|
||||
|
||||
import re
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional, Tuple, List, Any
|
||||
|
||||
from .logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Context variable for auth_context (set by HTTP middleware)
|
||||
auth_context_var: ContextVar = ContextVar('mcp_auth_context', default=None)
|
||||
|
||||
|
||||
class SQLSecurityError(Exception):
|
||||
"""Exception raised for SQL security validation failures"""
|
||||
pass
|
||||
|
||||
|
||||
class SQLSecurityUtils:
|
||||
"""
|
||||
SQL Security Utilities for preventing SQL injection attacks.
|
||||
|
||||
Provides:
|
||||
- Identifier validation (database names, table names, column names)
|
||||
- Safe identifier quoting with backticks
|
||||
- Safe table reference building
|
||||
- Auth context retrieval from context variables
|
||||
"""
|
||||
|
||||
# Valid SQL identifier pattern: letters, numbers, underscores
|
||||
# Must start with letter or underscore, not a number
|
||||
# Supports Unicode letters for international database/table names
|
||||
IDENTIFIER_PATTERN = re.compile(r'^[a-zA-Z_\u4e00-\u9fff][a-zA-Z0-9_\u4e00-\u9fff]*$')
|
||||
|
||||
# Maximum identifier length (MySQL/Doris standard)
|
||||
MAX_IDENTIFIER_LENGTH = 64
|
||||
|
||||
# SQL reserved keywords that should be quoted
|
||||
SQL_KEYWORDS = {
|
||||
'SELECT', 'FROM', 'WHERE', 'INSERT', 'UPDATE', 'DELETE', 'DROP',
|
||||
'CREATE', 'ALTER', 'TABLE', 'DATABASE', 'INDEX', 'VIEW', 'AND',
|
||||
'OR', 'NOT', 'NULL', 'TRUE', 'FALSE', 'IN', 'LIKE', 'BETWEEN',
|
||||
'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'ON', 'AS', 'ORDER',
|
||||
'BY', 'GROUP', 'HAVING', 'LIMIT', 'OFFSET', 'UNION', 'ALL',
|
||||
'DISTINCT', 'INTO', 'VALUES', 'SET', 'DEFAULT', 'PRIMARY', 'KEY',
|
||||
'FOREIGN', 'REFERENCES', 'CHECK', 'UNIQUE', 'CONSTRAINT'
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def validate_identifier(cls, name: str, identifier_type: str = "identifier") -> str:
|
||||
"""
|
||||
Validate a SQL identifier (database name, table name, column name, etc.)
|
||||
|
||||
Args:
|
||||
name: The identifier to validate
|
||||
identifier_type: Type description for error messages (e.g., "database name", "table name")
|
||||
|
||||
Returns:
|
||||
The validated identifier (unchanged if valid)
|
||||
|
||||
Raises:
|
||||
SQLSecurityError: If the identifier is invalid
|
||||
"""
|
||||
if not name:
|
||||
raise SQLSecurityError(f"Empty {identifier_type} is not allowed")
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise SQLSecurityError(f"Invalid {identifier_type}: must be a string, got {type(name).__name__}")
|
||||
|
||||
# Strip whitespace
|
||||
name = name.strip()
|
||||
|
||||
if not name:
|
||||
raise SQLSecurityError(f"Empty {identifier_type} is not allowed")
|
||||
|
||||
# Check length
|
||||
if len(name) > cls.MAX_IDENTIFIER_LENGTH:
|
||||
raise SQLSecurityError(
|
||||
f"Invalid {identifier_type}: '{name[:20]}...' exceeds maximum length of {cls.MAX_IDENTIFIER_LENGTH} characters"
|
||||
)
|
||||
|
||||
# Check for dangerous characters that could be SQL injection
|
||||
dangerous_chars = ["'", '"', ';', '--', '/*', '*/', '\\', '\x00']
|
||||
for char in dangerous_chars:
|
||||
if char in name:
|
||||
raise SQLSecurityError(
|
||||
f"Invalid {identifier_type}: '{name}' contains forbidden character '{char}'"
|
||||
)
|
||||
|
||||
# Validate pattern
|
||||
if not cls.IDENTIFIER_PATTERN.match(name):
|
||||
raise SQLSecurityError(
|
||||
f"Invalid {identifier_type}: '{name}' contains invalid characters. "
|
||||
f"Only letters, numbers, and underscores are allowed, and must start with a letter or underscore."
|
||||
)
|
||||
|
||||
logger.debug(f"Validated {identifier_type}: {name}")
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
def quote_identifier(cls, name: str, identifier_type: str = "identifier") -> str:
|
||||
"""
|
||||
Safely quote a SQL identifier using backticks.
|
||||
|
||||
Args:
|
||||
name: The identifier to quote
|
||||
identifier_type: Type description for error messages
|
||||
|
||||
Returns:
|
||||
The quoted identifier (e.g., `table_name`)
|
||||
|
||||
Raises:
|
||||
SQLSecurityError: If the identifier is invalid
|
||||
"""
|
||||
# First validate the identifier
|
||||
validated_name = cls.validate_identifier(name, identifier_type)
|
||||
|
||||
# Escape any backticks within the name (double them)
|
||||
escaped_name = validated_name.replace('`', '``')
|
||||
|
||||
return f"`{escaped_name}`"
|
||||
|
||||
@classmethod
|
||||
def build_table_reference(
|
||||
cls,
|
||||
table_name: str,
|
||||
db_name: Optional[str] = None,
|
||||
catalog_name: Optional[str] = None,
|
||||
quote: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Build a safe, fully-qualified table reference.
|
||||
|
||||
Args:
|
||||
table_name: The table name (required)
|
||||
db_name: The database name (optional)
|
||||
catalog_name: The catalog name (optional)
|
||||
quote: Whether to quote identifiers with backticks (default: True)
|
||||
|
||||
Returns:
|
||||
A safe table reference string (e.g., `catalog`.`db`.`table`)
|
||||
|
||||
Raises:
|
||||
SQLSecurityError: If any identifier is invalid
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if catalog_name:
|
||||
if quote:
|
||||
parts.append(cls.quote_identifier(catalog_name, "catalog name"))
|
||||
else:
|
||||
parts.append(cls.validate_identifier(catalog_name, "catalog name"))
|
||||
|
||||
if db_name:
|
||||
if quote:
|
||||
parts.append(cls.quote_identifier(db_name, "database name"))
|
||||
else:
|
||||
parts.append(cls.validate_identifier(db_name, "database name"))
|
||||
|
||||
if quote:
|
||||
parts.append(cls.quote_identifier(table_name, "table name"))
|
||||
else:
|
||||
parts.append(cls.validate_identifier(table_name, "table name"))
|
||||
|
||||
return '.'.join(parts)
|
||||
|
||||
@classmethod
|
||||
def build_column_reference(
|
||||
cls,
|
||||
column_name: str,
|
||||
table_name: Optional[str] = None,
|
||||
quote: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Build a safe column reference.
|
||||
|
||||
Args:
|
||||
column_name: The column name (required)
|
||||
table_name: The table name (optional, for qualified references)
|
||||
quote: Whether to quote identifiers with backticks (default: True)
|
||||
|
||||
Returns:
|
||||
A safe column reference string (e.g., `table`.`column`)
|
||||
|
||||
Raises:
|
||||
SQLSecurityError: If any identifier is invalid
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if table_name:
|
||||
if quote:
|
||||
parts.append(cls.quote_identifier(table_name, "table name"))
|
||||
else:
|
||||
parts.append(cls.validate_identifier(table_name, "table name"))
|
||||
|
||||
if quote:
|
||||
parts.append(cls.quote_identifier(column_name, "column name"))
|
||||
else:
|
||||
parts.append(cls.validate_identifier(column_name, "column name"))
|
||||
|
||||
return '.'.join(parts)
|
||||
|
||||
@classmethod
|
||||
def validate_and_build_where_condition(
|
||||
cls,
|
||||
column_name: str,
|
||||
operator: str = "=",
|
||||
use_param: bool = True
|
||||
) -> Tuple[str, bool]:
|
||||
"""
|
||||
Build a safe WHERE condition for a column.
|
||||
|
||||
Args:
|
||||
column_name: The column name
|
||||
operator: The comparison operator (=, !=, <, >, <=, >=, LIKE, IN)
|
||||
use_param: Whether to use parameterized placeholder (%s)
|
||||
|
||||
Returns:
|
||||
Tuple of (condition_string, needs_param)
|
||||
e.g., ("`column` = %s", True) or ("`column` = DATABASE()", False)
|
||||
|
||||
Raises:
|
||||
SQLSecurityError: If column name is invalid or operator is not allowed
|
||||
"""
|
||||
# Validate column name
|
||||
quoted_column = cls.quote_identifier(column_name, "column name")
|
||||
|
||||
# Validate operator
|
||||
allowed_operators = {'=', '!=', '<>', '<', '>', '<=', '>=', 'LIKE', 'IN', 'IS'}
|
||||
if operator.upper() not in allowed_operators:
|
||||
raise SQLSecurityError(f"Invalid operator: '{operator}'. Allowed: {allowed_operators}")
|
||||
|
||||
if use_param:
|
||||
return f"{quoted_column} {operator} %s", True
|
||||
else:
|
||||
return f"{quoted_column} {operator}", False
|
||||
|
||||
@staticmethod
|
||||
def get_auth_context():
|
||||
"""
|
||||
Get auth_context from the context variable.
|
||||
|
||||
This retrieves the auth_context that was set by the HTTP middleware
|
||||
during request processing.
|
||||
|
||||
Returns:
|
||||
The auth_context object, or None if not available
|
||||
"""
|
||||
try:
|
||||
auth_context = auth_context_var.get()
|
||||
if auth_context:
|
||||
logger.debug(f"Retrieved auth_context from context variable")
|
||||
return auth_context
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not retrieve auth_context: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def set_auth_context(auth_context):
|
||||
"""
|
||||
Set auth_context in the context variable.
|
||||
|
||||
This is typically called by the HTTP middleware during request processing.
|
||||
|
||||
Args:
|
||||
auth_context: The auth_context object to set
|
||||
"""
|
||||
auth_context_var.set(auth_context)
|
||||
logger.debug("Set auth_context in context variable")
|
||||
|
||||
|
||||
# Convenience functions for direct use
|
||||
validate_identifier = SQLSecurityUtils.validate_identifier
|
||||
quote_identifier = SQLSecurityUtils.quote_identifier
|
||||
build_table_reference = SQLSecurityUtils.build_table_reference
|
||||
build_column_reference = SQLSecurityUtils.build_column_reference
|
||||
get_auth_context = SQLSecurityUtils.get_auth_context
|
||||
set_auth_context = SQLSecurityUtils.set_auth_context
|
||||
|
||||
147
examples/cursor/README.md
Normal file
147
examples/cursor/README.md
Normal file
@@ -0,0 +1,147 @@
|
||||
<!--
|
||||
Licensed to the Apache Software Foundation (ASF) under one
|
||||
or more contributor license agreements. See the NOTICE file
|
||||
distributed with this work for additional information
|
||||
regarding copyright ownership. The ASF licenses this file
|
||||
to you under the Apache License, Version 2.0 (the
|
||||
"License"); you may not use this file except in compliance
|
||||
with the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing,
|
||||
software distributed under the License is distributed on an
|
||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
-->
|
||||
|
||||
|
||||
# Cursor Example: Integrating Doris MCP Server
|
||||
|
||||
This guide provides step-by-step instructions on how to integrate the `doris-mcp-server` with the [Cursor](https://cursor.sh/) IDE. This integration allows you to interact with your Apache Doris database using natural language queries directly within Cursor's AI chat.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
* [Prerequisites](#prerequisites)
|
||||
* [Step 1: Set Up the Project](#step-1-set-up-the-project)
|
||||
* [Step 2: Configure the MCP Server in Cursor](#step-2-configure-the-mcp-server-in-cursor)
|
||||
* [Step 3: Verify the Integration](#step-3-verify-the-integration)
|
||||
* [Step 4: Query Your Database](#step-4-query-your-database)
|
||||
|
||||
* [Example 1: List Tables](#example-1-list-tables)
|
||||
* [Example 2: Analyze Sales Trends](#example-2-analyze-sales-trends)
|
||||
|
||||
---
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Before you begin, ensure you have the following installed and configured:
|
||||
|
||||
* The **Cursor** IDE
|
||||
* **Git** for cloning the repository
|
||||
* Access to an **Apache Doris** cluster (FE host, port, username, and password)
|
||||
* **uv**, a fast Python package installer and runner
|
||||
|
||||
You can install `uv` with one of the following commands:
|
||||
|
||||
```bash
|
||||
# For macOS (recommended)
|
||||
brew install uv
|
||||
|
||||
# For other systems using pipx
|
||||
pipx install uv
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Step 1: Set Up the Project
|
||||
|
||||
First, clone the `doris-mcp-server` repository to your local machine:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/apache/doris-mcp-server.git
|
||||
|
||||
cd doris-mcp-server
|
||||
```
|
||||
|
||||
The necessary dependencies are listed in `requirements.txt` and will be managed automatically by `uv` in the next step.
|
||||
|
||||
---
|
||||
|
||||
### Step 2: Configure the MCP Server in Cursor
|
||||
|
||||
1. Open the cloned `doris-mcp-server` directory in Cursor.
|
||||
2. Click the ⚙️ icon (top-right), then go to **Tools & Integrations**.
|
||||

|
||||
3. Click **Add a custom MCP Server**.
|
||||
4. Paste the following JSON configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"doris-mcp": {
|
||||
"command": "uv",
|
||||
"args": [
|
||||
"run",
|
||||
"--project",
|
||||
"/path/to/your/doris-mcp-server",
|
||||
"doris-mcp-server"
|
||||
],
|
||||
"env": {
|
||||
"DORIS_HOST": "your_doris_fe_host",
|
||||
"DORIS_PORT": "9030",
|
||||
"DORIS_USER": "your_username",
|
||||
"DORIS_PASSWORD": "your_password",
|
||||
"DORIS_DATABASE": "ssb"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> ⚠️ **Important:**
|
||||
>
|
||||
> * Replace `"/path/to/your/doris-mcp-server"` with the **absolute path** to your local project directory.
|
||||
> * Fill in your actual Doris FE host, username, password, and database name.
|
||||
|
||||
---
|
||||
|
||||
### Step 3: Verify the Integration
|
||||
|
||||
Once saved, go back to the **Settings** panel. If everything is configured correctly, you’ll see a green status dot next to `doris-mcp-server`, along with available tools like `exec_query`.
|
||||
|
||||

|
||||
|
||||
---
|
||||
|
||||
### Step 4: Query Your Database
|
||||
|
||||
You can now chat with Cursor Agent to run SQL queries against your Doris database.
|
||||
|
||||
1. Open the chat panel using `Cmd + K` (macOS) or `Ctrl + K` (Windows/Linux), or click the chat icon in the top-right.
|
||||
2. Switch to **Agent Mode**.
|
||||
3. Start asking questions using natural language.
|
||||
|
||||

|
||||
|
||||
---
|
||||
|
||||
#### Example 1: List Tables
|
||||
|
||||
> **Prompt:** What tables are in the `ssb` database?
|
||||
|
||||
The agent will call the `get_db_table_list` tool and return the results.
|
||||
|
||||

|
||||
|
||||
---
|
||||
|
||||
#### Example 2: Analyze Sales Trends
|
||||
|
||||
> **Prompt:** What has been the sales trend over the past ten years in the `ssb` database, and which year had the fastest growth?
|
||||
|
||||
The agent will generate an appropriate SQL query, send it to the MCP server, and interpret the results to give you growth trends and highlights.
|
||||
|
||||

|
||||
@@ -103,6 +103,9 @@ If your Dify deployment requires a publicly accessible endpoint, you can use the
|
||||
2. Select **Agent** as the template and set the **App Name** (e.g., `Doris ChatBI`).
|
||||

|
||||
|
||||
|
||||
3. Import from DSL,[dify_doris_dsl.yml](dify_doris_dsl.yml)
|
||||
|
||||
-----
|
||||
|
||||
## Instructions & Tool Configuration
|
||||
127
examples/dify/dify_doris_dsl.yml
Normal file
127
examples/dify/dify_doris_dsl.yml
Normal file
@@ -0,0 +1,127 @@
|
||||
app:
|
||||
description: ''
|
||||
icon: 🤖
|
||||
icon_background: '#FFEAD5'
|
||||
mode: agent-chat
|
||||
name: doris
|
||||
use_icon_as_answer_icon: false
|
||||
dependencies:
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: langgenius/deepseek:0.0.5@21408d5c48cd9f18d66b08883d0999fe89e6d049c891324c2229dea23b9665d5
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: junjiem/mcp_sse:0.2.1@53cc613667fcf91dd7208dd5f6d2c8df3c7ff0af8b79e8f3c0a430f1b39bda4c
|
||||
kind: app
|
||||
model_config:
|
||||
agent_mode:
|
||||
enabled: true
|
||||
max_iteration: 10
|
||||
prompt: null
|
||||
strategy: function_call
|
||||
tools:
|
||||
- enabled: true
|
||||
isDeleted: false
|
||||
notAuthor: false
|
||||
provider_id: junjiem/mcp_sse/mcp_sse
|
||||
provider_name: junjiem/mcp_sse/mcp_sse
|
||||
provider_type: builtin
|
||||
tool_label: 获取 MCP 工具列表
|
||||
tool_name: mcp_sse_list_tools
|
||||
tool_parameters:
|
||||
prompts_as_tools: 1
|
||||
resources_as_tools: 1
|
||||
servers_config: null
|
||||
- enabled: true
|
||||
isDeleted: false
|
||||
notAuthor: false
|
||||
provider_id: junjiem/mcp_sse/mcp_sse
|
||||
provider_name: junjiem/mcp_sse/mcp_sse
|
||||
provider_type: builtin
|
||||
tool_label: 调用 MCP 工具
|
||||
tool_name: mcp_sse_call_tool
|
||||
tool_parameters:
|
||||
arguments: ''
|
||||
prompts_as_tools: ''
|
||||
resources_as_tools: ''
|
||||
servers_config: ''
|
||||
tool_name: ''
|
||||
annotation_reply:
|
||||
enabled: false
|
||||
chat_prompt_config: {}
|
||||
completion_prompt_config: {}
|
||||
dataset_configs:
|
||||
datasets:
|
||||
datasets: []
|
||||
reranking_enable: true
|
||||
reranking_mode: reranking_model
|
||||
reranking_model:
|
||||
reranking_model_name: ''
|
||||
reranking_provider_name: ''
|
||||
retrieval_model: multiple
|
||||
top_k: 4
|
||||
dataset_query_variable: ''
|
||||
external_data_tools: []
|
||||
file_upload:
|
||||
allowed_file_extensions:
|
||||
- .JPG
|
||||
- .JPEG
|
||||
- .PNG
|
||||
- .GIF
|
||||
- .WEBP
|
||||
- .SVG
|
||||
- .MP4
|
||||
- .MOV
|
||||
- .MPEG
|
||||
- .WEBM
|
||||
allowed_file_types: []
|
||||
allowed_file_upload_methods:
|
||||
- remote_url
|
||||
- local_file
|
||||
enabled: false
|
||||
image:
|
||||
detail: high
|
||||
enabled: false
|
||||
number_limits: 3
|
||||
transfer_methods:
|
||||
- remote_url
|
||||
- local_file
|
||||
number_limits: 3
|
||||
model:
|
||||
completion_params:
|
||||
stop: []
|
||||
mode: chat
|
||||
name: deepseek-chat
|
||||
provider: langgenius/deepseek/deepseek
|
||||
more_like_this:
|
||||
enabled: false
|
||||
opening_statement: ''
|
||||
pre_prompt: "<instruction>\nUse MCP tools to complete tasks as much as possible.\
|
||||
\ Carefully read the annotations, method names, and parameter descriptions of\
|
||||
\ each tool. Please follow these steps:\n1. Analyze the user's question and match\
|
||||
\ the most appropriate tool.\n2. Use tool names and parameters exactly as defined;\
|
||||
\ do not invent new ones.\n3. Pass parameters in the required JSON format.\n4.\
|
||||
\ When calling tools, use:\n {\"mcp_sse_call_tool\": {\"tool_name\": \"<tool_name>\"\
|
||||
, \"arguments\": \"{}\"}}\n5. Output plain text only—no XML tags.\n<input>\nUser\
|
||||
\ question: user_query\n</input>\n<output>\nReturn tool results or a final answer,\
|
||||
\ including analysis.\n</output>\n</instruction>"
|
||||
prompt_type: simple
|
||||
retriever_resource:
|
||||
enabled: true
|
||||
sensitive_word_avoidance:
|
||||
configs: []
|
||||
enabled: false
|
||||
type: ''
|
||||
speech_to_text:
|
||||
enabled: false
|
||||
suggested_questions: []
|
||||
suggested_questions_after_answer:
|
||||
enabled: false
|
||||
text_to_speech:
|
||||
enabled: false
|
||||
language: ''
|
||||
voice: ''
|
||||
user_input_form: []
|
||||
version: 0.3.0
|
||||
BIN
examples/images/cursor_add_mcp.png
Normal file
BIN
examples/images/cursor_add_mcp.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 323 KiB |
BIN
examples/images/cursor_agent.png
Normal file
BIN
examples/images/cursor_agent.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 673 KiB |
BIN
examples/images/cursor_ask1.png
Normal file
BIN
examples/images/cursor_ask1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 118 KiB |
BIN
examples/images/cursor_ask2.png
Normal file
BIN
examples/images/cursor_ask2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 232 KiB |
BIN
examples/images/cursor_doris-mcp.png
Normal file
BIN
examples/images/cursor_doris-mcp.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 80 KiB |
@@ -20,7 +20,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "doris-mcp-server"
|
||||
version = "0.4.2"
|
||||
version = "0.6.1"
|
||||
description = "Enterprise-grade Model Context Protocol (MCP) server implementation for Apache Doris"
|
||||
authors = [
|
||||
{name = "Yijia Su", email = "freeoneplus@apache.org"}
|
||||
@@ -46,6 +46,10 @@ dependencies = [
|
||||
# Database drivers
|
||||
"aiomysql>=0.2.0",
|
||||
"PyMySQL>=1.1.0",
|
||||
# ADBC (Arrow Flight SQL) dependencies
|
||||
"adbc-driver-manager>=0.8.0",
|
||||
"adbc-driver-flightsql>=0.8.0",
|
||||
"pyarrow>=14.0.0",
|
||||
# Async and utility libraries
|
||||
"asyncio-mqtt>=0.16.0",
|
||||
"aiofiles>=23.0.0",
|
||||
|
||||
@@ -5,6 +5,9 @@
|
||||
mcp>=1.8.0,<2.0.0
|
||||
aiomysql>=0.2.0
|
||||
PyMySQL>=1.1.0
|
||||
adbc-driver-manager>=0.8.0
|
||||
adbc-driver-flightsql>=0.8.0
|
||||
pyarrow>=14.0.0
|
||||
asyncio-mqtt>=0.16.0
|
||||
aiofiles>=23.0.0
|
||||
aiohttp>=3.9.0
|
||||
|
||||
@@ -64,9 +64,11 @@ else
|
||||
fi
|
||||
|
||||
# Set HTTP-specific environment variables
|
||||
# FIX for Issue #62 Bug 4: Use SERVER_PORT instead of MCP_PORT for consistency with code
|
||||
export MCP_TRANSPORT_TYPE="http"
|
||||
export MCP_HOST="${MCP_HOST:-0.0.0.0}"
|
||||
export MCP_PORT="${MCP_PORT:-3000}"
|
||||
export SERVER_PORT="${SERVER_PORT:-3000}" # Changed from MCP_PORT to SERVER_PORT
|
||||
export WORKERS="${WORKERS:-1}"
|
||||
export ALLOWED_ORIGINS="${ALLOWED_ORIGINS:-*}"
|
||||
export LOG_LEVEL="${LOG_LEVEL:-info}"
|
||||
export MCP_ALLOW_CREDENTIALS="${MCP_ALLOW_CREDENTIALS:-false}"
|
||||
@@ -76,14 +78,15 @@ export MCP_DEBUG_ADAPTER="true"
|
||||
export PYTHONPATH="$(pwd):$PYTHONPATH"
|
||||
|
||||
echo -e "${GREEN}Starting MCP server (Streamable HTTP mode)...${NC}"
|
||||
echo -e "${YELLOW}Service will run on http://${MCP_HOST}:${MCP_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Health Check: http://${MCP_HOST}:${MCP_PORT}/health${NC}"
|
||||
echo -e "${YELLOW}MCP Endpoint: http://${MCP_HOST}:${MCP_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Local access: http://localhost:${MCP_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Service will run on http://${MCP_HOST}:${SERVER_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Health Check: http://${MCP_HOST}:${SERVER_PORT}/health${NC}"
|
||||
echo -e "${YELLOW}MCP Endpoint: http://${MCP_HOST}:${SERVER_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Local access: http://localhost:${SERVER_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Workers: ${WORKERS}${NC}"
|
||||
echo -e "${YELLOW}Use Ctrl+C to stop the service${NC}"
|
||||
|
||||
# Start the server in HTTP mode (Streamable HTTP)
|
||||
python -m doris_mcp_server.main --transport http --host ${MCP_HOST} --port ${MCP_PORT}
|
||||
python -m doris_mcp_server.main --transport http --host ${MCP_HOST} --port ${SERVER_PORT} --workers ${WORKERS}
|
||||
|
||||
# Check exit status
|
||||
if [ $? -ne 0 ]; then
|
||||
@@ -95,4 +98,4 @@ fi
|
||||
echo -e "${YELLOW}Tip: If the page displays abnormally, please clear your browser cache or use incognito mode${NC}"
|
||||
echo -e "${YELLOW}Chrome browser clear cache shortcut: Ctrl+Shift+Del (Windows) or Cmd+Shift+Del (Mac)${NC}"
|
||||
echo -e "${CYAN}For testing HTTP endpoints, you can use:${NC}"
|
||||
echo -e "${CYAN} curl -X POST http://localhost:${MCP_PORT}/mcp -H 'Content-Type: application/json' -d '{\"method\":\"tools/list\"}'${NC}"
|
||||
echo -e "${CYAN} curl -X POST http://localhost:${SERVER_PORT}/mcp -H 'Content-Type: application/json' -d '{\"method\":\"tools/list\"}'${NC}"
|
||||
@@ -59,7 +59,6 @@ def test_config():
|
||||
config.database.password = "test_password"
|
||||
config.database.database = "test_db"
|
||||
config.database.health_check_interval = 60
|
||||
config.database.min_connections = 5
|
||||
config.database.max_connections = 20
|
||||
config.database.connection_timeout = 30
|
||||
config.database.max_connection_age = 3600
|
||||
|
||||
@@ -34,7 +34,7 @@ class TestEndToEndIntegration:
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Create mock configuration"""
|
||||
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
|
||||
from doris_mcp_server.utils.config import ADBCConfig, DatabaseConfig, SecurityConfig
|
||||
|
||||
config = Mock(spec=DorisConfig)
|
||||
|
||||
@@ -46,7 +46,6 @@ class TestEndToEndIntegration:
|
||||
config.database.password = "test_password"
|
||||
config.database.database = "test_db"
|
||||
config.database.health_check_interval = 60
|
||||
config.database.min_connections = 5
|
||||
config.database.max_connections = 20
|
||||
config.database.connection_timeout = 30
|
||||
config.database.max_connection_age = 3600
|
||||
@@ -57,6 +56,11 @@ class TestEndToEndIntegration:
|
||||
config.security.auth_type = "token"
|
||||
config.security.token_secret = "test_secret"
|
||||
config.security.token_expiry = 3600
|
||||
config.security.blocked_keywords = ["DROP"]
|
||||
|
||||
# Add adbc config
|
||||
config.adbc = Mock(spec=ADBCConfig)
|
||||
config.adbc.enabled = True
|
||||
|
||||
return config
|
||||
|
||||
@@ -231,7 +235,7 @@ class TestEndToEndIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_with_security(self, doris_server):
|
||||
"""Test tool execution with security checks"""
|
||||
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
with patch.object(doris_server.tools_manager.connection_manager, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [{"Database": "test_db"}]
|
||||
|
||||
# Test tool execution through tools manager
|
||||
@@ -258,7 +262,7 @@ class TestEndToEndIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_performance_monitoring_integration(self, doris_server):
|
||||
"""Test performance monitoring integration"""
|
||||
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
with patch.object(doris_server.tools_manager.connection_manager, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [
|
||||
{
|
||||
"query_count": 1500,
|
||||
|
||||
367
test/security/test_sql_injection.py
Normal file
367
test/security/test_sql_injection.py
Normal file
@@ -0,0 +1,367 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
SQL Security Test Suite for Apache Doris MCP Server
|
||||
|
||||
Tests for:
|
||||
1. SQL injection prevention via identifier validation
|
||||
2. Multi-statement SQL parsing in security validator
|
||||
3. auth_context enforcement
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
|
||||
class TestSQLSecurityUtils:
|
||||
"""Test cases for sql_security_utils module"""
|
||||
|
||||
def test_validate_identifier_accepts_valid_names(self):
|
||||
"""Test that valid identifiers are accepted"""
|
||||
from doris_mcp_server.utils.sql_security_utils import validate_identifier
|
||||
|
||||
valid_names = [
|
||||
"users",
|
||||
"my_table",
|
||||
"Table123",
|
||||
"_private_table",
|
||||
"CamelCaseTable",
|
||||
"table_with_numbers_123",
|
||||
]
|
||||
|
||||
for name in valid_names:
|
||||
result = validate_identifier(name, "table")
|
||||
assert result == name, f"Valid identifier '{name}' should be accepted"
|
||||
|
||||
def test_validate_identifier_rejects_sql_injection(self):
|
||||
"""Test that SQL injection attempts are rejected"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
injection_attempts = [
|
||||
# Basic SQL injection
|
||||
"'; DROP TABLE users; --",
|
||||
"table' OR '1'='1",
|
||||
"table'; DELETE FROM users; --",
|
||||
|
||||
# Union-based injection
|
||||
"table' UNION SELECT * FROM passwords --",
|
||||
|
||||
# Comment injection
|
||||
"table/**/OR/**/1=1",
|
||||
"table--comment",
|
||||
|
||||
# Special characters
|
||||
"table`; DROP TABLE users;",
|
||||
'table"; DROP TABLE users;',
|
||||
"table\"; DELETE FROM",
|
||||
|
||||
# Backtick escape attempt
|
||||
"analytics`; SELECT * FROM sensitive_table;--",
|
||||
|
||||
# Whitespace injection
|
||||
"table name with spaces",
|
||||
"table\ttab",
|
||||
"table\nnewline",
|
||||
]
|
||||
|
||||
for injection in injection_attempts:
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(injection, "table")
|
||||
|
||||
def test_validate_identifier_rejects_empty(self):
|
||||
"""Test that empty identifiers are rejected"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier("", "table")
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(None, "table")
|
||||
|
||||
def test_validate_identifier_rejects_too_long(self):
|
||||
"""Test that identifiers exceeding max length are rejected"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
# Doris identifier max length is typically 64 characters
|
||||
long_name = "a" * 100
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(long_name, "table")
|
||||
|
||||
def test_quote_identifier_adds_backticks(self):
|
||||
"""Test that quote_identifier properly escapes identifiers"""
|
||||
from doris_mcp_server.utils.sql_security_utils import quote_identifier
|
||||
|
||||
assert quote_identifier("my_table", "table") == "`my_table`"
|
||||
assert quote_identifier("users", "table") == "`users`"
|
||||
assert quote_identifier("Table123", "table") == "`Table123`"
|
||||
|
||||
def test_quote_identifier_validates_first(self):
|
||||
"""Test that quote_identifier validates before quoting"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
quote_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
quote_identifier("'; DROP TABLE users; --", "table")
|
||||
|
||||
|
||||
class TestSQLSecurityValidator:
|
||||
"""Test cases for SQLSecurityValidator multi-statement parsing"""
|
||||
|
||||
@pytest.fixture
|
||||
def dict_config(self):
|
||||
"""Create dictionary configuration"""
|
||||
return {
|
||||
"blocked_keywords": [
|
||||
"DROP", "CREATE", "ALTER", "TRUNCATE",
|
||||
"DELETE", "INSERT", "UPDATE",
|
||||
"GRANT", "REVOKE", "EXEC", "EXECUTE"
|
||||
],
|
||||
"max_query_complexity": 100,
|
||||
"enable_security_check": True
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_context(self):
|
||||
"""Create mock auth context"""
|
||||
from doris_mcp_server.utils.security import AuthContext, SecurityLevel
|
||||
return AuthContext(
|
||||
user_id="test_user",
|
||||
roles=["user"],
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validates_all_statements(self, dict_config, mock_auth_context):
|
||||
"""Test that validator checks ALL SQL statements, not just the first"""
|
||||
from doris_mcp_server.utils.security import SQLSecurityValidator
|
||||
|
||||
validator = SQLSecurityValidator(dict_config)
|
||||
|
||||
# Multi-statement with injection in second statement
|
||||
# This should be BLOCKED
|
||||
malicious_sql = "SELECT 1; DROP TABLE users; SELECT 2"
|
||||
|
||||
result = await validator.validate(malicious_sql, mock_auth_context)
|
||||
|
||||
assert not result.is_valid, "Multi-statement injection should be blocked"
|
||||
# Check for either DROP keyword detection or SQL injection detection
|
||||
error_upper = result.error_message.upper()
|
||||
assert ("DROP" in error_upper or
|
||||
"INJECTION" in error_upper or
|
||||
"BLOCKED" in error_upper), f"Expected DROP/injection/blocked in: {result.error_message}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocks_hidden_dangerous_statement(self, dict_config, mock_auth_context):
|
||||
"""Test that dangerous statements hidden after safe ones are blocked"""
|
||||
from doris_mcp_server.utils.security import SQLSecurityValidator
|
||||
|
||||
validator = SQLSecurityValidator(dict_config)
|
||||
|
||||
# Safe statement followed by dangerous one
|
||||
malicious_sql = """
|
||||
SELECT * FROM users WHERE id = 1;
|
||||
DELETE FROM audit_log;
|
||||
SELECT 1;
|
||||
"""
|
||||
|
||||
result = await validator.validate(malicious_sql, mock_auth_context)
|
||||
|
||||
assert not result.is_valid, "Hidden DELETE statement should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_safe_multi_statement(self, dict_config, mock_auth_context):
|
||||
"""Test that multiple safe SELECT statements are allowed"""
|
||||
from doris_mcp_server.utils.security import SQLSecurityValidator
|
||||
|
||||
validator = SQLSecurityValidator(dict_config)
|
||||
|
||||
safe_sql = """
|
||||
SELECT * FROM users;
|
||||
SELECT COUNT(*) FROM orders;
|
||||
SELECT id, name FROM products;
|
||||
"""
|
||||
|
||||
result = await validator.validate(safe_sql, mock_auth_context)
|
||||
|
||||
assert result.is_valid, f"Multiple safe SELECT statements should be allowed, got: {result.error_message}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_switch_injection_blocked(self, dict_config, mock_auth_context):
|
||||
"""Test that context switch SQL injection is blocked"""
|
||||
from doris_mcp_server.utils.security import SQLSecurityValidator
|
||||
|
||||
validator = SQLSecurityValidator(dict_config)
|
||||
|
||||
# Simulating the exec_query_for_mcp attack vector
|
||||
injected_sql = """
|
||||
USE `analytics`; SELECT * FROM sensitive_table;-- `;
|
||||
SELECT * FROM public_table;
|
||||
"""
|
||||
|
||||
result = await validator.validate(injected_sql, mock_auth_context)
|
||||
|
||||
# The validator should process all statements
|
||||
# Even if USE is allowed, subsequent unauthorized access should be caught
|
||||
# by table access checks (if configured)
|
||||
|
||||
|
||||
class TestExecQueryForMCP:
|
||||
"""Test cases for exec_query_for_mcp function"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_malicious_db_name(self):
|
||||
"""Test that malicious db_name is rejected"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
# The attack vector from security report
|
||||
malicious_db_name = "analytics`; SELECT * FROM sensitive_table;--"
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(malicious_db_name, "database name")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_malicious_catalog_name(self):
|
||||
"""Test that malicious catalog_name is rejected"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
malicious_catalog_name = "internal'; DROP DATABASE production;--"
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(malicious_catalog_name, "catalog name")
|
||||
|
||||
|
||||
class TestDependencyAnalysisTools:
|
||||
"""Test cases for dependency_analysis_tools security fixes"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tables_metadata_rejects_injection(self):
|
||||
"""Test that _get_tables_metadata rejects SQL injection"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
# The attack vector from security report
|
||||
injection_db_name = "test_db' OR '1'='1' --"
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(injection_db_name, "database name")
|
||||
|
||||
|
||||
class TestAuthContextEnforcement:
|
||||
"""Test cases for auth_context enforcement"""
|
||||
|
||||
def test_execute_requires_auth_context_for_security(self):
|
||||
"""Test that security checks require auth_context"""
|
||||
# This test documents the expected behavior:
|
||||
# When auth_context is None, security checks are skipped
|
||||
# When auth_context is provided, security checks are performed
|
||||
|
||||
# The fix ensures all execute() calls pass auth_context
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_auth_context_returns_context(self):
|
||||
"""Test that get_auth_context retrieves context from ContextVar"""
|
||||
from doris_mcp_server.utils.sql_security_utils import get_auth_context
|
||||
|
||||
# When no context is set, should return None
|
||||
result = get_auth_context()
|
||||
# This is expected - context is set by HTTP middleware
|
||||
assert result is None or hasattr(result, 'user_id')
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Integration test scenarios for security fixes"""
|
||||
|
||||
def test_attack_scenario_1_permission_bypass(self):
|
||||
"""
|
||||
Attack Scenario 1: Permission Bypass in Multi-Tenant Environment
|
||||
|
||||
Expected: User can only query their own database (db_name="tenant_a_db")
|
||||
Attack: Inject "tenant_a_db' OR '1'='1' --" to query ALL databases
|
||||
Result: Should be BLOCKED by validate_identifier()
|
||||
"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier("tenant_a_db' OR '1'='1' --", "database name")
|
||||
|
||||
def test_attack_scenario_2_union_injection(self):
|
||||
"""
|
||||
Attack Scenario 2: UNION-based Information Disclosure
|
||||
|
||||
Attack: Inject UNION SELECT to extract sensitive data
|
||||
Result: Should be BLOCKED
|
||||
"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(
|
||||
"test' UNION SELECT password FROM users --",
|
||||
"database name"
|
||||
)
|
||||
|
||||
def test_attack_scenario_3_backtick_escape(self):
|
||||
"""
|
||||
Attack Scenario 3: Backtick Escape Attempt
|
||||
|
||||
Attack: Use backticks to break out of quoted identifier
|
||||
Result: Should be BLOCKED
|
||||
"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(
|
||||
"analytics`; SELECT * FROM sensitive_table;--",
|
||||
"database name"
|
||||
)
|
||||
|
||||
|
||||
# Run tests with: pytest tests/test_sql_security.py -v
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
|
||||
871
test/security/test_sql_injection_api.py
Normal file
871
test/security/test_sql_injection_api.py
Normal file
@@ -0,0 +1,871 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
SQL Injection API Integration Tests
|
||||
|
||||
This module tests SQL injection prevention through the MCP HTTP API.
|
||||
It sends malicious payloads and verifies they are properly blocked.
|
||||
|
||||
Prerequisites:
|
||||
- MCP server running on localhost:3000
|
||||
- Run with: pytest test/security/test_sql_injection_api.py -v
|
||||
|
||||
Usage:
|
||||
# Start server first
|
||||
bash start_server.sh
|
||||
|
||||
# Run tests
|
||||
pytest test/security/test_sql_injection_api.py -v --no-cov
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import httpx
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# Server configuration
|
||||
MCP_BASE_URL = "http://localhost:3000"
|
||||
MCP_ENDPOINT = f"{MCP_BASE_URL}/mcp"
|
||||
HEALTH_ENDPOINT = f"{MCP_BASE_URL}/health"
|
||||
TIMEOUT = 30.0
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""Simple MCP HTTP client for testing"""
|
||||
|
||||
def __init__(self, base_url: str = MCP_BASE_URL):
|
||||
self.base_url = base_url
|
||||
self.mcp_endpoint = f"{base_url}/mcp"
|
||||
self.session_id: Optional[str] = None
|
||||
self.request_id = 0
|
||||
self.client = httpx.AsyncClient(timeout=TIMEOUT)
|
||||
|
||||
async def close(self):
|
||||
await self.client.aclose()
|
||||
|
||||
def _next_id(self) -> int:
|
||||
self.request_id += 1
|
||||
return self.request_id
|
||||
|
||||
async def initialize(self) -> dict:
|
||||
"""Initialize MCP session"""
|
||||
response = await self.client.post(
|
||||
self.mcp_endpoint,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream"
|
||||
},
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {
|
||||
"name": "sql-injection-test",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
},
|
||||
"id": self._next_id()
|
||||
}
|
||||
)
|
||||
|
||||
# Extract session ID from response header
|
||||
self.session_id = response.headers.get("mcp-session-id")
|
||||
return self._parse_response(response.text)
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
|
||||
"""Call an MCP tool"""
|
||||
if not self.session_id:
|
||||
await self.initialize()
|
||||
|
||||
response = await self.client.post(
|
||||
self.mcp_endpoint,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"mcp-session-id": self.session_id
|
||||
},
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
},
|
||||
"id": self._next_id()
|
||||
}
|
||||
)
|
||||
|
||||
return self._parse_response(response.text)
|
||||
|
||||
def _parse_response(self, text: str) -> dict:
|
||||
"""Parse JSON response"""
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
# Try SSE format
|
||||
lines = text.strip().split("\n")
|
||||
for line in lines:
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
return json.loads(line[6:])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return {"raw": text}
|
||||
|
||||
|
||||
def print_result(test_name: str, payload: dict, result: dict):
|
||||
"""Print test result in a readable format"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"TEST: {test_name}")
|
||||
print(f"{'='*60}")
|
||||
print(f"PAYLOAD: {json.dumps(payload, ensure_ascii=False)}")
|
||||
print(f"{'-'*60}")
|
||||
|
||||
# Extract inner result content
|
||||
if "result" in result and "content" in result.get("result", {}):
|
||||
for item in result["result"]["content"]:
|
||||
if item.get("type") == "text":
|
||||
try:
|
||||
inner = json.loads(item["text"])
|
||||
print("RESPONSE:")
|
||||
print(f" success: {inner.get('success')}")
|
||||
if inner.get('error'):
|
||||
print(f" error: {inner.get('error')}")
|
||||
if inner.get('error_type'):
|
||||
print(f" error_type: {inner.get('error_type')}")
|
||||
if inner.get('risk_level'):
|
||||
print(f" risk_level: {inner.get('risk_level')}")
|
||||
if inner.get('message'):
|
||||
print(f" message: {inner.get('message')}")
|
||||
if inner.get('data') is not None and inner.get('success'):
|
||||
data_str = json.dumps(inner.get('data'), ensure_ascii=False)
|
||||
if len(data_str) > 200:
|
||||
data_str = data_str[:200] + "..."
|
||||
print(f" data: {data_str}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
print(f"RESPONSE (raw): {item.get('text', '')[:500]}")
|
||||
elif "error" in result:
|
||||
print(f"RESPONSE ERROR: {result['error']}")
|
||||
else:
|
||||
print(f"RESPONSE (raw): {json.dumps(result, ensure_ascii=False)[:500]}")
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
class TestSQLInjectionAPI:
|
||||
"""Test SQL injection prevention through MCP API"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.fixture
|
||||
def is_server_running(self):
|
||||
"""Check if MCP server is running"""
|
||||
import httpx
|
||||
try:
|
||||
response = httpx.get(HEALTH_ENDPOINT, timeout=5.0)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_health(self):
|
||||
"""Test that MCP server is running and healthy"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(HEALTH_ENDPOINT)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_with_drop_injection(self, mcp_client):
|
||||
"""Test exec_query rejects DROP TABLE injection"""
|
||||
# Classic SQL injection: append DROP TABLE
|
||||
payload = {"sql": "SELECT * FROM users; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("DROP TABLE Injection", payload, result)
|
||||
|
||||
# Should return error, not execute the DROP
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"DROP TABLE injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_with_union_injection(self, mcp_client):
|
||||
"""Test exec_query blocks UNION-based injection attempts"""
|
||||
# UNION injection to extract data from other tables
|
||||
payload = {"sql": "SELECT id FROM users UNION SELECT password FROM admin_users"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("UNION Injection", payload, result)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_with_delete_injection(self, mcp_client):
|
||||
"""Test exec_query rejects DELETE injection"""
|
||||
payload = {"sql": "SELECT 1; DELETE FROM users WHERE 1=1; SELECT 2"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("DELETE Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"DELETE injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_with_update_injection(self, mcp_client):
|
||||
"""Test exec_query rejects UPDATE injection"""
|
||||
payload = {"sql": "SELECT 1; UPDATE users SET role='admin' WHERE id=1; --"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("UPDATE Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"UPDATE injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_db_name_injection(self, mcp_client):
|
||||
"""Test exec_query rejects SQL injection via db_name parameter"""
|
||||
# Attack vector: inject SQL via db_name parameter
|
||||
payload = {"sql": "SELECT 1", "db_name": "test'; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("db_name Parameter Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"db_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_catalog_name_injection(self, mcp_client):
|
||||
"""Test exec_query rejects SQL injection via catalog_name parameter"""
|
||||
# Attack vector: inject SQL via catalog_name parameter
|
||||
payload = {"sql": "SELECT 1", "catalog_name": "internal`; SELECT * FROM mysql.user; --"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("catalog_name Parameter Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"catalog_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_schema_injection(self, mcp_client):
|
||||
"""Test get_table_schema rejects SQL injection via table_name"""
|
||||
# Attack vector: inject SQL via table_name parameter
|
||||
payload = {"table_name": "users'; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||
print_result("table_name Injection (get_table_schema)", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_schema_db_injection(self, mcp_client):
|
||||
"""Test get_table_schema rejects SQL injection via db_name"""
|
||||
payload = {"table_name": "users", "db_name": "test' OR '1'='1"}
|
||||
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||
print_result("db_name Injection (get_table_schema)", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"db_name injection in get_table_schema should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_dependencies_injection(self, mcp_client):
|
||||
"""Test analyze_dependencies rejects SQL injection"""
|
||||
# This was the original vulnerability reported
|
||||
payload = {"table_name": "users", "db_name": "test_db' OR '1'='1' --"}
|
||||
result = await mcp_client.call_tool("analyze_dependencies", payload)
|
||||
print_result("analyze_dependencies Injection (Original Report)", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"analyze_dependencies db_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stacked_queries_injection(self, mcp_client):
|
||||
"""Test that stacked queries (multiple statements) are blocked"""
|
||||
# Multiple statements injection
|
||||
payload = {"sql": "SELECT * FROM users WHERE id = 1; INSERT INTO audit_log VALUES (NULL, 'hacked', NOW()); SELECT 1;"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Stacked Queries (INSERT) Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"Stacked queries with INSERT should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_comment_based_injection(self, mcp_client):
|
||||
"""Test that comment-based injection is blocked"""
|
||||
# Using comments to bypass filters
|
||||
payload = {"sql": "SELECT * FROM users WHERE id = 1/**/OR/**/1=1"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Comment-based Injection", payload, result)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hex_encoded_injection(self, mcp_client):
|
||||
"""Test that hex-encoded injection attempts are handled"""
|
||||
# Hex-encoded 'DROP' attempt
|
||||
payload = {"sql": "SELECT 0x44524F50205441424C4520757365727320"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Hex Encoded Injection", payload, result)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backtick_escape_injection(self, mcp_client):
|
||||
"""Test backtick escape injection is blocked"""
|
||||
# Attempt to escape backtick quoting
|
||||
payload = {"sql": "SELECT 1", "db_name": "analytics`; SELECT * FROM sensitive_table;--"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Backtick Escape Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"Backtick escape injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_query_succeeds(self, mcp_client):
|
||||
"""Test that valid queries still work"""
|
||||
# Simple valid query should work
|
||||
payload = {"sql": "SELECT 1 AS test_value"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Valid Query (should succeed)", payload, result)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_show_databases(self, mcp_client):
|
||||
"""Test that SHOW DATABASES works"""
|
||||
payload = {"sql": "SHOW DATABASES"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("SHOW DATABASES (should succeed)", payload, result)
|
||||
|
||||
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||
"""Check if result indicates blocked or error"""
|
||||
if not result:
|
||||
return True
|
||||
|
||||
# Check for JSON-RPC error
|
||||
if "error" in result:
|
||||
return True
|
||||
|
||||
# Check for error in result content
|
||||
if "result" in result:
|
||||
result_content = result.get("result", {})
|
||||
if isinstance(result_content, dict):
|
||||
# Check for isError flag
|
||||
if result_content.get("isError"):
|
||||
return True
|
||||
# Check content array for error messages
|
||||
content = result_content.get("content", [])
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
text = item.get("text", "")
|
||||
# Parse the JSON text content
|
||||
try:
|
||||
text_data = json.loads(text)
|
||||
# Check for success: false
|
||||
if text_data.get("success") is False:
|
||||
return True
|
||||
# Check for error field
|
||||
if text_data.get("error"):
|
||||
return True
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
# Check text for security keywords
|
||||
if any(keyword in text.lower() for keyword in [
|
||||
"error", "blocked", "invalid", "security",
|
||||
"injection", "denied", "forbidden", "not allowed",
|
||||
"security_violation", "risk_level"
|
||||
]):
|
||||
return True
|
||||
|
||||
# Check raw text response
|
||||
raw = result.get("raw", "")
|
||||
if isinstance(raw, str) and any(keyword in raw.lower() for keyword in [
|
||||
"error", "blocked", "invalid", "security"
|
||||
]):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class TestIdentifierInjectionAPI:
|
||||
"""Test identifier-based SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_name_with_semicolon(self, mcp_client):
|
||||
"""Test table name containing semicolon is rejected"""
|
||||
payload = {"table_name": "users; DROP TABLE users"}
|
||||
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||
print_result("Table Name with Semicolon", payload, result)
|
||||
|
||||
# Should be blocked by identifier validation
|
||||
assert self._contains_error_indicator(result), \
|
||||
f"Table name with semicolon should be rejected"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_name_with_quotes(self, mcp_client):
|
||||
"""Test table name containing quotes is rejected"""
|
||||
payload = {"table_name": "users' OR '1'='1"}
|
||||
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||
print_result("Table Name with Quotes", payload, result)
|
||||
|
||||
assert self._contains_error_indicator(result), \
|
||||
f"Table name with quotes should be rejected"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_db_name_with_special_chars(self, mcp_client):
|
||||
"""Test database name with special characters is rejected"""
|
||||
special_chars = [
|
||||
"test;db",
|
||||
"test'db",
|
||||
"test\"db",
|
||||
"test`db",
|
||||
"test--db",
|
||||
"test/*db*/",
|
||||
]
|
||||
|
||||
for db_name in special_chars:
|
||||
payload = {"table_name": "users", "db_name": db_name}
|
||||
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||
print_result(f"Special Char in db_name: {db_name}", payload, result)
|
||||
|
||||
assert self._contains_error_indicator(result), \
|
||||
f"db_name '{db_name}' should be rejected"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_identifiers_accepted(self, mcp_client):
|
||||
"""Test that valid identifiers are accepted"""
|
||||
valid_names = [
|
||||
"users",
|
||||
"my_table",
|
||||
"Table123",
|
||||
"_internal_table",
|
||||
]
|
||||
|
||||
for table_name in valid_names:
|
||||
payload = {"table_name": table_name}
|
||||
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||
print_result(f"Valid Identifier: {table_name}", payload, result)
|
||||
|
||||
def _contains_error_indicator(self, result: dict) -> bool:
|
||||
"""Check if result contains error indicators"""
|
||||
if not result:
|
||||
return True
|
||||
|
||||
# Check for JSON-RPC error
|
||||
if "error" in result:
|
||||
return True
|
||||
|
||||
# Check result content
|
||||
result_str = json.dumps(result).lower()
|
||||
error_keywords = [
|
||||
"error", "invalid", "illegal", "blocked",
|
||||
"security", "injection", "denied", "forbidden"
|
||||
]
|
||||
|
||||
return any(keyword in result_str for keyword in error_keywords)
|
||||
|
||||
|
||||
class TestMultiStatementInjectionAPI:
|
||||
"""Test multi-statement SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hidden_drop_after_select(self, mcp_client):
|
||||
"""Test DROP hidden after legitimate SELECT is blocked"""
|
||||
payload = {"sql": "SELECT id, name FROM users WHERE status = 'active'; DROP TABLE audit_log; SELECT 1;"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Hidden DROP after SELECT", payload, result)
|
||||
|
||||
assert self._is_dangerous_blocked(result), \
|
||||
f"Hidden DROP statement should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hidden_truncate_after_select(self, mcp_client):
|
||||
"""Test TRUNCATE hidden after SELECT is blocked"""
|
||||
payload = {"sql": "SELECT 1; TRUNCATE TABLE users"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Hidden TRUNCATE after SELECT", payload, result)
|
||||
|
||||
assert self._is_dangerous_blocked(result), \
|
||||
f"Hidden TRUNCATE should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hidden_grant_after_select(self, mcp_client):
|
||||
"""Test GRANT hidden after SELECT is blocked"""
|
||||
payload = {"sql": "SELECT 1; GRANT ALL ON *.* TO 'hacker'@'%'"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Hidden GRANT after SELECT", payload, result)
|
||||
|
||||
assert self._is_dangerous_blocked(result), \
|
||||
f"Hidden GRANT should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_safe_selects_allowed(self, mcp_client):
|
||||
"""Test that multiple SELECT statements may be allowed"""
|
||||
payload = {"sql": "SELECT 1; SELECT 2; SELECT 3;"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Multiple Safe SELECTs", payload, result)
|
||||
|
||||
def _is_dangerous_blocked(self, result: dict) -> bool:
|
||||
"""Check if dangerous operation was blocked"""
|
||||
if not result:
|
||||
return True
|
||||
|
||||
# Check for error
|
||||
if "error" in result:
|
||||
return True
|
||||
|
||||
# Check result content for blocking indicators
|
||||
result_str = json.dumps(result).lower()
|
||||
block_indicators = [
|
||||
"drop", "truncate", "grant", "revoke",
|
||||
"blocked", "denied", "forbidden", "not allowed",
|
||||
"security", "error"
|
||||
]
|
||||
|
||||
return any(indicator in result_str for indicator in block_indicators)
|
||||
|
||||
|
||||
class TestADBCQueryInjectionAPI:
|
||||
"""Test ADBC query SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_adbc_query_drop_injection(self, mcp_client):
|
||||
"""Test exec_adbc_query rejects DROP TABLE injection"""
|
||||
payload = {"sql": "SELECT * FROM users; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("exec_adbc_query", payload)
|
||||
print_result("ADBC DROP TABLE Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"ADBC DROP TABLE injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_adbc_query_delete_injection(self, mcp_client):
|
||||
"""Test exec_adbc_query rejects DELETE injection"""
|
||||
payload = {"sql": "SELECT 1; DELETE FROM users; --"}
|
||||
result = await mcp_client.call_tool("exec_adbc_query", payload)
|
||||
print_result("ADBC DELETE Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"ADBC DELETE injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_adbc_query_valid(self, mcp_client):
|
||||
"""Test exec_adbc_query allows valid queries"""
|
||||
payload = {"sql": "SELECT 1 AS test"}
|
||||
result = await mcp_client.call_tool("exec_adbc_query", payload)
|
||||
print_result("ADBC Valid Query", payload, result)
|
||||
|
||||
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||
"""Check if result indicates blocked or error"""
|
||||
if not result:
|
||||
return True
|
||||
if "error" in result:
|
||||
return True
|
||||
result_str = json.dumps(result).lower()
|
||||
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||
|
||||
|
||||
class TestMetadataToolsInjectionAPI:
|
||||
"""Test metadata tools SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_table_list_db_injection(self, mcp_client):
|
||||
"""Test get_db_table_list rejects db_name injection"""
|
||||
payload = {"db_name": "test'; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("get_db_table_list", payload)
|
||||
print_result("get_db_table_list db_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"db_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_table_list_catalog_injection(self, mcp_client):
|
||||
"""Test get_db_table_list rejects catalog_name injection"""
|
||||
payload = {"catalog_name": "internal`; SELECT * FROM mysql.user; --"}
|
||||
result = await mcp_client.call_tool("get_db_table_list", payload)
|
||||
print_result("get_db_table_list catalog_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"catalog_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_comment_injection(self, mcp_client):
|
||||
"""Test get_table_comment rejects table_name injection"""
|
||||
payload = {"table_name": "users'; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("get_table_comment", payload)
|
||||
print_result("get_table_comment table_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_column_comments_injection(self, mcp_client):
|
||||
"""Test get_table_column_comments rejects injection"""
|
||||
payload = {"table_name": "users'; DROP TABLE users; --", "db_name": "test"}
|
||||
result = await mcp_client.call_tool("get_table_column_comments", payload)
|
||||
print_result("get_table_column_comments Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_indexes_injection(self, mcp_client):
|
||||
"""Test get_table_indexes rejects table_name injection"""
|
||||
payload = {"table_name": "users; DROP TABLE users", "db_name": "test"}
|
||||
result = await mcp_client.call_tool("get_table_indexes", payload)
|
||||
print_result("get_table_indexes Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||
"""Check if result indicates blocked or error"""
|
||||
if not result:
|
||||
return True
|
||||
if "error" in result:
|
||||
return True
|
||||
result_str = json.dumps(result).lower()
|
||||
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||
|
||||
|
||||
class TestAnalyticsToolsInjectionAPI:
|
||||
"""Test analytics tools SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_columns_table_injection(self, mcp_client):
|
||||
"""Test analyze_columns rejects table_name injection"""
|
||||
payload = {"table_name": "users'; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("analyze_columns", payload)
|
||||
print_result("analyze_columns table_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_columns_db_injection(self, mcp_client):
|
||||
"""Test analyze_columns rejects db_name injection"""
|
||||
payload = {"table_name": "users", "db_name": "test' OR '1'='1"}
|
||||
result = await mcp_client.call_tool("analyze_columns", payload)
|
||||
print_result("analyze_columns db_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"db_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_basic_info_injection(self, mcp_client):
|
||||
"""Test get_table_basic_info rejects injection"""
|
||||
payload = {"table_name": "users; DROP TABLE audit_log"}
|
||||
result = await mcp_client.call_tool("get_table_basic_info", payload)
|
||||
print_result("get_table_basic_info Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_table_storage_injection(self, mcp_client):
|
||||
"""Test analyze_table_storage rejects injection"""
|
||||
payload = {"table_name": "users`; SELECT * FROM sensitive; --"}
|
||||
result = await mcp_client.call_tool("analyze_table_storage", payload)
|
||||
print_result("analyze_table_storage Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sql_explain_injection(self, mcp_client):
|
||||
"""Test get_sql_explain rejects SQL injection"""
|
||||
payload = {"sql": "SELECT 1; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("get_sql_explain", payload)
|
||||
print_result("get_sql_explain SQL Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"SQL injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sql_profile_injection(self, mcp_client):
|
||||
"""Test get_sql_profile rejects SQL injection"""
|
||||
payload = {"sql": "SELECT 1; DELETE FROM audit_log; --"}
|
||||
result = await mcp_client.call_tool("get_sql_profile", payload)
|
||||
print_result("get_sql_profile SQL Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"SQL injection should be blocked"
|
||||
|
||||
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||
"""Check if result indicates blocked or error"""
|
||||
if not result:
|
||||
return True
|
||||
if "error" in result:
|
||||
return True
|
||||
result_str = json.dumps(result).lower()
|
||||
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||
|
||||
|
||||
class TestGovernanceToolsInjectionAPI:
|
||||
"""Test data governance tools SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trace_column_lineage_table_injection(self, mcp_client):
|
||||
"""Test trace_column_lineage rejects table_name injection"""
|
||||
payload = {"table_name": "users'; DROP TABLE users; --", "column_name": "id"}
|
||||
result = await mcp_client.call_tool("trace_column_lineage", payload)
|
||||
print_result("trace_column_lineage table_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trace_column_lineage_column_injection(self, mcp_client):
|
||||
"""Test trace_column_lineage rejects column_name injection"""
|
||||
payload = {"table_name": "users", "column_name": "id; DROP TABLE users"}
|
||||
result = await mcp_client.call_tool("trace_column_lineage", payload)
|
||||
print_result("trace_column_lineage column_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"column_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_monitor_data_freshness_injection(self, mcp_client):
|
||||
"""Test monitor_data_freshness rejects table_name injection"""
|
||||
payload = {"table_name": "users`; SELECT * FROM passwords; --"}
|
||||
result = await mcp_client.call_tool("monitor_data_freshness", payload)
|
||||
print_result("monitor_data_freshness Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_data_access_patterns_injection(self, mcp_client):
|
||||
"""Test analyze_data_access_patterns rejects injection"""
|
||||
payload = {"table_name": "users' UNION SELECT password FROM admin --"}
|
||||
result = await mcp_client.call_tool("analyze_data_access_patterns", payload)
|
||||
print_result("analyze_data_access_patterns Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||
"""Check if result indicates blocked or error"""
|
||||
if not result:
|
||||
return True
|
||||
if "error" in result:
|
||||
return True
|
||||
result_str = json.dumps(result).lower()
|
||||
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||
|
||||
|
||||
class TestPerformanceToolsInjectionAPI:
|
||||
"""Test performance analytics tools SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_slow_queries_db_injection(self, mcp_client):
|
||||
"""Test analyze_slow_queries_topn rejects db_name injection"""
|
||||
payload = {"db_name": "test'; DROP TABLE audit_log; --"}
|
||||
result = await mcp_client.call_tool("analyze_slow_queries_topn", payload)
|
||||
print_result("analyze_slow_queries_topn db_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"db_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_resource_growth_db_injection(self, mcp_client):
|
||||
"""Test analyze_resource_growth_curves rejects db_name injection"""
|
||||
payload = {"db_name": "test`; GRANT ALL ON *.* TO 'hacker'; --"}
|
||||
result = await mcp_client.call_tool("analyze_resource_growth_curves", payload)
|
||||
print_result("analyze_resource_growth_curves db_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"db_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_data_size_injection(self, mcp_client):
|
||||
"""Test get_table_data_size rejects table_name injection"""
|
||||
payload = {"table_name": "users; TRUNCATE TABLE logs"}
|
||||
result = await mcp_client.call_tool("get_table_data_size", payload)
|
||||
print_result("get_table_data_size Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||
"""Check if result indicates blocked or error"""
|
||||
if not result:
|
||||
return True
|
||||
if "error" in result:
|
||||
return True
|
||||
result_str = json.dumps(result).lower()
|
||||
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||
|
||||
|
||||
# Pytest configuration for async tests
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create event loop for async tests"""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short", "-x"])
|
||||
|
||||
@@ -44,22 +44,31 @@
|
||||
}
|
||||
},
|
||||
"expected_tools": [
|
||||
"analyze_columns",
|
||||
"analyze_data_access_patterns",
|
||||
"analyze_data_flow_dependencies",
|
||||
"analyze_resource_growth_curves",
|
||||
"analyze_slow_queries_topn",
|
||||
"analyze_table_storage",
|
||||
"exec_adbc_query",
|
||||
"exec_query",
|
||||
"get_adbc_connection_info",
|
||||
"get_catalog_list",
|
||||
"get_db_list",
|
||||
"get_db_table_list",
|
||||
"get_table_schema",
|
||||
"get_table_comment",
|
||||
"get_table_column_comments",
|
||||
"get_table_indexes",
|
||||
"get_memory_stats",
|
||||
"get_monitoring_metrics",
|
||||
"get_recent_audit_logs",
|
||||
"get_catalog_list",
|
||||
"get_sql_explain",
|
||||
"get_sql_profile",
|
||||
"get_table_basic_info",
|
||||
"get_table_column_comments",
|
||||
"get_table_comment",
|
||||
"get_table_data_size",
|
||||
"get_monitoring_metrics_info",
|
||||
"get_monitoring_metrics_data",
|
||||
"get_realtime_memory_stats",
|
||||
"get_historical_memory_stats"
|
||||
"get_table_indexes",
|
||||
"get_table_schema",
|
||||
"monitor_data_freshness",
|
||||
"trace_column_lineage"
|
||||
],
|
||||
"expected_resources": [
|
||||
"database",
|
||||
|
||||
@@ -185,8 +185,9 @@ async def test_server_connectivity(transport: Optional[str] = None) -> bool:
|
||||
logger.error(f"Connectivity test failed: {e}")
|
||||
return False
|
||||
|
||||
result = await client.connect_and_run(test_connection)
|
||||
return result
|
||||
await client.connect_and_run(test_connection)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test server connectivity: {e}")
|
||||
return False
|
||||
|
||||
@@ -72,8 +72,7 @@ class TestToolsClientServer:
|
||||
|
||||
return tools
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert len(result) > 0
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_exec_query_via_client(self, client, test_config):
|
||||
@@ -91,14 +90,13 @@ class TestToolsClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
|
||||
if result["success"]:
|
||||
assert "result" in result, "Successful result should contain 'result' field"
|
||||
assert "data" in result, "Successful result should contain 'data' field"
|
||||
else:
|
||||
assert "error" in result, "Failed result should contain 'error' field"
|
||||
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
# Don't assert success=True as it depends on actual server state
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_get_db_list_via_client(self, client, test_config):
|
||||
@@ -115,8 +113,7 @@ class TestToolsClientServer:
|
||||
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_get_table_schema_via_client(self, client, test_config):
|
||||
@@ -133,10 +130,7 @@ class TestToolsClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_error_handling_via_client(self, client, test_config):
|
||||
@@ -151,8 +145,7 @@ class TestToolsClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_with_auth_token_via_client(self, client, test_config):
|
||||
@@ -171,5 +164,4 @@ class TestToolsClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@@ -45,7 +45,6 @@ class TestDorisToolsManager:
|
||||
config.database.password = "test_password"
|
||||
config.database.database = "test_db"
|
||||
config.database.health_check_interval = 60
|
||||
config.database.min_connections = 5
|
||||
config.database.max_connections = 20
|
||||
config.database.connection_timeout = 30
|
||||
config.database.max_connection_age = 3600
|
||||
|
||||
78
test/utils/test_db.py
Normal file
78
test/utils/test_db.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from doris_mcp_server.utils.db import DorisConnection, DorisSessionCache
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_cache():
|
||||
"""Provides a DorisSessionCache instance with a mock connection manager."""
|
||||
connection_manager = MagicMock()
|
||||
cache = DorisSessionCache(connection_manager=connection_manager)
|
||||
yield cache, connection_manager
|
||||
|
||||
|
||||
class TestDorisSessionCache:
|
||||
|
||||
def test_initialization(self, session_cache):
|
||||
cache, _ = session_cache
|
||||
assert cache.cache_system_session is True
|
||||
assert cache.cache_user_session is False
|
||||
assert not cache.cached
|
||||
|
||||
def test_should_cache(self, session_cache):
|
||||
cache, _ = session_cache
|
||||
assert cache._should_cache("query") is True
|
||||
assert cache._should_cache("system") is True
|
||||
assert cache._should_cache("user-test-session-id") is False
|
||||
|
||||
cache.cache_user_session = True
|
||||
assert cache._should_cache("user-test-session-id") is True
|
||||
|
||||
def test_save_and_get_session(self, session_cache):
|
||||
cache, _ = session_cache
|
||||
mock_connection = MagicMock(spec=DorisConnection)
|
||||
mock_connection.session_id = "query"
|
||||
|
||||
cache.save(mock_connection)
|
||||
retrieved_conn = cache.get("query")
|
||||
assert retrieved_conn is mock_connection
|
||||
|
||||
mock_user_connection = MagicMock(spec=DorisConnection)
|
||||
mock_user_connection.session_id = "user-test-session-id"
|
||||
cache.save(mock_user_connection)
|
||||
assert cache.get("user-test-session-id") is None
|
||||
|
||||
cache.cache_user_session = True
|
||||
cache.save(mock_user_connection)
|
||||
retrieved_user_conn = cache.get("user-test-session-id")
|
||||
assert retrieved_user_conn is mock_user_connection
|
||||
|
||||
def test_remove_session(self, session_cache):
|
||||
cache, _ = session_cache
|
||||
mock_connection = MagicMock(spec=DorisConnection)
|
||||
mock_connection.session_id = "system"
|
||||
|
||||
cache.save(mock_connection)
|
||||
assert cache.get("system") is not None
|
||||
|
||||
cache.remove("system")
|
||||
assert cache.get("system") is None
|
||||
|
||||
def test_clear_cache(self, session_cache):
|
||||
cache, connection_manager = session_cache
|
||||
mock_conn1 = MagicMock(spec=DorisConnection)
|
||||
mock_conn1.session_id = "query"
|
||||
mock_conn2 = MagicMock(spec=DorisConnection)
|
||||
mock_conn2.session_id = "system"
|
||||
|
||||
cache.save(mock_conn1)
|
||||
cache.save(mock_conn2)
|
||||
assert len(cache.cached) == 2
|
||||
|
||||
cache.clear()
|
||||
|
||||
assert not cache.cached
|
||||
connection_manager.release_connection.assert_any_call("query", mock_conn1)
|
||||
connection_manager.release_connection.assert_any_call("system", mock_conn2)
|
||||
assert connection_manager.release_connection.call_count == 2
|
||||
@@ -44,7 +44,6 @@ class TestDorisQueryExecutor:
|
||||
config.database.password = "test_password"
|
||||
config.database.database = "test_db"
|
||||
config.database.health_check_interval = 60
|
||||
config.database.min_connections = 5
|
||||
config.database.max_connections = 20
|
||||
config.database.connection_timeout = 30
|
||||
config.database.max_connection_age = 3600
|
||||
@@ -202,3 +201,73 @@ class TestDorisQueryExecutor:
|
||||
if result["success"]:
|
||||
assert "data" in result
|
||||
assert "row_count" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_multi_sql_statements(self, query_executor):
|
||||
"""Test execution of multiple SQL statements"""
|
||||
from doris_mcp_server.utils.query_executor import QueryResult
|
||||
|
||||
# Disable security check for this test
|
||||
query_executor.connection_manager.config.security.enable_security_check = False
|
||||
|
||||
with patch.object(query_executor, 'execute_query') as mock_execute:
|
||||
# Mock results for three SQL statements
|
||||
mock_execute.side_effect = [
|
||||
QueryResult(
|
||||
data=[{"id": 1, "name": "张三"}],
|
||||
row_count=1,
|
||||
execution_time=0.1,
|
||||
sql="SELECT id, name FROM users WHERE id = 1",
|
||||
metadata={"columns": ["id", "name"]}
|
||||
),
|
||||
QueryResult(
|
||||
data=[{"id": 2, "name": "李四"}],
|
||||
row_count=1,
|
||||
execution_time=0.12,
|
||||
sql="SELECT id, name FROM users WHERE id = 2",
|
||||
metadata={"columns": ["id", "name"]}
|
||||
),
|
||||
QueryResult(
|
||||
data=[{"count": 100}],
|
||||
row_count=1,
|
||||
execution_time=0.08,
|
||||
sql="SELECT COUNT(*) as count FROM users",
|
||||
metadata={"columns": ["count"]}
|
||||
)
|
||||
]
|
||||
|
||||
# Execute multiple SQL statements separated by semicolons
|
||||
multi_sql = """
|
||||
SELECT id, name FROM users WHERE id = 1;
|
||||
SELECT id, name FROM users WHERE id = 2;
|
||||
SELECT COUNT(*) as count FROM users;
|
||||
"""
|
||||
|
||||
result = await query_executor.execute_sql_for_mcp(multi_sql)
|
||||
|
||||
# Verify the result structure for multiple statements
|
||||
assert result["success"] is True
|
||||
assert result["multiple_results"] is True
|
||||
assert "results" in result
|
||||
assert len(result["results"]) == 3
|
||||
|
||||
# Verify first query result
|
||||
assert result["results"][0]["data"] == [{"id": 1, "name": "张三"}]
|
||||
assert result["results"][0]["row_count"] == 1
|
||||
assert result["results"][0]["metadata"]["columns"] == ["id", "name"]
|
||||
assert result["results"][0]["metadata"]["query"] == "SELECT id, name FROM users WHERE id = 1"
|
||||
|
||||
# Verify second query result
|
||||
assert result["results"][1]["data"] == [{"id": 2, "name": "李四"}]
|
||||
assert result["results"][1]["row_count"] == 1
|
||||
assert result["results"][1]["metadata"]["columns"] == ["id", "name"]
|
||||
assert result["results"][1]["metadata"]["query"] == "SELECT id, name FROM users WHERE id = 2"
|
||||
|
||||
# Verify third query result
|
||||
assert result["results"][2]["data"] == [{"count": 100}]
|
||||
assert result["results"][2]["row_count"] == 1
|
||||
assert result["results"][2]["metadata"]["columns"] == ["count"]
|
||||
assert result["results"][2]["metadata"]["query"] == "SELECT COUNT(*) as count FROM users"
|
||||
|
||||
# Verify execute_query was called three times
|
||||
assert mock_execute.call_count == 3
|
||||
|
||||
@@ -21,8 +21,6 @@ Tests the query execution functionality through actual MCP client-server communi
|
||||
Assumes the server is already running and configured properly
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import pytest
|
||||
import os
|
||||
import sys
|
||||
@@ -66,14 +64,13 @@ class TestQueryExecutorClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
|
||||
if result["success"]:
|
||||
assert "result" in result, "Successful result should contain 'result' field"
|
||||
assert "data" in result, "Successful result should contain 'data' field"
|
||||
else:
|
||||
assert "error" in result, "Failed result should contain 'error' field"
|
||||
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_show_databases_query_via_client(self, client, test_config):
|
||||
@@ -87,8 +84,7 @@ class TestQueryExecutorClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_information_schema_query_via_client(self, client, test_config):
|
||||
@@ -102,8 +98,7 @@ class TestQueryExecutorClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_max_rows_parameter_via_client(self, client, test_config):
|
||||
@@ -118,8 +113,7 @@ class TestQueryExecutorClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_error_handling_via_client(self, client, test_config):
|
||||
@@ -131,8 +125,7 @@ class TestQueryExecutorClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_auth_token_via_client(self, client, test_config):
|
||||
@@ -152,5 +145,4 @@ class TestQueryExecutorClientServer:
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
await client.connect_and_run(test_callback)
|
||||
|
||||
64
tokens.json
Normal file
64
tokens.json
Normal file
@@ -0,0 +1,64 @@
|
||||
{
|
||||
"version": "1.0",
|
||||
"description": "Doris MCP Server Token configuration file",
|
||||
"created_at": "2025-09-01T00:00:00Z",
|
||||
"tokens": [
|
||||
{
|
||||
"token_id": "admin-token",
|
||||
"token": "doris_admin_token_123456",
|
||||
"description": "Doris admin API access token",
|
||||
"expires_hours": null,
|
||||
"is_active": true,
|
||||
"database_config": {
|
||||
"host": "127.0.0.1",
|
||||
"port": 9030,
|
||||
"user": "root",
|
||||
"password": "",
|
||||
"database": "information_schema",
|
||||
"charset": "UTF8",
|
||||
"fe_http_port": 8030
|
||||
}
|
||||
},
|
||||
{
|
||||
"token_id": "analyst-token",
|
||||
"token": "doris_analyst_token_123456",
|
||||
"description": "Doris analyst API access token",
|
||||
"expires_hours": 8760,
|
||||
"is_active": true,
|
||||
"database_config": {
|
||||
"host": "127.0.0.1",
|
||||
"port": 9030,
|
||||
"user": "root",
|
||||
"password": "",
|
||||
"database": "information_schema",
|
||||
"charset": "UTF8",
|
||||
"fe_http_port": 8030
|
||||
}
|
||||
},
|
||||
{
|
||||
"token_id": "readonly-token",
|
||||
"token": "doris_readonly_token_123456",
|
||||
"description": "Doris readonly API access token",
|
||||
"expires_hours": 4320,
|
||||
"is_active": true,
|
||||
"database_config": {
|
||||
"host": "127.0.0.1",
|
||||
"port": 9030,
|
||||
"user": "root",
|
||||
"password": "",
|
||||
"database": "information_schema",
|
||||
"charset": "UTF8",
|
||||
"fe_http_port": 8030
|
||||
}
|
||||
}
|
||||
],
|
||||
"notes": [
|
||||
"The admin_token, analyst_token, readonly_token is default token,Please change the token before using in production!",
|
||||
"The token_id is the key of the token,Please use the token_id to identify the token",
|
||||
"The token is the value of the token,Please use the token to identify the token",
|
||||
"The description is the description of the token,Please use the description to identify the token",
|
||||
"The expires_hours is the expires hours of the token,Please use the expires_hours to identify the token",
|
||||
"The is_active is the is active of the token,Please use the is_active to identify the token",
|
||||
"The token_id, token, description, expires_hours, is_active is the metadata of the token,Please use the metadata to identify the token"
|
||||
]
|
||||
}
|
||||
420
uv.lock
generated
420
uv.lock
generated
@@ -6,6 +6,48 @@ resolution-markers = [
|
||||
"python_full_version < '3.13'",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "adbc-driver-flightsql"
|
||||
version = "1.7.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "adbc-driver-manager" },
|
||||
{ name = "importlib-resources" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b8/d4/ebd3eed981c771565677084474cdf465141455b5deb1ca409c616609bfd7/adbc_driver_flightsql-1.7.0.tar.gz", hash = "sha256:5dca460a2c66e45b29208eaf41a7206f252177435fa48b16f19833b12586f7a0", size = 21247 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/36/20/807fca9d904b7e0d3020439828d6410db7fd7fd635824a80cab113d9fad1/adbc_driver_flightsql-1.7.0-py3-none-macosx_10_15_x86_64.whl", hash = "sha256:a5658f9bc3676bd122b26138e9b9ce56b8bf37387efe157b4c66d56f942361c6", size = 7749664 },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/e6/9e50f6497819c911b9cc1962ffde610b60f7d8e951d6bb3fa145dcfb50a7/adbc_driver_flightsql-1.7.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:65e21df86b454d8db422c8ee22db31be217d88c42d9d6dd89119f06813037c91", size = 7302476 },
|
||||
{ url = "https://files.pythonhosted.org/packages/27/82/e51af85e7cc8c87bc8ce4fae8ca7ee1d3cf39c926be0aeab789cedc93f0a/adbc_driver_flightsql-1.7.0-py3-none-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3282fdc7b73c712780cc777975288c88b1e3a555355bbe09df101aa954f8f105", size = 7686056 },
|
||||
{ url = "https://files.pythonhosted.org/packages/8b/c9/591c8ecbaf010ba3f4b360db602050ee5880cd077a573c9e90fcb270ab71/adbc_driver_flightsql-1.7.0-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e0c5737ae6ee3bbfba44dcbc28ba1ff8cf3ab6521888c4b0f10dd6a482482161", size = 7050275 },
|
||||
{ url = "https://files.pythonhosted.org/packages/10/14/f339e9a5d8dbb3e3040215514cea9cca0a58640964aaccc6532f18003a03/adbc_driver_flightsql-1.7.0-py3-none-win_amd64.whl", hash = "sha256:f8b5290b322304b7d944ca823754e6354c1868dbbe94ddf84236f3e0329545da", size = 14312858 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "adbc-driver-manager"
|
||||
version = "1.7.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/bb/bf/2986a2cd3e1af658d2597f7e2308564e5c11e036f9736d5c256f1e00d578/adbc_driver_manager-1.7.0.tar.gz", hash = "sha256:e3edc5d77634b5925adf6eb4fbcd01676b54acb2f5b1d6864b6a97c6a899591a", size = 198128 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/74/3a/72bd9c45d55f1f5f4c549e206de8cfe3313b31f7b95fbcb180da05c81044/adbc_driver_manager-1.7.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:8da1ac4c19bcbf30b3bd54247ec889dfacc9b44147c70b4da79efe2e9ba93600", size = 524210 },
|
||||
{ url = "https://files.pythonhosted.org/packages/33/29/e1a8d8dde713a287f8021f3207127f133ddce578711a4575218bdf78ef27/adbc_driver_manager-1.7.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:408bc23bad1a6823b364e2388f85f96545e82c3b2db97d7828a4b94839d3f29e", size = 505902 },
|
||||
{ url = "https://files.pythonhosted.org/packages/59/00/773ece64a58c0ade797ab4577e7cdc4c71ebf800b86d2d5637e3bfe605e9/adbc_driver_manager-1.7.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf38294320c23e47ed3455348e910031ad8289c3f9167ae35519ac957b7add01", size = 2974883 },
|
||||
{ url = "https://files.pythonhosted.org/packages/7c/ad/1568da6ae9ab70983f1438503d3906c6b1355601230e891d16e272376a04/adbc_driver_manager-1.7.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:689f91b62c18a9f86f892f112786fb157cacc4729b4d81666db4ca778eade2a8", size = 2997781 },
|
||||
{ url = "https://files.pythonhosted.org/packages/19/66/2b6ea5afded25a3fa009873c2bbebcd9283910877cc10b9453d680c00b9a/adbc_driver_manager-1.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:f936cfc8d098898a47ef60396bd7a73926ec3068f2d6d92a2be4e56e4aaf3770", size = 690041 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/3b/91154c83a98f103a3d97c9e2cb838c3842aef84ca4f4b219164b182d9516/adbc_driver_manager-1.7.0-cp313-cp313-macosx_10_15_x86_64.whl", hash = "sha256:ab9ee36683fd54f61b0db0f4a96f70fe1932223e61df9329290370b145abb0a9", size = 522737 },
|
||||
{ url = "https://files.pythonhosted.org/packages/9c/52/4bc80c3388d5e2a3b6e504ba9656dd9eb3d8dbe822d07af38db1b8c96fb1/adbc_driver_manager-1.7.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4ec03d94177f71a8d3a149709f4111e021f9950229b35c0a803aadb1a1855a4b", size = 503896 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e1/f3/46052ca11224f661cef4721e19138bc73e750ba6aea54f22606950491606/adbc_driver_manager-1.7.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:700c79dac08a620018c912ede45a6dc7851819bc569a53073ab652dc0bd0c92f", size = 2972586 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a2/22/44738b41bb5ca30f94b5f4c00c71c20be86d7eb4ddc389d4cf3c7b8b69ef/adbc_driver_manager-1.7.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:98db0f5d0aa1635475f63700a7b6f677390beb59c69c7ba9d388bc8ce3779388", size = 2992001 },
|
||||
{ url = "https://files.pythonhosted.org/packages/1b/2b/5184fe5a529feb019582cc90d0f65e0021d52c34ca20620551532340645a/adbc_driver_manager-1.7.0-cp313-cp313-win_amd64.whl", hash = "sha256:4b7e5e9a163acb21804647cc7894501df51cdcd780ead770557112a26ca01ca6", size = 688789 },
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/e0/b283544e1bb7864bf5a5ac9cd330f111009eff9180ec5000420510cf9342/adbc_driver_manager-1.7.0-cp313-cp313t-macosx_10_15_x86_64.whl", hash = "sha256:ac83717965b83367a8ad6c0536603acdcfa66e0592d783f8940f55fda47d963e", size = 538625 },
|
||||
{ url = "https://files.pythonhosted.org/packages/77/5a/dc244264bd8d0c331a418d2bdda5cb6e26c30493ff075d706aa81d4e3b30/adbc_driver_manager-1.7.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:4c234cf81b00eaf7e7c65dbd0f0ddf7bdae93dfcf41e9d8543f9ecf4b10590f6", size = 523627 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/ff/a499a00367fd092edb20dc6e36c81e3c7a437671c70481cae97f46c8156a/adbc_driver_manager-1.7.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ad8aa4b039cc50722a700b544773388c6b1dea955781a01f79cd35d0a1e6edbf", size = 3037517 },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/6e/9dfdb113294dcb24b4f53924cd4a9c9af3fbe45a9790c1327048df731246/adbc_driver_manager-1.7.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4409ff53578e01842a8f57787ebfbfee790c1da01a6bd57fcb7701ed5d4dd4f7", size = 3016543 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiofiles"
|
||||
version = "24.1.0"
|
||||
@@ -518,6 +560,176 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8f/d7/9322c609343d929e75e7e5e6255e614fcc67572cfd083959cdef3b7aad79/docutils-0.21.2-py3-none-any.whl", hash = "sha256:dafca5b9e384f0e419294eb4d2ff9fa826435bf15f15b7bd45723e8ad76811b2", size = 587408 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "doris-mcp-server"
|
||||
version = "0.6.1"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "adbc-driver-flightsql" },
|
||||
{ name = "adbc-driver-manager" },
|
||||
{ name = "aiofiles" },
|
||||
{ name = "aiohttp" },
|
||||
{ name = "aiomysql" },
|
||||
{ name = "aioredis" },
|
||||
{ name = "asyncio-mqtt" },
|
||||
{ name = "bcrypt" },
|
||||
{ name = "click" },
|
||||
{ name = "cryptography" },
|
||||
{ name = "fastapi" },
|
||||
{ name = "httpx" },
|
||||
{ name = "mcp" },
|
||||
{ name = "numpy" },
|
||||
{ name = "orjson" },
|
||||
{ name = "pandas" },
|
||||
{ name = "passlib", extra = ["bcrypt"] },
|
||||
{ name = "prometheus-client" },
|
||||
{ name = "pyarrow" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "pyjwt" },
|
||||
{ name = "pymysql" },
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "python-jose", extra = ["cryptography"] },
|
||||
{ name = "python-multipart" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "requests" },
|
||||
{ name = "rich" },
|
||||
{ name = "sqlparse" },
|
||||
{ name = "starlette" },
|
||||
{ name = "structlog" },
|
||||
{ name = "toml" },
|
||||
{ name = "tqdm" },
|
||||
{ name = "typer" },
|
||||
{ name = "uvicorn", extra = ["standard"] },
|
||||
{ name = "websockets" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
dev = [
|
||||
{ name = "bandit" },
|
||||
{ name = "black" },
|
||||
{ name = "flake8" },
|
||||
{ name = "isort" },
|
||||
{ name = "mypy" },
|
||||
{ name = "myst-parser" },
|
||||
{ name = "pre-commit" },
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "pytest-mock" },
|
||||
{ name = "pytest-xdist" },
|
||||
{ name = "ruff" },
|
||||
{ name = "safety" },
|
||||
{ name = "sphinx" },
|
||||
{ name = "sphinx-rtd-theme" },
|
||||
{ name = "tox" },
|
||||
]
|
||||
docs = [
|
||||
{ name = "myst-parser" },
|
||||
{ name = "sphinx" },
|
||||
{ name = "sphinx-autoapi" },
|
||||
{ name = "sphinx-rtd-theme" },
|
||||
]
|
||||
monitoring = [
|
||||
{ name = "grafana-client" },
|
||||
{ name = "jaeger-client" },
|
||||
{ name = "opentelemetry-api" },
|
||||
{ name = "opentelemetry-sdk" },
|
||||
{ name = "prometheus-client" },
|
||||
]
|
||||
performance = [
|
||||
{ name = "cchardet" },
|
||||
{ name = "orjson" },
|
||||
{ name = "uvloop" },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "ruff" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "adbc-driver-flightsql", specifier = ">=0.8.0" },
|
||||
{ name = "adbc-driver-manager", specifier = ">=0.8.0" },
|
||||
{ name = "aiofiles", specifier = ">=23.0.0" },
|
||||
{ name = "aiohttp", specifier = ">=3.9.0" },
|
||||
{ name = "aiomysql", specifier = ">=0.2.0" },
|
||||
{ name = "aioredis", specifier = ">=2.0.0" },
|
||||
{ name = "asyncio-mqtt", specifier = ">=0.16.0" },
|
||||
{ name = "bandit", marker = "extra == 'dev'", specifier = ">=1.7.0" },
|
||||
{ name = "bcrypt", specifier = ">=4.1.0" },
|
||||
{ name = "black", marker = "extra == 'dev'", specifier = ">=23.12.0" },
|
||||
{ name = "cchardet", marker = "extra == 'performance'", specifier = ">=2.1.0" },
|
||||
{ name = "click", specifier = ">=8.1.0" },
|
||||
{ name = "cryptography", specifier = ">=41.0.0" },
|
||||
{ name = "fastapi", specifier = ">=0.108.0" },
|
||||
{ name = "flake8", marker = "extra == 'dev'", specifier = ">=7.0.0" },
|
||||
{ name = "grafana-client", marker = "extra == 'monitoring'", specifier = ">=3.5.0" },
|
||||
{ name = "httpx", specifier = ">=0.26.0" },
|
||||
{ name = "isort", marker = "extra == 'dev'", specifier = ">=5.13.0" },
|
||||
{ name = "jaeger-client", marker = "extra == 'monitoring'", specifier = ">=4.8.0" },
|
||||
{ name = "mcp", specifier = ">=1.8.0,<2.0.0" },
|
||||
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" },
|
||||
{ name = "myst-parser", marker = "extra == 'dev'", specifier = ">=2.0.0" },
|
||||
{ name = "myst-parser", marker = "extra == 'docs'", specifier = ">=2.0.0" },
|
||||
{ name = "numpy", specifier = ">=1.24.0" },
|
||||
{ name = "opentelemetry-api", marker = "extra == 'monitoring'", specifier = ">=1.21.0" },
|
||||
{ name = "opentelemetry-sdk", marker = "extra == 'monitoring'", specifier = ">=1.21.0" },
|
||||
{ name = "orjson", specifier = ">=3.9.0" },
|
||||
{ name = "orjson", marker = "extra == 'performance'", specifier = ">=3.9.0" },
|
||||
{ name = "pandas", specifier = ">=2.0.0" },
|
||||
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.0" },
|
||||
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.6.0" },
|
||||
{ name = "prometheus-client", specifier = ">=0.19.0" },
|
||||
{ name = "prometheus-client", marker = "extra == 'monitoring'", specifier = ">=0.19.0" },
|
||||
{ name = "pyarrow", specifier = ">=14.0.0" },
|
||||
{ name = "pydantic", specifier = ">=2.5.0" },
|
||||
{ name = "pydantic-settings", specifier = ">=2.1.0" },
|
||||
{ name = "pyjwt", specifier = ">=2.8.0" },
|
||||
{ name = "pymysql", specifier = ">=1.1.0" },
|
||||
{ name = "pytest", specifier = ">=8.4.0" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0" },
|
||||
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
|
||||
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23.0" },
|
||||
{ name = "pytest-cov", specifier = ">=6.1.1" },
|
||||
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.1.0" },
|
||||
{ name = "pytest-mock", marker = "extra == 'dev'", specifier = ">=3.12.0" },
|
||||
{ name = "pytest-xdist", marker = "extra == 'dev'", specifier = ">=3.5.0" },
|
||||
{ name = "python-dateutil", specifier = ">=2.8.0" },
|
||||
{ name = "python-dotenv", specifier = ">=1.0.0" },
|
||||
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
|
||||
{ name = "python-multipart", specifier = ">=0.0.6" },
|
||||
{ name = "pyyaml", specifier = ">=6.0.0" },
|
||||
{ name = "requests", specifier = ">=2.31.0" },
|
||||
{ name = "rich", specifier = ">=13.7.0" },
|
||||
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" },
|
||||
{ name = "safety", marker = "extra == 'dev'", specifier = ">=2.3.0" },
|
||||
{ name = "sphinx", marker = "extra == 'dev'", specifier = ">=7.2.0" },
|
||||
{ name = "sphinx", marker = "extra == 'docs'", specifier = ">=7.2.0" },
|
||||
{ name = "sphinx-autoapi", marker = "extra == 'docs'", specifier = ">=3.0.0" },
|
||||
{ name = "sphinx-rtd-theme", marker = "extra == 'dev'", specifier = ">=2.0.0" },
|
||||
{ name = "sphinx-rtd-theme", marker = "extra == 'docs'", specifier = ">=2.0.0" },
|
||||
{ name = "sqlparse", specifier = ">=0.4.4" },
|
||||
{ name = "starlette", specifier = ">=0.27.0" },
|
||||
{ name = "structlog", specifier = ">=23.2.0" },
|
||||
{ name = "toml", specifier = ">=0.10.0" },
|
||||
{ name = "tox", marker = "extra == 'dev'", specifier = ">=4.11.0" },
|
||||
{ name = "tqdm", specifier = ">=4.66.0" },
|
||||
{ name = "typer", specifier = ">=0.9.0" },
|
||||
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.25.0" },
|
||||
{ name = "uvloop", marker = "extra == 'performance'", specifier = ">=0.19.0" },
|
||||
{ name = "websockets", specifier = ">=12.0" },
|
||||
]
|
||||
provides-extras = ["dev", "docs", "performance", "monitoring"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [{ name = "ruff", specifier = ">=0.11.13" }]
|
||||
|
||||
[[package]]
|
||||
name = "dparse"
|
||||
version = "0.6.4"
|
||||
@@ -768,6 +980,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/20/b0/36bd937216ec521246249be3bf9855081de4c5e06a0c9b4219dbeda50373/importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd", size = 27656 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "importlib-resources"
|
||||
version = "6.5.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/cf/8c/f834fbf984f691b4f7ff60f50b514cc3de5cc08abfc3295564dd89c5e2e7/importlib_resources-6.5.2.tar.gz", hash = "sha256:185f87adef5bcc288449d98fb4fba07cea78bc036455dd44c5fc4a2fe78fed2c", size = 44693 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "2.1.0"
|
||||
@@ -946,170 +1167,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/79/45/823ad05504bea55cb0feb7470387f151252127ad5c72f8882e8fe6cf5c0e/mcp-1.9.3-py3-none-any.whl", hash = "sha256:69b0136d1ac9927402ed4cf221d4b8ff875e7132b0b06edd446448766f34f9b9", size = 131063 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mcp-doris-server"
|
||||
version = "0.4.2"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "aiofiles" },
|
||||
{ name = "aiohttp" },
|
||||
{ name = "aiomysql" },
|
||||
{ name = "aioredis" },
|
||||
{ name = "asyncio-mqtt" },
|
||||
{ name = "bcrypt" },
|
||||
{ name = "click" },
|
||||
{ name = "cryptography" },
|
||||
{ name = "fastapi" },
|
||||
{ name = "httpx" },
|
||||
{ name = "mcp" },
|
||||
{ name = "numpy" },
|
||||
{ name = "orjson" },
|
||||
{ name = "pandas" },
|
||||
{ name = "passlib", extra = ["bcrypt"] },
|
||||
{ name = "prometheus-client" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "pyjwt" },
|
||||
{ name = "pymysql" },
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "python-jose", extra = ["cryptography"] },
|
||||
{ name = "python-multipart" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "requests" },
|
||||
{ name = "rich" },
|
||||
{ name = "sqlparse" },
|
||||
{ name = "starlette" },
|
||||
{ name = "structlog" },
|
||||
{ name = "toml" },
|
||||
{ name = "tqdm" },
|
||||
{ name = "typer" },
|
||||
{ name = "uvicorn", extra = ["standard"] },
|
||||
{ name = "websockets" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
dev = [
|
||||
{ name = "bandit" },
|
||||
{ name = "black" },
|
||||
{ name = "flake8" },
|
||||
{ name = "isort" },
|
||||
{ name = "mypy" },
|
||||
{ name = "myst-parser" },
|
||||
{ name = "pre-commit" },
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "pytest-mock" },
|
||||
{ name = "pytest-xdist" },
|
||||
{ name = "ruff" },
|
||||
{ name = "safety" },
|
||||
{ name = "sphinx" },
|
||||
{ name = "sphinx-rtd-theme" },
|
||||
{ name = "tox" },
|
||||
]
|
||||
docs = [
|
||||
{ name = "myst-parser" },
|
||||
{ name = "sphinx" },
|
||||
{ name = "sphinx-autoapi" },
|
||||
{ name = "sphinx-rtd-theme" },
|
||||
]
|
||||
monitoring = [
|
||||
{ name = "grafana-client" },
|
||||
{ name = "jaeger-client" },
|
||||
{ name = "opentelemetry-api" },
|
||||
{ name = "opentelemetry-sdk" },
|
||||
{ name = "prometheus-client" },
|
||||
]
|
||||
performance = [
|
||||
{ name = "cchardet" },
|
||||
{ name = "orjson" },
|
||||
{ name = "uvloop" },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "ruff" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "aiofiles", specifier = ">=23.0.0" },
|
||||
{ name = "aiohttp", specifier = ">=3.9.0" },
|
||||
{ name = "aiomysql", specifier = ">=0.2.0" },
|
||||
{ name = "aioredis", specifier = ">=2.0.0" },
|
||||
{ name = "asyncio-mqtt", specifier = ">=0.16.0" },
|
||||
{ name = "bandit", marker = "extra == 'dev'", specifier = ">=1.7.0" },
|
||||
{ name = "bcrypt", specifier = ">=4.1.0" },
|
||||
{ name = "black", marker = "extra == 'dev'", specifier = ">=23.12.0" },
|
||||
{ name = "cchardet", marker = "extra == 'performance'", specifier = ">=2.1.0" },
|
||||
{ name = "click", specifier = ">=8.1.0" },
|
||||
{ name = "cryptography", specifier = ">=41.0.0" },
|
||||
{ name = "fastapi", specifier = ">=0.108.0" },
|
||||
{ name = "flake8", marker = "extra == 'dev'", specifier = ">=7.0.0" },
|
||||
{ name = "grafana-client", marker = "extra == 'monitoring'", specifier = ">=3.5.0" },
|
||||
{ name = "httpx", specifier = ">=0.26.0" },
|
||||
{ name = "isort", marker = "extra == 'dev'", specifier = ">=5.13.0" },
|
||||
{ name = "jaeger-client", marker = "extra == 'monitoring'", specifier = ">=4.8.0" },
|
||||
{ name = "mcp", specifier = ">=1.8.0,<2.0.0" },
|
||||
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" },
|
||||
{ name = "myst-parser", marker = "extra == 'dev'", specifier = ">=2.0.0" },
|
||||
{ name = "myst-parser", marker = "extra == 'docs'", specifier = ">=2.0.0" },
|
||||
{ name = "numpy", specifier = ">=1.24.0" },
|
||||
{ name = "opentelemetry-api", marker = "extra == 'monitoring'", specifier = ">=1.21.0" },
|
||||
{ name = "opentelemetry-sdk", marker = "extra == 'monitoring'", specifier = ">=1.21.0" },
|
||||
{ name = "orjson", specifier = ">=3.9.0" },
|
||||
{ name = "orjson", marker = "extra == 'performance'", specifier = ">=3.9.0" },
|
||||
{ name = "pandas", specifier = ">=2.0.0" },
|
||||
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.0" },
|
||||
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.6.0" },
|
||||
{ name = "prometheus-client", specifier = ">=0.19.0" },
|
||||
{ name = "prometheus-client", marker = "extra == 'monitoring'", specifier = ">=0.19.0" },
|
||||
{ name = "pydantic", specifier = ">=2.5.0" },
|
||||
{ name = "pydantic-settings", specifier = ">=2.1.0" },
|
||||
{ name = "pyjwt", specifier = ">=2.8.0" },
|
||||
{ name = "pymysql", specifier = ">=1.1.0" },
|
||||
{ name = "pytest", specifier = ">=8.4.0" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0" },
|
||||
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
|
||||
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23.0" },
|
||||
{ name = "pytest-cov", specifier = ">=6.1.1" },
|
||||
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.1.0" },
|
||||
{ name = "pytest-mock", marker = "extra == 'dev'", specifier = ">=3.12.0" },
|
||||
{ name = "pytest-xdist", marker = "extra == 'dev'", specifier = ">=3.5.0" },
|
||||
{ name = "python-dateutil", specifier = ">=2.8.0" },
|
||||
{ name = "python-dotenv", specifier = ">=1.0.0" },
|
||||
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
|
||||
{ name = "python-multipart", specifier = ">=0.0.6" },
|
||||
{ name = "pyyaml", specifier = ">=6.0.0" },
|
||||
{ name = "requests", specifier = ">=2.31.0" },
|
||||
{ name = "rich", specifier = ">=13.7.0" },
|
||||
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" },
|
||||
{ name = "safety", marker = "extra == 'dev'", specifier = ">=2.3.0" },
|
||||
{ name = "sphinx", marker = "extra == 'dev'", specifier = ">=7.2.0" },
|
||||
{ name = "sphinx", marker = "extra == 'docs'", specifier = ">=7.2.0" },
|
||||
{ name = "sphinx-autoapi", marker = "extra == 'docs'", specifier = ">=3.0.0" },
|
||||
{ name = "sphinx-rtd-theme", marker = "extra == 'dev'", specifier = ">=2.0.0" },
|
||||
{ name = "sphinx-rtd-theme", marker = "extra == 'docs'", specifier = ">=2.0.0" },
|
||||
{ name = "sqlparse", specifier = ">=0.4.4" },
|
||||
{ name = "starlette", specifier = ">=0.27.0" },
|
||||
{ name = "structlog", specifier = ">=23.2.0" },
|
||||
{ name = "toml", specifier = ">=0.10.0" },
|
||||
{ name = "tox", marker = "extra == 'dev'", specifier = ">=4.11.0" },
|
||||
{ name = "tqdm", specifier = ">=4.66.0" },
|
||||
{ name = "typer", specifier = ">=0.9.0" },
|
||||
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.25.0" },
|
||||
{ name = "uvloop", marker = "extra == 'performance'", specifier = ">=0.19.0" },
|
||||
{ name = "websockets", specifier = ">=12.0" },
|
||||
]
|
||||
provides-extras = ["dev", "docs", "performance", "monitoring"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [{ name = "ruff", specifier = ">=0.11.13" }]
|
||||
|
||||
[[package]]
|
||||
name = "mdit-py-plugins"
|
||||
version = "0.4.2"
|
||||
@@ -1605,6 +1662,41 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7b/d7/7831438e6c3ebbfa6e01a927127a6cb42ad3ab844247f3c5b96bea25d73d/psutil-6.1.1-cp37-abi3-win_amd64.whl", hash = "sha256:f35cfccb065fff93529d2afb4a2e89e363fe63ca1e4a5da22b603a85833c2649", size = 254444 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyarrow"
|
||||
version = "20.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a2/ee/a7810cb9f3d6e9238e61d312076a9859bf3668fd21c69744de9532383912/pyarrow-20.0.0.tar.gz", hash = "sha256:febc4a913592573c8d5805091a6c2b5064c8bd6e002131f01061797d91c783c1", size = 1125187 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a1/d6/0c10e0d54f6c13eb464ee9b67a68b8c71bcf2f67760ef5b6fbcddd2ab05f/pyarrow-20.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:75a51a5b0eef32727a247707d4755322cb970be7e935172b6a3a9f9ae98404ba", size = 30815067 },
|
||||
{ url = "https://files.pythonhosted.org/packages/7e/e2/04e9874abe4094a06fd8b0cbb0f1312d8dd7d707f144c2ec1e5e8f452ffa/pyarrow-20.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:211d5e84cecc640c7a3ab900f930aaff5cd2702177e0d562d426fb7c4f737781", size = 32297128 },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/fd/c565e5dcc906a3b471a83273039cb75cb79aad4a2d4a12f76cc5ae90a4b8/pyarrow-20.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ba3cf4182828be7a896cbd232aa8dd6a31bd1f9e32776cc3796c012855e1199", size = 41334890 },
|
||||
{ url = "https://files.pythonhosted.org/packages/af/a9/3bdd799e2c9b20c1ea6dc6fa8e83f29480a97711cf806e823f808c2316ac/pyarrow-20.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c3a01f313ffe27ac4126f4c2e5ea0f36a5fc6ab51f8726cf41fee4b256680bd", size = 42421775 },
|
||||
{ url = "https://files.pythonhosted.org/packages/10/f7/da98ccd86354c332f593218101ae56568d5dcedb460e342000bd89c49cc1/pyarrow-20.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:a2791f69ad72addd33510fec7bb14ee06c2a448e06b649e264c094c5b5f7ce28", size = 40687231 },
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/1b/2168d6050e52ff1e6cefc61d600723870bf569cbf41d13db939c8cf97a16/pyarrow-20.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4250e28a22302ce8692d3a0e8ec9d9dde54ec00d237cff4dfa9c1fbf79e472a8", size = 42295639 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/66/2d976c0c7158fd25591c8ca55aee026e6d5745a021915a1835578707feb3/pyarrow-20.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:89e030dc58fc760e4010148e6ff164d2f44441490280ef1e97a542375e41058e", size = 42908549 },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/a9/dfb999c2fc6911201dcbf348247f9cc382a8990f9ab45c12eabfd7243a38/pyarrow-20.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6102b4864d77102dbbb72965618e204e550135a940c2534711d5ffa787df2a5a", size = 44557216 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/8e/9adee63dfa3911be2382fb4d92e4b2e7d82610f9d9f668493bebaa2af50f/pyarrow-20.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:96d6a0a37d9c98be08f5ed6a10831d88d52cac7b13f5287f1e0f625a0de8062b", size = 25660496 },
|
||||
{ url = "https://files.pythonhosted.org/packages/9b/aa/daa413b81446d20d4dad2944110dcf4cf4f4179ef7f685dd5a6d7570dc8e/pyarrow-20.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:a15532e77b94c61efadde86d10957950392999503b3616b2ffcef7621a002893", size = 30798501 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/75/2303d1caa410925de902d32ac215dc80a7ce7dd8dfe95358c165f2adf107/pyarrow-20.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:dd43f58037443af715f34f1322c782ec463a3c8a94a85fdb2d987ceb5658e061", size = 32277895 },
|
||||
{ url = "https://files.pythonhosted.org/packages/92/41/fe18c7c0b38b20811b73d1bdd54b1fccba0dab0e51d2048878042d84afa8/pyarrow-20.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aa0d288143a8585806e3cc7c39566407aab646fb9ece164609dac1cfff45f6ae", size = 41327322 },
|
||||
{ url = "https://files.pythonhosted.org/packages/da/ab/7dbf3d11db67c72dbf36ae63dcbc9f30b866c153b3a22ef728523943eee6/pyarrow-20.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6953f0114f8d6f3d905d98e987d0924dabce59c3cda380bdfaa25a6201563b4", size = 42411441 },
|
||||
{ url = "https://files.pythonhosted.org/packages/90/c3/0c7da7b6dac863af75b64e2f827e4742161128c350bfe7955b426484e226/pyarrow-20.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:991f85b48a8a5e839b2128590ce07611fae48a904cae6cab1f089c5955b57eb5", size = 40677027 },
|
||||
{ url = "https://files.pythonhosted.org/packages/be/27/43a47fa0ff9053ab5203bb3faeec435d43c0d8bfa40179bfd076cdbd4e1c/pyarrow-20.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:97c8dc984ed09cb07d618d57d8d4b67a5100a30c3818c2fb0b04599f0da2de7b", size = 42281473 },
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/0b/d56c63b078876da81bbb9ba695a596eabee9b085555ed12bf6eb3b7cab0e/pyarrow-20.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9b71daf534f4745818f96c214dbc1e6124d7daf059167330b610fc69b6f3d3e3", size = 42893897 },
|
||||
{ url = "https://files.pythonhosted.org/packages/92/ac/7d4bd020ba9145f354012838692d48300c1b8fe5634bfda886abcada67ed/pyarrow-20.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e8b88758f9303fa5a83d6c90e176714b2fd3852e776fc2d7e42a22dd6c2fb368", size = 44543847 },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/07/290f4abf9ca702c5df7b47739c1b2c83588641ddfa2cc75e34a301d42e55/pyarrow-20.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:30b3051b7975801c1e1d387e17c588d8ab05ced9b1e14eec57915f79869b5031", size = 25653219 },
|
||||
{ url = "https://files.pythonhosted.org/packages/95/df/720bb17704b10bd69dde086e1400b8eefb8f58df3f8ac9cff6c425bf57f1/pyarrow-20.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:ca151afa4f9b7bc45bcc791eb9a89e90a9eb2772767d0b1e5389609c7d03db63", size = 30853957 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/72/0d5f875efc31baef742ba55a00a25213a19ea64d7176e0fe001c5d8b6e9a/pyarrow-20.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:4680f01ecd86e0dd63e39eb5cd59ef9ff24a9d166db328679e36c108dc993d4c", size = 32247972 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d5/bc/e48b4fa544d2eea72f7844180eb77f83f2030b84c8dad860f199f94307ed/pyarrow-20.0.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f4c8534e2ff059765647aa69b75d6543f9fef59e2cd4c6d18015192565d2b70", size = 41256434 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c3/01/974043a29874aa2cf4f87fb07fd108828fc7362300265a2a64a94965e35b/pyarrow-20.0.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e1f8a47f4b4ae4c69c4d702cfbdfe4d41e18e5c7ef6f1bb1c50918c1e81c57b", size = 42353648 },
|
||||
{ url = "https://files.pythonhosted.org/packages/68/95/cc0d3634cde9ca69b0e51cbe830d8915ea32dda2157560dda27ff3b3337b/pyarrow-20.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:a1f60dc14658efaa927f8214734f6a01a806d7690be4b3232ba526836d216122", size = 40619853 },
|
||||
{ url = "https://files.pythonhosted.org/packages/29/c2/3ad40e07e96a3e74e7ed7cc8285aadfa84eb848a798c98ec0ad009eb6bcc/pyarrow-20.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:204a846dca751428991346976b914d6d2a82ae5b8316a6ed99789ebf976551e6", size = 42241743 },
|
||||
{ url = "https://files.pythonhosted.org/packages/eb/cb/65fa110b483339add6a9bc7b6373614166b14e20375d4daa73483755f830/pyarrow-20.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:f3b117b922af5e4c6b9a9115825726cac7d8b1421c37c2b5e24fbacc8930612c", size = 42839441 },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/7b/f30b1954589243207d7a0fbc9997401044bf9a033eec78f6cb50da3f304a/pyarrow-20.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:e724a3fd23ae5b9c010e7be857f4405ed5e679db5c93e66204db1a69f733936a", size = 44503279 },
|
||||
{ url = "https://files.pythonhosted.org/packages/37/40/ad395740cd641869a13bcf60851296c89624662575621968dcfafabaa7f6/pyarrow-20.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:82f1ee5133bd8f49d31be1299dc07f585136679666b502540db854968576faf9", size = 25944982 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyasn1"
|
||||
version = "0.6.1"
|
||||
|
||||
Reference in New Issue
Block a user