58 Commits

Author SHA1 Message Date
The Apache Software Foundation
a6f893628b Set up default protection ruleset for default and release branches 2026-05-15 15:28:08 -05:00
bingquanzhao
81305ffbf9 [fix]fix token auth (#69)
* fix tocken auth

* Further fixes to the token overwriting issue and restoration of hot reloading of tokens.json.
2025-12-24 20:39:16 +08:00
zzzzwc
43143f0b30 feat: add batch SQL execution support for MCP (#70)
* feat: add batch SQL execution support for MCP

- Add sql field to QueryResult to track executed query
- Implement execute_batch_sqls_for_mcp for executing multiple SQL
- Use sqlparse to split and execute multiple SQL in single request
- Improve error handling in execute_batch_queries
- Return multiple results format when batch queries are detected

* test: add multi-SQL statements test for query executor
2025-12-24 12:45:29 +08:00
bingquanzhao
e58361e04b fix some security issues (#68) 2025-12-10 09:11:03 +08:00
Yijia Su
a125a2f5f8 [fix]Fixed five known issues, including token authentication and multi-worker operation. (#63)
* 0.6.1Version

* fix 0.5.1 schema async bug

* fix security bug

* fix security bug

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add a controllable MCP Server DB Pool permission authentication system, connect it with the Doris permission system, and provide it to enterprise-level applications concurrently with the multi-Worker mode.

* Add Tokens Management

* change version

* fix stdio start bug

* fix stdio start bug

* fix stdio start bug
2025-11-04 14:45:38 +08:00
Yijia Su
2613912df3 [Performance]Optimize Stdio and Streamable HTTP startup solutions (#60)
* 0.5.1 Version

* fix 0.5.1 schema async bug

* fix security bug

* fix security bug

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add a controllable MCP Server DB Pool permission authentication system, connect it with the Doris permission system, and provide it to enterprise-level applications concurrently with the multi-Worker mode.

* Add Tokens Management

* change version

* fix stdio start bug

* fix stdio start bug
2025-09-23 12:21:30 +08:00
Yijia Su
067f160b3e [fix]Release Version Change (#56)
* 0.5.1 Version

* fix 0.5.1 schema async bug

* fix security bug

* fix security bug

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add a controllable MCP Server DB Pool permission authentication system, connect it with the Doris permission system, and provide it to enterprise-level applications concurrently with the multi-Worker mode.

* Add Tokens Management

* change version
2025-09-03 12:41:38 +08:00
Yijia Su
9ba4cc6f45 [Performance]Add Token Management (#55)
* 0.5.1 Version

* fix 0.5.1 schema async bug

* fix security bug

* fix security bug

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add a controllable MCP Server DB Pool permission authentication system, connect it with the Doris permission system, and provide it to enterprise-level applications concurrently with the multi-Worker mode.

* Add Tokens Management
2025-09-03 11:55:38 +08:00
Yijia Su
f99399c6c7 [Performance]Add a controllable MCP Server DB Pool permission authentication system (#53)
* 0.5.1 Version

* fix 0.5.1 schema async bug

* fix security bug

* fix security bug

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add a controllable MCP Server DB Pool permission authentication system, connect it with the Doris permission system, and provide it to enterprise-level applications concurrently with the multi-Worker mode.
2025-09-02 18:40:48 +08:00
Yijia Su
c3d487ccdd [Performance]Add complete Token, JWT, OAuth authentication system (#52)
* 0.5.1 Version

* fix 0.5.1 schema async bug

* fix security bug

* fix security bug

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system
2025-09-02 17:01:43 +08:00
Yijia Su
c1e3b13851 [Performance]Optimizing concurrent startup capabilities (#48)
* 0.5.1 Version

* fix 0.5.1 schema async bug

* fix security bug

* fix security bug
2025-09-02 13:39:05 +08:00
Yijia Su
5923cc1c89 [BUG]Fix security bug (#50)
* 0.5.1 Version

* fix 0.5.1 schema async bug

* fix security bug
2025-08-29 08:48:38 +08:00
Yijia Su
9b5ac8533d [BUG]Fix schema async bug (#49)
* 0.5.1 Version

* fix 0.5.1 schema async bug
2025-08-19 10:30:09 +08:00
ivin
cc84d605e5 [feature]Implement session cache for Doris connections (#44)
* [feature]Implement session cache for Doris connections

This PR introduces a `DorisSessionCache` to cache and reuse `DorisConnection` objects in memory.
This helps to reduce the overhead of creating new connections, especially for frequently used system sessions like "query"
and "system", and avoid not calling release_connection leads to `Connection acquisition timed out` when the number of
connection pools reaches the maximum value.

The PR #34 fixed the issue when calling the tool `exec_query`, but in the codebase, a large number of other tools directly
using get_connection ("query") to get connection object but without calling the release_connection method will cause the
connection to fail to be obtained after a certain number of times.

Key changes:
- Added `DorisSessionCache` class to manage the lifecycle of cached sessions.
- The cache is configurable to store system sessions, user sessions, or both. By default, only system sessions are cached.
- Integrated the session cache into `DorisConnectionManager`.
- `get_connection` now checks the cache before creating a new connection.
- `release_connection` removes the connection from the cache.

* Add tests
2025-08-11 13:39:30 +08:00
drgnchan
55dbdd5e14 [improvement] Enhance SQL injection detection patterns in SQLSecurityValidator (#46) 2025-08-11 13:29:51 +08:00
ivin
affa4a0319 [Test]Update tests (#29) 2025-08-07 23:27:36 +08:00
大痴小乙
ecb5db8137 [bugfix]Fix line ending issues in start_server.sh script for Docker container execution (#39)
Problem:
When running the start_server.sh script in a Docker container, the following errors occurred:
- : not foundserver.sh: 18: (and other lines)
- /app/start_server.sh: 35: Syntax error: "elif" unexpected (expecting "then")

Root cause:
The script file was using Windows-style line endings (CRLF) instead of Unix/Linux-style line endings (LF),
which caused syntax errors when executed in a Linux environment.

Solution:
1. Ensure start_server.sh file uses proper Unix line endings (LF)
2. Add dos2unix command in Dockerfile to convert line ending format of the script file
3. Automatically fix line ending issues during image build to ensure proper script execution in containers

This fixes the issue with starting Doris MCP Server in Linux-based Docker containers.
2025-08-05 17:31:44 +08:00
ivin
5d15f6f3a4 [feat] Refactor configuration management and fix default host (#41)
This PR refactors the configuration management to unify the handling of command-line arguments, environment variables, and default settings.

Key changes:
- All configuration is now consistently managed through the `DorisConfig` object.
- Command-line arguments now correctly override environment variables and default settings.
- The default server host/port is now correctly read from the configuration.
- Wrong environment variable loading in `schema_extractor.py` has been removed.

closes #37
2025-08-04 14:48:02 +08:00
ivin
6247d49192 [bugfix]Release connection after executing query (#34) 2025-07-29 14:05:44 +08:00
ivin
fb5e864a24 [improvement]Add bucket information in the output of analyze_table_storage (#33) 2025-07-29 14:04:36 +08:00
ivin
9bb5b17199 [Chore]Fixes client startup errors (#27) 2025-07-15 13:58:44 +08:00
Yijia Su
6d3c128f54 0.5.1 Version (#28)
0.5.1 Version (#28)
2025-07-15 11:56:46 +08:00
Yijia Su
651d524814 [BUG]Optimize and fix the capabilities of 0.5.0 tools (#26)
1. **Unified Naming for CLI Arguments and Environment Variables** 
- All database-related CLI arguments now use the `--doris-*` prefix, and environment variables use `DORIS_*` for consistency and maintainability. 
- Backward compatibility: old `--db-*` arguments are still supported.

2. **Automatic Filtering of System SQL in Slow Query TopN** 
- Slow query analysis now automatically excludes SQL statements involving `__internal_schema`, `information_schema`, and `mysql` system databases, ensuring only business-related slow queries are counted. 
- Filtering is performed at the SQL level using `NOT LIKE` and `state != 'ERR'` for efficiency and safety.

3. **Unified Query Timeout Configuration** 
- If no `timeout` is specified for query execution, the system will use the `config.performance.query_timeout` value as the default, falling back to 30 seconds if not configured.
- This avoids hardcoding and makes timeout management more flexible.

4. **Tool execution optimization**
- Significantly reduce the execution time of some data governance and operation and maintenance tools
- Optimize execution logic and reduce data scanning
- Enable concurrent scanning to speed up retrieval

5. **Log system optimization**
- Fix the Console log printing logic and output the log content correctly
- Add advanced tool execution process log output to facilitate further positioning of error locations

6. **DB Connection optimization**
- Fixed a connection pool acquisition exception caused by deadlock

7. **Other Improvements**
- Help documentation and CLI examples updated to reflect new and legacy parameter compatibility.
- Code comments and documentation further standardized for better team collaboration and open-source community understanding.
2025-07-14 19:04:11 +08:00
Yijia Su
54572d0861 [Feature]Add 9 New Tools (#23)
release 0.5.0
2025-07-11 12:03:13 +08:00
Yijia Su
d12dfbd014 [improvement]Optimize and refactor the log system (#21)
* add logger system AND fix Readme
2025-07-10 14:02:10 +08:00
Yijia Su
4052b7e938 [BUG]Completely solve the at_eof problem (#20)
* fix at_eof bug

* update uv.lock

* fix bug and change pool min values

* Fixed startup errors caused by multiple versions of MCP services

* fix connection bug
2025-07-10 13:08:32 +08:00
Yijia Su
693c48d5ee [BUG]Fixed startup errors caused by multiple versions of MCP services (#13)
* fix at_eof bug

* update uv.lock

* fix bug and change pool min values

* Fixed startup errors caused by multiple versions of MCP services
2025-07-03 15:04:16 +08:00
Yijia Su
c1ce9a5cc7 [Config]Delete the minimum data pool variable (#11)
* fix at_eof bug

* update uv.lock

* fix bug and change pool min values
2025-07-02 19:57:45 +08:00
Yijia Su
282a1c0bd9 [BUG]Further fix the at_eof problem caused by aiomysql (#9)
* fix at_eof bug

* update uv.lock
2025-07-02 19:29:37 +08:00
Yijia Su
e3b9bf96ab Update .asf.yaml (#10) 2025-07-02 19:26:30 +08:00
Gerry Qi
667cecbbe0 Add .gitignore file (#7)
* Add dify dsl demo

* Deploying on docker

* Add .gitignore file

---------

Co-authored-by: Gerry.qi 齐晓明 <Gerry.qi@pousheng.com>
2025-07-02 18:30:46 +08:00
haijun huang
c777905bd3 fix the cofig of doris-mcp-server (#6) 2025-07-02 18:29:28 +08:00
haijun huang
d4ea125e35 add cursor demo (#4)
* add cursor demo

* fix image
2025-07-02 10:00:22 +08:00
Gerry Qi
f135d9b949 Add dify dsl demo (#3)
* Add dify dsl demo

* Deploying on docker

---------

Co-authored-by: Gerry.qi 齐晓明 <Gerry.qi@pousheng.com>
Co-authored-by: Gerry.qi <Gerry.qi@outlook.com>
2025-06-27 16:28:58 +08:00
Yijia Su
124dd0da88 Update .asf.yaml 2025-06-27 12:54:52 +08:00
Yijia Su
775b4cb630 Update .asf.yaml 2025-06-27 12:53:00 +08:00
FreeOnePlus
26e8bc1149 change pipy project name 2025-06-27 12:44:57 +08:00
FreeOnePlus
8526cb75fe v0.4.2 preview 2025-06-26 20:23:54 +08:00
FreeOnePlus
97006a756d v0.4.1 preview 2025-06-26 18:55:30 +08:00
Yijia Su
72865654e2 Merge pull request #2 from echo-hhj/example
[demo]add dify demo
2025-06-23 12:44:23 +08:00
HuangHaijun
050c09f902 fix contents 2025-06-19 13:10:32 +08:00
HuangHaijun
159399bd38 fix the way to start server 2025-06-19 13:03:17 +08:00
HuangHaijun
e859fbb778 fix1 2025-06-18 12:54:14 +08:00
HuangHaijun
1b9cb29f5f fix 2025-06-18 12:53:18 +08:00
HuangHaijun
c95c0fe03c add dify demo 2025-06-18 12:44:05 +08:00
FreeOnePlus
1e2e79d90d v0.4.0 preview 2025-06-12 19:36:16 +08:00
FreeOnePlus
609816bc4a fix doc bug 2025-06-12 05:10:07 +08:00
FreeOnePlus
5d46d153e1 1. Fix DB Connection BUG
2. Modify the global default configuration items and obtain them from Config
2025-06-11 11:52:15 +08:00
FreeOnePlus
0a81d5693b add readme QA module 2025-06-10 21:28:39 +08:00
FreeOnePlus
a4306867f6 Merge remote-tracking branch 'origin/master' 2025-06-10 21:11:05 +08:00
FreeOnePlus
a22ff3ae9b add readme QA module 2025-06-10 21:04:46 +08:00
Yijia Su
2c5f26889c Update .asf.yaml
open issue,discussions
2025-06-10 14:04:14 +08:00
FreeOnePlus
e47534c296 Merge remote-tracking branch 'origin/master' 2025-06-09 23:07:54 +08:00
FreeOnePlus
0f52591259 Add pip install command & fix pyproject.toml bug 2025-06-09 22:54:01 +08:00
Yijia Su
3b429f37b3 Merge pull request #1 from iouAkira/patch-1
fix Dockerfile
2025-06-09 18:48:58 +08:00
FreeOnePlus
f5a4c8abbe Add pip install command & fix pyproject.toml bug 2025-06-09 18:42:34 +08:00
FreeOnePlus
87563ef6e1 Add pip install command 2025-06-09 18:37:41 +08:00
Akira
b6157c500b fix Dockerfile 2025-06-09 09:21:42 +08:00
75 changed files with 24757 additions and 1462 deletions

View File

@@ -24,9 +24,28 @@ github:
- olap - olap
- lakehouse - lakehouse
- mcp - mcp
- ai
enabled_merge_buttons: enabled_merge_buttons:
squash: true squash: true
merge: false merge: false
rebase: false rebase: false
features:
issues: true
projects: true
rulesets:
- name: "Default Branch Protection"
type: branch
branches:
includes:
- "~DEFAULT_BRANCH"
- "release/*"
- "rel/*"
excludes: []
bypass_teams:
- root
restrict_deletion: true
restrict_force_push: true
notifications: notifications:
pullrequests_status: commits@doris.apache.org issues: commits@doris.apache.org
commits: commits@doris.apache.org
pullrequests: commits@doris.apache.org

2
.dockerignore Normal file
View File

@@ -0,0 +1,2 @@
**/.venv
**/venv

View File

@@ -14,58 +14,512 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# ===================================================================
# Doris MCP Server Environment Configuration Example
# ===================================================================
# Copy this file to .env and modify the configuration values as needed
# Doris MCP Server Environment Configuration # ===================================================================
# Copy this file to .env and modify the values as needed # Database Connection Configuration
# ===================================================================
# Database Configuration # Doris FE (Frontend) connection settings
DORIS_HOST=localhost DORIS_HOST=localhost
DORIS_PORT=9030 DORIS_PORT=9030
DORIS_USER=root DORIS_USER=root
DORIS_PASSWORD=your_password_here DORIS_PASSWORD=
DORIS_DATABASE=your_database_name DORIS_DATABASE=information_schema
# Connection Pool Settings # Doris FE HTTP API port (for Profile and other HTTP APIs)
DORIS_MIN_CONNECTIONS=5 DORIS_FE_HTTP_PORT=8030
# Doris BE (Backend) nodes configuration (optional, for external access)
# Format: host1,host2,host3 (if empty, will use "show backends" to get BE nodes)
DORIS_BE_HOSTS=
DORIS_BE_WEBSERVER_PORT=8040
# Connection pool configuration
DORIS_MAX_CONNECTIONS=20 DORIS_MAX_CONNECTIONS=20
DORIS_CONNECTION_TIMEOUT=30 DORIS_CONNECTION_TIMEOUT=30
DORIS_HEALTH_CHECK_INTERVAL=60 DORIS_HEALTH_CHECK_INTERVAL=60
DORIS_MAX_CONNECTION_AGE=3600 DORIS_MAX_CONNECTION_AGE=3600
# Security Settings # Arrow Flight SQL Configuration (Required for ADBC tools)
# FE_ARROW_FLIGHT_SQL_PORT=
# BE_ARROW_FLIGHT_SQL_PORT=
# ===================================================================
# Security Configuration
# ===================================================================
# Independent Authentication Switches - NEW DESIGN!
# Each authentication method can be enabled/disabled independently
# Any enabled method that succeeds will allow access
# If all methods are disabled, anonymous access is allowed
# Legacy configuration - kept for backward compatibility
# AUTH_TYPE is now deprecated - use individual switches above
AUTH_TYPE=token AUTH_TYPE=token
TOKEN_SECRET=your_256_bit_secret_key_here
# Token Authentication (Default method - simple and effective)
ENABLE_TOKEN_AUTH=false
# JWT Authentication (For stateless applications)
ENABLE_JWT_AUTH=false
# OAuth 2.0/OIDC Authentication (For enterprise integration)
ENABLE_OAUTH_AUTH=false
# ===================================================================
# Token Authentication Configuration (Enable with ENABLE_TOKEN_AUTH=true)
# ===================================================================
# Basic token authentication settings
TOKEN_FILE_PATH=tokens.json
ENABLE_TOKEN_EXPIRY=true
DEFAULT_TOKEN_EXPIRY_HOURS=720
TOKEN_HASH_ALGORITHM=sha256
# ===================================================================
# Token Management Security Configuration (NEW in v0.6.0) - CRITICAL SECURITY SETTINGS
# ===================================================================
# HTTP Token Management Endpoints (DISABLED BY DEFAULT FOR SECURITY)
# WARNING: These endpoints allow creation, deletion, and management of authentication tokens
# Only enable if you need HTTP-based token management and understand the security implications
ENABLE_HTTP_TOKEN_MANAGEMENT=true
# Admin Authentication Token (REQUIRED if HTTP token management is enabled)
# This token is required to access HTTP token management endpoints
# SECURITY: Generate a secure random token in production - NEVER use default values
TOKEN_MANAGEMENT_ADMIN_TOKEN=
# IP Address Restrictions for Token Management (CRITICAL SECURITY CONTROL)
# Only these IP addresses/networks can access token management endpoints
# DEFAULT: localhost only (most secure) - add other IPs/networks only if necessary
# Format: comma-separated list of IPs and CIDR networks
# Examples:
# - Localhost only: 127.0.0.1,::1
# - Private network: 127.0.0.1,192.168.1.0/24,10.0.0.0/8
# - Specific IPs: 127.0.0.1,192.168.1.10,192.168.1.11
TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1
# Require Admin Authentication (ENABLED BY DEFAULT FOR SECURITY)
# When true, all token management operations require valid admin token
# When false, only IP restrictions apply (NOT RECOMMENDED for production)
REQUIRE_ADMIN_AUTH=true
# ===================================================================
# JWT Authentication Configuration (Enable with ENABLE_JWT_AUTH=true)
# ===================================================================
# JWT token settings (when ENABLE_JWT_AUTH=true)
JWT_SECRET_KEY=your_jwt_secret_key_here_change_in_production
JWT_ALGORITHM=HS256
JWT_EXPIRATION_HOURS=24
JWT_ISSUER=doris-mcp-server
JWT_AUDIENCE=doris-mcp-client
# JWT token validation settings
JWT_VERIFY_SIGNATURE=true
JWT_VERIFY_EXPIRATION=true
JWT_VERIFY_AUDIENCE=true
JWT_VERIFY_ISSUER=true
# JWT refresh token settings
ENABLE_JWT_REFRESH=true
JWT_REFRESH_EXPIRATION_DAYS=30
JWT_REFRESH_SECRET_KEY=your_jwt_refresh_secret_key_here
# JWT user claims configuration
JWT_USER_ID_CLAIM=user_id
JWT_ROLES_CLAIM=roles
JWT_PERMISSIONS_CLAIM=permissions
JWT_SECURITY_LEVEL_CLAIM=security_level
# ===================================================================
# OAuth 2.0 / OpenID Connect Configuration (Enable with ENABLE_OAUTH_AUTH=true)
# ===================================================================
# OAuth provider settings (when ENABLE_OAUTH_AUTH=true)
OAUTH_PROVIDER_TYPE=generic
OAUTH_CLIENT_ID=your_oauth_client_id
OAUTH_CLIENT_SECRET=your_oauth_client_secret
OAUTH_REDIRECT_URI=http://localhost:3000/auth/callback
# OAuth endpoints (for generic provider)
OAUTH_AUTHORIZATION_URL=https://your-provider.com/auth
OAUTH_TOKEN_URL=https://your-provider.com/token
OAUTH_USERINFO_URL=https://your-provider.com/userinfo
OAUTH_JWKS_URL=https://your-provider.com/.well-known/jwks.json
# OAuth scope and claims
OAUTH_SCOPE=openid profile email
OAUTH_USER_ID_CLAIM=sub
OAUTH_USERNAME_CLAIM=preferred_username
OAUTH_EMAIL_CLAIM=email
OAUTH_ROLES_CLAIM=roles
OAUTH_GROUPS_CLAIM=groups
# OAuth session settings
OAUTH_SESSION_SECRET=your_oauth_session_secret_here
OAUTH_SESSION_EXPIRY=3600
OAUTH_STATE_EXPIRY=300
# Popular OAuth providers presets (uncomment and configure as needed)
# Google OAuth Configuration
# OAUTH_PROVIDER_TYPE=google
# OAUTH_CLIENT_ID=your_google_client_id.apps.googleusercontent.com
# OAUTH_CLIENT_SECRET=your_google_client_secret
# OAUTH_AUTHORIZATION_URL=https://accounts.google.com/o/oauth2/auth
# OAUTH_TOKEN_URL=https://oauth2.googleapis.com/token
# OAUTH_USERINFO_URL=https://www.googleapis.com/oauth2/v1/userinfo
# OAUTH_JWKS_URL=https://www.googleapis.com/oauth2/v3/certs
# OAUTH_SCOPE=openid profile email
# Microsoft Azure AD Configuration
# OAUTH_PROVIDER_TYPE=azure
# OAUTH_CLIENT_ID=your_azure_client_id
# OAUTH_CLIENT_SECRET=your_azure_client_secret
# OAUTH_TENANT_ID=your_tenant_id
# OAUTH_AUTHORIZATION_URL=https://login.microsoftonline.com/{tenant}/oauth2/v2.0/authorize
# OAUTH_TOKEN_URL=https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token
# OAUTH_USERINFO_URL=https://graph.microsoft.com/v1.0/me
# OAUTH_JWKS_URL=https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys
# OAUTH_SCOPE=openid profile email
# GitHub OAuth Configuration
# OAUTH_PROVIDER_TYPE=github
# OAUTH_CLIENT_ID=your_github_client_id
# OAUTH_CLIENT_SECRET=your_github_client_secret
# OAUTH_AUTHORIZATION_URL=https://github.com/login/oauth/authorize
# OAUTH_TOKEN_URL=https://github.com/login/oauth/access_token
# OAUTH_USERINFO_URL=https://api.github.com/user
# OAUTH_SCOPE=user:email
# GitLab OAuth Configuration
# OAUTH_PROVIDER_TYPE=gitlab
# OAUTH_CLIENT_ID=your_gitlab_client_id
# OAUTH_CLIENT_SECRET=your_gitlab_client_secret
# OAUTH_AUTHORIZATION_URL=https://gitlab.com/oauth/authorize
# OAUTH_TOKEN_URL=https://gitlab.com/oauth/token
# OAUTH_USERINFO_URL=https://gitlab.com/api/v4/user
# OAUTH_SCOPE=read_user
# Keycloak OAuth Configuration
# OAUTH_PROVIDER_TYPE=keycloak
# OAUTH_CLIENT_ID=your_keycloak_client_id
# OAUTH_CLIENT_SECRET=your_keycloak_client_secret
# OAUTH_REALM=your_realm
# OAUTH_SERVER_URL=https://your-keycloak-server.com
# OAUTH_AUTHORIZATION_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/auth
# OAUTH_TOKEN_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/token
# OAUTH_USERINFO_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/userinfo
# OAUTH_JWKS_URL=https://your-keycloak-server.com/auth/realms/{realm}/protocol/openid-connect/certs
# OAUTH_SCOPE=openid profile email
# Legacy token settings (for backward compatibility)
TOKEN_SECRET=your_secret_key_here
TOKEN_EXPIRY=3600 TOKEN_EXPIRY=3600
# SQL security check
ENABLE_SECURITY_CHECK=true
# Blocked keywords (comma separated)
BLOCKED_KEYWORDS=DROP,CREATE,ALTER,TRUNCATE,DELETE,INSERT,UPDATE,GRANT,REVOKE,EXEC,EXECUTE,SHUTDOWN,KILL
# Query limits
MAX_QUERY_COMPLEXITY=100
MAX_RESULT_ROWS=10000 MAX_RESULT_ROWS=10000
# Data masking
ENABLE_MASKING=true ENABLE_MASKING=true
# Performance Settings # ===================================================================
# Performance Configuration
# ===================================================================
# Query cache
ENABLE_QUERY_CACHE=true ENABLE_QUERY_CACHE=true
CACHE_TTL=300 CACHE_TTL=300
MAX_CACHE_SIZE=1000 MAX_CACHE_SIZE=1000
# Concurrency control
MAX_CONCURRENT_QUERIES=50 MAX_CONCURRENT_QUERIES=50
QUERY_TIMEOUT=300 QUERY_TIMEOUT=300
# Logging Configuration # Response content size limit (characters)
LOG_LEVEL=INFO MAX_RESPONSE_CONTENT_SIZE=4096
LOG_FILE_PATH=./log/doris-mcp-server.log
ENABLE_AUDIT=true
AUDIT_FILE_PATH=./log/doris-mcp-audit.log
# Monitoring Settings # ===================================================================
# ADBC (Arrow Flight SQL) Configuration
# ===================================================================
# Enable/disable ADBC tools
ADBC_ENABLED=true
# Default ADBC query parameters
ADBC_DEFAULT_MAX_ROWS=100000
ADBC_DEFAULT_TIMEOUT=60
# Format: "arrow", "pandas", "dict"
ADBC_DEFAULT_RETURN_FORMAT=arrow
# ADBC connection timeout
ADBC_CONNECTION_TIMEOUT=300
# ===================================================================
# Logging Configuration
# ===================================================================
# Basic logging configuration
LOG_LEVEL=INFO
LOG_FILE_PATH=
# Audit logging
ENABLE_AUDIT=true
AUDIT_FILE_PATH=
# Log file rotation configuration
LOG_MAX_FILE_SIZE=10485760
LOG_BACKUP_COUNT=5
# ===================================================================
# Log Cleanup Configuration - NEW!
# ===================================================================
# Enable automatic log cleanup
ENABLE_LOG_CLEANUP=true
# Maximum age of log files in days (files older than this will be deleted)
LOG_MAX_AGE_DAYS=30
# Cleanup check interval in hours
LOG_CLEANUP_INTERVAL_HOURS=24
# ===================================================================
# Monitoring Configuration
# ===================================================================
# Metrics collection
ENABLE_METRICS=true ENABLE_METRICS=true
METRICS_PORT=3001 METRICS_PORT=3001
METRICS_PATH=/metrics
HEALTH_CHECK_PORT=3002 HEALTH_CHECK_PORT=3002
HEALTH_CHECK_PATH=/health
# Alert configuration
ENABLE_ALERTS=false ENABLE_ALERTS=false
ALERT_WEBHOOK_URL= ALERT_WEBHOOK_URL=
# Server Settings # ===================================================================
# Server Configuration
# ===================================================================
# Basic server information
SERVER_NAME=doris-mcp-server SERVER_NAME=doris-mcp-server
SERVER_VERSION=0.3.0 SERVER_VERSION=0.6.0
SERVER_PORT=3000 SERVER_PORT=3000
# Development Settings (for development environment only) # Temporary files directory
DEBUG=false TEMP_FILES_DIR=tmp
VERBOSE=false
# ===================================================================
# Configuration Examples for Different Environments
# ===================================================================
# Development Environment Example:
# LOG_LEVEL=DEBUG
# LOG_MAX_AGE_DAYS=7
# LOG_CLEANUP_INTERVAL_HOURS=6
# ENABLE_SECURITY_CHECK=false
# Production Environment Example:
# LOG_LEVEL=INFO
# LOG_MAX_AGE_DAYS=30
# LOG_CLEANUP_INTERVAL_HOURS=24
# ENABLE_SECURITY_CHECK=true
# ENABLE_LOG_CLEANUP=true
# Testing Environment Example:
# LOG_LEVEL=WARNING
# LOG_MAX_AGE_DAYS=3
# LOG_CLEANUP_INTERVAL_HOURS=1
# MAX_RESULT_ROWS=1000
# ===================================================================
# Advanced Configuration Notes
# ===================================================================
# 1. Log Cleanup Feature:
# - ENABLE_LOG_CLEANUP: Controls whether to enable automatic cleanup
# - LOG_MAX_AGE_DAYS: File retention days, recommended 30 days for production, 7 days for development
# - LOG_CLEANUP_INTERVAL_HOURS: Check frequency, recommended 24 hours
# 2. Security Best Practices:
# - NEW: Enable individual authentication methods using ENABLE_TOKEN_AUTH, ENABLE_JWT_AUTH, ENABLE_OAUTH_AUTH
# - When all methods are disabled, ALL requests are allowed with anonymous access
# - Authentication methods work independently - any one succeeding allows access
# - Token Auth: Change default tokens (DEFAULT_ADMIN_TOKEN, etc.) in production
# - JWT Auth: Change JWT_SECRET_KEY and JWT_REFRESH_SECRET_KEY in production
# - OAuth Auth: Configure OAuth provider settings and secure client secrets
# - Must change TOKEN_SECRET in production environment (legacy compatibility)
# - Adjust BLOCKED_KEYWORDS according to business needs
# - Enable ENABLE_SECURITY_CHECK and ENABLE_MASKING
# - NEW v0.6.0: Token Management Security (CRITICAL):
# * ENABLE_HTTP_TOKEN_MANAGEMENT=false by default (SECURE BY DEFAULT)
# * Only enable if you need HTTP token management endpoints
# * TOKEN_MANAGEMENT_ADMIN_TOKEN: Use secure random token in production
# * TOKEN_MANAGEMENT_ALLOWED_IPS: Restrict to localhost (127.0.0.1,::1) only
# * REQUIRE_ADMIN_AUTH=true: Always require admin authentication
# * Never expose token management endpoints to external networks
# 3. Performance Tuning:
# - Adjust MAX_CONCURRENT_QUERIES based on hardware resources
# - Adjust QUERY_TIMEOUT based on query complexity
# - Adjust MAX_CACHE_SIZE based on memory size
# 4. Connection Pool Optimization:
# - DORIS_MAX_CONNECTIONS recommended to be 2-4 times the number of CPU cores
# - DORIS_CONNECTION_TIMEOUT adjust based on network latency
# - DORIS_MAX_CONNECTION_AGE recommended 1 hour to avoid long connection issues
# 5. ADBC (Arrow Flight SQL) Configuration:
# - FE_ARROW_FLIGHT_SQL_PORT and BE_ARROW_FLIGHT_SQL_PORT: Required for ADBC functionality
# - ADBC_DEFAULT_MAX_ROWS: Default maximum rows for ADBC queries (recommended: 100000)
# - ADBC_DEFAULT_TIMEOUT: Default timeout for ADBC queries in seconds (recommended: 60)
# - ADBC_DEFAULT_RETURN_FORMAT: Default return format (arrow/pandas/dict, recommended: arrow)
# - ADBC_CONNECTION_TIMEOUT: Connection timeout for ADBC (recommended: 30)
# - ADBC_ENABLED: Enable or disable ADBC tools (true/false)
# - Prerequisites: Install adbc_driver_manager, adbc_driver_flightsql, pyarrow packages
# 6. Authentication Configuration Guide - UPDATED DESIGN!
#
# Independent Authentication Control (NEW):
# - ENABLE_TOKEN_AUTH=false (default): Disable token authentication
# - ENABLE_JWT_AUTH=false (default): Disable JWT authentication
# - ENABLE_OAUTH_AUTH=false (default): Disable OAuth authentication
# - When all methods are disabled, no authentication is required (anonymous access)
# - When multiple methods are enabled, any one succeeding allows access
# - Recommended for development/testing: all false, production: enable needed methods
#
# Token Authentication (ENABLE_TOKEN_AUTH=true) - Recommended for most use cases:
# - Simple and secure token-based authentication
# - Configurable default tokens via environment variables
# - Support for custom tokens via TOKEN_* environment variables
# - Token file configuration via tokens.json
# - Built-in token management HTTP endpoints
# - No user management complexity - pure API access control
#
# JWT Authentication (ENABLE_JWT_AUTH=true) - For stateless applications:
# - JSON Web Token based authentication
# - Configurable token expiration and refresh
# - Support for standard JWT claims
# - RSA/ECDSA/HS256 algorithm support
# - Suitable for microservices and distributed systems
#
# OAuth 2.0/OIDC (ENABLE_OAUTH_AUTH=true) - For enterprise integration:
# - Integration with external identity providers
# - Support for popular providers (Google, Microsoft, GitHub, GitLab, Keycloak)
# - OpenID Connect compatibility
# - Automatic user provisioning from provider
# - Secure authorization code flow
#
# Authentication Method Selection Guide:
# - No Auth (all switches false): Development, testing, trusted networks
# - Token Auth only: Small teams, simple deployment, direct API access
# - JWT Auth only: Stateless apps, microservices, mobile clients
# - OAuth Auth only: Enterprise SSO, large teams, external identity providers
# - Multiple methods: Flexible access, different client types, migration scenarios
# 7. Token Management Security Configuration Guide (NEW in v0.6.0) - CRITICAL!
#
# ⚠️ SECURITY WARNING: Token management endpoints are POWERFUL and DANGEROUS
# They allow creation, revocation, and management of authentication tokens.
# Improper configuration can lead to complete system compromise.
#
# 🔒 SECURE BY DEFAULT:
# - ENABLE_HTTP_TOKEN_MANAGEMENT=false (disabled by default)
# - REQUIRE_ADMIN_AUTH=true (admin auth required by default)
# - TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1 (localhost only by default)
#
# 🛡️ SECURITY LAYERS (Applied in order):
# 1. Configuration Check: HTTP token management must be explicitly enabled
# 2. IP Restrictions: Only allowed IP addresses/networks can access endpoints
# 3. Admin Authentication: Valid admin token required for all operations
#
# 📋 CONFIGURATION OPTIONS:
#
# Disable Token Management (RECOMMENDED for most deployments):
# ENABLE_HTTP_TOKEN_MANAGEMENT=false
# # All token management endpoints will return 403 Forbidden
#
# Enable with Maximum Security (Production):
# ENABLE_HTTP_TOKEN_MANAGEMENT=true
# TOKEN_MANAGEMENT_ADMIN_TOKEN=<secure-random-token-256-bit>
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1
# REQUIRE_ADMIN_AUTH=true
#
# Enable for Private Network (Advanced):
# ENABLE_HTTP_TOKEN_MANAGEMENT=true
# TOKEN_MANAGEMENT_ADMIN_TOKEN=<secure-random-token-256-bit>
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,192.168.1.0/24,10.0.0.0/8
# REQUIRE_ADMIN_AUTH=true
#
# 🔑 ADMIN TOKEN GENERATION:
# # Generate secure admin token (Linux/macOS):
# openssl rand -hex 32
#
# # Generate secure admin token (Python):
# python -c "import secrets; print(secrets.token_urlsafe(32))"
#
# 🌐 IP CONFIGURATION EXAMPLES:
# # Localhost only (most secure):
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1
#
# # Private network + localhost:
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,::1,192.168.1.0/24,10.0.0.0/8
#
# # Specific servers only:
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,192.168.1.10,192.168.1.11
#
# # Corporate network (be careful):
# TOKEN_MANAGEMENT_ALLOWED_IPS=127.0.0.1,172.16.0.0/12,192.168.0.0/16
#
# 🚫 NEVER DO THIS (Security Anti-Patterns):
# # NEVER allow all IPs:
# # TOKEN_MANAGEMENT_ALLOWED_IPS=0.0.0.0/0 # DANGEROUS!
#
# # NEVER disable admin auth in production:
# # REQUIRE_ADMIN_AUTH=false # DANGEROUS!
#
# # NEVER use weak admin tokens:
# # TOKEN_MANAGEMENT_ADMIN_TOKEN=admin # DANGEROUS!
# # TOKEN_MANAGEMENT_ADMIN_TOKEN=123456 # DANGEROUS!
#
# 📊 ENDPOINT SECURITY TESTING:
# # Test security (should fail):
# curl -X POST http://external-ip:3000/token/create
# # Expected: 403 Forbidden (IP not allowed)
#
# # Test without auth (should fail):
# curl -X POST http://127.0.0.1:3000/token/create
# # Expected: 401 Unauthorized (missing admin token)
#
# # Test with valid auth (should succeed if enabled):
# curl -H "Authorization: Bearer your-admin-token" http://127.0.0.1:3000/token/stats
# # Expected: 200 OK with token statistics
#
# 🔍 MONITORING & AUDITING:
# # All token management access attempts are logged:
# tail -f logs/doris_mcp_server_audit.log | grep "token management"
#
# # Monitor security events:
# tail -f logs/doris_mcp_server_info.log | grep -E "(access denied|token management)"
#
# ✅ SECURITY BEST PRACTICES:
# - Keep ENABLE_HTTP_TOKEN_MANAGEMENT=false unless absolutely necessary
# - Use file-based token management (tokens.json) instead of HTTP endpoints
# - Generate strong admin tokens using cryptographically secure methods
# - Restrict access to localhost (127.0.0.1,::1) whenever possible
# - Never expose token management endpoints to public internet
# - Regularly audit token management access logs
# - Use firewall rules as additional protection layer
# - Consider VPN access for remote token management needs

23
.gitignore vendored Normal file
View File

@@ -0,0 +1,23 @@
*.log
*.log.*
*.bak
logs
/configs/*.py
.vscode/
__pycache__/
*.log
.python-version
Pipfile.lock
poetry.lock
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
.idea/
.coverage
coverage.xml

View File

@@ -32,6 +32,7 @@ RUN apt-get update && apt-get install -y \
g++ \ g++ \
pkg-config \ pkg-config \
default-libmysqlclient-dev \ default-libmysqlclient-dev \
dos2unix \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Copy requirements file # Copy requirements file
@@ -43,12 +44,13 @@ RUN pip install --no-cache-dir -r requirements.txt
# Copy application code # Copy application code
COPY . . COPY . .
# Convert line endings for shell scripts and ensure proper execution format
RUN find . -name "*.sh" -exec dos2unix {} \; && \
find . -name "*.sh" -exec chmod +x {} \;
# Create necessary directories # Create necessary directories
RUN mkdir -p /app/logs /app/config /app/data RUN mkdir -p /app/logs /app/config /app/data
# Set permissions
RUN chmod +x /app/start.sh
# Create non-root user # Create non-root user
RUN groupadd -r doris && useradd -r -g doris doris RUN groupadd -r doris && useradd -r -g doris doris
RUN chown -R doris:doris /app RUN chown -R doris:doris /app
@@ -62,4 +64,4 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
EXPOSE 3000 3001 3002 EXPOSE 3000 3001 3002
# Start command # Start command
CMD ["/app/start.sh"] CMD ["/app/start_server.sh"]

1090
README.md

File diff suppressed because it is too large Load Diff

View File

@@ -133,9 +133,6 @@ async def database_operations(client):
# Get table schema # Get table schema
schema = await client.get_table_schema("table_name", "db_name") schema = await client.get_table_schema("table_name", "db_name")
# Column data analysis
analysis = await client.analyze_column("table", "column", "basic")
``` ```
## 🧪 Testing ## 🧪 Testing
@@ -177,7 +174,6 @@ python test_unified_client.py benchmark
2. get_table_list: Get table list for specified database 2. get_table_list: Get table list for specified database
3. get_table_schema: Get table structure information 3. get_table_schema: Get table structure information
4. exec_query: Execute SQL query 4. exec_query: Execute SQL query
5. column_analysis: Analyze column data distribution and statistics
... ...
🧪 Testing basic functionality... 🧪 Testing basic functionality...
@@ -189,8 +185,6 @@ python test_unified_client.py benchmark
✅ SSB query successful ✅ SSB query successful
4⃣ Getting table structure... 4⃣ Getting table structure...
✅ Table structure retrieved successfully ✅ Table structure retrieved successfully
5⃣ Column data analysis...
✅ Column analysis successful
✅ HTTP mode testing completed! ✅ HTTP mode testing completed!
``` ```
@@ -256,12 +250,6 @@ async def comprehensive_example():
schema_result = await client.get_table_schema("lineorder", "ssb") schema_result = await client.get_table_schema("lineorder", "ssb")
print(f"Table schema: {schema_result}") print(f"Table schema: {schema_result}")
# Column analysis
analysis_result = await client.analyze_column(
"lineorder", "lo_orderkey", "basic"
)
print(f"Column analysis: {analysis_result}")
await client.connect_and_run(demo_operations) await client.connect_and_run(demo_operations)
# Run the example # Run the example

View File

@@ -323,7 +323,7 @@ class DorisUnifiedClient:
async with streamablehttp_client( async with streamablehttp_client(
self.config.server_url, self.config.server_url,
timeout=timedelta(seconds=self.config.timeout) timeout=timedelta(seconds=self.config.timeout)
) as (read, write): ) as (read, write, _):
async with ClientSession(read, write) as session: async with ClientSession(read, write) as session:
self.session = session self.session = session
self._init_sub_clients() self._init_sub_clients()
@@ -422,18 +422,14 @@ class DorisUnifiedClient:
return await self.call_tool(tool_name, kwargs) return await self.call_tool(tool_name, kwargs)
async def analyze_column(self, table_name: str, column_name: str, analysis_type: str = "basic", **kwargs) -> dict[str, Any]: async def get_memory_stats(self, tracker_type: str = "overview", include_details: bool = True, **kwargs) -> dict[str, Any]:
"""Analyze column""" """Get memory statistics"""
tool_name = await self._find_tool_by_pattern(["column_analysis", "analyze_column", "column"]) tool_name = await self._find_tool_by_pattern(["memory", "realtime_memory"])
if not tool_name: if not tool_name:
return {"success": False, "error": "Column analysis tool not found"} return {"success": False, "error": "Memory stats tool not found"}
arguments = { arguments = {"tracker_type": tracker_type, "include_details": include_details}
"table_name": table_name, arguments.update(kwargs)
"column_name": column_name,
"analysis_type": analysis_type,
**kwargs
}
return await self.call_tool(tool_name, arguments) return await self.call_tool(tool_name, arguments)
async def call_tool_by_function(self, function_description: str, arguments: dict[str, Any]) -> dict[str, Any]: async def call_tool_by_function(self, function_description: str, arguments: dict[str, Any]) -> dict[str, Any]:
@@ -467,7 +463,7 @@ async def create_http_client(server_url: str, timeout: int = 60) -> DorisUnified
# Example usage # Example usage
async def example_stdio(): async def example_stdio():
"""stdio mode example""" """stdio mode example"""
client = await create_stdio_client("python", ["doris_mcp_server/main.py"]) client = await create_stdio_client("python", ["-m", "doris_mcp_server.main", "--transport", "stdio"])
async def test_client(client: DorisUnifiedClient): async def test_client(client: DorisUnifiedClient):
# Get server capabilities # Get server capabilities

View File

@@ -0,0 +1,56 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Doris MCP Server Authentication Module
Provides JWT-based, Token-based, and OAuth 2.0/OIDC authentication and authorization services
"""
from .jwt_manager import JWTManager
from .key_manager import KeyManager
from .token_validators import TokenValidator, TokenBlacklist
from .auth_middleware import AuthMiddleware
from .token_manager import TokenManager, TokenInfo, TokenValidationResult
from .token_handlers import TokenHandlers
from .oauth_client import OAuthClient, OAuthStateManager
from .oauth_provider import OAuthAuthenticationProvider
from .oauth_types import (
OAuthProvider, OAuthState, OAuthTokens, OAuthUserInfo,
OIDCDiscovery, OAuthError, OAuthProviderConfig
)
__all__ = [
"JWTManager",
"KeyManager",
"TokenValidator",
"TokenBlacklist",
"AuthMiddleware",
"TokenManager",
"TokenInfo",
"TokenValidationResult",
"TokenHandlers",
"OAuthClient",
"OAuthStateManager",
"OAuthAuthenticationProvider",
"OAuthProvider",
"OAuthState",
"OAuthTokens",
"OAuthUserInfo",
"OIDCDiscovery",
"OAuthError",
"OAuthProviderConfig"
]

View File

@@ -0,0 +1,271 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Authentication Middleware Module
Provides middleware for JWT authentication in HTTP and MCP contexts
"""
from typing import Optional, Dict, Any, Callable, Awaitable
from datetime import datetime
from .jwt_manager import JWTManager
from ..utils.security import AuthContext, SecurityLevel
from ..utils.logger import get_logger
logger = get_logger(__name__)
class AuthMiddleware:
"""Authentication Middleware
Provides JWT authentication functionality for HTTP and MCP requests
"""
def __init__(self, jwt_manager: JWTManager):
"""Initialize authentication middleware
Args:
jwt_manager: JWT manager instance
"""
self.jwt_manager = jwt_manager
logger.info("AuthMiddleware initialized")
def extract_token_from_header(self, authorization: str) -> Optional[str]:
"""Extract JWT token from Authorization header
Args:
authorization: Authorization header value
Returns:
JWT token string, or None if not found
"""
if not authorization:
return None
# Support Bearer format
if authorization.startswith('Bearer '):
return authorization[7:] # Remove "Bearer " prefix
# Support direct token format
if not authorization.startswith('Basic '):
return authorization
return None
async def authenticate_request(self, auth_info: Dict[str, Any]) -> AuthContext:
"""Authenticate request and return authentication context
Args:
auth_info: Authentication information dictionary
Returns:
AuthContext authentication context
Raises:
ValueError: Authentication failed
"""
try:
auth_type = auth_info.get("type", "jwt")
if auth_type == "jwt" or auth_type == "token":
return await self._authenticate_jwt(auth_info)
else:
raise ValueError(f"Unsupported authentication type: {auth_type}")
except Exception as e:
logger.error(f"Request authentication failed: {e}")
raise
async def _authenticate_jwt(self, auth_info: Dict[str, Any]) -> AuthContext:
"""JWT authentication processing
Args:
auth_info: Authentication information containing JWT token
Returns:
AuthContext authentication context
"""
# Get token
token = auth_info.get("token")
if not token:
# Try to get from Authorization header
authorization = auth_info.get("authorization")
token = self.extract_token_from_header(authorization)
if not token:
raise ValueError("Missing JWT token")
try:
# Validate token
validation_result = await self.jwt_manager.validate_token(token, 'access')
payload = validation_result['payload']
# Build authentication context
auth_context = AuthContext(
token_id=payload.get('jti', ''),
user_id=payload.get('sub'),
roles=payload.get('roles', []),
permissions=payload.get('permissions', []),
security_level=SecurityLevel(payload.get('security_level', 'internal')),
session_id=payload.get('jti'), # Use JWT ID as session ID
login_time=datetime.fromtimestamp(payload.get('iat', 0)),
last_activity=datetime.utcnow(),
token=token # Store raw token for token-bound database configuration
)
logger.info(f"JWT authentication successful for user: {auth_context.user_id}")
return auth_context
except Exception as e:
logger.error(f"JWT authentication failed: {e}")
raise ValueError(f"JWT authentication failed: {str(e)}")
async def create_auth_response_headers(self, auth_context: AuthContext) -> Dict[str, str]:
"""Create authentication response headers
Args:
auth_context: Authentication context
Returns:
Response headers dictionary
"""
return {
'X-Auth-User': auth_context.user_id,
'X-Auth-Roles': ','.join(auth_context.roles),
'X-Auth-Session': auth_context.session_id,
'X-Auth-Security-Level': auth_context.security_level.value
}
def create_http_middleware(self, skip_paths: Optional[list] = None):
"""Create HTTP middleware function
Args:
skip_paths: List of paths to skip authentication
Returns:
ASGI middleware function
"""
skip_paths = skip_paths or ['/health', '/docs', '/openapi.json']
async def middleware(scope, receive, send):
"""HTTP authentication middleware"""
if scope['type'] != 'http':
# Pass through non-HTTP requests directly
return await self.app(scope, receive, send)
path = scope.get('path', '')
# Check if authentication should be skipped
if any(path.startswith(skip) for skip in skip_paths):
return await self.app(scope, receive, send)
# Extract authentication information
headers = dict(scope.get('headers', []))
authorization = headers.get(b'authorization', b'').decode()
try:
# Perform authentication
auth_info = {
'type': 'jwt',
'authorization': authorization
}
auth_context = await self.authenticate_request(auth_info)
# Add authentication context to scope
scope['auth_context'] = auth_context
# Create response wrapper to add authentication headers
async def send_wrapper(message):
if message['type'] == 'http.response.start':
headers = dict(message.get('headers', []))
auth_headers = await self.create_auth_response_headers(auth_context)
for key, value in auth_headers.items():
headers[key.encode()] = value.encode()
message['headers'] = list(headers.items())
await send(message)
return await self.app(scope, receive, send_wrapper)
except Exception as e:
# Authentication failed, return 401 error
response_body = f'{{"error": "Authentication failed", "message": "{str(e)}"}}'
await send({
'type': 'http.response.start',
'status': 401,
'headers': [
(b'content-type', b'application/json'),
(b'www-authenticate', b'Bearer')
]
})
await send({
'type': 'http.response.body',
'body': response_body.encode()
})
return middleware
async def authenticate_mcp_request(self, headers: Dict[str, str]) -> AuthContext:
"""Authenticate MCP request
Args:
headers: MCP request headers
Returns:
AuthContext authentication context
"""
try:
# Extract authentication information from multiple possible header fields
authorization = (
headers.get('Authorization') or
headers.get('authorization') or
headers.get('X-Auth-Token') or
headers.get('x-auth-token')
)
auth_info = {
'type': 'jwt',
'authorization': authorization
}
return await self.authenticate_request(auth_info)
except Exception as e:
logger.error(f"MCP request authentication failed: {e}")
raise
class AuthenticationError(Exception):
"""Authentication error exception"""
def __init__(self, message: str, error_code: str = "AUTH_FAILED"):
self.message = message
self.error_code = error_code
super().__init__(message)
class AuthorizationError(Exception):
"""Authorization error exception"""
def __init__(self, message: str, error_code: str = "ACCESS_DENIED"):
self.message = message
self.error_code = error_code
super().__init__(message)

View File

@@ -0,0 +1,471 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
JWT Manager Module
Provides comprehensive JWT token management including generation, validation, refresh and revocation
"""
import time
import uuid
import asyncio
from typing import Dict, Any, Optional, Tuple
from datetime import datetime, timedelta
try:
import jwt
except ImportError:
raise ImportError("PyJWT is required for JWT functionality. Install with: pip install PyJWT[crypto]")
from .key_manager import KeyManager
from .token_validators import TokenValidator, TokenBlacklist
from ..utils.logger import get_logger
logger = get_logger(__name__)
class JWTManager:
"""JWT Token Manager
Provides comprehensive JWT token lifecycle management, including:
- Token generation and signing
- Token validation and parsing
- Token refresh mechanism
- Token revocation and blacklist
- Automatic key rotation
"""
def __init__(self, config):
"""Initialize JWT manager
Args:
config: DorisConfig configuration object (with security attribute)
"""
self.config = config
# Access JWT settings through the security configuration
if hasattr(config, 'security'):
security_config = config.security
else:
# Fallback if config is passed directly as SecurityConfig
security_config = config
self.algorithm = security_config.jwt_algorithm
self.issuer = security_config.jwt_issuer
self.audience = security_config.jwt_audience
self.access_token_expiry = security_config.jwt_access_token_expiry
self.refresh_token_expiry = security_config.jwt_refresh_token_expiry
self.enable_refresh = security_config.enable_token_refresh
self.enable_revocation = security_config.enable_token_revocation
# Initialize components
self.key_manager = KeyManager(config)
self.token_blacklist = TokenBlacklist()
self.validator = TokenValidator(config, self.token_blacklist)
# Automatic key rotation task
self._key_rotation_task = None
logger.info(f"JWTManager initialized with algorithm: {self.algorithm}")
async def initialize(self) -> bool:
"""Initialize JWT manager"""
try:
# Initialize key manager
if not await self.key_manager.initialize():
logger.error("Failed to initialize key manager")
return False
# Start token validator
await self.validator.start()
# Start automatic key rotation
if self.key_manager.key_rotation_interval > 0:
self._key_rotation_task = asyncio.create_task(self._auto_key_rotation())
logger.info("JWTManager initialization completed")
return True
except Exception as e:
logger.error(f"Failed to initialize JWTManager: {e}")
return False
async def shutdown(self):
"""Shutdown JWT manager"""
try:
# Stop key rotation task
if self._key_rotation_task:
self._key_rotation_task.cancel()
try:
await self._key_rotation_task
except asyncio.CancelledError:
pass
# Stop validator
await self.validator.stop()
logger.info("JWTManager shutdown completed")
except Exception as e:
logger.error(f"Error during JWTManager shutdown: {e}")
async def generate_tokens(self, user_info: Dict[str, Any],
custom_claims: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Generate access token and refresh token
Args:
user_info: User information dictionary, containing user_id, roles, permissions, etc.
custom_claims: Custom claims
Returns:
Dictionary containing access_token and refresh_token
"""
try:
current_time = int(time.time())
jti = str(uuid.uuid4())
# Build base payload
base_payload = {
'iss': self.issuer,
'aud': self.audience,
'iat': current_time,
'jti': jti,
'sub': user_info.get('user_id'),
'roles': user_info.get('roles', []),
'permissions': user_info.get('permissions', []),
'security_level': user_info.get('security_level', 'internal')
}
# Add custom claims
if custom_claims:
base_payload.update(custom_claims)
# Generate access token
access_payload = base_payload.copy()
access_payload.update({
'exp': current_time + self.access_token_expiry,
'token_type': 'access'
})
access_token = await self._sign_token(access_payload)
result = {
'access_token': access_token,
'token_type': 'Bearer',
'expires_in': self.access_token_expiry,
'user_id': user_info.get('user_id'),
'issued_at': current_time
}
# Generate refresh token (if enabled)
if self.enable_refresh:
refresh_jti = str(uuid.uuid4())
refresh_payload = {
'iss': self.issuer,
'aud': self.audience,
'iat': current_time,
'exp': current_time + self.refresh_token_expiry,
'jti': refresh_jti,
'sub': user_info.get('user_id'),
'token_type': 'refresh',
'access_jti': jti # Associated access token ID
}
refresh_token = await self._sign_token(refresh_payload)
result.update({
'refresh_token': refresh_token,
'refresh_expires_in': self.refresh_token_expiry
})
logger.info(f"Generated tokens for user: {user_info.get('user_id')}")
return result
except Exception as e:
logger.error(f"Failed to generate tokens: {e}")
raise
async def _sign_token(self, payload: Dict[str, Any]) -> str:
"""Sign JWT token
Args:
payload: JWT payload
Returns:
Signed JWT token
"""
try:
signing_key = self.key_manager.get_private_key()
if self.algorithm == "HS256":
# Symmetric key signing
token = jwt.encode(payload, signing_key, algorithm=self.algorithm)
else:
# Asymmetric key signing
token = jwt.encode(payload, signing_key, algorithm=self.algorithm)
return token
except Exception as e:
logger.error(f"Failed to sign token: {e}")
raise
async def validate_token(self, token: str, token_type: str = 'access') -> Dict[str, Any]:
"""Validate JWT token
Args:
token: JWT token string
token_type: Token type ('access' or 'refresh')
Returns:
Validation result and user information
Raises:
ValueError: Token validation failed
"""
try:
# Decode token
verification_key = self.key_manager.get_public_key()
# Get security configuration
if hasattr(self.config, 'security'):
security_config = self.config.security
else:
security_config = self.config
# JWT decoding options
options = {
'verify_signature': security_config.jwt_verify_signature,
'verify_exp': security_config.jwt_require_exp,
'verify_iat': security_config.jwt_require_iat,
'verify_nbf': security_config.jwt_require_nbf,
'verify_aud': security_config.jwt_verify_audience,
'verify_iss': security_config.jwt_verify_issuer,
}
# Decode JWT
payload = jwt.decode(
token,
verification_key,
algorithms=[self.algorithm],
audience=self.audience if security_config.jwt_verify_audience else None,
issuer=self.issuer if security_config.jwt_verify_issuer else None,
leeway=security_config.jwt_leeway,
options=options
)
# Check token type
if payload.get('token_type') != token_type:
raise ValueError(f"Invalid token type: expected {token_type}")
# Use validator for additional checks
validation_result = await self.validator.validate_claims(payload)
logger.info(f"Token validation successful for user: {payload.get('sub')}")
return validation_result
except jwt.ExpiredSignatureError:
raise ValueError("Token has expired")
except jwt.InvalidTokenError as e:
raise ValueError(f"Invalid token: {str(e)}")
except Exception as e:
logger.error(f"Token validation failed: {e}")
raise ValueError(f"Token validation failed: {str(e)}")
async def refresh_token(self, refresh_token: str) -> Dict[str, Any]:
"""Refresh access token
Args:
refresh_token: Refresh token
Returns:
New token pair
"""
if not self.enable_refresh:
raise ValueError("Token refresh is disabled")
try:
# Validate refresh token
refresh_result = await self.validate_token(refresh_token, 'refresh')
refresh_payload = refresh_result['payload']
# Revoke associated access token (if revocation is enabled)
if self.enable_revocation:
access_jti = refresh_payload.get('access_jti')
if access_jti:
# Should revoke old access token here, but since we don't know its expiration time,
# in practice might need to store more information or use different strategy
pass
# Build new user information
user_info = {
'user_id': refresh_payload.get('sub'),
'roles': refresh_payload.get('roles', []),
'permissions': refresh_payload.get('permissions', []),
'security_level': refresh_payload.get('security_level', 'internal')
}
# Generate new token pair
new_tokens = await self.generate_tokens(user_info)
logger.info(f"Token refreshed for user: {user_info['user_id']}")
return new_tokens
except Exception as e:
logger.error(f"Token refresh failed: {e}")
raise
async def revoke_token(self, token: str) -> bool:
"""Revoke token
Args:
token: Token to revoke
Returns:
Whether revocation was successful
"""
if not self.enable_revocation:
logger.warning("Token revocation is disabled")
return False
try:
# Decode token to get JTI and expiration time
verification_key = self.key_manager.get_public_key()
payload = jwt.decode(
token,
verification_key,
algorithms=[self.algorithm],
options={'verify_exp': False} # Allow decoding expired tokens
)
jti = payload.get('jti')
exp = payload.get('exp')
if not jti or not exp:
logger.error("Token missing required claims for revocation")
return False
# Add to blacklist
await self.validator.revoke_token(jti, exp)
logger.info(f"Token {jti} revoked successfully")
return True
except Exception as e:
logger.error(f"Token revocation failed: {e}")
return False
async def decode_token_unsafe(self, token: str) -> Dict[str, Any]:
"""Decode token without verifying signature (for debugging only)
Args:
token: JWT token
Returns:
Token payload
"""
try:
payload = jwt.decode(token, options={'verify_signature': False})
return payload
except Exception as e:
logger.error(f"Failed to decode token: {e}")
raise
async def get_token_info(self, token: str) -> Dict[str, Any]:
"""Get token information (without verifying signature)
Args:
token: JWT token
Returns:
Token information
"""
try:
payload = await self.decode_token_unsafe(token)
return {
'jti': payload.get('jti'),
'sub': payload.get('sub'),
'iss': payload.get('iss'),
'aud': payload.get('aud'),
'iat': payload.get('iat'),
'exp': payload.get('exp'),
'token_type': payload.get('token_type'),
'roles': payload.get('roles'),
'permissions': payload.get('permissions'),
'security_level': payload.get('security_level'),
'is_expired': payload.get('exp', 0) < time.time() if payload.get('exp') else None
}
except Exception as e:
logger.error(f"Failed to get token info: {e}")
raise
async def _auto_key_rotation(self):
"""Automatic key rotation task"""
while True:
try:
# Check if key rotation is needed
if await self.key_manager.is_key_expired():
logger.info("Key rotation needed, rotating keys...")
await self.key_manager.rotate_keys()
# Wait until next check
await asyncio.sleep(3600) # Check every hour
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in auto key rotation: {e}")
# Wait longer before retry after error
await asyncio.sleep(3600)
async def get_public_key_info(self) -> Dict[str, Any]:
"""Get public key information (for client verification)
Returns:
Public key information
"""
key_info = await self.key_manager.get_key_info()
public_key_pem = await self.key_manager.export_public_key_pem()
return {
'algorithm': self.algorithm,
'public_key_pem': public_key_pem,
'key_info': key_info
}
async def get_manager_stats(self) -> Dict[str, Any]:
"""Get manager statistics
Returns:
Statistics information
"""
key_info = await self.key_manager.get_key_info()
validation_stats = await self.validator.get_validation_stats()
return {
'jwt_config': {
'algorithm': self.algorithm,
'issuer': self.issuer,
'audience': self.audience,
'access_token_expiry': self.access_token_expiry,
'refresh_token_expiry': self.refresh_token_expiry,
'enable_refresh': self.enable_refresh,
'enable_revocation': self.enable_revocation
},
'key_manager': key_info,
'validator': validation_stats
}

View File

@@ -0,0 +1,343 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
JWT Key Management Module
Provides secure key generation, loading, rotation and management for JWT tokens
"""
import os
import time
import secrets
from pathlib import Path
from typing import Optional, Tuple, Union
from datetime import datetime, timedelta
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa, ec
from cryptography.hazmat.backends import default_backend
from ..utils.logger import get_logger
logger = get_logger(__name__)
class KeyManager:
"""JWT Key Manager
Responsible for generating, loading, rotating and securely storing JWT signing keys
Supports RSA and EC algorithms, provides automatic key rotation functionality
"""
def __init__(self, config):
"""Initialize key manager
Args:
config: DorisConfig configuration object (with security attribute)
"""
self.config = config
# Access JWT settings through the security configuration
if hasattr(config, 'security'):
security_config = config.security
else:
# Fallback if config is passed directly as SecurityConfig
security_config = config
self.algorithm = security_config.jwt_algorithm
self.key_rotation_interval = security_config.key_rotation_interval
self.private_key_path = security_config.jwt_private_key_path
self.public_key_path = security_config.jwt_public_key_path
self.secret_key = security_config.jwt_secret_key
# Key storage
self._private_key = None
self._public_key = None
self._secret_key = None
self._key_generated_at = None
logger.info(f"KeyManager initialized with algorithm: {self.algorithm}")
async def initialize(self) -> bool:
"""Initialize key manager, load or generate keys"""
try:
if self.algorithm == "HS256":
await self._initialize_symmetric_key()
else:
await self._initialize_asymmetric_keys()
logger.info("KeyManager initialization completed")
return True
except Exception as e:
logger.error(f"Failed to initialize KeyManager: {e}")
return False
async def _initialize_symmetric_key(self):
"""Initialize symmetric key (HS256)"""
if self.secret_key:
# Use configured key
self._secret_key = self.secret_key.encode()
logger.info("Loaded symmetric key from configuration")
else:
# Generate new key
self._secret_key = await self.generate_symmetric_key()
logger.info("Generated new symmetric key")
self._key_generated_at = datetime.utcnow()
async def _initialize_asymmetric_keys(self):
"""Initialize asymmetric key pair (RS256/ES256)"""
# Try to load keys from files
if await self._load_keys_from_files():
logger.info("Loaded asymmetric keys from files")
return
# Try to load from environment variables
if await self._load_keys_from_env():
logger.info("Loaded asymmetric keys from environment")
return
# Generate new key pair
await self.generate_key_pair()
logger.info("Generated new asymmetric key pair")
async def _load_keys_from_files(self) -> bool:
"""Load keys from files"""
try:
if not self.private_key_path or not self.public_key_path:
return False
private_path = Path(self.private_key_path)
public_path = Path(self.public_key_path)
if not (private_path.exists() and public_path.exists()):
return False
# Read private key
with open(private_path, 'rb') as f:
private_key_data = f.read()
self._private_key = serialization.load_pem_private_key(
private_key_data, password=None, backend=default_backend()
)
# Read public key
with open(public_path, 'rb') as f:
public_key_data = f.read()
self._public_key = serialization.load_pem_public_key(
public_key_data, backend=default_backend()
)
# Get key generation time (using file modification time)
self._key_generated_at = datetime.fromtimestamp(private_path.stat().st_mtime)
return True
except Exception as e:
logger.error(f"Failed to load keys from files: {e}")
return False
async def _load_keys_from_env(self) -> bool:
"""Load keys from environment variables"""
try:
private_key_env = os.getenv('JWT_PRIVATE_KEY')
public_key_env = os.getenv('JWT_PUBLIC_KEY')
if not (private_key_env and public_key_env):
return False
# Parse private key
self._private_key = serialization.load_pem_private_key(
private_key_env.encode(), password=None, backend=default_backend()
)
# Parse public key
self._public_key = serialization.load_pem_public_key(
public_key_env.encode(), backend=default_backend()
)
self._key_generated_at = datetime.utcnow()
return True
except Exception as e:
logger.error(f"Failed to load keys from environment: {e}")
return False
async def generate_symmetric_key(self, length: int = 32) -> bytes:
"""Generate symmetric key
Args:
length: Key length (bytes), default 32 bytes (256 bits)
Returns:
Generated key
"""
return secrets.token_bytes(length)
async def generate_key_pair(self) -> Tuple[bytes, bytes]:
"""Generate asymmetric key pair
Returns:
(private key PEM, public key PEM) tuple
"""
try:
if self.algorithm == "RS256":
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
)
elif self.algorithm == "ES256":
private_key = ec.generate_private_key(
ec.SECP256R1(), backend=default_backend()
)
else:
raise ValueError(f"Unsupported algorithm for key generation: {self.algorithm}")
# Get public key
public_key = private_key.public_key()
# Serialize private key
private_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
# Serialize public key
public_pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
# Store keys
self._private_key = private_key
self._public_key = public_key
self._key_generated_at = datetime.utcnow()
# If file paths are configured, save to files
if self.private_key_path and self.public_key_path:
await self._save_keys_to_files(private_pem, public_pem)
logger.info(f"Generated new {self.algorithm} key pair")
return private_pem, public_pem
except Exception as e:
logger.error(f"Failed to generate key pair: {e}")
raise
async def _save_keys_to_files(self, private_pem: bytes, public_pem: bytes):
"""Save keys to files"""
try:
# Ensure directories exist
private_path = Path(self.private_key_path)
public_path = Path(self.public_key_path)
private_path.parent.mkdir(parents=True, exist_ok=True)
public_path.parent.mkdir(parents=True, exist_ok=True)
# Save private key (set secure permissions)
with open(private_path, 'wb') as f:
f.write(private_pem)
os.chmod(private_path, 0o600) # Only owner can read/write
# Save public key
with open(public_path, 'wb') as f:
f.write(public_pem)
os.chmod(public_path, 0o644) # Owner read/write, others read only
logger.info(f"Saved keys to files: {private_path}, {public_path}")
except Exception as e:
logger.error(f"Failed to save keys to files: {e}")
raise
def get_private_key(self):
"""Get private key for signing"""
if self.algorithm == "HS256":
return self._secret_key
else:
return self._private_key
def get_public_key(self):
"""Get public key for verification"""
if self.algorithm == "HS256":
return self._secret_key
else:
return self._public_key
def get_algorithm(self) -> str:
"""Get signing algorithm"""
return self.algorithm
async def is_key_expired(self) -> bool:
"""Check if key is expired"""
if not self._key_generated_at:
return True
expiry_time = self._key_generated_at + timedelta(seconds=self.key_rotation_interval)
return datetime.utcnow() > expiry_time
async def rotate_keys(self) -> bool:
"""Rotate keys"""
try:
logger.info("Starting key rotation")
if self.algorithm == "HS256":
# Generate new symmetric key
self._secret_key = await self.generate_symmetric_key()
self._key_generated_at = datetime.utcnow()
else:
# Generate new asymmetric key pair
await self.generate_key_pair()
logger.info("Key rotation completed successfully")
return True
except Exception as e:
logger.error(f"Key rotation failed: {e}")
return False
async def get_key_info(self) -> dict:
"""Get key information"""
return {
"algorithm": self.algorithm,
"key_generated_at": self._key_generated_at.isoformat() if self._key_generated_at else None,
"key_expires_at": (
self._key_generated_at + timedelta(seconds=self.key_rotation_interval)
).isoformat() if self._key_generated_at else None,
"is_expired": await self.is_key_expired(),
"has_private_key": self._private_key is not None or self._secret_key is not None,
"has_public_key": self._public_key is not None or self._secret_key is not None
}
async def export_public_key_pem(self) -> Optional[str]:
"""Export public key in PEM format"""
if self.algorithm == "HS256":
return None # Symmetric key not exported
if not self._public_key:
return None
try:
public_pem = self._public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
return public_pem.decode()
except Exception as e:
logger.error(f"Failed to export public key: {e}")
return None

View File

@@ -0,0 +1,536 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
OAuth 2.0/OIDC Client Manager
Provides OAuth authentication client implementation with PKCE and OIDC support
"""
import base64
import hashlib
import secrets
import uuid
from datetime import datetime, timedelta
from typing import Dict, Optional, Any, Tuple
from urllib.parse import urlencode, parse_qs, urlparse
import asyncio
import json
try:
import aiohttp
except ImportError:
raise ImportError("aiohttp is required for OAuth functionality. Install with: pip install aiohttp")
from .oauth_types import (
OAuthProvider, OAuthState, OAuthTokens, OAuthUserInfo,
OIDCDiscovery, OAuthError, OAuthProviderConfig, OAUTH_PROVIDERS
)
from ..utils.logger import get_logger
logger = get_logger(__name__)
class OAuthStateManager:
"""Manages OAuth state parameters for CSRF protection"""
def __init__(self, state_expiry: int = 600):
"""Initialize state manager
Args:
state_expiry: State expiry time in seconds
"""
self.state_expiry = state_expiry
self._states: Dict[str, OAuthState] = {}
self._cleanup_task = None
logger.info("OAuthStateManager initialized")
async def start(self):
"""Start periodic cleanup task"""
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
logger.info("OAuth state manager started")
async def stop(self):
"""Stop periodic cleanup task"""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
logger.info("OAuth state manager stopped")
def create_state(self, redirect_uri: str, pkce_enabled: bool = True,
nonce_enabled: bool = True) -> OAuthState:
"""Create new OAuth state
Args:
redirect_uri: OAuth redirect URI
pkce_enabled: Whether to enable PKCE
nonce_enabled: Whether to enable nonce (for OIDC)
Returns:
OAuth state object
"""
state = secrets.token_urlsafe(32)
nonce = secrets.token_urlsafe(32) if nonce_enabled else None
pkce_verifier = None
pkce_challenge = None
if pkce_enabled:
pkce_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=')
challenge_bytes = hashlib.sha256(pkce_verifier.encode()).digest()
pkce_challenge = base64.urlsafe_b64encode(challenge_bytes).decode('utf-8').rstrip('=')
oauth_state = OAuthState(
state=state,
nonce=nonce,
pkce_verifier=pkce_verifier,
pkce_challenge=pkce_challenge,
redirect_uri=redirect_uri,
created_at=datetime.utcnow(),
expires_at=datetime.utcnow() + timedelta(seconds=self.state_expiry)
)
self._states[state] = oauth_state
logger.debug(f"Created OAuth state: {state}")
return oauth_state
def get_state(self, state: str) -> Optional[OAuthState]:
"""Get OAuth state by state parameter
Args:
state: State parameter
Returns:
OAuth state object or None if not found/expired
"""
oauth_state = self._states.get(state)
if oauth_state and oauth_state.expires_at > datetime.utcnow():
return oauth_state
elif oauth_state:
# Remove expired state
del self._states[state]
logger.debug(f"Removed expired OAuth state: {state}")
return None
def consume_state(self, state: str) -> Optional[OAuthState]:
"""Get and remove OAuth state
Args:
state: State parameter
Returns:
OAuth state object or None if not found/expired
"""
oauth_state = self.get_state(state)
if oauth_state:
del self._states[state]
logger.debug(f"Consumed OAuth state: {state}")
return oauth_state
async def _periodic_cleanup(self):
"""Periodic cleanup of expired states"""
while True:
try:
await asyncio.sleep(300) # Clean up every 5 minutes
current_time = datetime.utcnow()
expired_states = [
state for state, oauth_state in self._states.items()
if oauth_state.expires_at <= current_time
]
for state in expired_states:
del self._states[state]
if expired_states:
logger.info(f"Cleaned up {len(expired_states)} expired OAuth states")
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error during OAuth state cleanup: {e}")
class OAuthClient:
"""OAuth 2.0/OIDC Client implementation"""
def __init__(self, config):
"""Initialize OAuth client
Args:
config: DorisConfig with OAuth configuration
"""
self.config = config
# Access OAuth settings through security configuration
if hasattr(config, 'security'):
security_config = config.security
else:
security_config = config
self.enabled = security_config.oauth_enabled
if not self.enabled:
logger.info("OAuth client disabled by configuration")
return
# Build provider configuration
self.provider_config = self._build_provider_config(security_config)
self.state_manager = OAuthStateManager(security_config.oauth_state_expiry)
# HTTP client session
self._session: Optional[aiohttp.ClientSession] = None
# Discovery cache
self._discovery_cache: Optional[OIDCDiscovery] = None
self._discovery_cache_time: Optional[datetime] = None
logger.info(f"OAuthClient initialized for provider: {self.provider_config.provider.value}")
def _build_provider_config(self, security_config) -> OAuthProviderConfig:
"""Build OAuth provider configuration
Args:
security_config: Security configuration object
Returns:
OAuth provider configuration
"""
try:
provider = OAuthProvider(security_config.oauth_provider)
except ValueError:
provider = OAuthProvider.CUSTOM
# Get default configuration for known providers
defaults = OAUTH_PROVIDERS.get(provider, {})
return OAuthProviderConfig(
provider=provider,
client_id=security_config.oauth_client_id,
client_secret=security_config.oauth_client_secret,
redirect_uri=security_config.oauth_redirect_uri,
scopes=security_config.oauth_scopes or defaults.get("scopes", ["openid", "email", "profile"]),
# Endpoints (use configured or defaults)
authorization_endpoint=security_config.oauth_authorization_endpoint or defaults.get("authorization_endpoint", ""),
token_endpoint=security_config.oauth_token_endpoint or defaults.get("token_endpoint", ""),
userinfo_endpoint=security_config.oauth_userinfo_endpoint or defaults.get("userinfo_endpoint"),
jwks_uri=security_config.oauth_jwks_uri or defaults.get("jwks_uri"),
# Discovery
discovery_url=security_config.oidc_discovery_url or defaults.get("discovery_url"),
# Settings
pkce_enabled=security_config.oauth_pkce_enabled,
nonce_enabled=security_config.oauth_nonce_enabled,
# User mapping
user_id_claim=security_config.oauth_user_id_claim or defaults.get("user_id_claim", "sub"),
email_claim=security_config.oauth_email_claim or defaults.get("email_claim", "email"),
name_claim=security_config.oauth_name_claim or defaults.get("name_claim", "name"),
roles_claim=security_config.oauth_roles_claim,
default_roles=security_config.oauth_default_roles
)
async def initialize(self) -> bool:
"""Initialize OAuth client
Returns:
True if initialization successful
"""
if not self.enabled:
return True
try:
# Create HTTP session
self._session = aiohttp.ClientSession()
# Start state manager
await self.state_manager.start()
# Perform OIDC discovery if configured
if self.provider_config.discovery_url:
await self._discover_oidc_endpoints()
logger.info("OAuth client initialization completed")
return True
except Exception as e:
logger.error(f"Failed to initialize OAuth client: {e}")
return False
async def shutdown(self):
"""Shutdown OAuth client"""
if not self.enabled:
return
try:
# Stop state manager
await self.state_manager.stop()
# Close HTTP session
if self._session:
await self._session.close()
logger.info("OAuth client shutdown completed")
except Exception as e:
logger.error(f"Error during OAuth client shutdown: {e}")
async def _discover_oidc_endpoints(self):
"""Discover OIDC endpoints using discovery URL"""
try:
# Check cache first
if (self._discovery_cache and self._discovery_cache_time and
datetime.utcnow() - self._discovery_cache_time < timedelta(hours=1)):
return self._discovery_cache
logger.info(f"Discovering OIDC endpoints: {self.provider_config.discovery_url}")
async with self._session.get(self.provider_config.discovery_url) as response:
response.raise_for_status()
data = await response.json()
discovery = OIDCDiscovery(
issuer=data["issuer"],
authorization_endpoint=data["authorization_endpoint"],
token_endpoint=data["token_endpoint"],
userinfo_endpoint=data.get("userinfo_endpoint"),
jwks_uri=data.get("jwks_uri"),
scopes_supported=data.get("scopes_supported"),
response_types_supported=data.get("response_types_supported"),
subject_types_supported=data.get("subject_types_supported"),
id_token_signing_alg_values_supported=data.get("id_token_signing_alg_values_supported")
)
# Update provider configuration with discovered endpoints
if not self.provider_config.authorization_endpoint:
self.provider_config.authorization_endpoint = discovery.authorization_endpoint
if not self.provider_config.token_endpoint:
self.provider_config.token_endpoint = discovery.token_endpoint
if not self.provider_config.userinfo_endpoint:
self.provider_config.userinfo_endpoint = discovery.userinfo_endpoint
if not self.provider_config.jwks_uri:
self.provider_config.jwks_uri = discovery.jwks_uri
# Cache discovery result
self._discovery_cache = discovery
self._discovery_cache_time = datetime.utcnow()
logger.info("OIDC endpoint discovery completed successfully")
return discovery
except Exception as e:
logger.error(f"OIDC endpoint discovery failed: {e}")
raise
def build_authorization_url(self) -> Tuple[str, OAuthState]:
"""Build OAuth authorization URL
Returns:
Tuple of (authorization_url, oauth_state)
"""
if not self.enabled:
raise ValueError("OAuth client is not enabled")
# Create state for CSRF protection
oauth_state = self.state_manager.create_state(
redirect_uri=self.provider_config.redirect_uri,
pkce_enabled=self.provider_config.pkce_enabled,
nonce_enabled=self.provider_config.nonce_enabled
)
# Build authorization parameters
params = {
'response_type': 'code',
'client_id': self.provider_config.client_id,
'redirect_uri': self.provider_config.redirect_uri,
'scope': ' '.join(self.provider_config.scopes),
'state': oauth_state.state
}
# Add PKCE challenge
if oauth_state.pkce_challenge:
params['code_challenge'] = oauth_state.pkce_challenge
params['code_challenge_method'] = 'S256'
# Add nonce for OIDC
if oauth_state.nonce:
params['nonce'] = oauth_state.nonce
# Build URL
authorization_url = f"{self.provider_config.authorization_endpoint}?{urlencode(params)}"
logger.info(f"Built OAuth authorization URL for state: {oauth_state.state}")
return authorization_url, oauth_state
async def exchange_code_for_tokens(self, code: str, state: str) -> Tuple[OAuthTokens, OAuthState]:
"""Exchange authorization code for tokens
Args:
code: Authorization code
state: State parameter
Returns:
Tuple of (OAuth tokens, OAuth state)
Raises:
ValueError: If state is invalid or exchange fails
"""
if not self.enabled:
raise ValueError("OAuth client is not enabled")
# Validate and consume state
oauth_state = self.state_manager.consume_state(state)
if not oauth_state:
raise ValueError("Invalid or expired state parameter")
try:
# Prepare token request
data = {
'grant_type': 'authorization_code',
'client_id': self.provider_config.client_id,
'client_secret': self.provider_config.client_secret,
'code': code,
'redirect_uri': oauth_state.redirect_uri
}
# Add PKCE verifier
if oauth_state.pkce_verifier:
data['code_verifier'] = oauth_state.pkce_verifier
# Make token request
async with self._session.post(
self.provider_config.token_endpoint,
data=data,
headers={'Content-Type': 'application/x-www-form-urlencoded'}
) as response:
response_data = await response.json()
if response.status != 200:
error_msg = response_data.get('error_description', response_data.get('error', 'Token exchange failed'))
raise ValueError(f"Token exchange failed: {error_msg}")
tokens = OAuthTokens(
access_token=response_data['access_token'],
token_type=response_data.get('token_type', 'Bearer'),
expires_in=response_data.get('expires_in'),
refresh_token=response_data.get('refresh_token'),
scope=response_data.get('scope'),
id_token=response_data.get('id_token')
)
logger.info("Successfully exchanged authorization code for tokens")
return tokens, oauth_state
except Exception as e:
logger.error(f"Token exchange failed: {e}")
raise ValueError(f"Token exchange failed: {str(e)}")
async def get_user_info(self, tokens: OAuthTokens) -> OAuthUserInfo:
"""Get user information from OAuth provider
Args:
tokens: OAuth tokens
Returns:
OAuth user information
"""
if not self.enabled:
raise ValueError("OAuth client is not enabled")
if not self.provider_config.userinfo_endpoint:
raise ValueError("Userinfo endpoint not configured")
try:
# Make userinfo request
headers = {'Authorization': f'{tokens.token_type} {tokens.access_token}'}
async with self._session.get(
self.provider_config.userinfo_endpoint,
headers=headers
) as response:
response.raise_for_status()
user_data = await response.json()
# Extract user information using configured claims
user_info = OAuthUserInfo(
sub=str(user_data.get(self.provider_config.user_id_claim, '')),
email=user_data.get(self.provider_config.email_claim),
name=user_data.get(self.provider_config.name_claim),
given_name=user_data.get('given_name'),
family_name=user_data.get('family_name'),
picture=user_data.get('picture'),
locale=user_data.get('locale'),
email_verified=user_data.get('email_verified'),
roles=user_data.get(self.provider_config.roles_claim, self.provider_config.default_roles.copy()),
raw_claims=user_data
)
logger.info(f"Retrieved user info for user: {user_info.sub}")
return user_info
except Exception as e:
logger.error(f"Failed to get user info: {e}")
raise ValueError(f"Failed to get user info: {str(e)}")
async def refresh_tokens(self, refresh_token: str) -> OAuthTokens:
"""Refresh OAuth tokens
Args:
refresh_token: Refresh token
Returns:
New OAuth tokens
"""
if not self.enabled:
raise ValueError("OAuth client is not enabled")
try:
data = {
'grant_type': 'refresh_token',
'client_id': self.provider_config.client_id,
'client_secret': self.provider_config.client_secret,
'refresh_token': refresh_token
}
async with self._session.post(
self.provider_config.token_endpoint,
data=data,
headers={'Content-Type': 'application/x-www-form-urlencoded'}
) as response:
response_data = await response.json()
if response.status != 200:
error_msg = response_data.get('error_description', response_data.get('error', 'Token refresh failed'))
raise ValueError(f"Token refresh failed: {error_msg}")
tokens = OAuthTokens(
access_token=response_data['access_token'],
token_type=response_data.get('token_type', 'Bearer'),
expires_in=response_data.get('expires_in'),
refresh_token=response_data.get('refresh_token', refresh_token), # Keep old if not provided
scope=response_data.get('scope'),
id_token=response_data.get('id_token')
)
logger.info("Successfully refreshed OAuth tokens")
return tokens
except Exception as e:
logger.error(f"Token refresh failed: {e}")
raise ValueError(f"Token refresh failed: {str(e)}")

View File

@@ -0,0 +1,312 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
OAuth HTTP Handlers
Provides HTTP endpoints for OAuth authentication flow
"""
from typing import Dict, Any
from urllib.parse import parse_qs, urlparse
import json
from starlette.responses import JSONResponse, RedirectResponse, HTMLResponse
from starlette.requests import Request
from ..utils.logger import get_logger
logger = get_logger(__name__)
class OAuthHandlers:
"""OAuth HTTP request handlers"""
def __init__(self, security_manager):
"""Initialize OAuth handlers
Args:
security_manager: DorisSecurityManager instance
"""
self.security_manager = security_manager
logger.info("OAuth handlers initialized")
async def handle_login(self, request: Request) -> JSONResponse:
"""Handle OAuth login initiation
Returns JSON with authorization URL and state
"""
try:
# Check if OAuth is enabled
oauth_info = self.security_manager.get_oauth_provider_info()
if not oauth_info.get("enabled"):
return JSONResponse(
{"error": "OAuth authentication is not enabled"},
status_code=400
)
# Get authorization URL
authorization_url, state = self.security_manager.get_oauth_authorization_url()
return JSONResponse({
"authorization_url": authorization_url,
"state": state,
"provider": oauth_info.get("provider"),
"message": "Navigate to authorization_url to complete OAuth login"
})
except Exception as e:
logger.error(f"OAuth login initiation failed: {e}")
return JSONResponse(
{"error": f"OAuth login failed: {str(e)}"},
status_code=500
)
async def handle_callback(self, request: Request) -> JSONResponse:
"""Handle OAuth callback
Processes the OAuth callback and returns authentication result
"""
try:
# Get query parameters
query_params = dict(request.query_params)
# Check for error in callback
if "error" in query_params:
error_description = query_params.get("error_description", "Unknown error")
logger.warning(f"OAuth callback error: {query_params['error']} - {error_description}")
return JSONResponse(
{
"error": query_params["error"],
"error_description": error_description,
"error_uri": query_params.get("error_uri")
},
status_code=400
)
# Extract required parameters
code = query_params.get("code")
state = query_params.get("state")
if not code or not state:
return JSONResponse(
{"error": "Missing required parameters: code and state"},
status_code=400
)
# Handle OAuth callback
auth_context = await self.security_manager.handle_oauth_callback(code, state)
# Return successful authentication response
return JSONResponse({
"success": True,
"user_id": auth_context.user_id,
"roles": auth_context.roles,
"permissions": auth_context.permissions,
"security_level": auth_context.security_level.value,
"session_id": auth_context.session_id,
"message": "OAuth authentication successful"
})
except Exception as e:
logger.error(f"OAuth callback handling failed: {e}")
return JSONResponse(
{"error": f"OAuth callback failed: {str(e)}"},
status_code=500
)
async def handle_provider_info(self, request: Request) -> JSONResponse:
"""Handle OAuth provider information request
Returns information about the configured OAuth provider
"""
try:
provider_info = self.security_manager.get_oauth_provider_info()
return JSONResponse(provider_info)
except Exception as e:
logger.error(f"Failed to get OAuth provider info: {e}")
return JSONResponse(
{"error": f"Failed to get provider info: {str(e)}"},
status_code=500
)
async def handle_demo_page(self, request: Request) -> HTMLResponse:
"""Handle OAuth demo page
Returns a simple HTML page for testing OAuth flow
"""
oauth_info = self.security_manager.get_oauth_provider_info()
if not oauth_info.get("enabled"):
return HTMLResponse("""
<html>
<head><title>OAuth Demo</title></head>
<body>
<h1>OAuth Demo</h1>
<p style="color: red;">OAuth authentication is not enabled.</p>
<p>Please configure OAuth settings in your security configuration.</p>
</body>
</html>
""")
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>Doris MCP Server - OAuth Demo</title>
<style>
body {{
font-family: Arial, sans-serif;
max-width: 800px;
margin: 0 auto;
padding: 20px;
}}
.info {{
background-color: #f0f8ff;
padding: 15px;
border-left: 4px solid #0066cc;
margin: 20px 0;
}}
.error {{
background-color: #ffe6e6;
padding: 15px;
border-left: 4px solid #cc0000;
margin: 20px 0;
}}
.success {{
background-color: #e6ffe6;
padding: 15px;
border-left: 4px solid #00cc00;
margin: 20px 0;
}}
button {{
background-color: #0066cc;
color: white;
padding: 10px 20px;
border: none;
border-radius: 4px;
cursor: pointer;
font-size: 16px;
}}
button:hover {{
background-color: #0052a3;
}}
pre {{
background-color: #f5f5f5;
padding: 10px;
border-radius: 4px;
overflow-x: auto;
}}
</style>
</head>
<body>
<h1>Doris MCP Server - OAuth Demo</h1>
<div class="info">
<h3>OAuth Configuration</h3>
<p><strong>Provider:</strong> {oauth_info.get('provider', 'N/A')}</p>
<p><strong>Client ID:</strong> {oauth_info.get('client_id', 'N/A')}</p>
<p><strong>Scopes:</strong> {', '.join(oauth_info.get('scopes', []))}</p>
<p><strong>PKCE Enabled:</strong> {oauth_info.get('pkce_enabled', False)}</p>
</div>
<div>
<h3>OAuth Authentication Test</h3>
<p>Click the button below to start OAuth authentication flow:</p>
<button onclick="startOAuthFlow()">Start OAuth Login</button>
</div>
<div id="result" style="margin-top: 20px;"></div>
<div>
<h3>API Endpoints</h3>
<ul>
<li><code>GET /auth/login</code> - Initiate OAuth login</li>
<li><code>GET /auth/callback</code> - OAuth callback handler</li>
<li><code>GET /auth/provider</code> - Provider information</li>
</ul>
</div>
<script>
async function startOAuthFlow() {{
const resultDiv = document.getElementById('result');
resultDiv.innerHTML = '<div class="info">Initiating OAuth flow...</div>';
try {{
const response = await fetch('/auth/login');
const data = await response.json();
if (response.ok) {{
resultDiv.innerHTML = `
<div class="success">
<h4>OAuth URL Generated Successfully</h4>
<p><strong>State:</strong> ${{data.state}}</p>
<p><strong>Provider:</strong> ${{data.provider}}</p>
<p><a href="${{data.authorization_url}}" target="_blank">Click here to authenticate</a></p>
<p><em>Note: After authentication, you will be redirected to the callback URL.</em></p>
</div>
`;
// Automatically redirect to OAuth provider
// window.open(data.authorization_url, '_blank');
}} else {{
resultDiv.innerHTML = `
<div class="error">
<h4>Error</h4>
<p>${{data.error}}</p>
</div>
`;
}}
}} catch (error) {{
resultDiv.innerHTML = `
<div class="error">
<h4>Network Error</h4>
<p>${{error.message}}</p>
</div>
`;
}}
}}
// Handle OAuth callback result if present in URL
window.addEventListener('load', function() {{
const urlParams = new URLSearchParams(window.location.search);
if (urlParams.has('code') && urlParams.has('state')) {{
const resultDiv = document.getElementById('result');
resultDiv.innerHTML = `
<div class="success">
<h4>OAuth Callback Received</h4>
<p>Code: ${{urlParams.get('code')}}</p>
<p>State: ${{urlParams.get('state')}}</p>
<p>The authentication was successful!</p>
</div>
`;
}} else if (urlParams.has('error')) {{
const resultDiv = document.getElementById('result');
resultDiv.innerHTML = `
<div class="error">
<h4>OAuth Error</h4>
<p>Error: ${{urlParams.get('error')}}</p>
<p>Description: ${{urlParams.get('error_description') || 'No description'}}</p>
</div>
`;
}}
}});
</script>
</body>
</html>
"""
return HTMLResponse(html_content)

View File

@@ -0,0 +1,289 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
OAuth Authentication Provider
Integrates OAuth 2.0/OIDC authentication with the existing authentication framework
"""
from typing import Dict, Any, Optional, Tuple
from datetime import datetime
from .oauth_client import OAuthClient
from .oauth_types import OAuthTokens, OAuthUserInfo, OAuthState
from ..utils.security import AuthContext, SecurityLevel
from ..utils.logger import get_logger
logger = get_logger(__name__)
class OAuthAuthenticationProvider:
"""OAuth authentication provider for Doris MCP Server"""
def __init__(self, config):
"""Initialize OAuth authentication provider
Args:
config: DorisConfig with OAuth configuration
"""
self.config = config
self.oauth_client = OAuthClient(config)
self.enabled = self.oauth_client.enabled
logger.info(f"OAuthAuthenticationProvider initialized (enabled: {self.enabled})")
async def initialize(self) -> bool:
"""Initialize OAuth authentication provider
Returns:
True if initialization successful
"""
if not self.enabled:
return True
success = await self.oauth_client.initialize()
if success:
logger.info("OAuth authentication provider initialized successfully")
else:
logger.error("Failed to initialize OAuth authentication provider")
return success
async def shutdown(self):
"""Shutdown OAuth authentication provider"""
if self.enabled:
await self.oauth_client.shutdown()
logger.info("OAuth authentication provider shutdown completed")
def get_authorization_url(self) -> Tuple[str, str]:
"""Get OAuth authorization URL
Returns:
Tuple of (authorization_url, state)
"""
if not self.enabled:
raise ValueError("OAuth authentication is not enabled")
authorization_url, oauth_state = self.oauth_client.build_authorization_url()
return authorization_url, oauth_state.state
async def handle_callback(self, code: str, state: str) -> AuthContext:
"""Handle OAuth callback and create authentication context
Args:
code: Authorization code from OAuth provider
state: State parameter for CSRF protection
Returns:
AuthContext for the authenticated user
Raises:
ValueError: If authentication fails
"""
if not self.enabled:
raise ValueError("OAuth authentication is not enabled")
try:
# Exchange code for tokens
tokens, oauth_state = await self.oauth_client.exchange_code_for_tokens(code, state)
# Get user information
user_info = await self.oauth_client.get_user_info(tokens)
# Create authentication context
auth_context = await self._create_auth_context(user_info, tokens)
logger.info(f"OAuth authentication successful for user: {auth_context.user_id}")
return auth_context
except Exception as e:
logger.error(f"OAuth callback handling failed: {e}")
raise ValueError(f"OAuth authentication failed: {str(e)}")
async def authenticate_with_token(self, access_token: str) -> AuthContext:
"""Authenticate using OAuth access token
Args:
access_token: OAuth access token
Returns:
AuthContext for the authenticated user
"""
if not self.enabled:
raise ValueError("OAuth authentication is not enabled")
try:
# Create token object
tokens = OAuthTokens(access_token=access_token)
# Get user information
user_info = await self.oauth_client.get_user_info(tokens)
# Create authentication context
auth_context = await self._create_auth_context(user_info, tokens)
logger.info(f"OAuth token authentication successful for user: {auth_context.user_id}")
return auth_context
except Exception as e:
logger.error(f"OAuth token authentication failed: {e}")
raise ValueError(f"OAuth token authentication failed: {str(e)}")
async def refresh_authentication(self, refresh_token: str) -> Tuple[AuthContext, str]:
"""Refresh OAuth authentication
Args:
refresh_token: OAuth refresh token
Returns:
Tuple of (AuthContext, new_access_token)
"""
if not self.enabled:
raise ValueError("OAuth authentication is not enabled")
try:
# Refresh tokens
tokens = await self.oauth_client.refresh_tokens(refresh_token)
# Get updated user information
user_info = await self.oauth_client.get_user_info(tokens)
# Create authentication context
auth_context = await self._create_auth_context(user_info, tokens)
logger.info(f"OAuth refresh successful for user: {auth_context.user_id}")
return auth_context, tokens.access_token
except Exception as e:
logger.error(f"OAuth refresh failed: {e}")
raise ValueError(f"OAuth refresh failed: {str(e)}")
async def _create_auth_context(self, user_info: OAuthUserInfo, tokens: OAuthTokens) -> AuthContext:
"""Create authentication context from OAuth user info
Args:
user_info: OAuth user information
tokens: OAuth tokens
Returns:
AuthContext for the user
"""
# Determine security level based on roles or email domain
security_level = await self._determine_security_level(user_info)
# Map OAuth roles to application permissions
permissions = await self._map_permissions(user_info.roles)
# Generate session ID
session_id = f"oauth_{user_info.sub}_{datetime.utcnow().timestamp()}"
return AuthContext(
token_id=f"oauth_{user_info.sub}",
user_id=user_info.sub,
roles=user_info.roles,
permissions=permissions,
security_level=security_level,
session_id=session_id,
login_time=datetime.utcnow(),
last_activity=datetime.utcnow(),
token="" # OAuth doesn't have raw token, use empty string
)
async def _determine_security_level(self, user_info: OAuthUserInfo) -> SecurityLevel:
"""Determine security level for OAuth user
Args:
user_info: OAuth user information
Returns:
SecurityLevel for the user
"""
# Check if user has admin roles
admin_roles = {"admin", "administrator", "data_admin", "super_admin"}
if any(role.lower() in admin_roles for role in user_info.roles):
return SecurityLevel.SECRET
# Check email domain for internal users
if user_info.email:
# You can configure trusted domains for internal access
trusted_domains = ["yourcompany.com", "internal.org"] # Configure as needed
email_domain = user_info.email.split("@")[-1].lower()
if email_domain in trusted_domains:
return SecurityLevel.CONFIDENTIAL
# Check for special roles
elevated_roles = {"data_analyst", "developer", "manager"}
if any(role.lower() in elevated_roles for role in user_info.roles):
return SecurityLevel.CONFIDENTIAL
# Default to internal level for OAuth users
return SecurityLevel.INTERNAL
async def _map_permissions(self, roles: list[str]) -> list[str]:
"""Map OAuth roles to application permissions
Args:
roles: OAuth user roles
Returns:
List of application permissions
"""
permissions = set()
# Role to permission mapping
role_permissions = {
"admin": ["admin", "read_data", "write_data", "manage_users"],
"administrator": ["admin", "read_data", "write_data", "manage_users"],
"data_admin": ["admin", "read_data", "write_data"],
"super_admin": ["admin", "read_data", "write_data", "manage_users", "system_admin"],
"data_analyst": ["read_data", "query_database"],
"developer": ["read_data", "query_database", "debug"],
"viewer": ["read_data"],
"user": ["read_data"],
"oauth_user": ["read_data"] # Default OAuth user permission
}
# Map roles to permissions
for role in roles:
role_lower = role.lower()
if role_lower in role_permissions:
permissions.update(role_permissions[role_lower])
# Ensure OAuth users have at least basic permissions
if not permissions:
permissions.add("read_data")
return list(permissions)
def get_provider_info(self) -> Dict[str, Any]:
"""Get OAuth provider information
Returns:
Provider information dictionary
"""
if not self.enabled:
return {"enabled": False}
config = self.oauth_client.provider_config
return {
"enabled": True,
"provider": config.provider.value,
"client_id": config.client_id,
"scopes": config.scopes,
"redirect_uri": config.redirect_uri,
"pkce_enabled": config.pkce_enabled,
"nonce_enabled": config.nonce_enabled
}

View File

@@ -0,0 +1,196 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
OAuth 2.0/OIDC Type Definitions
Provides data types and models for OAuth authentication flow
"""
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Dict, Any, Optional, List
class OAuthProvider(Enum):
"""OAuth provider enumeration"""
GOOGLE = "google"
MICROSOFT = "microsoft"
GITHUB = "github"
CUSTOM = "custom"
class OAuthGrantType(Enum):
"""OAuth grant type enumeration"""
AUTHORIZATION_CODE = "authorization_code"
REFRESH_TOKEN = "refresh_token"
@dataclass
class OAuthState:
"""OAuth state parameter for CSRF protection"""
state: str
nonce: Optional[str] = None
pkce_verifier: Optional[str] = None
pkce_challenge: Optional[str] = None
redirect_uri: str = ""
created_at: datetime = None
expires_at: datetime = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.utcnow()
@dataclass
class OAuthTokens:
"""OAuth token response"""
access_token: str
token_type: str = "Bearer"
expires_in: Optional[int] = None
refresh_token: Optional[str] = None
scope: Optional[str] = None
id_token: Optional[str] = None # OIDC ID token
created_at: datetime = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.utcnow()
@dataclass
class OAuthUserInfo:
"""OAuth/OIDC user information"""
sub: str # Subject identifier
email: Optional[str] = None
email_verified: Optional[bool] = None
name: Optional[str] = None
given_name: Optional[str] = None
family_name: Optional[str] = None
picture: Optional[str] = None
locale: Optional[str] = None
roles: List[str] = None
raw_claims: Dict[str, Any] = None
def __post_init__(self):
if self.roles is None:
self.roles = []
if self.raw_claims is None:
self.raw_claims = {}
@dataclass
class OIDCDiscovery:
"""OIDC Discovery document"""
issuer: str
authorization_endpoint: str
token_endpoint: str
userinfo_endpoint: Optional[str] = None
jwks_uri: Optional[str] = None
scopes_supported: List[str] = None
response_types_supported: List[str] = None
subject_types_supported: List[str] = None
id_token_signing_alg_values_supported: List[str] = None
def __post_init__(self):
if self.scopes_supported is None:
self.scopes_supported = ["openid"]
if self.response_types_supported is None:
self.response_types_supported = ["code"]
if self.subject_types_supported is None:
self.subject_types_supported = ["public"]
if self.id_token_signing_alg_values_supported is None:
self.id_token_signing_alg_values_supported = ["RS256"]
@dataclass
class OAuthError:
"""OAuth error response"""
error: str
error_description: Optional[str] = None
error_uri: Optional[str] = None
state: Optional[str] = None
@dataclass
class OAuthProviderConfig:
"""OAuth provider configuration"""
provider: OAuthProvider
client_id: str
client_secret: str
redirect_uri: str
scopes: List[str]
# Endpoints
authorization_endpoint: str
token_endpoint: str
userinfo_endpoint: Optional[str] = None
jwks_uri: Optional[str] = None
# Discovery
discovery_url: Optional[str] = None
# Settings
pkce_enabled: bool = True
nonce_enabled: bool = True
# User mapping
user_id_claim: str = "sub"
email_claim: str = "email"
name_claim: str = "name"
roles_claim: str = "roles"
default_roles: List[str] = None
def __post_init__(self):
if self.default_roles is None:
self.default_roles = ["oauth_user"]
# Pre-defined provider configurations
OAUTH_PROVIDERS = {
OAuthProvider.GOOGLE: {
"authorization_endpoint": "https://accounts.google.com/o/oauth2/auth",
"token_endpoint": "https://oauth2.googleapis.com/token",
"userinfo_endpoint": "https://openidconnect.googleapis.com/v1/userinfo",
"jwks_uri": "https://www.googleapis.com/oauth2/v3/certs",
"discovery_url": "https://accounts.google.com/.well-known/openid_configuration",
"scopes": ["openid", "email", "profile"],
"user_id_claim": "sub",
"email_claim": "email",
"name_claim": "name"
},
OAuthProvider.MICROSOFT: {
"authorization_endpoint": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
"token_endpoint": "https://login.microsoftonline.com/common/oauth2/v2.0/token",
"userinfo_endpoint": "https://graph.microsoft.com/v1.0/me",
"jwks_uri": "https://login.microsoftonline.com/common/discovery/v2.0/keys",
"discovery_url": "https://login.microsoftonline.com/common/v2.0/.well-known/openid_configuration",
"scopes": ["openid", "profile", "email", "User.Read"],
"user_id_claim": "sub",
"email_claim": "email",
"name_claim": "name"
},
OAuthProvider.GITHUB: {
"authorization_endpoint": "https://github.com/login/oauth/authorize",
"token_endpoint": "https://github.com/login/oauth/access_token",
"userinfo_endpoint": "https://api.github.com/user",
"scopes": ["user:email", "read:user"],
"user_id_claim": "id",
"email_claim": "email",
"name_claim": "name"
}
}

View File

@@ -0,0 +1,677 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Token Authentication HTTP Handlers
Provides HTTP endpoints for token management including creation, revocation,
listing, and statistics. Used for administrative token management in HTTP mode.
"""
import json
from typing import Dict, Any
from starlette.requests import Request
from starlette.responses import JSONResponse, HTMLResponse
from ..utils.logger import get_logger
from ..utils.security import SecurityLevel
from ..utils.config import DatabaseConfig
from .token_security_middleware import TokenSecurityMiddleware
class TokenHandlers:
"""Token Authentication HTTP Handlers"""
def __init__(self, security_manager, config=None):
self.security_manager = security_manager
self.logger = get_logger(__name__)
# Initialize security middleware if config is provided
if config:
self.security_middleware = TokenSecurityMiddleware(config)
else:
self.security_middleware = None
self.logger.warning("Token handlers initialized without security middleware - access control disabled")
async def handle_create_token(self, request: Request) -> JSONResponse:
"""Handle token creation request"""
# Apply security checks
if self.security_middleware:
security_response = await self.security_middleware.check_token_management_access(request)
if security_response:
return security_response
try:
# Check if token manager is available
if not self.security_manager.auth_provider.token_manager:
return JSONResponse({
"error": "Token authentication is not enabled"
}, status_code=503)
# Parse request data
if request.method == "GET":
# GET request with query parameters
query_params = dict(request.query_params)
token_id = query_params.get("token_id")
expires_hours_str = query_params.get("expires_hours")
description = query_params.get("description", "")
custom_token = query_params.get("custom_token")
# Database configuration from query params
db_config = None
if query_params.get("db_host"):
db_config = DatabaseConfig(
host=query_params.get("db_host", "localhost"),
port=int(query_params.get("db_port", "9030")),
user=query_params.get("db_user", "root"),
password=query_params.get("db_password", ""),
database=query_params.get("db_database", "information_schema"),
fe_http_port=int(query_params.get("db_fe_http_port", "8030"))
)
else:
# POST request with JSON body
try:
body = await request.json()
except:
return JSONResponse({
"error": "Invalid JSON body"
}, status_code=400)
token_id = body.get("token_id")
expires_hours_str = body.get("expires_hours")
description = body.get("description", "")
custom_token = body.get("custom_token")
# Database configuration from JSON body
db_config = None
if body.get("database_config"):
db_data = body["database_config"]
try:
db_config = DatabaseConfig(
host=db_data.get("host", "localhost"),
port=int(db_data.get("port", 9030)),
user=db_data.get("user", "root"),
password=db_data.get("password", ""),
database=db_data.get("database", "information_schema"),
fe_http_port=int(db_data.get("fe_http_port", 8030))
)
except (ValueError, TypeError) as e:
return JSONResponse({
"error": f"Invalid database configuration: {str(e)}"
}, status_code=400)
# Validate required fields
if not token_id:
return JSONResponse({
"error": "token_id is required"
}, status_code=400)
# Parse expires_hours
expires_hours = None
if expires_hours_str:
try:
expires_hours = int(expires_hours_str)
except ValueError:
return JSONResponse({
"error": "expires_hours must be an integer"
}, status_code=400)
# Create token using the actual API
try:
token = await self.security_manager.create_token(
token_id=token_id,
expires_hours=expires_hours,
description=description,
custom_token=custom_token,
database_config=db_config
)
return JSONResponse({
"success": True,
"token_id": token_id,
"token": token,
"expires_hours": expires_hours,
"description": description,
"message": "Token created successfully"
})
except Exception as e:
self.logger.error(f"Token creation failed: {e}")
return JSONResponse({
"error": f"Token creation failed: {str(e)}"
}, status_code=400)
except Exception as e:
self.logger.error(f"Error in handle_create_token: {e}")
return JSONResponse({
"error": f"Internal server error: {str(e)}"
}, status_code=500)
async def handle_revoke_token(self, request: Request) -> JSONResponse:
"""Handle token revocation request"""
# Apply security checks
if self.security_middleware:
security_response = await self.security_middleware.check_token_management_access(request)
if security_response:
return security_response
try:
# Check if token manager is available
if not self.security_manager.auth_provider.token_manager:
return JSONResponse({
"error": "Token authentication is not enabled"
}, status_code=503)
# Get token_id from query parameters or path
token_id = request.query_params.get("token_id")
if not token_id and request.method == "DELETE":
# Try to get from path: /token/revoke/{token_id}
path_parts = str(request.url.path).split("/")
if len(path_parts) >= 4:
token_id = path_parts[-1]
if not token_id:
return JSONResponse({
"error": "token_id is required"
}, status_code=400)
# Revoke token
success = await self.security_manager.revoke_token(token_id)
if success:
return JSONResponse({
"success": True,
"token_id": token_id,
"message": "Token revoked successfully"
})
else:
return JSONResponse({
"success": False,
"token_id": token_id,
"message": "Token not found or already revoked"
}, status_code=404)
except Exception as e:
self.logger.error(f"Error in handle_revoke_token: {e}")
return JSONResponse({
"error": f"Internal server error: {str(e)}"
}, status_code=500)
async def handle_list_tokens(self, request: Request) -> JSONResponse:
"""Handle token listing request"""
# Apply security checks
if self.security_middleware:
security_response = await self.security_middleware.check_token_management_access(request)
if security_response:
return security_response
try:
# Check if token manager is available
if not self.security_manager.auth_provider.token_manager:
return JSONResponse({
"error": "Token authentication is not enabled"
}, status_code=503)
# Get tokens list
tokens = await self.security_manager.list_tokens()
return JSONResponse({
"success": True,
"count": len(tokens),
"tokens": tokens
})
except Exception as e:
self.logger.error(f"Error in handle_list_tokens: {e}")
return JSONResponse({
"error": f"Internal server error: {str(e)}"
}, status_code=500)
async def handle_token_stats(self, request: Request) -> JSONResponse:
"""Handle token statistics request"""
# Apply security checks
if self.security_middleware:
security_response = await self.security_middleware.check_token_management_access(request)
if security_response:
return security_response
try:
# Check if token manager is available
if not self.security_manager.auth_provider.token_manager:
return JSONResponse({
"error": "Token authentication is not enabled"
}, status_code=503)
# Get token statistics
stats = self.security_manager.get_token_stats()
return JSONResponse({
"success": True,
"stats": stats
})
except Exception as e:
self.logger.error(f"Error in handle_token_stats: {e}")
return JSONResponse({
"error": f"Internal server error: {str(e)}"
}, status_code=500)
async def handle_cleanup_tokens(self, request: Request) -> JSONResponse:
"""Handle expired tokens cleanup request"""
# Apply security checks
if self.security_middleware:
security_response = await self.security_middleware.check_token_management_access(request)
if security_response:
return security_response
try:
# Check if token manager is available
if not self.security_manager.auth_provider.token_manager:
return JSONResponse({
"error": "Token authentication is not enabled"
}, status_code=503)
# Cleanup expired tokens
cleaned_count = await self.security_manager.cleanup_expired_tokens()
return JSONResponse({
"success": True,
"cleaned_count": cleaned_count,
"message": f"Cleaned up {cleaned_count} expired tokens"
})
except Exception as e:
self.logger.error(f"Error in handle_cleanup_tokens: {e}")
return JSONResponse({
"error": f"Internal server error: {str(e)}"
}, status_code=500)
async def handle_management_page(self, request: Request) -> HTMLResponse:
"""Handle token management demo page"""
# Apply security checks
if self.security_middleware:
security_response = await self.security_middleware.check_token_management_access(request)
if security_response:
# Convert JSON response to HTML for demo page
error_data = security_response.body.decode('utf-8') if hasattr(security_response, 'body') else '{"error": "Access denied"}'
try:
error_info = json.loads(error_data)
except:
error_info = {"error": "Access denied"}
error_html = f"""
<!DOCTYPE html>
<html>
<head>
<title>Access Denied - Token Management</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 50px; background: #f5f5f5; }}
.container {{ max-width: 600px; margin: 0 auto; background: white; padding: 30px; border-radius: 8px; }}
.error {{ color: #dc3545; background: #f8d7da; border: 1px solid #f5c6cb; padding: 15px; border-radius: 5px; }}
.security-info {{ background: #d1ecf1; border: 1px solid #bee5eb; padding: 15px; border-radius: 5px; margin-top: 20px; }}
</style>
</head>
<body>
<div class="container">
<h1>🔐 Token Management - Access Denied</h1>
<div class="error">
<h3>Access Denied</h3>
<p><strong>Error:</strong> {error_info.get('error', 'Access denied')}</p>
<p><strong>Message:</strong> {error_info.get('message', 'Token management access is restricted')}</p>
{'<p><strong>Your IP:</strong> ' + str(error_info.get('client_ip', 'Unknown')) + '</p>' if 'client_ip' in error_info else ''}
</div>
<div class="security-info">
<h3>🛡️ Security Information</h3>
<p>Token management endpoints are protected by the following security measures:</p>
<ul>
<li><strong>IP Restrictions:</strong> Only localhost/127.0.0.1 access allowed</li>
<li><strong>Admin Authentication:</strong> Valid admin token required</li>
<li><strong>Configuration Control:</strong> Must be explicitly enabled</li>
</ul>
<p>If you need access, please:</p>
<ol>
<li>Access from the server host (127.0.0.1)</li>
<li>Ensure HTTP token management is enabled in configuration</li>
<li>Provide valid admin authentication</li>
</ol>
</div>
</div>
</body>
</html>
"""
return HTMLResponse(error_html, status_code=security_response.status_code)
try:
# Check if token manager is available
if not self.security_manager.auth_provider.token_manager:
html_content = """
<!DOCTYPE html>
<html>
<head>
<title>Token Management - Not Available</title>
<style>
body { font-family: Arial, sans-serif; margin: 50px; }
.error { color: red; font-size: 18px; }
</style>
</head>
<body>
<h1>Token Management</h1>
<div class="error">Token authentication is not enabled on this server.</div>
</body>
</html>
"""
return HTMLResponse(html_content)
# Get current stats for demo
stats = self.security_manager.get_token_stats()
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>Doris MCP Server - Token Management</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 50px; background: #f5f5f5; }}
.container {{ max-width: 1200px; margin: 0 auto; background: white; padding: 30px; border-radius: 8px; }}
h1 {{ color: #333; }}
.section {{ margin: 30px 0; padding: 20px; border: 1px solid #ddd; border-radius: 5px; }}
.stats {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 15px; }}
.stat-item {{ padding: 15px; background: #f8f9fa; border-radius: 5px; text-align: center; }}
.stat-value {{ font-size: 24px; font-weight: bold; color: #007bff; }}
.form-group {{ margin: 15px 0; }}
.form-group label {{ display: block; margin-bottom: 5px; font-weight: bold; }}
.form-group input, .form-group textarea {{ width: 100%; padding: 8px; border: 1px solid #ddd; border-radius: 4px; }}
button {{ padding: 10px 20px; margin: 5px; border: none; border-radius: 4px; cursor: pointer; }}
.btn-primary {{ background: #007bff; color: white; }}
.btn-danger {{ background: #dc3545; color: white; }}
.btn-success {{ background: #28a745; color: white; }}
.response {{ margin: 15px 0; padding: 15px; border-radius: 5px; }}
.response.success {{ background: #d4edda; border: 1px solid #c3e6cb; }}
.response.error {{ background: #f8d7da; border: 1px solid #f5c6cb; }}
.token-list {{ margin: 15px 0; }}
.token-item {{ padding: 10px; margin: 5px 0; background: #f8f9fa; border-radius: 4px; }}
pre {{ background: #f8f9fa; padding: 10px; border-radius: 4px; overflow-x: auto; }}
</style>
</head>
<body>
<div class="container">
<h1>🔐 Doris MCP Server - Token Management</h1>
<div class="section">
<h2>📊 Token Statistics</h2>
<div class="stats">
<div class="stat-item">
<div class="stat-value">{stats.get('total_tokens', 0)}</div>
<div>Total Tokens</div>
</div>
<div class="stat-item">
<div class="stat-value">{stats.get('active_tokens', 0)}</div>
<div>Active Tokens</div>
</div>
<div class="stat-item">
<div class="stat-value">{stats.get('expired_tokens', 0)}</div>
<div>Expired Tokens</div>
</div>
</div>
<p><strong>Token Expiry:</strong> {'Enabled' if stats.get('expiry_enabled') else 'Disabled'}</p>
<p><strong>Default Expiry:</strong> {stats.get('default_expiry_hours', 0)} hours</p>
</div>
<div class="section">
<h2> Create New Token</h2>
<form id="createTokenForm">
<div class="form-group">
<label for="token_id">Token ID (required):</label>
<input type="text" id="token_id" name="token_id" placeholder="e.g., my-app-token" required>
</div>
<div class="form-group">
<label for="expires_hours">Expires Hours (optional):</label>
<input type="number" id="expires_hours" name="expires_hours" placeholder="e.g., 720 (30 days), leave empty for default">
</div>
<div class="form-group">
<label for="description">Description (optional):</label>
<textarea id="description" name="description" placeholder="Token description"></textarea>
</div>
<div class="form-group">
<label for="custom_token">Custom Token (optional):</label>
<input type="text" id="custom_token" name="custom_token" placeholder="Leave empty to auto-generate">
<small style="color: #666; display: block; margin-top: 5px;">If not provided, a secure token will be generated automatically</small>
</div>
<div class="section" style="margin: 20px 0; padding: 15px; background: #f8f9fa; border-radius: 5px;">
<h3>🗄️ Database Configuration (Optional)</h3>
<p style="color: #666; font-size: 14px; margin-bottom: 15px;">Configure database connection for this token. Leave empty to use system defaults.</p>
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 10px;">
<div class="form-group">
<label for="db_host">Host:</label>
<input type="text" id="db_host" name="db_host" placeholder="localhost">
</div>
<div class="form-group">
<label for="db_port">Port:</label>
<input type="number" id="db_port" name="db_port" placeholder="9030">
</div>
<div class="form-group">
<label for="db_user">User:</label>
<input type="text" id="db_user" name="db_user" placeholder="root">
</div>
<div class="form-group">
<label for="db_password">Password:</label>
<input type="password" id="db_password" name="db_password" placeholder="(optional)">
</div>
<div class="form-group">
<label for="db_database">Database:</label>
<input type="text" id="db_database" name="db_database" placeholder="information_schema">
</div>
<div class="form-group">
<label for="db_fe_http_port">FE HTTP Port:</label>
<input type="number" id="db_fe_http_port" name="db_fe_http_port" placeholder="8030">
</div>
</div>
</div>
<button type="submit" class="btn-primary">Create Token</button>
</form>
<div id="createTokenResponse"></div>
</div>
<div class="section">
<h2>📋 Token Management</h2>
<button id="listTokensBtn" class="btn-success">Refresh Token List</button>
<button id="cleanupTokensBtn" class="btn-primary">Cleanup Expired Tokens</button>
<div id="tokenListResponse"></div>
<h3>Revoke Token</h3>
<div class="form-group">
<input type="text" id="revokeTokenId" placeholder="Enter token ID to revoke">
<button id="revokeTokenBtn" class="btn-danger">Revoke Token</button>
</div>
<div id="revokeTokenResponse"></div>
</div>
<div class="section">
<h2>🔧 API Endpoints</h2>
<p>Use these endpoints for programmatic token management:</p>
<ul>
<li><strong>POST /token/create</strong> - Create new token</li>
<li><strong>DELETE /token/revoke?token_id=...</strong> - Revoke token</li>
<li><strong>GET /token/list</strong> - List all tokens</li>
<li><strong>GET /token/stats</strong> - Get token statistics</li>
<li><strong>POST /token/cleanup</strong> - Cleanup expired tokens</li>
</ul>
</div>
</div>
<script>
// Get admin token from URL parameters
const urlParams = new URLSearchParams(window.location.search);
const adminToken = urlParams.get('admin_token');
// Create request headers with admin token
function getAuthHeaders() {{
if (adminToken) {{
return {{
'Content-Type': 'application/json',
'Authorization': `Bearer ${{adminToken}}`
}};
}} else {{
return {{'Content-Type': 'application/json'}};
}}
}}
// Create URL with admin token parameter
function getAuthURL(baseUrl) {{
if (adminToken) {{
const separator = baseUrl.includes('?') ? '&' : '?';
return `${{baseUrl}}${{separator}}admin_token=${{encodeURIComponent(adminToken)}}`;
}}
return baseUrl;
}}
function showResponse(elementId, data, isSuccess = true) {{
const element = document.getElementById(elementId);
element.innerHTML = '<pre>' + JSON.stringify(data, null, 2) + '</pre>';
element.className = 'response ' + (isSuccess ? 'success' : 'error');
}}
// Create token form - updated to match actual API
document.getElementById('createTokenForm').addEventListener('submit', async (e) => {{
e.preventDefault();
const formData = new FormData(e.target);
const data = Object.fromEntries(formData.entries());
// Remove empty fields for optional parameters
if (!data.expires_hours) delete data.expires_hours;
if (!data.description) delete data.description;
if (!data.custom_token) delete data.custom_token;
// Handle database configuration
if (data.db_host) {{
data.database_config = {{
host: data.db_host,
port: data.db_port ? parseInt(data.db_port) : 9030,
user: data.db_user || 'root',
password: data.db_password || '',
database: data.db_database || 'information_schema',
fe_http_port: data.db_fe_http_port ? parseInt(data.db_fe_http_port) : 8030
}};
}}
// Remove individual database fields from data
delete data.db_host;
delete data.db_port;
delete data.db_user;
delete data.db_password;
delete data.db_database;
delete data.db_fe_http_port;
try {{
const response = await fetch(getAuthURL('/token/create'), {{
method: 'POST',
headers: getAuthHeaders(),
body: JSON.stringify(data)
}});
const result = await response.json();
showResponse('createTokenResponse', result, response.ok);
// Refresh token list if creation was successful
if (response.ok) {{
document.getElementById('listTokensBtn').click();
}}
}} catch (error) {{
showResponse('createTokenResponse', {{error: error.message}}, false);
}}
}});
// List tokens
document.getElementById('listTokensBtn').addEventListener('click', async () => {{
try {{
const response = await fetch(getAuthURL('/token/list'), {{
headers: getAuthHeaders()
}});
const result = await response.json();
showResponse('tokenListResponse', result, response.ok);
}} catch (error) {{
showResponse('tokenListResponse', {{error: error.message}}, false);
}}
}});
// Cleanup tokens
document.getElementById('cleanupTokensBtn').addEventListener('click', async () => {{
try {{
const response = await fetch(getAuthURL('/token/cleanup'), {{
method: 'POST',
headers: getAuthHeaders()
}});
const result = await response.json();
showResponse('tokenListResponse', result, response.ok);
}} catch (error) {{
showResponse('tokenListResponse', {{error: error.message}}, false);
}}
}});
// Revoke token
document.getElementById('revokeTokenBtn').addEventListener('click', async () => {{
const tokenId = document.getElementById('revokeTokenId').value;
if (!tokenId) {{
showResponse('revokeTokenResponse', {{error: 'Token ID is required'}}, false);
return;
}}
try {{
const response = await fetch(getAuthURL(`/token/revoke?token_id=${{encodeURIComponent(tokenId)}}`), {{
method: 'DELETE',
headers: getAuthHeaders()
}});
const result = await response.json();
showResponse('revokeTokenResponse', result, response.ok);
// Refresh token list if revocation was successful
if (response.ok) {{
document.getElementById('listTokensBtn').click();
document.getElementById('revokeTokenId').value = '';
}}
}} catch (error) {{
showResponse('revokeTokenResponse', {{error: error.message}}, false);
}}
}});
// Load token list on page load
document.getElementById('listTokensBtn').click();
</script>
</body>
</html>
"""
return HTMLResponse(html_content)
except Exception as e:
self.logger.error(f"Error in handle_demo_page: {e}")
error_html = f"""
<!DOCTYPE html>
<html>
<head>
<title>Token Management Error</title>
<style>body {{ font-family: Arial, sans-serif; margin: 50px; }}</style>
</head>
<body>
<h1>Token Management Error</h1>
<p>Error loading token management page: {str(e)}</p>
</body>
</html>
"""
return HTMLResponse(error_html, status_code=500)

View File

@@ -0,0 +1,827 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Token Authentication Management Module
Provides enterprise-grade token authentication system with configurable tokens,
expiration management, role-based access control and secure token storage.
"""
import hashlib
import json
import os
import secrets
import time
import asyncio
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from pathlib import Path
from ..utils.logger import get_logger
from ..utils.security import SecurityLevel
@dataclass
class DatabaseConfig:
"""Database connection configuration for token binding"""
host: str
port: int = 9030
user: str = ""
password: str = ""
database: str = "information_schema"
charset: str = "UTF8"
fe_http_port: int = 8030
@dataclass
class TokenInfo:
"""Token information structure with optional database binding"""
token_id: str # Unique token identifier for audit and management
created_at: datetime = field(default_factory=datetime.utcnow)
expires_at: Optional[datetime] = None
last_used: Optional[datetime] = None
description: str = "" # Optional description for token purpose
is_active: bool = True
database_config: Optional[DatabaseConfig] = None # Optional database binding
@dataclass
class TokenValidationResult:
"""Token validation result"""
is_valid: bool
token_info: Optional[TokenInfo] = None
error_message: Optional[str] = None
class TokenManager:
"""Enterprise Token Authentication Manager
Features:
- Configurable token storage (file-based or environment variables)
- Token expiration management
- Secure token hashing
- Role-based access control
- Token lifecycle management
"""
def __init__(self, config):
self.config = config
self.logger = get_logger(__name__)
# Token storage
self._tokens: Dict[str, TokenInfo] = {} # token_hash -> TokenInfo
self._token_ids: Dict[str, str] = {} # token_id -> token_hash
# Configuration
self.token_file_path = getattr(config.security, 'token_file_path', 'tokens.json')
self.enable_token_expiry = getattr(config.security, 'enable_token_expiry', True)
self.default_token_expiry_hours = getattr(config.security, 'default_token_expiry_hours', 24 * 30) # 30 days
self.token_hash_algorithm = getattr(config.security, 'token_hash_algorithm', 'sha256')
# Hot reload configuration
self.enable_hot_reload = True
self.hot_reload_interval = 10 # Check every 10 seconds
self._file_last_modified = 0
self._hot_reload_task = None
# Initialize with default tokens if none exist
self._initialize_default_tokens()
# Load tokens from configuration
self._load_tokens()
# Start hot reload monitoring
if self.enable_hot_reload:
self._start_hot_reload()
self.logger.info(f"TokenManager initialized with {len(self._tokens)} tokens, hot reload: {self.enable_hot_reload}")
def _initialize_default_tokens(self):
"""Initialize default tokens for basic authentication (configurable via environment)"""
# Default token configurations (can be overridden by environment variables)
default_tokens = [
{
'token_id': 'admin-token',
'token': os.getenv('DEFAULT_ADMIN_TOKEN', 'doris_admin_token_123456'),
'description': os.getenv('DEFAULT_ADMIN_DESCRIPTION', 'Default admin API access token'),
'expires_hours': None # Never expires
},
{
'token_id': 'analyst-token',
'token': os.getenv('DEFAULT_ANALYST_TOKEN', 'doris_analyst_token_123456'),
'description': os.getenv('DEFAULT_ANALYST_DESCRIPTION', 'Default data analysis API access token'),
'expires_hours': None # Never expires
},
{
'token_id': 'readonly-token',
'token': os.getenv('DEFAULT_READONLY_TOKEN', 'doris_readonly_token_123456'),
'description': os.getenv('DEFAULT_READONLY_DESCRIPTION', 'Default read-only API access token'),
'expires_hours': None # Never expires
}
]
# Only add default tokens if no custom tokens are defined via environment variables
# Check if any TOKEN_* environment variables exist (excluding system and legacy configs)
excluded_prefixes = ('DEFAULT_', 'TOKEN_FILE_PATH', 'TOKEN_HASH_')
excluded_vars = {'TOKEN_SECRET', 'TOKEN_EXPIRY'}
custom_tokens_exist = any(
key.startswith('TOKEN_') and
not key.startswith(excluded_prefixes) and
not key.endswith(('_EXPIRES_HOURS', '_DESCRIPTION')) and
key not in excluded_vars
for key in os.environ.keys()
)
# Also check if token file exists and has content
token_file_exists = False
if os.path.exists(self.token_file_path):
try:
with open(self.token_file_path, 'r') as f:
content = f.read().strip()
if content and content != '{}':
token_file_exists = True
except:
pass
# Add default tokens only if no custom configuration exists
if not custom_tokens_exist and not token_file_exists:
for token_config in default_tokens:
self._add_token_from_config(token_config)
self.logger.info(f"Initialized {len(default_tokens)} default tokens (no custom config found)")
else:
self.logger.info("Skipped default tokens initialization (custom tokens detected)")
def _add_token_from_config(self, token_config: Dict[str, Any]):
"""Add token from configuration with optional database binding"""
try:
# Calculate expiration time
expires_at = None
if self.enable_token_expiry:
expires_hours = token_config.get('expires_hours', self.default_token_expiry_hours)
if expires_hours is not None:
expires_at = datetime.utcnow() + timedelta(hours=expires_hours)
# Parse database configuration if provided
database_config = None
if 'database_config' in token_config:
db_config = token_config['database_config']
database_config = DatabaseConfig(
host=db_config.get('host', 'localhost'),
port=db_config.get('port', 9030),
user=db_config.get('user', 'root'),
password=db_config.get('password', ''),
database=db_config.get('database', 'information_schema'),
charset=db_config.get('charset', 'UTF8'),
fe_http_port=db_config.get('fe_http_port', 8030)
)
# Create token info
token_info = TokenInfo(
token_id=token_config['token_id'],
expires_at=expires_at,
description=token_config.get('description', ''),
is_active=token_config.get('is_active', True),
database_config=database_config
)
# Hash the token
raw_token = token_config['token']
token_hash = self._hash_token(raw_token)
# Store token
self._tokens[token_hash] = token_info
self._token_ids[token_info.token_id] = token_hash
db_info = f" with DB binding ({database_config.host})" if database_config else ""
self.logger.debug(f"Added token '{token_info.token_id}'{db_info}")
except Exception as e:
self.logger.error(f"Failed to add token from config: {e}")
raise
def _load_tokens(self):
"""Load tokens from configuration sources"""
# 1. Load from environment variables
self._load_tokens_from_env()
# 2. Load from token file if exists
if os.path.exists(self.token_file_path):
self._load_tokens_from_file()
self.logger.info(f"Token loading completed, total tokens: {len(self._tokens)}")
def _load_tokens_from_env(self):
"""Load tokens from environment variables
Simplified format:
TOKEN_<ID>=<token>
TOKEN_<ID>_EXPIRES_HOURS=<hours>
TOKEN_<ID>_DESCRIPTION=<description>
"""
token_prefixes = set()
# Find all TOKEN_ environment variables (exclude legacy and system variables)
excluded_token_vars = {
'TOKEN_SECRET', # Legacy token secret
'TOKEN_EXPIRY', # Legacy token expiry
'TOKEN_FILE_PATH', # System config
'TOKEN_HASH_ALGORITHM' # System config
}
for key in os.environ:
if (key.startswith('TOKEN_') and
not key.endswith(('_EXPIRES_HOURS', '_DESCRIPTION')) and
key not in excluded_token_vars):
token_id = key[6:] # Remove 'TOKEN_' prefix
token_prefixes.add(token_id)
# Load each token
for token_id in token_prefixes:
try:
token = os.environ.get(f'TOKEN_{token_id}')
if not token:
continue
expires_hours_str = os.environ.get(f'TOKEN_{token_id}_EXPIRES_HOURS', str(self.default_token_expiry_hours))
description = os.environ.get(f'TOKEN_{token_id}_DESCRIPTION', f'Environment token {token_id}')
expires_hours = None
try:
if expires_hours_str and expires_hours_str.lower() != 'none':
expires_hours = int(expires_hours_str)
except ValueError:
expires_hours = self.default_token_expiry_hours
# Add token
token_config = {
'token_id': token_id.lower(),
'token': token,
'expires_hours': expires_hours,
'description': description
}
self._add_token_from_config(token_config)
except Exception as e:
self.logger.error(f"Failed to load token {token_id} from environment: {e}")
def _load_tokens_from_file(self):
"""Load tokens from JSON file"""
try:
with open(self.token_file_path, 'r', encoding='utf-8') as f:
tokens_data = json.load(f)
if isinstance(tokens_data, dict) and 'tokens' in tokens_data:
tokens_list = tokens_data['tokens']
elif isinstance(tokens_data, list):
tokens_list = tokens_data
else:
self.logger.error(f"Invalid token file format: {self.token_file_path}")
return
for token_config in tokens_list:
self._add_token_from_config(token_config)
self.logger.info(f"Loaded {len(tokens_list)} tokens from file: {self.token_file_path}")
except Exception as e:
self.logger.error(f"Failed to load tokens from file {self.token_file_path}: {e}")
def _hash_token(self, token: str) -> str:
"""Hash token for secure storage"""
if self.token_hash_algorithm == 'sha256':
return hashlib.sha256(token.encode('utf-8')).hexdigest()
elif self.token_hash_algorithm == 'sha512':
return hashlib.sha512(token.encode('utf-8')).hexdigest()
else:
# Fallback to sha256
return hashlib.sha256(token.encode('utf-8')).hexdigest()
async def validate_token(self, token: str) -> TokenValidationResult:
"""Validate token and return user information"""
try:
# Hash the provided token
token_hash = self._hash_token(token)
# Find token info
token_info = self._tokens.get(token_hash)
if not token_info:
return TokenValidationResult(
is_valid=False,
error_message="Invalid token"
)
# Check if token is active
if not token_info.is_active:
return TokenValidationResult(
is_valid=False,
error_message="Token is inactive"
)
# Check expiration
if token_info.expires_at and datetime.utcnow() > token_info.expires_at:
return TokenValidationResult(
is_valid=False,
error_message="Token has expired"
)
# Update last used time
token_info.last_used = datetime.utcnow()
return TokenValidationResult(
is_valid=True,
token_info=token_info
)
except Exception as e:
self.logger.error(f"Token validation error: {e}")
return TokenValidationResult(
is_valid=False,
error_message=f"Token validation failed: {str(e)}"
)
def generate_token(self, length: int = 32) -> str:
"""Generate a cryptographically secure random token"""
return secrets.token_urlsafe(length)
async def create_token(
self,
token_id: str,
expires_hours: Optional[int] = None,
description: str = "",
custom_token: Optional[str] = None,
database_config: Optional[DatabaseConfig] = None
) -> str:
"""Create a new token"""
try:
# Check if token_id already exists
if token_id in self._token_ids:
raise ValueError(f"Token ID '{token_id}' already exists")
# Generate or use provided token
if custom_token:
raw_token = custom_token
else:
raw_token = self.generate_token()
# Calculate expiration
expires_at = None
if expires_hours is not None:
expires_at = datetime.utcnow() + timedelta(hours=expires_hours)
elif self.enable_token_expiry:
expires_at = datetime.utcnow() + timedelta(hours=self.default_token_expiry_hours)
# Create token info
token_info = TokenInfo(
token_id=token_id,
expires_at=expires_at,
description=description,
database_config=database_config
)
# Hash and store token
token_hash = self._hash_token(raw_token)
self._tokens[token_hash] = token_info
self._token_ids[token_id] = token_hash
self.logger.info(f"Created new token '{token_id}'")
# Save token to file
self._save_token_to_file(token_id, raw_token, token_info)
return raw_token
except Exception as e:
self.logger.error(f"Failed to create token: {e}")
raise
async def revoke_token(self, token_id: str) -> bool:
"""Revoke a token by token ID"""
try:
if token_id not in self._token_ids:
self.logger.warning(f"Token ID '{token_id}' not found")
return False
# Get token hash and remove from storage
token_hash = self._token_ids[token_id]
if token_hash in self._tokens:
del self._tokens[token_hash]
del self._token_ids[token_id]
self.logger.info(f"Revoked token '{token_id}'")
# Save updated tokens to file
self._remove_token_from_file(token_id)
return True
except Exception as e:
self.logger.error(f"Failed to revoke token '{token_id}': {e}")
return False
def _save_tokens_to_file(self):
"""Save current tokens to JSON file"""
try:
# Convert current tokens to file format
tokens_list = []
for token_hash, token_info in self._tokens.items():
# Find the raw token for this token_info
raw_token = None
for tid, thash in self._token_ids.items():
if thash == token_hash and tid == token_info.token_id:
# We can't recover the original token from hash,
# so we'll create a placeholder for existing tokens
raw_token = f"<existing_token_hash_{token_hash[:8]}>"
break
if raw_token is None:
continue
token_config = {
"token_id": token_info.token_id,
"token": raw_token,
"description": token_info.description,
"expires_hours": None,
"is_active": token_info.is_active
}
# Add expiration info
if token_info.expires_at:
# Calculate remaining hours from now
remaining = token_info.expires_at - datetime.utcnow()
if remaining.total_seconds() > 0:
token_config["expires_hours"] = int(remaining.total_seconds() / 3600)
else:
token_config["expires_hours"] = 0
# Add database config if present
if token_info.database_config:
token_config["database_config"] = {
"host": token_info.database_config.host,
"port": token_info.database_config.port,
"user": token_info.database_config.user,
"password": token_info.database_config.password,
"database": token_info.database_config.database,
"charset": token_info.database_config.charset,
"fe_http_port": token_info.database_config.fe_http_port
}
tokens_list.append(token_config)
# Create file content
file_content = {
"version": "1.0",
"description": "Doris MCP Server Token configuration file",
"created_at": datetime.utcnow().isoformat() + "Z",
"tokens": tokens_list,
"notes": [
"This file is automatically updated when tokens are created or revoked",
"Please backup this file before making manual changes",
"Tokens with hash placeholders were loaded from previous configuration"
]
}
# Save to file
with open(self.token_file_path, 'w', encoding='utf-8') as f:
json.dump(file_content, f, indent=2, ensure_ascii=False)
self.logger.info(f"Saved {len(tokens_list)} tokens to file: {self.token_file_path}")
except Exception as e:
self.logger.error(f"Failed to save tokens to file {self.token_file_path}: {e}")
def _save_token_to_file(self, token_id: str, raw_token: str, token_info: TokenInfo):
"""Save a single new token to file (for newly created tokens only)"""
try:
# Load existing file
existing_data = {"tokens": []}
if os.path.exists(self.token_file_path):
try:
with open(self.token_file_path, 'r', encoding='utf-8') as f:
existing_data = json.load(f)
except Exception as e:
self.logger.warning(f"Could not load existing token file: {e}")
# Ensure tokens list exists
if 'tokens' not in existing_data or not isinstance(existing_data['tokens'], list):
existing_data['tokens'] = []
# Check if token already exists in file
token_exists = False
for i, token_config in enumerate(existing_data['tokens']):
if token_config.get('token_id') == token_id:
# Update existing token
existing_data['tokens'][i] = self._token_info_to_config(token_id, raw_token, token_info)
token_exists = True
break
# Add new token if it doesn't exist
if not token_exists:
new_token_config = self._token_info_to_config(token_id, raw_token, token_info)
existing_data['tokens'].append(new_token_config)
# Update metadata
existing_data.update({
"version": "1.0",
"description": "Doris MCP Server Token configuration file",
"updated_at": datetime.utcnow().isoformat() + "Z"
})
# Save to file
with open(self.token_file_path, 'w', encoding='utf-8') as f:
json.dump(existing_data, f, indent=2, ensure_ascii=False)
self.logger.info(f"Saved token '{token_id}' to file: {self.token_file_path}")
except Exception as e:
self.logger.error(f"Failed to save token '{token_id}' to file: {e}")
def _token_info_to_config(self, token_id: str, raw_token: str, token_info: TokenInfo) -> dict:
"""Convert TokenInfo to file configuration format"""
token_config = {
"token_id": token_id,
"token": raw_token,
"description": token_info.description,
"expires_hours": None,
"is_active": token_info.is_active
}
# Add expiration info
if token_info.expires_at:
# Calculate remaining hours from creation time
remaining = token_info.expires_at - token_info.created_at
token_config["expires_hours"] = int(remaining.total_seconds() / 3600) if remaining.total_seconds() > 0 else None
# Add database config if present
if token_info.database_config:
token_config["database_config"] = {
"host": token_info.database_config.host,
"port": token_info.database_config.port,
"user": token_info.database_config.user,
"password": token_info.database_config.password,
"database": token_info.database_config.database,
"charset": token_info.database_config.charset,
"fe_http_port": token_info.database_config.fe_http_port
}
return token_config
def _remove_token_from_file(self, token_id: str):
"""Remove a token from the JSON file"""
try:
if not os.path.exists(self.token_file_path):
return
# Load existing file
with open(self.token_file_path, 'r', encoding='utf-8') as f:
existing_data = json.load(f)
if 'tokens' not in existing_data or not isinstance(existing_data['tokens'], list):
return
# Remove the token
original_count = len(existing_data['tokens'])
existing_data['tokens'] = [
token for token in existing_data['tokens']
if token.get('token_id') != token_id
]
if len(existing_data['tokens']) < original_count:
# Update metadata
existing_data.update({
"version": "1.0",
"description": "Doris MCP Server Token configuration file",
"updated_at": datetime.utcnow().isoformat() + "Z"
})
# Save to file
with open(self.token_file_path, 'w', encoding='utf-8') as f:
json.dump(existing_data, f, indent=2, ensure_ascii=False)
self.logger.info(f"Removed token '{token_id}' from file: {self.token_file_path}")
except Exception as e:
self.logger.error(f"Failed to remove token '{token_id}' from file: {e}")
async def list_tokens(self) -> List[Dict[str, Any]]:
"""List all tokens (without sensitive data)"""
tokens = []
for token_hash, token_info in self._tokens.items():
token_data = {
'token_id': token_info.token_id,
'created_at': token_info.created_at.isoformat(),
'expires_at': token_info.expires_at.isoformat() if token_info.expires_at else None,
'last_used': token_info.last_used.isoformat() if token_info.last_used else None,
'is_active': token_info.is_active,
'description': token_info.description,
'is_expired': token_info.expires_at and datetime.utcnow() > token_info.expires_at if token_info.expires_at else False
}
# Add database binding info (without sensitive data)
if token_info.database_config:
token_data['database_binding'] = {
'host': token_info.database_config.host,
'port': token_info.database_config.port,
'user': token_info.database_config.user,
'database': token_info.database_config.database,
'has_password': bool(token_info.database_config.password)
}
else:
token_data['database_binding'] = None
tokens.append(token_data)
# Sort by creation time
tokens.sort(key=lambda x: x['created_at'], reverse=True)
return tokens
async def cleanup_expired_tokens(self) -> int:
"""Remove expired tokens and return count"""
if not self.enable_token_expiry:
return 0
now = datetime.utcnow()
expired_tokens = []
# Find expired tokens
for token_hash, token_info in self._tokens.items():
if token_info.expires_at and now > token_info.expires_at:
expired_tokens.append((token_hash, token_info.token_id))
# Remove expired tokens
for token_hash, token_id in expired_tokens:
del self._tokens[token_hash]
if token_id in self._token_ids:
del self._token_ids[token_id]
if expired_tokens:
self.logger.info(f"Cleaned up {len(expired_tokens)} expired tokens")
return len(expired_tokens)
async def save_tokens_to_file(self, file_path: Optional[str] = None) -> bool:
"""Save current tokens to JSON file"""
try:
file_path = file_path or self.token_file_path
tokens_list = await self.list_tokens()
tokens_data = {
'version': '1.0',
'created_at': datetime.utcnow().isoformat(),
'tokens': tokens_list
}
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(tokens_data, f, indent=2, ensure_ascii=False)
self.logger.info(f"Saved {len(tokens_list)} tokens to file: {file_path}")
return True
except Exception as e:
self.logger.error(f"Failed to save tokens to file: {e}")
return False
def get_database_config_by_token(self, token: str) -> Optional[DatabaseConfig]:
"""Get database configuration bound to a token
Args:
token: The raw token string
Returns:
DatabaseConfig if token exists and has database binding, None otherwise
"""
try:
token_hash = self._hash_token(token)
token_info = self._tokens.get(token_hash)
if not token_info or not token_info.is_active:
return None
# Check expiration
if token_info.expires_at and datetime.utcnow() > token_info.expires_at:
return None
return token_info.database_config
except Exception as e:
self.logger.error(f"Failed to get database config for token: {e}")
return None
def get_token_stats(self) -> Dict[str, Any]:
"""Get token statistics"""
now = datetime.utcnow()
total_tokens = len(self._tokens)
active_tokens = sum(1 for info in self._tokens.values() if info.is_active)
expired_tokens = sum(1 for info in self._tokens.values()
if info.expires_at and now > info.expires_at)
tokens_with_db = sum(1 for info in self._tokens.values()
if info.database_config is not None)
return {
'total_tokens': total_tokens,
'active_tokens': active_tokens,
'expired_tokens': expired_tokens,
'tokens_with_database_binding': tokens_with_db,
'expiry_enabled': self.enable_token_expiry,
'default_expiry_hours': self.default_token_expiry_hours,
'hot_reload_enabled': self.enable_hot_reload,
'last_file_check': datetime.fromtimestamp(self._file_last_modified).isoformat() if self._file_last_modified else None
}
def _start_hot_reload(self):
"""Start hot reload monitoring task"""
if self._hot_reload_task:
return # Already running
# Update initial file modification time
self._update_file_modified_time()
# Start monitoring task
self._hot_reload_task = asyncio.create_task(self._hot_reload_monitor())
self.logger.info(f"Started hot reload monitoring for {self.token_file_path}")
def stop_hot_reload(self):
"""Stop hot reload monitoring"""
if self._hot_reload_task:
self._hot_reload_task.cancel()
self._hot_reload_task = None
self.logger.info("Stopped hot reload monitoring")
def _update_file_modified_time(self):
"""Update the last modified time of tokens file"""
try:
if os.path.exists(self.token_file_path):
self._file_last_modified = os.path.getmtime(self.token_file_path)
except Exception as e:
self.logger.debug(f"Failed to get file modification time: {e}")
async def _hot_reload_monitor(self):
"""Background task to monitor tokens.json file changes"""
while True:
try:
await asyncio.sleep(self.hot_reload_interval)
if not os.path.exists(self.token_file_path):
continue
# Check if file was modified
current_mtime = os.path.getmtime(self.token_file_path)
if current_mtime > self._file_last_modified:
self.logger.info(f"Detected changes in {self.token_file_path}, reloading tokens...")
try:
# Backup current tokens
old_tokens = self._tokens.copy()
old_token_ids = self._token_ids.copy()
# Clear and reload
self._tokens.clear()
self._token_ids.clear()
# Reinitialize default tokens
self._initialize_default_tokens()
# Load from file
self._load_tokens_from_file()
# Update modification time
self._file_last_modified = current_mtime
self.logger.info(f"Hot reload completed, {len(self._tokens)} tokens loaded")
except Exception as reload_error:
# Restore backup on failure
self.logger.error(f"Hot reload failed, restoring previous tokens: {reload_error}")
self._tokens = old_tokens
self._token_ids = old_token_ids
except asyncio.CancelledError:
self.logger.info("Hot reload monitor stopped")
break
except Exception as e:
self.logger.error(f"Error in hot reload monitor: {e}")

View File

@@ -0,0 +1,227 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Token Management Security Middleware
Provides comprehensive security controls for token management endpoints including
IP restrictions, admin authentication, and configuration-based access control.
"""
import hashlib
import hmac
import ipaddress
import secrets
import time
from typing import Optional, List, Dict, Any
from starlette.requests import Request
from starlette.responses import JSONResponse
from ..utils.logger import get_logger
from ..utils.config import DorisConfig
class TokenSecurityMiddleware:
"""Security middleware for token management endpoints"""
def __init__(self, config: DorisConfig):
self.config = config
self.logger = get_logger(__name__)
# Initialize admin token hash if provided
self._admin_token_hash = None
if config.security.token_management_admin_token:
self._admin_token_hash = self._hash_token(config.security.token_management_admin_token)
# Normalize allowed IPs
self._allowed_networks = self._parse_allowed_networks(
config.security.token_management_allowed_ips
)
self.logger.info(f"Token management security initialized: "
f"HTTP endpoints {'enabled' if config.security.enable_http_token_management else 'disabled'}, "
f"Admin auth {'required' if config.security.require_admin_auth else 'optional'}, "
f"Allowed networks: {len(self._allowed_networks)}")
def _hash_token(self, token: str) -> str:
"""Hash token using SHA-256"""
return hashlib.sha256(token.encode('utf-8')).hexdigest()
def _parse_allowed_networks(self, allowed_ips: List[str]) -> List[ipaddress.IPv4Network | ipaddress.IPv6Network]:
"""Parse allowed IP addresses and networks"""
networks = []
for ip_str in allowed_ips:
try:
# Try to parse as network (CIDR notation)
if '/' in ip_str:
networks.append(ipaddress.ip_network(ip_str, strict=False))
else:
# Parse as single IP and convert to /32 network
ip = ipaddress.ip_address(ip_str)
if isinstance(ip, ipaddress.IPv4Address):
networks.append(ipaddress.IPv4Network(f"{ip}/32"))
else:
networks.append(ipaddress.IPv6Network(f"{ip}/128"))
except ValueError as e:
self.logger.warning(f"Invalid IP/network '{ip_str}': {e}")
return networks
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP from request, considering proxies"""
# Check X-Forwarded-For header first (for proxy setups)
forwarded_for = request.headers.get('X-Forwarded-For')
if forwarded_for:
# Take the first IP (original client)
client_ip = forwarded_for.split(',')[0].strip()
elif request.headers.get('X-Real-IP'):
client_ip = request.headers.get('X-Real-IP')
else:
# Direct connection
client_ip = request.client.host if request.client else "unknown"
return client_ip
def _is_ip_allowed(self, client_ip: str) -> bool:
"""Check if client IP is in allowed networks"""
try:
client_addr = ipaddress.ip_address(client_ip)
for network in self._allowed_networks:
if client_addr in network:
return True
return False
except ValueError:
self.logger.warning(f"Invalid client IP format: {client_ip}")
return False
def _extract_admin_token(self, request: Request) -> Optional[str]:
"""Extract admin token from request headers"""
# Try Authorization header first
auth_header = request.headers.get('Authorization', '')
if auth_header.startswith('Bearer '):
return auth_header[7:]
elif auth_header.startswith('Token '):
return auth_header[6:]
# Try X-Admin-Token header
admin_token = request.headers.get('X-Admin-Token', '')
if admin_token:
return admin_token
# Try query parameter as fallback (not recommended for production)
admin_token = request.query_params.get('admin_token', '')
if admin_token:
self.logger.warning("Admin token passed via query parameter - this is insecure for production")
return admin_token
return None
def _verify_admin_token(self, provided_token: str) -> bool:
"""Verify provided admin token against configured token"""
if not self._admin_token_hash:
self.logger.warning("No admin token configured for token management")
return False
provided_hash = self._hash_token(provided_token)
# Use constant-time comparison to prevent timing attacks
return hmac.compare_digest(self._admin_token_hash, provided_hash)
async def check_token_management_access(self, request: Request) -> Optional[JSONResponse]:
"""
Check if request is authorized for token management operations
Returns:
None if access is granted
JSONResponse with error if access is denied
"""
# Check if HTTP token management is enabled
if not self.config.security.enable_http_token_management:
self.logger.warning(f"Token management endpoint access denied - HTTP management disabled: {request.url.path}")
return JSONResponse({
"error": "Token management endpoints are disabled for security",
"message": "HTTP token management is disabled. Use file-based token management instead.",
"suggestion": "Edit tokens.json file directly or enable HTTP management with proper security configuration"
}, status_code=403)
# Extract client IP
client_ip = self._get_client_ip(request)
# Check IP restrictions
if not self._is_ip_allowed(client_ip):
self.logger.warning(f"Token management access denied for IP {client_ip}: not in allowed list")
return JSONResponse({
"error": "Access denied - IP not allowed",
"client_ip": client_ip,
"message": "Token management is restricted to specific IP addresses",
"allowed_networks": [str(net) for net in self._allowed_networks]
}, status_code=403)
# Check admin authentication if required
if self.config.security.require_admin_auth:
admin_token = self._extract_admin_token(request)
if not admin_token:
self.logger.warning(f"Token management access denied for IP {client_ip}: missing admin token")
return JSONResponse({
"error": "Admin authentication required",
"message": "Token management requires admin authentication",
"hint": "Provide admin token in Authorization header: 'Bearer <admin_token>' or 'X-Admin-Token: <admin_token>'"
}, status_code=401)
if not self._verify_admin_token(admin_token):
self.logger.warning(f"Token management access denied for IP {client_ip}: invalid admin token")
return JSONResponse({
"error": "Invalid admin token",
"message": "The provided admin token is invalid"
}, status_code=401)
# Log successful access
self.logger.info(f"Token management access granted for IP {client_ip} to {request.url.path}")
# Access granted
return None
def get_security_info(self) -> Dict[str, Any]:
"""Get current security configuration info (for demo/status pages)"""
return {
"http_token_management_enabled": self.config.security.enable_http_token_management,
"admin_auth_required": self.config.security.require_admin_auth,
"admin_token_configured": bool(self._admin_token_hash),
"allowed_networks_count": len(self._allowed_networks),
"allowed_networks": [str(net) for net in self._allowed_networks],
"security_features": [
"IP address restrictions",
"Admin token authentication" if self.config.security.require_admin_auth else "Optional admin authentication",
"Secure token hashing",
"Request logging and auditing"
]
}
def generate_admin_token(self) -> str:
"""Generate a secure admin token"""
return secrets.token_urlsafe(32)
# Convenience function for middleware creation
def create_token_security_middleware(config: DorisConfig) -> TokenSecurityMiddleware:
"""Create token security middleware with configuration"""
return TokenSecurityMiddleware(config)

View File

@@ -0,0 +1,365 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
JWT Token Validation Module
Provides token validation, blacklist management and security features
"""
import time
import asyncio
from typing import Dict, Set, Optional, Any
from datetime import datetime, timedelta
from collections import defaultdict
from ..utils.logger import get_logger
logger = get_logger(__name__)
class TokenBlacklist:
"""JWT Token Blacklist Manager
Manages revoked tokens to prevent revoked tokens from being used again
Supports both in-memory and persistent storage
"""
def __init__(self, cleanup_interval: int = 3600):
"""Initialize token blacklist
Args:
cleanup_interval: Interval for cleaning up expired tokens (seconds)
"""
self.cleanup_interval = cleanup_interval
# Storage format: {token_jti: expiry_timestamp}
self._blacklisted_tokens: Dict[str, float] = {}
self._cleanup_task = None
logger.info("TokenBlacklist initialized")
async def start(self):
"""Start blacklist manager"""
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
logger.info("TokenBlacklist started with periodic cleanup")
async def stop(self):
"""Stop blacklist manager"""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
logger.info("TokenBlacklist stopped")
async def add_token(self, jti: str, exp: float):
"""Add token to blacklist
Args:
jti: JWT ID (unique identifier)
exp: Token expiration timestamp
"""
self._blacklisted_tokens[jti] = exp
logger.info(f"Token {jti} added to blacklist")
async def is_blacklisted(self, jti: str) -> bool:
"""Check if token is blacklisted
Args:
jti: JWT ID
Returns:
True if blacklisted, False otherwise
"""
return jti in self._blacklisted_tokens
async def remove_token(self, jti: str) -> bool:
"""Remove token from blacklist
Args:
jti: JWT ID
Returns:
True if removed, False if not found
"""
if jti in self._blacklisted_tokens:
del self._blacklisted_tokens[jti]
logger.info(f"Token {jti} removed from blacklist")
return True
return False
async def cleanup_expired(self) -> int:
"""Clean up expired blacklisted tokens
Returns:
Number of tokens cleaned up
"""
current_time = time.time()
expired_tokens = [
jti for jti, exp in self._blacklisted_tokens.items()
if exp <= current_time
]
for jti in expired_tokens:
del self._blacklisted_tokens[jti]
if expired_tokens:
logger.info(f"Cleaned up {len(expired_tokens)} expired tokens from blacklist")
return len(expired_tokens)
async def get_stats(self) -> Dict[str, Any]:
"""Get blacklist statistics"""
current_time = time.time()
active_tokens = sum(1 for exp in self._blacklisted_tokens.values() if exp > current_time)
return {
"total_blacklisted": len(self._blacklisted_tokens),
"active_blacklisted": active_tokens,
"expired_blacklisted": len(self._blacklisted_tokens) - active_tokens,
"cleanup_interval": self.cleanup_interval
}
async def _periodic_cleanup(self):
"""Periodically clean up expired tokens"""
while True:
try:
await asyncio.sleep(self.cleanup_interval)
await self.cleanup_expired()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error during periodic cleanup: {e}")
class RateLimiter:
"""Token usage rate limiter"""
def __init__(self, max_requests: int = 100, time_window: int = 3600):
"""Initialize rate limiter
Args:
max_requests: Maximum requests within time window
time_window: Time window in seconds
"""
self.max_requests = max_requests
self.time_window = time_window
# Storage format: {user_id: [timestamp1, timestamp2, ...]}
self._request_history: Dict[str, list] = defaultdict(list)
logger.info(f"RateLimiter initialized: {max_requests} requests per {time_window} seconds")
async def is_allowed(self, user_id: str) -> bool:
"""Check if user is allowed to make request
Args:
user_id: User ID
Returns:
True if allowed, False otherwise
"""
current_time = time.time()
user_requests = self._request_history[user_id]
# Clean up expired request records
cutoff_time = current_time - self.time_window
user_requests[:] = [t for t in user_requests if t > cutoff_time]
# Check if limit exceeded
if len(user_requests) >= self.max_requests:
logger.warning(f"Rate limit exceeded for user {user_id}")
return False
# Record current request
user_requests.append(current_time)
return True
async def get_usage(self, user_id: str) -> Dict[str, Any]:
"""Get user usage information
Args:
user_id: User ID
Returns:
Usage statistics
"""
current_time = time.time()
user_requests = self._request_history[user_id]
# Clean up expired records
cutoff_time = current_time - self.time_window
active_requests = [t for t in user_requests if t > cutoff_time]
return {
"user_id": user_id,
"requests_in_window": len(active_requests),
"max_requests": self.max_requests,
"time_window": self.time_window,
"remaining_requests": max(0, self.max_requests - len(active_requests))
}
class TokenValidator:
"""JWT Token Validator
Provides comprehensive JWT token validation functionality, including signature verification,
claim validation, blacklist checking and rate limiting
"""
def __init__(self, config, blacklist: Optional[TokenBlacklist] = None):
"""Initialize token validator
Args:
config: DorisConfig configuration object (with security attribute)
blacklist: Token blacklist manager
"""
self.config = config
self.blacklist = blacklist or TokenBlacklist()
self.rate_limiter = RateLimiter()
# Access JWT settings through the security configuration
if hasattr(config, 'security'):
security_config = config.security
else:
# Fallback if config is passed directly as SecurityConfig
security_config = config
# Validation options
self.verify_signature = security_config.jwt_verify_signature
self.verify_audience = security_config.jwt_verify_audience
self.verify_issuer = security_config.jwt_verify_issuer
self.require_exp = security_config.jwt_require_exp
self.require_iat = security_config.jwt_require_iat
self.require_nbf = security_config.jwt_require_nbf
self.leeway = security_config.jwt_leeway
# Expected values
self.expected_audience = security_config.jwt_audience
self.expected_issuer = security_config.jwt_issuer
logger.info("TokenValidator initialized")
async def validate_claims(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""Validate JWT claims
Args:
payload: JWT payload
Returns:
Validation result
Raises:
ValueError: Validation failed
"""
current_time = time.time()
# Validate issuer
if self.verify_issuer:
if payload.get('iss') != self.expected_issuer:
raise ValueError(f"Invalid issuer: expected {self.expected_issuer}")
# Validate audience
if self.verify_audience:
aud = payload.get('aud')
if isinstance(aud, list):
if self.expected_audience not in aud:
raise ValueError(f"Invalid audience: {self.expected_audience} not in {aud}")
elif aud != self.expected_audience:
raise ValueError(f"Invalid audience: expected {self.expected_audience}")
# Validate expiration time
if self.require_exp or 'exp' in payload:
exp = payload.get('exp')
if not exp:
raise ValueError("Missing 'exp' claim")
if current_time > exp + self.leeway:
raise ValueError("Token has expired")
# Validate not before time
if self.require_nbf or 'nbf' in payload:
nbf = payload.get('nbf')
if not nbf:
raise ValueError("Missing 'nbf' claim")
if current_time < nbf - self.leeway:
raise ValueError("Token not yet valid")
# Validate issued at time
if self.require_iat or 'iat' in payload:
iat = payload.get('iat')
if not iat:
raise ValueError("Missing 'iat' claim")
# Allow some clock skew, but cannot be future time
if iat > current_time + self.leeway:
raise ValueError("Token issued in the future")
# Check blacklist
jti = payload.get('jti')
if jti and await self.blacklist.is_blacklisted(jti):
raise ValueError("Token has been revoked")
# Rate limit check
user_id = payload.get('sub')
if user_id:
if not await self.rate_limiter.is_allowed(user_id):
raise ValueError("Rate limit exceeded")
return {
"valid": True,
"user_id": user_id,
"payload": payload
}
async def start(self):
"""Start validator"""
await self.blacklist.start()
logger.info("TokenValidator started")
async def stop(self):
"""Stop validator"""
await self.blacklist.stop()
logger.info("TokenValidator stopped")
async def revoke_token(self, jti: str, exp: float):
"""Revoke token
Args:
jti: JWT ID
exp: Token expiration time
"""
await self.blacklist.add_token(jti, exp)
logger.info(f"Token {jti} has been revoked")
async def get_validation_stats(self) -> Dict[str, Any]:
"""Get validation statistics"""
blacklist_stats = await self.blacklist.get_stats()
return {
"blacklist": blacklist_stats,
"validation_config": {
"verify_signature": self.verify_signature,
"verify_audience": self.verify_audience,
"verify_issuer": self.verify_issuer,
"require_exp": self.require_exp,
"require_iat": self.require_iat,
"require_nbf": self.require_nbf,
"leeway": self.leeway
}
}
async def get_user_rate_limit_info(self, user_id: str) -> Dict[str, Any]:
"""Get user rate limit information"""
return await self.rate_limiter.get_usage(user_id)

View File

@@ -28,14 +28,182 @@ import json
import logging import logging
from typing import Any from typing import Any
from mcp.server import Server # MCP version compatibility handling
from mcp.server.models import InitializationOptions MCP_VERSION = 'unknown'
Server = None
InitializationOptions = None
Prompt = None
Resource = None
TextContent = None
Tool = None
def _import_mcp_with_compatibility():
"""Import MCP components with multi-version compatibility"""
global MCP_VERSION, Server, InitializationOptions, Prompt, Resource, TextContent, Tool
try:
# Strategy 1: Try direct server-only imports to avoid client-side issues
from mcp.server import Server as _Server
from mcp.server.models import InitializationOptions as _InitOptions
from mcp.types import ( from mcp.types import (
Prompt, Prompt as _Prompt,
Resource, Resource as _Resource,
TextContent, TextContent as _TextContent,
Tool, Tool as _Tool,
)
# Assign to globals
Server = _Server
InitializationOptions = _InitOptions
Prompt = _Prompt
Resource = _Resource
TextContent = _TextContent
Tool = _Tool
# Try to get version safely
try:
import mcp
MCP_VERSION = getattr(mcp, '__version__', None)
if not MCP_VERSION:
# Fallback: try to get version from package metadata
try:
import importlib.metadata
MCP_VERSION = importlib.metadata.version('mcp')
except Exception:
# Second fallback: try pkg_resources
try:
import pkg_resources
MCP_VERSION = pkg_resources.get_distribution('mcp').version
except Exception:
MCP_VERSION = 'detected-but-version-unknown'
except Exception:
# Version detection failed, but imports worked
try:
import importlib.metadata
MCP_VERSION = importlib.metadata.version('mcp')
except Exception:
try:
import pkg_resources
MCP_VERSION = pkg_resources.get_distribution('mcp').version
except Exception:
MCP_VERSION = 'imported-successfully'
logger = logging.getLogger(__name__)
logger.info(f"MCP components imported successfully, version: {MCP_VERSION}")
return True
except Exception as import_error:
logger = logging.getLogger(__name__)
# Strategy 2: Handle RequestContext compatibility issues in 1.9.x versions
error_str = str(import_error).lower()
if 'requestcontext' in error_str and 'too few arguments' in error_str:
logger.warning(f"Detected MCP RequestContext compatibility issue: {import_error}")
logger.info("Attempting comprehensive workaround for MCP 1.9.x RequestContext issue...")
try:
# Comprehensive monkey patch approach
import sys
import types
# Create and install mock modules before any MCP imports
if 'mcp.shared.context' not in sys.modules:
mock_context_module = types.ModuleType('mcp.shared.context')
class FlexibleRequestContext:
"""Flexible RequestContext that accepts variable arguments"""
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
def __class_getitem__(cls, params):
# Accept any number of parameters and return cls
return cls
# Add other methods that might be called
def __getattr__(self, name):
return lambda *args, **kwargs: None
mock_context_module.RequestContext = FlexibleRequestContext
sys.modules['mcp.shared.context'] = mock_context_module
# Also patch the typing system to be more permissive
original_check_generic = None
try:
import typing
if hasattr(typing, '_check_generic'):
original_check_generic = typing._check_generic
def permissive_check_generic(cls, params, elen):
# Don't enforce strict parameter count checking
return
typing._check_generic = permissive_check_generic
except Exception:
pass
# Clear any cached imports that might have failed
modules_to_clear = [k for k in sys.modules.keys() if k.startswith('mcp.')]
for module in modules_to_clear:
if module in sys.modules:
del sys.modules[module]
# Now try importing again with the patches in place
from mcp.server import Server as _Server
from mcp.server.models import InitializationOptions as _InitOptions
from mcp.types import (
Prompt as _Prompt,
Resource as _Resource,
TextContent as _TextContent,
Tool as _Tool,
)
# Assign to globals
Server = _Server
InitializationOptions = _InitOptions
Prompt = _Prompt
Resource = _Resource
TextContent = _TextContent
Tool = _Tool
# Try to detect actual version even in compatibility mode
try:
import importlib.metadata
actual_version = importlib.metadata.version('mcp')
MCP_VERSION = f'compatibility-mode-{actual_version}'
except Exception:
try:
import pkg_resources
actual_version = pkg_resources.get_distribution('mcp').version
MCP_VERSION = f'compatibility-mode-{actual_version}'
except Exception:
MCP_VERSION = 'compatibility-mode-1.9.x'
logger.info("MCP 1.9.x compatibility workaround successful!")
# Restore original typing function if we patched it
if original_check_generic:
typing._check_generic = original_check_generic
return True
except Exception as workaround_error:
logger.error(f"MCP compatibility workaround failed: {workaround_error}")
# Restore original typing function if we patched it
if original_check_generic:
try:
import typing
typing._check_generic = original_check_generic
except Exception:
pass
logger.error(f"Failed to import MCP components: {import_error}")
return False
# Perform MCP import with compatibility handling
if not _import_mcp_with_compatibility():
raise ImportError(
"Failed to import MCP components. Please ensure MCP is properly installed. "
"Supported versions: 1.8.x, 1.9.x"
) )
from .tools.tools_manager import DorisToolsManager from .tools.tools_manager import DorisToolsManager
@@ -44,11 +212,16 @@ from .tools.resources_manager import DorisResourcesManager
from .utils.config import DorisConfig from .utils.config import DorisConfig
from .utils.db import DorisConnectionManager from .utils.db import DorisConnectionManager
from .utils.security import DorisSecurityManager from .utils.security import DorisSecurityManager
import os
# Configure logging # Configure logging - will be properly initialized later
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Create a default config instance for getting default values
_default_config = DorisConfig()
class DorisServer: class DorisServer:
"""Apache Doris MCP Server main class""" """Apache Doris MCP Server main class"""
@@ -57,20 +230,101 @@ class DorisServer:
self.config = config self.config = config
self.server = Server("doris-mcp-server") self.server = Server("doris-mcp-server")
# Initialize security manager # Initialize security manager (without connection_manager initially)
self.security_manager = DorisSecurityManager(config) self.security_manager = DorisSecurityManager(config)
# Initialize connection manager, pass in security manager # Initialize connection manager, pass in security manager and token manager for token-bound DB config
self.connection_manager = DorisConnectionManager(config, self.security_manager) token_manager = self.security_manager.auth_provider.token_manager if hasattr(self.security_manager, 'auth_provider') and hasattr(self.security_manager.auth_provider, 'token_manager') else None
self.connection_manager = DorisConnectionManager(config, self.security_manager, token_manager)
# Set connection manager reference in security manager for database validation
self.security_manager.connection_manager = self.connection_manager
# Initialize independent managers # Initialize independent managers
self.resources_manager = DorisResourcesManager(self.connection_manager) self.resources_manager = DorisResourcesManager(self.connection_manager)
self.tools_manager = DorisToolsManager(self.connection_manager) self.tools_manager = DorisToolsManager(self.connection_manager)
self.prompts_manager = DorisPromptsManager(self.connection_manager) self.prompts_manager = DorisPromptsManager(self.connection_manager)
self.logger = logging.getLogger(f"{__name__}.DorisServer") # Import here to avoid circular imports
from .utils.logger import get_logger
self.logger = get_logger(f"{__name__}.DorisServer")
self._setup_handlers() self._setup_handlers()
async def _extract_auth_info_from_scope(self, scope, headers):
"""Extract authentication information from ASGI scope and headers"""
auth_info = {}
# Extract client IP
client = scope.get("client")
if client:
auth_info["client_ip"] = client[0]
else:
auth_info["client_ip"] = "unknown"
# Extract token from Authorization header
authorization = headers.get(b'authorization', b'').decode('utf-8')
if authorization:
if authorization.startswith('Bearer '):
auth_info["token"] = authorization[7:]
auth_info["authorization"] = authorization
elif authorization.startswith('Token '):
auth_info["token"] = authorization[6:]
auth_info["authorization"] = authorization
# Extract token from query parameters (for compatibility)
query_string = scope.get("query_string", b"").decode('utf-8')
if query_string and "token=" in query_string:
import urllib.parse
query_params = urllib.parse.parse_qs(query_string)
if "token" in query_params:
auth_info["token"] = query_params["token"][0]
# If no token found, this will be handled by the authentication system
# (either return anonymous context if auth disabled, or raise error if auth enabled)
return auth_info
def _get_mcp_capabilities(self):
"""Get MCP capabilities with version compatibility"""
try:
# For MCP 1.9.x and newer
from mcp.server.lowlevel.server import NotificationOptions
return self.server.get_capabilities(
notification_options=NotificationOptions(
prompts_changed=True,
resources_changed=True,
tools_changed=True
),
experimental_capabilities={}
)
except TypeError:
try:
# For MCP 1.8.x
from mcp.server.lowlevel.server import NotificationOptions
return self.server.get_capabilities(
notification_options=NotificationOptions(
prompts_changed=True,
resources_changed=True,
tools_changed=True
),
experimental_capabilities={}
)
except Exception as e:
self.logger.warning(f"Could not get capabilities with NotificationOptions: {e}")
# Fallback for older versions
try:
return self.server.get_capabilities()
except Exception as fallback_e:
self.logger.error(f"Failed to get capabilities: {fallback_e}")
# Return minimal capabilities
return {
"resources": {},
"tools": {},
"prompts": {}
}
def _setup_handlers(self): def _setup_handlers(self):
"""Setup MCP protocol handlers""" """Setup MCP protocol handlers"""
@@ -174,12 +428,24 @@ class DorisServer:
self.logger.info("Starting Doris MCP Server (stdio mode)") self.logger.info("Starting Doris MCP Server (stdio mode)")
try: try:
# Ensure connection manager is initialized # Initialize security manager first (includes JWT setup if enabled)
await self.connection_manager.initialize() await self.security_manager.initialize()
self.logger.info("Connection manager initialization completed") self.logger.info("Security manager initialization completed")
# Start stdio server - using simpler approach # For stdio mode, we must establish a working database connection
# Use the dedicated stdio mode initialization method
await self.connection_manager.initialize_for_stdio_mode()
# Start stdio server - using compatible import approach
try:
from mcp.server.stdio import stdio_server from mcp.server.stdio import stdio_server
except ImportError:
# Fallback for different MCP versions
try:
from mcp.server import stdio_server
except ImportError as stdio_import_error:
self.logger.error(f"Failed to import stdio_server: {stdio_import_error}")
raise RuntimeError("stdio_server module not available in this MCP version")
self.logger.info("Creating stdio_server transport...") self.logger.info("Creating stdio_server transport...")
@@ -189,22 +455,12 @@ class DorisServer:
read_stream, write_stream = streams read_stream, write_stream = streams
self.logger.info("stdio_server streams created successfully") self.logger.info("stdio_server streams created successfully")
# Create initialization options # Create initialization options with version compatibility
# MCP 1.8.0 requires parameters for get_capabilities capabilities = self._get_mcp_capabilities()
from mcp.server.lowlevel.server import NotificationOptions
capabilities = self.server.get_capabilities(
notification_options=NotificationOptions(
prompts_changed=True,
resources_changed=True,
tools_changed=True
),
experimental_capabilities={}
)
init_options = InitializationOptions( init_options = InitializationOptions(
server_name="doris-mcp-server", server_name="doris-mcp-server",
server_version="1.0.0", server_version=os.getenv("SERVER_VERSION", _default_config.server_version),
capabilities=capabilities, capabilities=capabilities,
) )
self.logger.info("Initialization options created successfully") self.logger.info("Initialization options created successfully")
@@ -237,13 +493,21 @@ class DorisServer:
async def start_http(self, host: str = "localhost", port: int = 3000): async def start_http(self, host: str = os.getenv("SERVER_HOST", _default_config.database.host), port: int = os.getenv("SERVER_PORT", _default_config.server_port), workers: int = 1):
"""Start Streamable HTTP transport mode""" """Start Streamable HTTP transport mode with workers support"""
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}") self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}, workers: {workers}")
try: try:
# Ensure connection manager is initialized # Initialize security manager first (includes JWT setup if enabled)
await self.connection_manager.initialize() await self.security_manager.initialize()
self.logger.info("Security manager initialization completed")
# For HTTP mode, try to initialize global connection pool with graceful degradation
global_pool_created = await self.connection_manager.initialize_for_http_mode()
if global_pool_created:
self.logger.info("Global database connection pool available for HTTP mode")
else:
self.logger.info("HTTP mode running without global database pool, will use token-bound configurations")
# Use Starlette and StreamableHTTPSessionManager according to official example # Use Starlette and StreamableHTTPSessionManager according to official example
import uvicorn import uvicorn
@@ -251,9 +515,9 @@ class DorisServer:
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.routing import Mount, Route from starlette.routing import Route
from starlette.responses import JSONResponse, Response from starlette.responses import JSONResponse, Response
from starlette.types import Receive, Scope, Send from starlette.types import Scope
# Create session manager # Create session manager
session_manager = StreamableHTTPSessionManager( session_manager = StreamableHTTPSessionManager(
@@ -268,6 +532,44 @@ class DorisServer:
async def health_check(request): async def health_check(request):
return JSONResponse({"status": "healthy", "service": "doris-mcp-server"}) return JSONResponse({"status": "healthy", "service": "doris-mcp-server"})
# OAuth endpoints
from .auth.oauth_handlers import OAuthHandlers
oauth_handlers = OAuthHandlers(self.security_manager)
async def oauth_login(request):
return await oauth_handlers.handle_login(request)
async def oauth_callback(request):
return await oauth_handlers.handle_callback(request)
async def oauth_provider_info(request):
return await oauth_handlers.handle_provider_info(request)
async def oauth_demo(request):
return await oauth_handlers.handle_demo_page(request)
# Token management endpoints
from .auth.token_handlers import TokenHandlers
token_handlers = TokenHandlers(self.security_manager, self.config)
async def token_create(request):
return await token_handlers.handle_create_token(request)
async def token_revoke(request):
return await token_handlers.handle_revoke_token(request)
async def token_list(request):
return await token_handlers.handle_list_tokens(request)
async def token_stats(request):
return await token_handlers.handle_token_stats(request)
async def token_cleanup(request):
return await token_handlers.handle_cleanup_tokens(request)
async def token_management(request):
return await token_handlers.handle_management_page(request)
# Lifecycle manager - simplified since we manage session_manager externally # Lifecycle manager - simplified since we manage session_manager externally
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def lifespan(app: Starlette) -> AsyncIterator[None]: async def lifespan(app: Starlette) -> AsyncIterator[None]:
@@ -283,6 +585,18 @@ class DorisServer:
debug=True, debug=True,
routes=[ routes=[
Route("/health", health_check, methods=["GET"]), Route("/health", health_check, methods=["GET"]),
# OAuth endpoints
Route("/auth/login", oauth_login, methods=["GET"]),
Route("/auth/callback", oauth_callback, methods=["GET"]),
Route("/auth/provider", oauth_provider_info, methods=["GET"]),
Route("/auth/demo", oauth_demo, methods=["GET"]),
# Token management endpoints
Route("/token/create", token_create, methods=["GET", "POST"]),
Route("/token/revoke", token_revoke, methods=["GET", "DELETE"]),
Route("/token/list", token_list, methods=["GET"]),
Route("/token/stats", token_stats, methods=["GET"]),
Route("/token/cleanup", token_cleanup, methods=["GET", "POST"]),
Route("/token/management", token_management, methods=["GET"]),
], ],
lifespan=lifespan, lifespan=lifespan,
) )
@@ -300,8 +614,10 @@ class DorisServer:
self.logger.info(f"Received request for path: {path}") self.logger.info(f"Received request for path: {path}")
try: try:
# Handle health check # Handle health check, auth, and token management endpoints
if path.startswith("/health"): if (path.startswith("/health") or
path.startswith("/auth/") or
path.startswith("/token/")):
await starlette_app(scope, receive, send) await starlette_app(scope, receive, send)
return return
@@ -314,6 +630,39 @@ class DorisServer:
self.logger.info(f"MCP Request - Method: {method}") self.logger.info(f"MCP Request - Method: {method}")
self.logger.info(f"MCP Request - Headers: {headers}") self.logger.info(f"MCP Request - Headers: {headers}")
# Authentication check for MCP requests
try:
# Extract authentication information
auth_info = await self._extract_auth_info_from_scope(scope, headers)
# Authenticate the request
auth_context = await self.security_manager.authenticate_request(auth_info)
self.logger.info(f"MCP request authenticated: token_id={auth_context.token_id}, client_ip={auth_context.client_ip}")
# Store auth context in scope for potential use by tools/resources
scope["auth_context"] = auth_context
# FIX for Issue #62 Bug 1: Set auth_context in context variable
# This allows tools to access token information for token-bound database configuration
# CRITICAL: Use the global ContextVar from security.py to ensure same instance is used everywhere
try:
from .utils.security import mcp_auth_context_var
mcp_auth_context_var.set(auth_context)
self.logger.debug(f"Set auth_context in context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
except Exception as ctx_error:
self.logger.warning(f"Failed to set auth_context in context variable: {ctx_error}")
except Exception as auth_error:
self.logger.error(f"MCP authentication failed: {auth_error}")
# Return 401 Unauthorized
from starlette.responses import JSONResponse
response = JSONResponse(
{"error": "Authentication required", "message": str(auth_error)},
status_code=401
)
await response(scope, receive, send)
return
# Handle Dify compatibility for GET requests # Handle Dify compatibility for GET requests
if method == "GET": if method == "GET":
accept_header = headers.get(b'accept', b'').decode('utf-8') accept_header = headers.get(b'accept', b'').decode('utf-8')
@@ -356,7 +705,23 @@ class DorisServer:
self.logger.warning(f"Unsupported scope type: {scope['type']}") self.logger.warning(f"Unsupported scope type: {scope['type']}")
return return
# Start uvicorn server with session manager lifecycle # Choose startup method based on worker count
if workers > 1:
self.logger.info(f"Using multi-process mode with {workers} workers")
self.logger.info("Note: Multi-worker mode provides full MCP functionality with independent worker processes")
# Use the dedicated multiworker app module with full MCP support
uvicorn.run(
"doris_mcp_server.multiworker_app:app",
host=host,
port=port,
workers=workers,
log_level="info"
)
else:
self.logger.info("Using single-process mode")
# Single worker mode, use original logic with session manager lifecycle
config = uvicorn.Config( config = uvicorn.Config(
app=mcp_app, app=mcp_app,
host=host, host=host,
@@ -383,10 +748,16 @@ class DorisServer:
self.logger.error(f" Exception {i+1}: {type(exc).__name__}: {exc}") self.logger.error(f" Exception {i+1}: {type(exc).__name__}: {exc}")
raise raise
async def shutdown(self): async def shutdown(self):
"""Shutdown server""" """Shutdown server"""
self.logger.info("Shutting down Doris MCP Server") self.logger.info("Shutting down Doris MCP Server")
try: try:
# Shutdown security manager first (includes JWT cleanup)
await self.security_manager.shutdown()
self.logger.info("Security manager shutdown completed")
await self.connection_manager.close() await self.connection_manager.close()
self.logger.info("Doris MCP Server has been shut down") self.logger.info("Doris MCP Server has been shut down")
except Exception as e: except Exception as e:
@@ -406,6 +777,11 @@ Transport Modes:
Examples: Examples:
python -m doris_mcp_server --transport stdio python -m doris_mcp_server --transport stdio
python -m doris_mcp_server --transport http --host 0.0.0.0 --port 3000 python -m doris_mcp_server --transport http --host 0.0.0.0 --port 3000
python -m doris_mcp_server --transport stdio --doris-host localhost --doris-port 9030
python -m doris_mcp_server --transport http --doris-user admin --doris-database test_db
# Backward compatibility: --db-* parameters are also supported
python -m doris_mcp_server --transport stdio --db-host localhost --db-port 9030
""" """
) )
@@ -413,91 +789,151 @@ Examples:
"--transport", "--transport",
type=str, type=str,
choices=["stdio", "http"], choices=["stdio", "http"],
default="stdio", default=os.getenv("TRANSPORT", _default_config.transport),
help="Transport protocol type: stdio (local), http (Streamable HTTP)", help=f"Transport protocol type: stdio (local), http (Streamable HTTP) (default: {_default_config.transport})",
) )
parser.add_argument( parser.add_argument(
"--host", "--host",
type=str, type=str,
default="localhost", default=os.getenv("SERVER_HOST", _default_config.server_host),
help="Host address for HTTP mode (default: localhost)", help=f"Host address for HTTP mode (default: {_default_config.server_host})",
) )
parser.add_argument( parser.add_argument(
"--port", type=int, default=3000, help="Port number for HTTP mode (default: 3000)" "--port",
type=int,
default=3000,
help="Port number for HTTP mode (default: 3000)"
) )
parser.add_argument( parser.add_argument(
"--db-host", "--workers",
type=int,
default=1,
help="Number of worker processes for HTTP mode (default: 1, use 0 for auto-detect CPU cores)"
)
parser.add_argument(
"--doris-host", "--db-host",
type=str, type=str,
default="localhost", default=os.getenv("DORIS_HOST", _default_config.database.host),
help="Doris database host address (default: localhost)", help=f"Doris database host address (default: {_default_config.database.host})",
) )
parser.add_argument( parser.add_argument(
"--db-port", type=int, default=9030, help="Doris database port number (default: 9030)" "--doris-port", "--db-port", type=int, default=9030, help="Doris database port number (default: 9030)"
) )
parser.add_argument( parser.add_argument(
"--db-user", type=str, default="root", help="Doris database username (default: root)" "--doris-user", "--db-user", type=str, default=os.getenv("DORIS_USER", _default_config.database.user), help=f"Doris database username (default: {_default_config.database.user})"
) )
parser.add_argument("--db-password", type=str, default="", help="Doris database password") parser.add_argument("--doris-password", "--db-password", type=str, default=os.getenv("DORIS_PASSWORD", ""), help="Doris database password")
parser.add_argument( parser.add_argument(
"--db-database", "--doris-database", "--db-database",
type=str, type=str,
default="information_schema", default=os.getenv("DORIS_DATABASE", _default_config.database.database),
help="Doris database name (default: information_schema)", help=f"Doris database name (default: {_default_config.database.database})",
) )
parser.add_argument( parser.add_argument(
"--log-level", "--log-level",
type=str, type=str,
choices=["DEBUG", "INFO", "WARNING", "ERROR"], choices=["DEBUG", "INFO", "WARNING", "ERROR"],
default="INFO", default=os.getenv("LOG_LEVEL", _default_config.logging.level),
help="Log level (default: INFO)", help=f"Log level (default: {_default_config.logging.level})",
) )
return parser return parser
async def main(): def update_configuration(config: DorisConfig):
"""Main function""" """Update doris configuration object"""
# For some arguments, if not specified, environment variables or default configurations will be used as default values
parser = create_arg_parser() parser = create_arg_parser()
args = parser.parse_args() args = parser.parse_args()
# Set log level # Update config values
logging.getLogger().setLevel(getattr(logging, args.log_level))
# Create configuration - priority: command line arguments > .env file > default values
config = DorisConfig.from_env() # First load from .env file and environment variables
# Command line arguments override configuration (if provided) # Command line arguments override configuration (if provided)
if args.db_host != "localhost": # If not default value, use command line argument # basic
config.database.host = args.db_host if args.transport != _default_config.transport:
if args.db_port != 9030: config.transport = args.transport
config.database.port = args.db_port if args.host != _default_config.server_host:
if args.db_user != "root": config.server_host = args.host
config.database.user = args.db_user if args.port != _default_config.server_port:
if args.db_password: # Use password if provided config.server_port = args.port
config.database.password = args.db_password server_name = os.getenv("SERVER_NAME")
if args.db_database != "information_schema": if server_name:
config.database.database = args.db_database config.server_name = server_name
if args.log_level != "INFO": server_version = os.getenv("SERVER_VERSION")
if server_version:
config.server_version = server_version
# database
if args.doris_host != _default_config.database.host: # If not default value, use command line argument
config.database.host = args.doris_host
if args.doris_port != _default_config.database.port:
config.database.port = args.doris_port
if args.doris_user != _default_config.database.user:
config.database.user = args.doris_user
if args.doris_password: # Use password if provided
config.database.password = args.doris_password
if args.doris_database != _default_config.database.database:
config.database.database = args.doris_database
# logging
if args.log_level != _default_config.logging.level:
config.logging.level = args.log_level config.logging.level = args.log_level
# workers (add to config for HTTP mode)
if hasattr(args, 'workers'):
config.workers = args.workers
async def main():
"""Main function"""
# Create configuration - priority: command line arguments > env variables > .env file > default values
# First load from .env file and environment variables
config = DorisConfig.from_env()
# Then parse the command line arguments, and update the config object.
update_configuration(config)
# Initialize enhanced logging system
from .utils.config import ConfigManager
config_manager = ConfigManager(config)
config_manager.setup_logging()
# Get logger with proper configuration
from .utils.logger import get_logger, log_system_info
logger = get_logger(__name__)
# Log system information for debugging
log_system_info()
logger.info("Starting Doris MCP Server...")
logger.info(f"Transport: {config.transport}")
logger.info(f"Log Level: {config.logging.level}")
# Create server instance # Create server instance
server = DorisServer(config) server = DorisServer(config)
try: try:
if args.transport == "stdio": if config.transport == "stdio":
await server.start_stdio() await server.start_stdio()
elif args.transport == "http": elif config.transport == "http":
await server.start_http(args.host, args.port) # Get workers configuration with auto-detection support
workers = getattr(config, 'workers', 1)
if workers == 0:
import multiprocessing
workers = multiprocessing.cpu_count()
logger.info(f"Auto-detected {workers} CPU cores for worker processes")
await server.start_http(config.server_host, config.server_port, workers)
else: else:
logger.error(f"Unsupported transport protocol: {args.transport}") logger.error(f"Unsupported transport protocol: {config.transport}")
await server.shutdown() await server.shutdown()
return 1 return 1
@@ -518,6 +954,10 @@ async def main():
except Exception as shutdown_error: except Exception as shutdown_error:
logger.error(f"Error occurred while shutting down server: {shutdown_error}") logger.error(f"Error occurred while shutting down server: {shutdown_error}")
# Shutdown logging system
from .utils.logger import shutdown_logging
shutdown_logging()
return 0 return 0

View File

@@ -0,0 +1,627 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Multi-worker application module for doris-mcp-server
This module provides full MCP functionality with multi-worker support.
Each worker process creates its own MCP server and session manager using the same
robust architecture as the single-worker mode.
"""
import os
import asyncio
from contextlib import asynccontextmanager
import json
import logging
from typing import Any
# Import MCP components with compatibility handling
# Use the same import strategy as main.py for consistency
MCP_VERSION = 'unknown'
Server = None
InitializationOptions = None
Prompt = None
Resource = None
TextContent = None
Tool = None
def _import_mcp_with_compatibility():
"""Import MCP components with multi-version compatibility"""
global MCP_VERSION, Server, InitializationOptions, Prompt, Resource, TextContent, Tool
try:
# Strategy 1: Try direct server-only imports to avoid client-side issues
from mcp.server import Server as _Server
from mcp.server.models import InitializationOptions as _InitOptions
from mcp.types import (
Prompt as _Prompt,
Resource as _Resource,
TextContent as _TextContent,
Tool as _Tool,
)
# Assign to globals
Server = _Server
InitializationOptions = _InitOptions
Prompt = _Prompt
Resource = _Resource
TextContent = _TextContent
Tool = _Tool
# Try to get version safely
try:
import mcp
MCP_VERSION = getattr(mcp, '__version__', None)
if not MCP_VERSION:
# Fallback: try to get version from package metadata
try:
import importlib.metadata
MCP_VERSION = importlib.metadata.version('mcp')
except Exception:
# Second fallback: try pkg_resources
try:
import pkg_resources
MCP_VERSION = pkg_resources.get_distribution('mcp').version
except Exception:
MCP_VERSION = 'detected-but-version-unknown'
except Exception:
# Version detection failed, but imports worked
try:
import importlib.metadata
MCP_VERSION = importlib.metadata.version('mcp')
except Exception:
try:
import pkg_resources
MCP_VERSION = pkg_resources.get_distribution('mcp').version
except Exception:
MCP_VERSION = 'imported-successfully'
logger = logging.getLogger(__name__)
logger.info(f"MCP components imported successfully in multiworker, version: {MCP_VERSION}")
return True
except Exception as import_error:
logger = logging.getLogger(__name__)
# Strategy 2: Handle RequestContext compatibility issues in 1.9.x versions
error_str = str(import_error).lower()
if 'requestcontext' in error_str and 'too few arguments' in error_str:
logger.warning(f"Detected MCP RequestContext compatibility issue: {import_error}")
logger.info("Attempting comprehensive workaround for MCP 1.9.x RequestContext issue...")
try:
# Comprehensive monkey patch approach
import sys
import types
# Create and install mock modules before any MCP imports
if 'mcp.shared.context' not in sys.modules:
mock_context_module = types.ModuleType('mcp.shared.context')
class FlexibleRequestContext:
"""Flexible RequestContext that accepts variable arguments"""
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
def __class_getitem__(cls, params):
# Accept any number of parameters and return cls
return cls
# Add other methods that might be called
def __getattr__(self, name):
return lambda *args, **kwargs: None
mock_context_module.RequestContext = FlexibleRequestContext
sys.modules['mcp.shared.context'] = mock_context_module
# Also patch the typing system to be more permissive
original_check_generic = None
try:
import typing
if hasattr(typing, '_check_generic'):
original_check_generic = typing._check_generic
def permissive_check_generic(cls, params, elen):
# Don't enforce strict parameter count checking
return
typing._check_generic = permissive_check_generic
except Exception:
pass
# Clear any cached imports that might have failed
modules_to_clear = [k for k in sys.modules.keys() if k.startswith('mcp.')]
for module in modules_to_clear:
if module in sys.modules:
del sys.modules[module]
# Now try importing again with the patches in place
from mcp.server import Server as _Server
from mcp.server.models import InitializationOptions as _InitOptions
from mcp.types import (
Prompt as _Prompt,
Resource as _Resource,
TextContent as _TextContent,
Tool as _Tool,
)
# Assign to globals
Server = _Server
InitializationOptions = _InitOptions
Prompt = _Prompt
Resource = _Resource
TextContent = _TextContent
Tool = _Tool
# Try to detect actual version even in compatibility mode
try:
import importlib.metadata
actual_version = importlib.metadata.version('mcp')
MCP_VERSION = f'compatibility-mode-{actual_version}'
except Exception:
try:
import pkg_resources
actual_version = pkg_resources.get_distribution('mcp').version
MCP_VERSION = f'compatibility-mode-{actual_version}'
except Exception:
MCP_VERSION = 'compatibility-mode-1.9.x'
logger.info("MCP 1.9.x compatibility workaround successful in multiworker!")
# Restore original typing function if we patched it
if original_check_generic:
typing._check_generic = original_check_generic
return True
except Exception as workaround_error:
logger.error(f"MCP compatibility workaround failed in multiworker: {workaround_error}")
# Restore original typing function if we patched it
if original_check_generic:
try:
import typing
typing._check_generic = original_check_generic
except Exception:
pass
logger.error(f"Failed to import MCP components in multiworker: {import_error}")
return False
# Perform MCP import with compatibility handling
if not _import_mcp_with_compatibility():
raise ImportError(
"Failed to import MCP components in multiworker. Please ensure MCP is properly installed. "
"Supported versions: 1.8.x, 1.9.x"
)
from starlette.applications import Starlette
from starlette.routing import Route
from starlette.responses import JSONResponse, Response
# Import Doris MCP components
from .tools.tools_manager import DorisToolsManager
from .tools.prompts_manager import DorisPromptsManager
from .tools.resources_manager import DorisResourcesManager
from .utils.config import DorisConfig
from .utils.db import DorisConnectionManager
from .utils.security import DorisSecurityManager
# Global variables for worker-specific instances
_worker_server = None
_worker_session_manager = None
_worker_connection_manager = None
_worker_security_manager = None
_worker_session_manager_context = None
_worker_initialized = False
def get_mcp_capabilities():
"""Get MCP capabilities for worker - use the same logic as main.py"""
try:
# For MCP 1.9.x and newer
from mcp.server.lowlevel.server import NotificationOptions
capabilities = {
"resources": {},
"tools": {},
"prompts": {},
"notification_options": {
"prompts_changed": True,
"resources_changed": True,
"tools_changed": True
}
}
return capabilities
except Exception as e:
# Import logger properly
from .utils.logger import get_logger
logger = get_logger(__name__)
logger.warning(f"Failed to get full capabilities in multiworker: {e}")
return {
"resources": {},
"tools": {},
"prompts": {}
}
async def initialize_worker():
"""Initialize MCP server and managers for this worker process"""
global _worker_server, _worker_session_manager, _worker_connection_manager, _worker_security_manager, _worker_session_manager_context, _worker_initialized, _oauth_handlers, _token_handlers
if _worker_initialized:
return
try:
# Import logger properly
from .utils.logger import get_logger
logger = get_logger(__name__)
logger.info(f"Initializing MCP worker process {os.getpid()}")
# Create configuration
config = DorisConfig.from_env()
# Initialize enhanced logging system
from .utils.config import ConfigManager
config_manager = ConfigManager(config)
config_manager.setup_logging()
# Create security manager
_worker_security_manager = DorisSecurityManager(config)
# Initialize security manager first (includes JWT setup if enabled)
await _worker_security_manager.initialize()
logger.info(f"Worker {os.getpid()} security manager initialization completed")
# Create connection manager with token manager for token-bound DB config
token_manager = _worker_security_manager.auth_provider.token_manager if hasattr(_worker_security_manager, 'auth_provider') and hasattr(_worker_security_manager.auth_provider, 'token_manager') else None
_worker_connection_manager = DorisConnectionManager(config, _worker_security_manager, token_manager)
# Set connection manager reference in security manager for database validation
_worker_security_manager.connection_manager = _worker_connection_manager
await _worker_connection_manager.initialize()
# Create MCP server
_worker_server = Server("doris-mcp-server")
# Create managers
resources_manager = DorisResourcesManager(_worker_connection_manager)
tools_manager = DorisToolsManager(_worker_connection_manager)
prompts_manager = DorisPromptsManager(_worker_connection_manager)
# Setup MCP handlers
@_worker_server.list_resources()
async def handle_list_resources() -> list[Resource]:
"""Handle resource list request"""
try:
logger.info("Handling resource list request in worker")
resources = await resources_manager.list_resources()
logger.info(f"Returning {len(resources)} resources from worker")
return resources
except Exception as e:
logger.error(f"Failed to handle resource list request in worker: {e}")
return []
@_worker_server.read_resource()
async def handle_read_resource(uri: str) -> str:
"""Handle resource read request"""
try:
logger.info(f"Handling resource read request in worker: {uri}")
content = await resources_manager.read_resource(uri)
return content
except Exception as e:
logger.error(f"Failed to handle resource read request in worker: {e}")
return json.dumps(
{"error": f"Failed to read resource: {str(e)}", "uri": uri},
ensure_ascii=False,
indent=2,
)
@_worker_server.list_tools()
async def handle_list_tools() -> list[Tool]:
"""Handle tool list request"""
try:
logger.info("Handling tool list request in worker")
tools = await tools_manager.list_tools()
logger.info(f"Returning {len(tools)} tools from worker")
return tools
except Exception as e:
logger.error(f"Failed to handle tool list request in worker: {e}")
return []
@_worker_server.call_tool()
async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
"""Handle tool call request"""
try:
logger.info(f"Handling tool call request in worker: {name}")
result = await tools_manager.call_tool(name, arguments)
return [TextContent(type="text", text=result)]
except Exception as e:
logger.error(f"Failed to handle tool call request in worker: {e}")
error_result = json.dumps(
{
"error": f"Tool call failed: {str(e)}",
"tool_name": name,
"arguments": arguments,
},
ensure_ascii=False,
indent=2,
)
return [TextContent(type="text", text=error_result)]
@_worker_server.list_prompts()
async def handle_list_prompts() -> list[Prompt]:
"""Handle prompt list request"""
try:
logger.info("Handling prompt list request in worker")
prompts = await prompts_manager.list_prompts()
logger.info(f"Returning {len(prompts)} prompts from worker")
return prompts
except Exception as e:
logger.error(f"Failed to handle prompt list request in worker: {e}")
return []
@_worker_server.get_prompt()
async def handle_get_prompt(name: str, arguments: dict[str, Any]) -> str:
"""Handle prompt get request"""
try:
logger.info(f"Handling prompt get request in worker: {name}")
result = await prompts_manager.get_prompt(name, arguments)
return result
except Exception as e:
logger.error(f"Failed to handle prompt get request in worker: {e}")
error_result = json.dumps(
{
"error": f"Failed to get prompt: {str(e)}",
"prompt_name": name,
"arguments": arguments,
},
ensure_ascii=False,
indent=2,
)
return error_result
# Create session manager for this worker
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
_worker_session_manager = StreamableHTTPSessionManager(
app=_worker_server,
json_response=True,
stateless=True # Use stateless mode for multi-worker compatibility
)
# Start the session manager context
_worker_session_manager_context = _worker_session_manager.run()
await _worker_session_manager_context.__aenter__()
# Initialize OAuth and Token handlers
from .auth.oauth_handlers import OAuthHandlers
from .auth.token_handlers import TokenHandlers
_oauth_handlers = OAuthHandlers(_worker_security_manager)
_token_handlers = TokenHandlers(_worker_security_manager, config)
_worker_initialized = True
logger.info(f"Worker {os.getpid()} MCP initialization completed successfully")
except Exception as e:
from .utils.logger import get_logger
logger = get_logger(__name__)
logger.error(f"Failed to initialize worker {os.getpid()}: {e}")
import traceback
logger.error("Complete error stack:")
logger.error(traceback.format_exc())
raise
async def health_check(request):
"""Health check endpoint that shows worker PID"""
return JSONResponse({
"status": "healthy",
"service": "doris-mcp-server",
"worker_pid": os.getpid(),
"worker_mode": "multi-process-full-mcp",
"mcp_initialized": _worker_initialized,
"mcp_version": MCP_VERSION
})
# OAuth and Token handlers (initialize after worker setup)
_oauth_handlers = None
_token_handlers = None
async def oauth_login(request):
"""OAuth login endpoint"""
if not _oauth_handlers:
return JSONResponse({"error": "OAuth not initialized"}, status_code=503)
return await _oauth_handlers.handle_login(request)
async def oauth_callback(request):
"""OAuth callback endpoint"""
if not _oauth_handlers:
return JSONResponse({"error": "OAuth not initialized"}, status_code=503)
return await _oauth_handlers.handle_callback(request)
async def oauth_provider_info(request):
"""OAuth provider info endpoint"""
if not _oauth_handlers:
return JSONResponse({"error": "OAuth not initialized"}, status_code=503)
return await _oauth_handlers.handle_provider_info(request)
async def oauth_demo(request):
"""OAuth demo page endpoint"""
if not _oauth_handlers:
from starlette.responses import HTMLResponse
return HTMLResponse("<h1>OAuth not initialized</h1>")
return await _oauth_handlers.handle_demo_page(request)
# Token management endpoints
async def token_create(request):
"""Token creation endpoint"""
if not _token_handlers:
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
return await _token_handlers.handle_create_token(request)
async def token_revoke(request):
"""Token revocation endpoint"""
if not _token_handlers:
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
return await _token_handlers.handle_revoke_token(request)
async def token_list(request):
"""Token listing endpoint"""
if not _token_handlers:
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
return await _token_handlers.handle_list_tokens(request)
async def token_stats(request):
"""Token statistics endpoint"""
if not _token_handlers:
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
return await _token_handlers.handle_token_stats(request)
async def token_cleanup(request):
"""Token cleanup endpoint"""
if not _token_handlers:
return JSONResponse({"error": "Token handlers not initialized"}, status_code=503)
return await _token_handlers.handle_cleanup_tokens(request)
async def token_management(request):
"""Token management page endpoint"""
if not _token_handlers:
from starlette.responses import HTMLResponse
return HTMLResponse("<h1>Token handlers not initialized</h1>")
return await _token_handlers.handle_management_page(request)
async def root_info(request):
"""Root endpoint"""
return JSONResponse({
"service": "doris-mcp-server",
"mode": "multi-worker-full-mcp",
"worker_pid": os.getpid(),
"mcp_initialized": _worker_initialized,
"mcp_version": MCP_VERSION,
"endpoints": {
"health": "/health",
"mcp": "/mcp"
}
})
@asynccontextmanager
async def lifespan(app):
"""Application lifespan manager"""
# Startup
try:
await initialize_worker()
# Import logger properly
from .utils.logger import get_logger
logger = get_logger(__name__)
logger.info(f"Worker {os.getpid()} startup completed")
yield
finally:
# Shutdown
from .utils.logger import get_logger
logger = get_logger(__name__)
# Close session manager context
if _worker_session_manager_context:
try:
await _worker_session_manager_context.__aexit__(None, None, None)
logger.info(f"Worker {os.getpid()} session manager context closed")
except Exception as e:
logger.error(f"Error closing worker session manager context: {e}")
if _worker_connection_manager:
try:
await _worker_connection_manager.close()
logger.info(f"Worker {os.getpid()} connection manager closed")
except Exception as e:
logger.error(f"Error closing worker connection manager: {e}")
if _worker_security_manager:
try:
await _worker_security_manager.shutdown()
logger.info(f"Worker {os.getpid()} security manager shutdown completed")
except Exception as e:
logger.error(f"Error shutting down worker security manager: {e}")
# Shutdown logging system
try:
from .utils.logger import shutdown_logging
shutdown_logging()
except Exception as e:
logger.error(f"Error shutting down logging system: {e}")
async def mcp_asgi_app(scope, receive, send):
"""ASGI app that handles MCP requests"""
if not _worker_initialized:
# Send error response if worker not initialized
await send({
'type': 'http.response.start',
'status': 503,
'headers': [(b'content-type', b'application/json')]
})
await send({
'type': 'http.response.body',
'body': b'{"error": "Worker not initialized"}'
})
return
# Import logger properly
from .utils.logger import get_logger
logger = get_logger(__name__)
# Get request path for logging
path = scope.get('path', '')
method = scope.get('method', 'UNKNOWN')
logger.debug(f"Worker {os.getpid()} handling MCP request: {method} {path}")
# Handle the request directly without nested run context
await _worker_session_manager.handle_request(scope, receive, send)
# Create Starlette app with basic routes
basic_app = Starlette(
debug=True,
routes=[
Route("/", root_info, methods=["GET"]),
Route("/health", health_check, methods=["GET"]),
# OAuth endpoints
Route("/auth/login", oauth_login, methods=["GET"]),
Route("/auth/callback", oauth_callback, methods=["GET"]),
Route("/auth/provider", oauth_provider_info, methods=["GET"]),
Route("/auth/demo", oauth_demo, methods=["GET"]),
# Token management endpoints
Route("/token/create", token_create, methods=["GET", "POST"]),
Route("/token/revoke", token_revoke, methods=["GET", "DELETE"]),
Route("/token/list", token_list, methods=["GET"]),
Route("/token/stats", token_stats, methods=["GET"]),
Route("/token/cleanup", token_cleanup, methods=["GET", "POST"]),
Route("/token/management", token_management, methods=["GET"]),
],
lifespan=lifespan
)
# Create main ASGI app that routes between basic app and MCP
async def app(scope, receive, send):
"""Main ASGI app that routes requests"""
path = scope.get('path', '/')
if path == "/mcp" or path.startswith('/mcp/'):
# Handle MCP requests with session manager
await mcp_asgi_app(scope, receive, send)
else:
# Handle other requests with basic Starlette app (includes auth endpoints)
await basic_app(scope, receive, send)

View File

@@ -31,6 +31,7 @@ from mcp.types import (
) )
from ..utils.db import DorisConnectionManager from ..utils.db import DorisConnectionManager
from ..utils.sql_security_utils import get_auth_context
class PromptTemplate: class PromptTemplate:
@@ -422,7 +423,8 @@ Please generate accurate and efficient SQL queries based on the above requiremen
AND table_type = 'BASE TABLE' AND table_type = 'BASE TABLE'
""" """
db_result = await connection.execute(db_info_sql) auth_context = get_auth_context()
db_result = await connection.execute(db_info_sql, auth_context=auth_context)
db_info = db_result.data[0] if db_result.data else {} db_info = db_result.data[0] if db_result.data else {}
# Get main table list # Get main table list
@@ -438,7 +440,7 @@ Please generate accurate and efficient SQL queries based on the above requiremen
LIMIT 10 LIMIT 10
""" """
tables_result = await connection.execute(tables_sql) tables_result = await connection.execute(tables_sql, auth_context=auth_context)
context = f"""Current database statistics: context = f"""Current database statistics:
- Total number of tables: {db_info.get("table_count", 0)} - Total number of tables: {db_info.get("table_count", 0)}

View File

@@ -26,6 +26,7 @@ from typing import Any
from mcp.types import Resource from mcp.types import Resource
from ..utils.db import DorisConnectionManager from ..utils.db import DorisConnectionManager
from ..utils.sql_security_utils import get_auth_context
class TableMetadata: class TableMetadata:
@@ -169,7 +170,8 @@ class DorisResourcesManager:
ORDER BY table_name ORDER BY table_name
""" """
result = await connection.execute(tables_query) auth_context = get_auth_context()
result = await connection.execute(tables_query, auth_context=auth_context)
tables = [] tables = []
for row in result.data: for row in result.data:
@@ -204,7 +206,8 @@ class DorisResourcesManager:
ORDER BY ordinal_position ORDER BY ordinal_position
""" """
result = await connection.execute(columns_query, (table_name,)) auth_context = get_auth_context()
result = await connection.execute(columns_query, params=(table_name,), auth_context=auth_context)
return [dict(row) for row in result.data] return [dict(row) for row in result.data]
async def _get_view_metadata(self) -> list[ViewMetadata]: async def _get_view_metadata(self) -> list[ViewMetadata]:
@@ -226,7 +229,8 @@ class DorisResourcesManager:
ORDER BY table_name ORDER BY table_name
""" """
result = await connection.execute(views_query) auth_context = get_auth_context()
result = await connection.execute(views_query, auth_context=auth_context)
views = [] views = []
for row in result.data: for row in result.data:
@@ -257,7 +261,8 @@ class DorisResourcesManager:
AND table_name = %s AND table_name = %s
""" """
table_result = await connection.execute(table_info_query, (table_name,)) auth_context = get_auth_context()
table_result = await connection.execute(table_info_query, params=(table_name,), auth_context=auth_context)
if not table_result.data: if not table_result.data:
raise ValueError(f"Table {table_name} does not exist") raise ValueError(f"Table {table_name} does not exist")
@@ -295,7 +300,8 @@ class DorisResourcesManager:
ORDER BY index_name, seq_in_index ORDER BY index_name, seq_in_index
""" """
result = await connection.execute(indexes_query, (table_name,)) auth_context = get_auth_context()
result = await connection.execute(indexes_query, params=(table_name,), auth_context=auth_context)
return [dict(row) for row in result.data] return [dict(row) for row in result.data]
async def _get_view_definition(self, view_name: str) -> str: async def _get_view_definition(self, view_name: str) -> str:
@@ -312,7 +318,8 @@ class DorisResourcesManager:
AND table_name = %s AND table_name = %s
""" """
result = await connection.execute(view_query, (view_name,)) auth_context = get_auth_context()
result = await connection.execute(view_query, params=(view_name,), auth_context=auth_context)
if not result.data: if not result.data:
raise ValueError(f"View {view_name} does not exist") raise ValueError(f"View {view_name} does not exist")
@@ -340,7 +347,8 @@ class DorisResourcesManager:
AND table_type = 'BASE TABLE' AND table_type = 'BASE TABLE'
""" """
table_result = await connection.execute(table_stats_query) auth_context = get_auth_context()
table_result = await connection.execute(table_stats_query, auth_context=auth_context)
table_stats = table_result.data[0] if table_result.data else {} table_stats = table_result.data[0] if table_result.data else {}
# Get view statistics # Get view statistics
@@ -350,7 +358,7 @@ class DorisResourcesManager:
WHERE table_schema = DATABASE() WHERE table_schema = DATABASE()
""" """
view_result = await connection.execute(view_stats_query) view_result = await connection.execute(view_stats_query, auth_context=auth_context)
view_stats = view_result.data[0] if view_result.data else {} view_stats = view_result.data[0] if view_result.data else {}
stats_info = { stats_info = {

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,542 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Apache Doris ADBC Query Tools
High-performance data querying using Apache Arrow Flight SQL protocol
"""
import os
import socket
import time
from datetime import datetime
from typing import Any, Dict, List, Optional
from ..utils.logger import get_logger
from ..utils.db import DorisConnectionManager
from ..utils.sql_security_utils import get_auth_context
logger = get_logger(__name__)
def _convert_numpy_types(obj):
"""Convert numpy types to native Python types for JSON serialization"""
try:
# Import numpy only when needed
import numpy as np
import pandas as pd
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.bool_):
return bool(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, (pd.Timestamp, pd.NaT.__class__)):
return str(obj)
elif pd.isna(obj):
return None
else:
return obj
except ImportError:
# If numpy/pandas not available, return as-is
return obj
def _convert_dataframe_to_json_serializable(df):
"""Convert DataFrame to JSON serializable format"""
try:
import pandas as pd
import numpy as np
# Convert DataFrame to records
records = df.to_dict('records')
# Convert each record's values
converted_records = []
for record in records:
converted_record = {}
for key, value in record.items():
converted_record[key] = _convert_numpy_types(value)
converted_records.append(converted_record)
return converted_records
except ImportError:
# Fallback to basic dict conversion
return df.to_dict('records')
class DorisADBCQueryTools:
"""ADBC Query Tools for high-performance data transfer using Arrow Flight SQL"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
self.adbc_client = None
self.flight_sql_module = None
self.adbc_manager_module = None
async def exec_adbc_query(
self,
sql: str,
max_rows: int | None = None,
timeout: int | None = None,
return_format: str | None = None
) -> Dict[str, Any]:
"""
Execute SQL query using ADBC (Arrow Flight SQL) protocol
Args:
sql: SQL statement to execute
max_rows: Maximum number of rows to return (uses config default if None)
timeout: Query timeout in seconds (uses config default if None)
return_format: Format for returned data ("arrow", "pandas", "dict", uses config default if None)
Returns:
Query results in specified format with metadata
"""
try:
start_time = time.time()
# Use configuration defaults if parameters not specified
adbc_config = self.connection_manager.config.adbc
max_rows = max_rows if max_rows is not None else adbc_config.default_max_rows
timeout = timeout if timeout is not None else adbc_config.default_timeout
return_format = return_format if return_format is not None else adbc_config.default_return_format
# Step 1: Check environment variables and port availability
port_check_result = await self._check_arrow_flight_ports()
if not port_check_result["success"]:
return port_check_result
# Step 2: Import required ADBC modules
import_result = await self._import_adbc_modules()
if not import_result["success"]:
return import_result
# Step 3: Create ADBC connection
connection_result = await self._create_adbc_connection()
if not connection_result["success"]:
return connection_result
# Step 4: Execute query using ADBC
query_result = await self._execute_query_with_adbc(
sql, max_rows, timeout, return_format
)
execution_time = time.time() - start_time
if query_result["success"]:
query_result["execution_time"] = round(execution_time, 3)
query_result["protocol"] = "ADBC_Arrow_Flight_SQL"
query_result["timestamp"] = datetime.now().isoformat()
return query_result
except Exception as e:
logger.error(f"ADBC query execution failed: {str(e)}")
return {
"success": False,
"error": f"ADBC query execution failed: {str(e)}",
"error_type": "execution_error",
"timestamp": datetime.now().isoformat()
}
async def _check_arrow_flight_ports(self) -> Dict[str, Any]:
"""Check Arrow Flight SQL port configuration and availability"""
try:
# Check environment variables
fe_port = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
be_port = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
if not fe_port:
return {
"success": False,
"error": "Missing environment variable FE_ARROW_FLIGHT_SQL_PORT, please configure Arrow Flight SQL FE port in .env file",
"error_type": "missing_fe_port_config"
}
if not be_port:
return {
"success": False,
"error": "Missing environment variable BE_ARROW_FLIGHT_SQL_PORT, please configure Arrow Flight SQL BE port in .env file",
"error_type": "missing_be_port_config"
}
# Convert to integer and validate
try:
fe_port = int(fe_port)
be_port = int(be_port)
except ValueError:
return {
"success": False,
"error": "Invalid Arrow Flight SQL port configuration, please ensure FE_ARROW_FLIGHT_SQL_PORT and BE_ARROW_FLIGHT_SQL_PORT are valid numbers",
"error_type": "invalid_port_format"
}
# Get host address
db_config = self.connection_manager.config.database
fe_host = db_config.host
# Check FE Arrow Flight SQL port availability
fe_available = self._check_port_connectivity(fe_host, fe_port)
if not fe_available:
return {
"success": False,
"error": f"Cannot connect to FE Arrow Flight SQL port {fe_host}:{fe_port}, please check if service is running",
"error_type": "fe_port_unavailable",
"fe_host": fe_host,
"fe_port": fe_port
}
# Get BE host list
be_hosts = await self._get_be_hosts()
if not be_hosts:
return {
"success": False,
"error": "Cannot get BE node information, please check cluster status",
"error_type": "no_be_hosts"
}
# Check at least one BE Arrow Flight SQL port availability
be_available_count = 0
be_check_results = []
for be_host in be_hosts[:3]: # Check first 3 BE nodes
be_available = self._check_port_connectivity(be_host, be_port)
be_check_results.append({
"host": be_host,
"port": be_port,
"available": be_available
})
if be_available:
be_available_count += 1
if be_available_count == 0:
return {
"success": False,
"error": f"Cannot connect to any BE Arrow Flight SQL port (port: {be_port}), please check if BE services are running",
"error_type": "no_be_ports_available",
"be_check_results": be_check_results
}
return {
"success": True,
"fe_host": fe_host,
"fe_port": fe_port,
"be_port": be_port,
"be_hosts": be_hosts,
"be_available_count": be_available_count,
"be_check_results": be_check_results
}
except Exception as e:
logger.error(f"Arrow Flight port check failed: {str(e)}")
return {
"success": False,
"error": f"Arrow Flight port check failed: {str(e)}",
"error_type": "port_check_error"
}
def _check_port_connectivity(self, host: str, port: int, timeout: int | None = None) -> bool:
"""Check port connectivity"""
try:
# Use config timeout if not specified
if timeout is None:
timeout = self.connection_manager.config.adbc.connection_timeout
with socket.create_connection((host, port), timeout=timeout):
return True
except (socket.timeout, socket.error, OSError):
return False
async def _get_be_hosts(self) -> List[str]:
"""Get BE host list"""
try:
db_config = self.connection_manager.config.database
# Use configured BE hosts first
if db_config.be_hosts:
logger.info(f"Using configured BE hosts: {db_config.be_hosts}")
return db_config.be_hosts
# Get BE nodes via SHOW BACKENDS
logger.info("No BE hosts configured, getting BE node information via SHOW BACKENDS")
connection = await self.connection_manager.get_connection("query")
auth_context = get_auth_context()
result = await connection.execute("SHOW BACKENDS", auth_context=auth_context)
be_hosts = []
for row in result.data:
host = row.get("Host")
alive = row.get("Alive", "").lower()
if host and alive == "true":
be_hosts.append(host)
logger.info(f"Got {len(be_hosts)} active BE nodes from SHOW BACKENDS")
return be_hosts
except Exception as e:
logger.error(f"Failed to get BE hosts: {str(e)}")
return []
async def _import_adbc_modules(self) -> Dict[str, Any]:
"""Import ADBC related modules"""
try:
# Import ADBC Driver Manager
try:
import adbc_driver_manager
self.adbc_manager_module = adbc_driver_manager
except ImportError:
return {
"success": False,
"error": "Missing adbc_driver_manager module, please install: pip install adbc_driver_manager",
"error_type": "missing_adbc_manager"
}
# Import ADBC Flight SQL Driver
try:
import adbc_driver_flightsql.dbapi as flight_sql
self.flight_sql_module = flight_sql
except ImportError:
return {
"success": False,
"error": "Missing adbc_driver_flightsql module, please install: pip install adbc_driver_flightsql",
"error_type": "missing_flight_sql_driver"
}
return {
"success": True,
"adbc_manager_version": getattr(adbc_driver_manager, '__version__', 'unknown'),
"flight_sql_version": getattr(flight_sql, '__version__', 'unknown')
}
except Exception as e:
logger.error(f"ADBC module import failed: {str(e)}")
return {
"success": False,
"error": f"ADBC module import failed: {str(e)}",
"error_type": "import_error"
}
async def _create_adbc_connection(self) -> Dict[str, Any]:
"""Create ADBC connection"""
try:
db_config = self.connection_manager.config.database
fe_port = int(os.getenv("FE_ARROW_FLIGHT_SQL_PORT"))
# Build connection URI
uri = f"grpc://{db_config.host}:{fe_port}"
# Create database connection parameters
db_kwargs = {
self.adbc_manager_module.DatabaseOptions.USERNAME.value: db_config.user,
self.adbc_manager_module.DatabaseOptions.PASSWORD.value: db_config.password,
}
# Create connection
self.adbc_client = self.flight_sql_module.connect(
uri=uri,
db_kwargs=db_kwargs
)
return {
"success": True,
"uri": uri,
"connection_established": True
}
except Exception as e:
logger.error(f"Failed to create ADBC connection: {str(e)}")
return {
"success": False,
"error": f"Failed to create ADBC connection: {str(e)}",
"error_type": "connection_error"
}
async def _execute_query_with_adbc(
self,
sql: str,
max_rows: int,
timeout: int,
return_format: str
) -> Dict[str, Any]:
"""Execute query using ADBC"""
try:
if not self.adbc_client:
return {
"success": False,
"error": "ADBC connection not established",
"error_type": "no_connection"
}
# SECURITY FIX: Perform SQL security validation before executing
auth_context = get_auth_context()
if self.connection_manager.security_manager:
# Always perform security validation, even without auth_context
# Use a default context for basic SQL security checks
validation_result = await self.connection_manager.security_manager.validate_sql_security(sql, auth_context)
if not validation_result.is_valid:
return {
"success": False,
"error": f"SQL security validation failed: {validation_result.error_message}",
"error_type": "security_violation",
"risk_level": validation_result.risk_level
}
cursor = self.adbc_client.cursor()
start_time = time.time()
# Execute query
cursor.execute(sql)
# Get results based on return format
if return_format == "arrow":
# Return Arrow format
arrow_data = cursor.fetchallarrow()
# Limit rows
if len(arrow_data) > max_rows:
arrow_data = arrow_data.slice(0, max_rows)
# Convert Arrow data to serializable format
preview_df = arrow_data.to_pandas().head(10) if len(arrow_data) > 0 else None
result_data = {
"format": "arrow",
"num_rows": len(arrow_data),
"num_columns": len(arrow_data.schema),
"column_names": arrow_data.schema.names,
"column_types": [str(field.type) for field in arrow_data.schema],
"data_preview": _convert_dataframe_to_json_serializable(preview_df) if preview_df is not None else [],
"total_bytes": arrow_data.nbytes if hasattr(arrow_data, 'nbytes') else 0
}
elif return_format == "pandas":
# Return Pandas DataFrame
df = cursor.fetch_df()
# Limit rows
if len(df) > max_rows:
df = df.head(max_rows)
result_data = {
"format": "pandas",
"num_rows": len(df),
"num_columns": len(df.columns),
"column_names": df.columns.tolist(),
"column_types": df.dtypes.astype(str).tolist(),
"data": _convert_dataframe_to_json_serializable(df),
"memory_usage": int(df.memory_usage(deep=True).sum())
}
else: # return_format == "dict"
# Return dictionary format
arrow_data = cursor.fetchallarrow()
df = arrow_data.to_pandas()
# Limit rows
if len(df) > max_rows:
df = df.head(max_rows)
result_data = {
"format": "dict",
"num_rows": len(df),
"num_columns": len(df.columns),
"column_names": df.columns.tolist(),
"column_types": df.dtypes.astype(str).tolist(),
"data": _convert_dataframe_to_json_serializable(df)
}
execution_time = time.time() - start_time
cursor.close()
return {
"success": True,
"result": result_data,
"execution_time": round(execution_time, 3),
"sql": sql,
"max_rows_applied": len(result_data.get("data", [])) >= max_rows
}
except Exception as e:
logger.error(f"ADBC query execution failed: {str(e)}")
return {
"success": False,
"error": f"ADBC query execution failed: {str(e)}",
"error_type": "query_execution_error",
"sql": sql
}
async def get_adbc_connection_info(self) -> Dict[str, Any]:
"""Get ADBC connection information and status"""
try:
# Check port status
port_status = await self._check_arrow_flight_ports()
# Check module status
module_status = await self._import_adbc_modules()
# Get configuration information
db_config = self.connection_manager.config.database
fe_port = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
be_port = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
connection_info = {
"adbc_available": module_status["success"],
"ports_available": port_status["success"],
"configuration": {
"fe_host": db_config.host,
"fe_arrow_flight_port": fe_port,
"be_arrow_flight_port": be_port,
"user": db_config.user
},
"port_status": port_status,
"module_status": module_status,
"timestamp": datetime.now().isoformat()
}
if port_status["success"] and module_status["success"]:
connection_info["status"] = "ready"
connection_info["message"] = "ADBC Arrow Flight SQL connection ready"
else:
connection_info["status"] = "not_ready"
errors = []
if not port_status["success"]:
errors.append(port_status["error"])
if not module_status["success"]:
errors.append(module_status["error"])
connection_info["message"] = "; ".join(errors)
return connection_info
except Exception as e:
logger.error(f"Failed to get ADBC connection information: {str(e)}")
return {
"status": "error",
"error": f"Failed to get ADBC connection information: {str(e)}",
"timestamp": datetime.now().isoformat()
}
def __del__(self):
"""Cleanup resources"""
try:
if self.adbc_client:
self.adbc_client.close()
except:
pass

File diff suppressed because it is too large Load Diff

View File

@@ -32,6 +32,8 @@ try:
except ImportError: except ImportError:
load_dotenv = None load_dotenv = None
from .logger import get_logger
@dataclass @dataclass
class DatabaseConfig: class DatabaseConfig:
@@ -41,38 +43,105 @@ class DatabaseConfig:
port: int = 9030 port: int = 9030
user: str = "root" user: str = "root"
password: str = "" password: str = ""
database: str = "test" database: str = "information_schema"
charset: str = "utf8mb4" charset: str = "UTF8"
# FE HTTP API port for profile and other HTTP APIs
fe_http_port: int = 8030
# BE nodes configuration for external access
# If be_hosts is empty, will use "show backends" to get BE nodes
be_hosts: list[str] = field(default_factory=list)
be_webserver_port: int = 8040
# Arrow Flight SQL Configuration (Required for ADBC tools)
fe_arrow_flight_sql_port: int | None = None
be_arrow_flight_sql_port: int | None = None
# Connection pool configuration # Connection pool configuration
min_connections: int = 5 # Note: min_connections is fixed at 0 to avoid at_eof connection issues
# This prevents pre-creation of connections which can cause state problems
_min_connections: int = field(default=0, init=False) # Internal use only, always 0
max_connections: int = 20 max_connections: int = 20
connection_timeout: int = 30 connection_timeout: int = 30
health_check_interval: int = 60 health_check_interval: int = 60
max_connection_age: int = 3600 max_connection_age: int = 3600
@property
def min_connections(self) -> int:
"""Minimum connections is always 0 to prevent at_eof issues"""
return self._min_connections
@dataclass @dataclass
class SecurityConfig: class SecurityConfig:
"""Security configuration""" """Security configuration"""
# Authentication configuration # Independent authentication switches - any one enabled allows that method
auth_type: str = "token" # token, basic, oauth enable_token_auth: bool = False # Enable token-based authentication (default: disabled)
token_secret: str = "default_secret" enable_jwt_auth: bool = False # Enable JWT authentication (default: disabled)
enable_oauth_auth: bool = False # Enable OAuth 2.0/OIDC authentication (default: disabled)
# Legacy configuration (kept for backward compatibility)
auth_type: str = "token" # jwt, token, basic, oauth (deprecated: use individual switches)
token_secret: str = "default_secret" # Legacy token secret for backward compatibility
token_expiry: int = 3600 token_expiry: int = 3600
# Enhanced Token Authentication Configuration
token_file_path: str = "tokens.json" # Path to token configuration file
enable_token_expiry: bool = True # Enable token expiration
default_token_expiry_hours: int = 24 * 30 # Default expiry: 30 days
token_hash_algorithm: str = "sha256" # Token hashing algorithm: sha256, sha512
# Token Management Security (New in v0.6.0)
enable_http_token_management: bool = False # Enable HTTP token management endpoints (default: disabled for security)
token_management_admin_token: str = "" # Admin token for token management endpoints (required if HTTP management enabled)
token_management_allowed_ips: list[str] = field(default_factory=lambda: ["127.0.0.1", "::1", "localhost"]) # Allowed IPs for token management
require_admin_auth: bool = True # Require admin authentication for token management (default: true)
# JWT Configuration
jwt_algorithm: str = "RS256" # RS256, ES256, HS256
jwt_issuer: str = "doris-mcp-server"
jwt_audience: str = "doris-mcp-client"
jwt_private_key_path: str = ""
jwt_public_key_path: str = ""
jwt_secret_key: str = "" # Only used for HS256 algorithm
jwt_access_token_expiry: int = 3600 # 1 hour
jwt_refresh_token_expiry: int = 86400 # 24 hours
enable_token_refresh: bool = True
enable_token_revocation: bool = True
key_rotation_interval: int = 30 * 24 * 3600 # 30 days in seconds
# JWT Security Features
jwt_require_iat: bool = True # Require "issued at" claim
jwt_require_exp: bool = True # Require "expires at" claim
jwt_require_nbf: bool = False # Require "not before" claim
jwt_leeway: int = 10 # Clock skew tolerance in seconds
jwt_verify_signature: bool = True # Verify JWT signature
jwt_verify_audience: bool = True # Verify audience claim
jwt_verify_issuer: bool = True # Verify issuer claim
# SQL security configuration # SQL security configuration
enable_security_check: bool = True # Main switch: whether to enable SQL security check
blocked_keywords: list[str] = field( blocked_keywords: list[str] = field(
default_factory=lambda: [ default_factory=lambda: [
# DDL Operations (Data Definition Language)
"DROP", "DROP",
"DELETE",
"TRUNCATE",
"ALTER",
"CREATE", "CREATE",
"ALTER",
"TRUNCATE",
# DML Operations (Data Manipulation Language)
"DELETE",
"INSERT", "INSERT",
"UPDATE", "UPDATE",
# DCL Operations (Data Control Language)
"GRANT", "GRANT",
"REVOKE", "REVOKE",
# System Operations
"EXEC",
"EXECUTE",
"SHUTDOWN",
"KILL",
] ]
) )
max_query_complexity: int = 100 max_query_complexity: int = 100
@@ -85,6 +154,45 @@ class SecurityConfig:
enable_masking: bool = True enable_masking: bool = True
masking_rules: list[dict[str, Any]] = field(default_factory=list) masking_rules: list[dict[str, Any]] = field(default_factory=list)
# OAuth 2.0/OIDC Configuration
oauth_enabled: bool = False
oauth_provider: str = "" # 'google', 'microsoft', 'github', 'custom'
oauth_client_id: str = ""
oauth_client_secret: str = ""
oauth_redirect_uri: str = "http://localhost:3000/auth/callback"
# OIDC Discovery
oidc_discovery_url: str = "" # e.g., https://accounts.google.com/.well-known/openid_configuration
oauth_authorization_endpoint: str = ""
oauth_token_endpoint: str = ""
oauth_userinfo_endpoint: str = ""
oauth_jwks_uri: str = ""
# OAuth Scopes and Settings
oauth_scopes: list[str] = field(default_factory=list)
oauth_state_expiry: int = 600 # State parameter expiry in seconds (10 minutes)
oauth_pkce_enabled: bool = True # Enable PKCE for better security
oauth_nonce_enabled: bool = True # Enable nonce for OIDC
# User Mapping Configuration
oauth_user_id_claim: str = "sub" # JWT claim for user ID
oauth_email_claim: str = "email"
oauth_name_claim: str = "name"
oauth_roles_claim: str = "roles" # Custom claim for roles
oauth_default_roles: list[str] = field(default_factory=lambda: ["oauth_user"])
def __post_init__(self):
"""Initialize default OAuth scopes based on provider"""
if not self.oauth_scopes and self.oauth_provider:
if self.oauth_provider == "google":
self.oauth_scopes = ["openid", "email", "profile"]
elif self.oauth_provider == "microsoft":
self.oauth_scopes = ["openid", "profile", "email", "User.Read"]
elif self.oauth_provider == "github":
self.oauth_scopes = ["user:email", "read:user"]
else:
self.oauth_scopes = ["openid", "email", "profile"]
@dataclass @dataclass
class PerformanceConfig: class PerformanceConfig:
@@ -103,6 +211,52 @@ class PerformanceConfig:
connection_pool_size: int = 20 connection_pool_size: int = 20
idle_timeout: int = 1800 idle_timeout: int = 1800
# Response content size limit (characters)
max_response_content_size: int = 4096
@dataclass
class DataQualityConfig:
"""Data quality analysis configuration"""
# Column analysis configuration
max_columns_per_batch: int = 20 # Maximum columns to analyze in a single batch
default_sample_size: int = 100000 # Default sample size for analysis
# Sampling strategy configuration
small_table_threshold: int = 100000 # Tables smaller than this use full table analysis
medium_table_threshold: int = 1000000 # Tables smaller than this use simple LIMIT sampling
# Tables larger than medium_table_threshold use systematic sampling
# Performance optimization
enable_batch_analysis: bool = True # Enable batch analysis for multiple columns
batch_timeout: int = 300 # Timeout for batch analysis in seconds
# Accuracy vs Performance trade-off
enable_fast_mode: bool = False # Use approximate algorithms for faster results
fast_mode_sample_size: int = 10000 # Sample size for fast mode
# Statistical analysis configuration
enable_distribution_analysis: bool = True # Enable distribution analysis
histogram_bins: int = 20 # Number of bins for histogram analysis
percentile_levels: list[float] = field(default_factory=lambda: [0.25, 0.5, 0.75, 0.95, 0.99]) # Percentile levels to calculate
@dataclass
class ADBCConfig:
"""ADBC (Arrow Flight SQL) configuration"""
# Default query parameters
default_max_rows: int = 100000
default_timeout: int = 60
default_return_format: str = "arrow" # "arrow", "pandas", "dict"
# Connection timeout for ADBC
connection_timeout: int = 30
# Whether to enable ADBC tools
enabled: bool = True
@dataclass @dataclass
class LoggingConfig: class LoggingConfig:
@@ -118,6 +272,11 @@ class LoggingConfig:
enable_audit: bool = True enable_audit: bool = True
audit_file_path: str | None = None audit_file_path: str | None = None
# Log cleanup configuration
enable_cleanup: bool = True
max_age_days: int = 30
cleanup_interval_hours: int = 24
@dataclass @dataclass
class MonitoringConfig: class MonitoringConfig:
@@ -125,11 +284,11 @@ class MonitoringConfig:
# Metrics collection configuration # Metrics collection configuration
enable_metrics: bool = True enable_metrics: bool = True
metrics_port: int = 8081 metrics_port: int = 3001
metrics_path: str = "/metrics" metrics_path: str = "/metrics"
# Health check configuration # Health check configuration
health_check_port: int = 8082 health_check_port: int = 3002
health_check_path: str = "/health" health_check_path: str = "/health"
# Alert configuration # Alert configuration
@@ -143,15 +302,22 @@ class DorisConfig:
# Basic configuration # Basic configuration
server_name: str = "doris-mcp-server" server_name: str = "doris-mcp-server"
server_version: str = "1.0.0" server_version: str = "0.4.1"
server_port: int = 8080 server_host: str = "localhost"
server_port: int = 3000
transport: str = "stdio"
# Temporary files configuration
temp_files_dir: str = "tmp" # Temporary files directory for Explain and Profile outputs
# Sub-configuration modules # Sub-configuration modules
database: DatabaseConfig = field(default_factory=DatabaseConfig) database: DatabaseConfig = field(default_factory=DatabaseConfig)
security: SecurityConfig = field(default_factory=SecurityConfig) security: SecurityConfig = field(default_factory=SecurityConfig)
performance: PerformanceConfig = field(default_factory=PerformanceConfig) performance: PerformanceConfig = field(default_factory=PerformanceConfig)
data_quality: DataQualityConfig = field(default_factory=DataQualityConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig) logging: LoggingConfig = field(default_factory=LoggingConfig)
monitoring: MonitoringConfig = field(default_factory=MonitoringConfig) monitoring: MonitoringConfig = field(default_factory=MonitoringConfig)
adbc: ADBCConfig = field(default_factory=ADBCConfig)
# Custom configuration # Custom configuration
custom_config: dict[str, Any] = field(default_factory=dict) custom_config: dict[str, Any] = field(default_factory=dict)
@@ -181,6 +347,9 @@ class DorisConfig:
def from_env(cls, env_file: str | None = None) -> "DorisConfig": def from_env(cls, env_file: str | None = None) -> "DorisConfig":
"""Load configuration from environment variables """Load configuration from environment variables
The kv pairs in the. env file will be loaded as environment variables,
but the existing environment variables will not be overridden.
Args: Args:
env_file: .env file path, if None, search in the following order: env_file: .env file path, if None, search in the following order:
.env, .env.local, .env.production, .env.development .env, .env.local, .env.production, .env.development
@@ -199,7 +368,7 @@ class DorisConfig:
env_files = [".env", ".env.local", ".env.production", ".env.development"] env_files = [".env", ".env.local", ".env.production", ".env.development"]
for env_path in env_files: for env_path in env_files:
if Path(env_path).exists(): if Path(env_path).exists():
load_dotenv(env_path) load_dotenv(env_path, override=False)
logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_path}") logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_path}")
break break
else: else:
@@ -209,17 +378,45 @@ class DorisConfig:
config = cls() config = cls()
# Database configuration # Database configuration - handle empty strings properly
config.database.host = os.getenv("DORIS_HOST", config.database.host) doris_host = os.getenv("DORIS_HOST", "").strip()
config.database.port = int(os.getenv("DORIS_PORT", str(config.database.port))) config.database.host = doris_host if doris_host else config.database.host
config.database.user = os.getenv("DORIS_USER", config.database.user)
config.database.password = os.getenv("DORIS_PASSWORD", config.database.password) doris_port = os.getenv("DORIS_PORT", "").strip()
config.database.database = os.getenv("DORIS_DATABASE", config.database.database) if doris_port and doris_port.isdigit():
config.database.port = int(doris_port)
doris_user = os.getenv("DORIS_USER", "").strip()
config.database.user = doris_user if doris_user else config.database.user
doris_password = os.getenv("DORIS_PASSWORD", "")
config.database.password = doris_password if doris_password else config.database.password
doris_database = os.getenv("DORIS_DATABASE", "").strip()
config.database.database = doris_database if doris_database else config.database.database
doris_fe_http_port = os.getenv("DORIS_FE_HTTP_PORT", "").strip()
if doris_fe_http_port and doris_fe_http_port.isdigit():
config.database.fe_http_port = int(doris_fe_http_port)
# BE nodes configuration
be_hosts_env = os.getenv("DORIS_BE_HOSTS", "")
if be_hosts_env:
config.database.be_hosts = [host.strip() for host in be_hosts_env.split(",") if host.strip()]
be_webserver_port = os.getenv("DORIS_BE_WEBSERVER_PORT", "").strip()
if be_webserver_port and be_webserver_port.isdigit():
config.database.be_webserver_port = int(be_webserver_port)
# Arrow Flight SQL Configuration
fe_arrow_port_env = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
if fe_arrow_port_env:
config.database.fe_arrow_flight_sql_port = int(fe_arrow_port_env)
be_arrow_port_env = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
if be_arrow_port_env:
config.database.be_arrow_flight_sql_port = int(be_arrow_port_env)
# Connection pool configuration # Connection pool configuration
config.database.min_connections = int(
os.getenv("DORIS_MIN_CONNECTIONS", str(config.database.min_connections))
)
config.database.max_connections = int( config.database.max_connections = int(
os.getenv("DORIS_MAX_CONNECTIONS", str(config.database.max_connections)) os.getenv("DORIS_MAX_CONNECTIONS", str(config.database.max_connections))
) )
@@ -234,6 +431,10 @@ class DorisConfig:
) )
# Security configuration # Security configuration
# Independent authentication switches
config.security.enable_token_auth = os.getenv("ENABLE_TOKEN_AUTH", str(config.security.enable_token_auth)).lower() == "true"
config.security.enable_jwt_auth = os.getenv("ENABLE_JWT_AUTH", str(config.security.enable_jwt_auth)).lower() == "true"
config.security.enable_oauth_auth = os.getenv("ENABLE_OAUTH_AUTH", str(config.security.enable_oauth_auth)).lower() == "true"
config.security.auth_type = os.getenv("AUTH_TYPE", config.security.auth_type) config.security.auth_type = os.getenv("AUTH_TYPE", config.security.auth_type)
config.security.token_secret = os.getenv("TOKEN_SECRET", config.security.token_secret) config.security.token_secret = os.getenv("TOKEN_SECRET", config.security.token_secret)
config.security.token_expiry = int( config.security.token_expiry = int(
@@ -245,10 +446,51 @@ class DorisConfig:
config.security.max_query_complexity = int( config.security.max_query_complexity = int(
os.getenv("MAX_QUERY_COMPLEXITY", str(config.security.max_query_complexity)) os.getenv("MAX_QUERY_COMPLEXITY", str(config.security.max_query_complexity))
) )
config.security.enable_security_check = (
os.getenv("ENABLE_SECURITY_CHECK", str(config.security.enable_security_check).lower()).lower() == "true"
)
# Handle blocked keywords environment variable configuration
# Format: BLOCKED_KEYWORDS="DROP,DELETE,TRUNCATE,ALTER,CREATE,INSERT,UPDATE,GRANT,REVOKE"
blocked_keywords_env = os.getenv("BLOCKED_KEYWORDS", "")
if blocked_keywords_env:
# If environment variable is provided, use keywords list from environment variable
config.security.blocked_keywords = [
keyword.strip().upper()
for keyword in blocked_keywords_env.split(",")
if keyword.strip()
]
# If environment variable is empty, keep default configuration unchanged
config.security.enable_masking = ( config.security.enable_masking = (
os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true" os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true"
) )
# Enhanced Token Authentication configuration
config.security.token_file_path = os.getenv("TOKEN_FILE_PATH", config.security.token_file_path)
config.security.enable_token_expiry = (
os.getenv("ENABLE_TOKEN_EXPIRY", str(config.security.enable_token_expiry).lower()).lower() == "true"
)
config.security.default_token_expiry_hours = int(
os.getenv("DEFAULT_TOKEN_EXPIRY_HOURS", str(config.security.default_token_expiry_hours))
)
config.security.token_hash_algorithm = os.getenv("TOKEN_HASH_ALGORITHM", config.security.token_hash_algorithm)
# Token Management Security Configuration (New in v0.6.0)
config.security.enable_http_token_management = (
os.getenv("ENABLE_HTTP_TOKEN_MANAGEMENT", str(config.security.enable_http_token_management).lower()).lower() == "true"
)
config.security.token_management_admin_token = os.getenv("TOKEN_MANAGEMENT_ADMIN_TOKEN", config.security.token_management_admin_token)
# Parse allowed IPs from comma-separated string
allowed_ips_str = os.getenv("TOKEN_MANAGEMENT_ALLOWED_IPS", "")
if allowed_ips_str:
config.security.token_management_allowed_ips = [ip.strip() for ip in allowed_ips_str.split(",") if ip.strip()]
config.security.require_admin_auth = (
os.getenv("REQUIRE_ADMIN_AUTH", str(config.security.require_admin_auth).lower()).lower() == "true"
)
# Performance configuration # Performance configuration
config.performance.enable_query_cache = ( config.performance.enable_query_cache = (
os.getenv("ENABLE_QUERY_CACHE", "true").lower() == "true" os.getenv("ENABLE_QUERY_CACHE", "true").lower() == "true"
@@ -265,6 +507,9 @@ class DorisConfig:
config.performance.query_timeout = int( config.performance.query_timeout = int(
os.getenv("QUERY_TIMEOUT", str(config.performance.query_timeout)) os.getenv("QUERY_TIMEOUT", str(config.performance.query_timeout))
) )
config.performance.max_response_content_size = int(
os.getenv("MAX_RESPONSE_CONTENT_SIZE", str(config.performance.max_response_content_size))
)
# Logging configuration # Logging configuration
config.logging.level = os.getenv("LOG_LEVEL", config.logging.level) config.logging.level = os.getenv("LOG_LEVEL", config.logging.level)
@@ -273,6 +518,15 @@ class DorisConfig:
os.getenv("ENABLE_AUDIT", str(config.logging.enable_audit).lower()).lower() == "true" os.getenv("ENABLE_AUDIT", str(config.logging.enable_audit).lower()).lower() == "true"
) )
config.logging.audit_file_path = os.getenv("AUDIT_FILE_PATH", config.logging.audit_file_path) config.logging.audit_file_path = os.getenv("AUDIT_FILE_PATH", config.logging.audit_file_path)
config.logging.enable_cleanup = (
os.getenv("ENABLE_LOG_CLEANUP", str(config.logging.enable_cleanup).lower()).lower() == "true"
)
config.logging.max_age_days = int(
os.getenv("LOG_MAX_AGE_DAYS", str(config.logging.max_age_days))
)
config.logging.cleanup_interval_hours = int(
os.getenv("LOG_CLEANUP_INTERVAL_HOURS", str(config.logging.cleanup_interval_hours))
)
# Monitoring configuration # Monitoring configuration
config.monitoring.enable_metrics = ( config.monitoring.enable_metrics = (
@@ -289,10 +543,60 @@ class DorisConfig:
) )
config.monitoring.alert_webhook_url = os.getenv("ALERT_WEBHOOK_URL", config.monitoring.alert_webhook_url) config.monitoring.alert_webhook_url = os.getenv("ALERT_WEBHOOK_URL", config.monitoring.alert_webhook_url)
# ADBC configuration
config.adbc.default_max_rows = int(
os.getenv("ADBC_DEFAULT_MAX_ROWS", str(config.adbc.default_max_rows))
)
config.adbc.default_timeout = int(
os.getenv("ADBC_DEFAULT_TIMEOUT", str(config.adbc.default_timeout))
)
config.adbc.default_return_format = os.getenv("ADBC_DEFAULT_RETURN_FORMAT", config.adbc.default_return_format)
config.adbc.connection_timeout = int(
os.getenv("ADBC_CONNECTION_TIMEOUT", str(config.adbc.connection_timeout))
)
config.adbc.enabled = (
os.getenv("ADBC_ENABLED", str(config.adbc.enabled).lower()).lower() == "true"
)
# Data quality configuration
config.data_quality.max_columns_per_batch = int(
os.getenv("DATA_QUALITY_MAX_COLUMNS_PER_BATCH", str(config.data_quality.max_columns_per_batch))
)
config.data_quality.default_sample_size = int(
os.getenv("DATA_QUALITY_DEFAULT_SAMPLE_SIZE", str(config.data_quality.default_sample_size))
)
config.data_quality.small_table_threshold = int(
os.getenv("DATA_QUALITY_SMALL_TABLE_THRESHOLD", str(config.data_quality.small_table_threshold))
)
config.data_quality.medium_table_threshold = int(
os.getenv("DATA_QUALITY_MEDIUM_TABLE_THRESHOLD", str(config.data_quality.medium_table_threshold))
)
config.data_quality.enable_batch_analysis = (
os.getenv("DATA_QUALITY_ENABLE_BATCH_ANALYSIS", str(config.data_quality.enable_batch_analysis).lower()).lower() == "true"
)
config.data_quality.batch_timeout = int(
os.getenv("DATA_QUALITY_BATCH_TIMEOUT", str(config.data_quality.batch_timeout))
)
config.data_quality.enable_fast_mode = (
os.getenv("DATA_QUALITY_ENABLE_FAST_MODE", str(config.data_quality.enable_fast_mode).lower()).lower() == "true"
)
config.data_quality.fast_mode_sample_size = int(
os.getenv("DATA_QUALITY_FAST_MODE_SAMPLE_SIZE", str(config.data_quality.fast_mode_sample_size))
)
config.data_quality.enable_distribution_analysis = (
os.getenv("DATA_QUALITY_ENABLE_DISTRIBUTION_ANALYSIS", str(config.data_quality.enable_distribution_analysis).lower()).lower() == "true"
)
config.data_quality.histogram_bins = int(
os.getenv("DATA_QUALITY_HISTOGRAM_BINS", str(config.data_quality.histogram_bins))
)
# Server configuration # Server configuration
config.server_name = os.getenv("SERVER_NAME", config.server_name) config.server_name = os.getenv("SERVER_NAME", config.server_name)
config.server_version = os.getenv("SERVER_VERSION", config.server_version) config.server_version = os.getenv("SERVER_VERSION", config.server_version)
config.server_port = int(os.getenv("SERVER_PORT", str(config.server_port))) server_port = os.getenv("SERVER_PORT", "").strip()
if server_port and server_port.isdigit():
config.server_port = int(server_port)
config.temp_files_dir = os.getenv("TEMP_FILES_DIR", config.temp_files_dir)
return config return config
@@ -302,7 +606,7 @@ class DorisConfig:
config = cls() config = cls()
# Update basic configuration # Update basic configuration
for key in ["server_name", "server_version", "server_port"]: for key in ["server_name", "server_version", "server_port", "temp_files_dir"]:
if key in config_data: if key in config_data:
setattr(config, key, config_data[key]) setattr(config, key, config_data[key])
@@ -327,6 +631,13 @@ class DorisConfig:
if hasattr(config.performance, key): if hasattr(config.performance, key):
setattr(config.performance, key, value) setattr(config.performance, key, value)
# Update data quality configuration
if "data_quality" in config_data:
dq_config = config_data["data_quality"]
for key, value in dq_config.items():
if hasattr(config.data_quality, key):
setattr(config.data_quality, key, value)
# Update logging configuration # Update logging configuration
if "logging" in config_data: if "logging" in config_data:
log_config = config_data["logging"] log_config = config_data["logging"]
@@ -341,6 +652,13 @@ class DorisConfig:
if hasattr(config.monitoring, key): if hasattr(config.monitoring, key):
setattr(config.monitoring, key, value) setattr(config.monitoring, key, value)
# Update ADBC configuration
if "adbc" in config_data:
adbc_config = config_data["adbc"]
for key, value in adbc_config.items():
if hasattr(config.adbc, key):
setattr(config.adbc, key, value)
# Custom configuration # Custom configuration
config.custom_config = config_data.get("custom", {}) config.custom_config = config_data.get("custom", {})
@@ -352,6 +670,7 @@ class DorisConfig:
"server_name": self.server_name, "server_name": self.server_name,
"server_version": self.server_version, "server_version": self.server_version,
"server_port": self.server_port, "server_port": self.server_port,
"temp_files_dir": self.temp_files_dir,
"database": { "database": {
"host": self.database.host, "host": self.database.host,
"port": self.database.port, "port": self.database.port,
@@ -359,7 +678,12 @@ class DorisConfig:
"password": "***", # Hide password "password": "***", # Hide password
"database": self.database.database, "database": self.database.database,
"charset": self.database.charset, "charset": self.database.charset,
"min_connections": self.database.min_connections, "fe_http_port": self.database.fe_http_port,
"be_hosts": self.database.be_hosts,
"be_webserver_port": self.database.be_webserver_port,
"fe_arrow_flight_sql_port": self.database.fe_arrow_flight_sql_port,
"be_arrow_flight_sql_port": self.database.be_arrow_flight_sql_port,
"min_connections": self.database.min_connections, # Always 0, shown for reference
"max_connections": self.database.max_connections, "max_connections": self.database.max_connections,
"connection_timeout": self.database.connection_timeout, "connection_timeout": self.database.connection_timeout,
"health_check_interval": self.database.health_check_interval, "health_check_interval": self.database.health_check_interval,
@@ -369,6 +693,7 @@ class DorisConfig:
"auth_type": self.security.auth_type, "auth_type": self.security.auth_type,
"token_secret": "***", # Hide secret key "token_secret": "***", # Hide secret key
"token_expiry": self.security.token_expiry, "token_expiry": self.security.token_expiry,
"enable_security_check": self.security.enable_security_check,
"blocked_keywords": self.security.blocked_keywords, "blocked_keywords": self.security.blocked_keywords,
"max_query_complexity": self.security.max_query_complexity, "max_query_complexity": self.security.max_query_complexity,
"max_result_rows": self.security.max_result_rows, "max_result_rows": self.security.max_result_rows,
@@ -384,6 +709,20 @@ class DorisConfig:
"query_timeout": self.performance.query_timeout, "query_timeout": self.performance.query_timeout,
"connection_pool_size": self.performance.connection_pool_size, "connection_pool_size": self.performance.connection_pool_size,
"idle_timeout": self.performance.idle_timeout, "idle_timeout": self.performance.idle_timeout,
"max_response_content_size": self.performance.max_response_content_size,
},
"data_quality": {
"max_columns_per_batch": self.data_quality.max_columns_per_batch,
"default_sample_size": self.data_quality.default_sample_size,
"small_table_threshold": self.data_quality.small_table_threshold,
"medium_table_threshold": self.data_quality.medium_table_threshold,
"enable_batch_analysis": self.data_quality.enable_batch_analysis,
"batch_timeout": self.data_quality.batch_timeout,
"enable_fast_mode": self.data_quality.enable_fast_mode,
"fast_mode_sample_size": self.data_quality.fast_mode_sample_size,
"enable_distribution_analysis": self.data_quality.enable_distribution_analysis,
"histogram_bins": self.data_quality.histogram_bins,
"percentile_levels": self.data_quality.percentile_levels,
}, },
"logging": { "logging": {
"level": self.logging.level, "level": self.logging.level,
@@ -393,6 +732,9 @@ class DorisConfig:
"backup_count": self.logging.backup_count, "backup_count": self.logging.backup_count,
"enable_audit": self.logging.enable_audit, "enable_audit": self.logging.enable_audit,
"audit_file_path": self.logging.audit_file_path, "audit_file_path": self.logging.audit_file_path,
"enable_cleanup": self.logging.enable_cleanup,
"max_age_days": self.logging.max_age_days,
"cleanup_interval_hours": self.logging.cleanup_interval_hours,
}, },
"monitoring": { "monitoring": {
"enable_metrics": self.monitoring.enable_metrics, "enable_metrics": self.monitoring.enable_metrics,
@@ -403,6 +745,13 @@ class DorisConfig:
"enable_alerts": self.monitoring.enable_alerts, "enable_alerts": self.monitoring.enable_alerts,
"alert_webhook_url": self.monitoring.alert_webhook_url, "alert_webhook_url": self.monitoring.alert_webhook_url,
}, },
"adbc": {
"default_max_rows": self.adbc.default_max_rows,
"default_timeout": self.adbc.default_timeout,
"default_return_format": self.adbc.default_return_format,
"connection_timeout": self.adbc.connection_timeout,
"enabled": self.adbc.enabled,
},
"custom": self.custom_config, "custom": self.custom_config,
} }
@@ -435,11 +784,8 @@ class DorisConfig:
if not self.database.user: if not self.database.user:
errors.append("Database username cannot be empty") errors.append("Database username cannot be empty")
if self.database.min_connections <= 0: if self.database.max_connections <= 0:
errors.append("Minimum connections must be greater than 0") errors.append("Maximum connections must be greater than 0")
if self.database.max_connections <= self.database.min_connections:
errors.append("Maximum connections must be greater than minimum connections")
# Validate security configuration # Validate security configuration
if self.security.auth_type not in ["token", "basic", "oauth"]: if self.security.auth_type not in ["token", "basic", "oauth"]:
@@ -464,6 +810,31 @@ class DorisConfig:
if self.performance.query_timeout <= 0: if self.performance.query_timeout <= 0:
errors.append("Query timeout must be greater than 0") errors.append("Query timeout must be greater than 0")
# Validate data quality configuration
if self.data_quality.max_columns_per_batch <= 0:
errors.append("Max columns per batch must be greater than 0")
if self.data_quality.default_sample_size <= 0:
errors.append("Default sample size must be greater than 0")
if self.data_quality.small_table_threshold <= 0:
errors.append("Small table threshold must be greater than 0")
if self.data_quality.medium_table_threshold <= 0:
errors.append("Medium table threshold must be greater than 0")
if self.data_quality.small_table_threshold >= self.data_quality.medium_table_threshold:
errors.append("Small table threshold must be less than medium table threshold")
if self.data_quality.batch_timeout <= 0:
errors.append("Batch timeout must be greater than 0")
if self.data_quality.fast_mode_sample_size <= 0:
errors.append("Fast mode sample size must be greater than 0")
if self.data_quality.histogram_bins <= 0:
errors.append("Histogram bins must be greater than 0")
# Validate logging configuration # Validate logging configuration
if self.logging.level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]: if self.logging.level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
errors.append("Log level must be one of DEBUG, INFO, WARNING, ERROR, or CRITICAL") errors.append("Log level must be one of DEBUG, INFO, WARNING, ERROR, or CRITICAL")
@@ -474,6 +845,12 @@ class DorisConfig:
if self.logging.backup_count < 0: if self.logging.backup_count < 0:
errors.append("Log backup count cannot be negative") errors.append("Log backup count cannot be negative")
if self.logging.max_age_days <= 0:
errors.append("Log max age days must be greater than 0")
if self.logging.cleanup_interval_hours <= 0:
errors.append("Log cleanup interval hours must be greater than 0")
# Validate monitoring configuration # Validate monitoring configuration
if not (1 <= self.monitoring.metrics_port <= 65535): if not (1 <= self.monitoring.metrics_port <= 65535):
errors.append("Monitoring port must be in the range 1-65535") errors.append("Monitoring port must be in the range 1-65535")
@@ -481,6 +858,19 @@ class DorisConfig:
if not (1 <= self.monitoring.health_check_port <= 65535): if not (1 <= self.monitoring.health_check_port <= 65535):
errors.append("Health check port must be in the range 1-65535") errors.append("Health check port must be in the range 1-65535")
# Validate ADBC configuration
if self.adbc.default_max_rows <= 0:
errors.append("ADBC default max rows must be greater than 0")
if self.adbc.default_timeout <= 0:
errors.append("ADBC default timeout must be greater than 0")
if self.adbc.default_return_format not in ["arrow", "pandas", "dict"]:
errors.append("ADBC default return format must be one of arrow, pandas, or dict")
if self.adbc.connection_timeout <= 0:
errors.append("ADBC connection timeout must be greater than 0")
return errors return errors
def get_connection_string(self) -> str: def get_connection_string(self) -> str:
@@ -492,7 +882,7 @@ class DorisConfig:
return { return {
"server": f"{self.server_name} v{self.server_version}", "server": f"{self.server_name} v{self.server_version}",
"database": f"{self.database.host}:{self.database.port}/{self.database.database}", "database": f"{self.database.host}:{self.database.port}/{self.database.database}",
"connection_pool": f"{self.database.min_connections}-{self.database.max_connections}", "connection_pool": f"0-{self.database.max_connections} (min fixed at 0 for stability)",
"security": { "security": {
"auth_type": self.security.auth_type, "auth_type": self.security.auth_type,
"masking_enabled": self.security.enable_masking, "masking_enabled": self.security.enable_masking,
@@ -518,56 +908,50 @@ class ConfigManager:
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
def setup_logging(self): def setup_logging(self):
"""Setup logging configuration""" """Setup logging configuration using enhanced logger"""
# Configure root logger from .logger import setup_logging, get_logger
root_logger = logging.getLogger() import sys
root_logger.setLevel(getattr(logging, self.config.logging.level.upper()))
# Clear existing handlers # Determine log directory
for handler in root_logger.handlers[:]: log_dir = "logs"
root_logger.removeHandler(handler)
# Create formatter
formatter = logging.Formatter(self.config.logging.format)
# Console handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
root_logger.addHandler(console_handler)
# File handler (if configured)
if self.config.logging.file_path: if self.config.logging.file_path:
try: # Extract directory from file path if provided
from logging.handlers import RotatingFileHandler from pathlib import Path
log_dir = str(Path(self.config.logging.file_path).parent)
file_handler = RotatingFileHandler( # Detect if we're in stdio mode by checking if this is likely MCP stdio communication
self.config.logging.file_path, # In stdio mode, we shouldn't output to console as it interferes with JSON protocol
maxBytes=self.config.logging.max_file_size, is_stdio_mode = (
backupCount=self.config.logging.backup_count, self.config.transport == "stdio" or
encoding="utf-8", "--transport" in sys.argv and "stdio" in sys.argv or
not sys.stdout.isatty() # Not a terminal (likely piped/redirected)
) )
file_handler.setFormatter(formatter)
root_logger.addHandler(file_handler)
except Exception as e:
self.logger.warning(f"Failed to setup file logging: {e}")
# Audit log handler (if configured) # Setup enhanced logging with cleanup functionality
if self.config.logging.enable_audit and self.config.logging.audit_file_path: setup_logging(
try: level=self.config.logging.level,
from logging.handlers import RotatingFileHandler log_dir=log_dir,
enable_console=not is_stdio_mode, # Disable console logging in stdio mode
audit_logger = logging.getLogger("audit") enable_file=True,
audit_handler = RotatingFileHandler( enable_audit=self.config.logging.enable_audit,
self.config.logging.audit_file_path, audit_file=self.config.logging.audit_file_path,
maxBytes=self.config.logging.max_file_size, max_file_size=self.config.logging.max_file_size,
backupCount=self.config.logging.backup_count, backup_count=self.config.logging.backup_count,
encoding="utf-8", enable_cleanup=self.config.logging.enable_cleanup,
max_age_days=self.config.logging.max_age_days,
cleanup_interval_hours=self.config.logging.cleanup_interval_hours
) )
audit_handler.setFormatter(formatter)
audit_logger.addHandler(audit_handler) # Update logger to use new system
audit_logger.setLevel(logging.INFO) self.logger = get_logger(__name__)
except Exception as e:
self.logger.warning(f"Failed to setup audit logging: {e}") self.logger.info("Enhanced logging system with cleanup initialized successfully")
self.logger.info(f"Log directory: {log_dir}")
self.logger.info(f"Log level: {self.config.logging.level}")
self.logger.info(f"Audit logging: {'Enabled' if self.config.logging.enable_audit else 'Disabled'}")
self.logger.info(f"Log cleanup: {'Enabled' if self.config.logging.enable_cleanup else 'Disabled'}")
if self.config.logging.enable_cleanup:
self.logger.info(f"Cleanup config: Max age {self.config.logging.max_age_days} days, interval {self.config.logging.cleanup_interval_hours}h")
def validate_config(self) -> bool: def validate_config(self) -> bool:
"""Validate configuration""" """Validate configuration"""

View File

@@ -0,0 +1,771 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Data Exploration Tools Module
Provides table data distribution analysis and exploration capabilities
"""
import time
import math
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from .db import DorisConnectionManager
from .logger import get_logger
from .sql_security_utils import (
SQLSecurityError,
validate_identifier,
quote_identifier,
build_table_reference,
get_auth_context
)
logger = get_logger(__name__)
class DataExplorationTools:
"""Data exploration tools for table distribution analysis"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
logger.info("DataExplorationTools initialized")
# ==================== Private Helper Methods ====================
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
"""Build full table name with catalog and database using three-part naming convention"""
# SECURITY FIX: Use build_table_reference for safe identifier handling
effective_catalog = catalog_name if catalog_name else "internal"
if db_name:
return build_table_reference(table_name, db_name, effective_catalog)
else:
return build_table_reference(table_name, catalog_name=effective_catalog)
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
"""Get basic table information including row count"""
try:
# SECURITY FIX: Get auth_context for security validation
# table_name should already be validated by _build_full_table_name
auth_context = get_auth_context()
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
result = await connection.execute(count_sql, auth_context=auth_context)
if result.data:
return {"row_count": result.data[0]["row_count"]}
return None
except SQLSecurityError as e:
logger.warning(f"Security validation failed for table {table_name}: {str(e)}")
return {"row_count": 0}
except Exception as e:
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
return {"row_count": 0}
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
"""Get detailed column information"""
try:
# SECURITY FIX: Validate identifiers and use parameterized query
auth_context = get_auth_context()
try:
validate_identifier(table_name, "table name")
if db_name:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
# Build parameterized query
params = [table_name]
where_conditions = ["table_name = %s"]
if db_name:
where_conditions.append("table_schema = %s")
params.append(db_name)
else:
where_conditions.append("table_schema = DATABASE()")
columns_sql = f"""
SELECT
column_name,
data_type,
is_nullable,
column_comment,
ordinal_position
FROM information_schema.columns
WHERE {' AND '.join(where_conditions)}
ORDER BY ordinal_position
"""
result = await connection.execute(columns_sql, params=tuple(params), auth_context=auth_context)
return result.data if result.data else []
except SQLSecurityError as e:
logger.warning(f"Security validation failed: {str(e)}")
return []
except Exception as e:
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
return []
async def _determine_sampling_strategy(self, connection, table_name: str, total_rows: int, sample_size: int) -> Dict[str, Any]:
"""Determine optimal sampling strategy based on table size"""
if total_rows <= sample_size:
# Use all data if table is small enough
return {
"total_rows": total_rows,
"sample_size": total_rows,
"sampling_method": "full_scan",
"sampling_ratio": 1.0,
"use_sampling": False,
"sample_table_expression": table_name
}
else:
# Use random sampling for large tables
sampling_ratio = sample_size / total_rows
return {
"total_rows": total_rows,
"sample_size": sample_size,
"sampling_method": "random_sample",
"sampling_ratio": round(sampling_ratio, 4),
"use_sampling": True,
"sample_table_expression": f"(SELECT * FROM {table_name} ORDER BY RAND() LIMIT {sample_size}) as sample_table"
}
def _select_analysis_columns(self, columns_info: List[Dict], include_all: bool) -> List[Dict]:
"""Select columns for analysis based on strategy"""
if include_all:
return columns_info
# If not analyzing all columns, prioritize key columns
priority_keywords = ['id', 'key', 'code', 'status', 'type', 'amount', 'count', 'date', 'time']
priority_columns = []
other_columns = []
for col in columns_info:
col_name_lower = col["column_name"].lower()
if any(keyword in col_name_lower for keyword in priority_keywords):
priority_columns.append(col)
else:
other_columns.append(col)
# Return priority columns plus first 10 other columns
return priority_columns + other_columns[:10]
def _is_numeric_type(self, data_type: str) -> bool:
"""Check if column type is numeric"""
numeric_types = [
'tinyint', 'smallint', 'int', 'bigint', 'largeint',
'float', 'double', 'decimal', 'numeric'
]
return any(num_type in data_type.lower() for num_type in numeric_types)
def _is_categorical_type(self, data_type: str) -> bool:
"""Check if column type is categorical"""
categorical_types = ['varchar', 'char', 'string', 'text', 'enum']
return any(cat_type in data_type.lower() for cat_type in categorical_types)
def _is_temporal_type(self, data_type: str) -> bool:
"""Check if column type is temporal"""
temporal_types = ['date', 'datetime', 'timestamp', 'time']
return any(temp_type in data_type.lower() for temp_type in temporal_types)
async def _analyze_numeric_distributions(self, connection, table_name: str, numeric_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
"""Analyze distribution patterns for numeric columns"""
numeric_analysis = {}
for column in numeric_columns:
col_name = column["column_name"]
try:
# Basic statistics
table_expr = sampling_info.get("sample_table_expression", table_name)
stats_sql = f"""
SELECT
COUNT({col_name}) as count,
MIN({col_name}) as min_value,
MAX({col_name}) as max_value,
AVG({col_name}) as mean_value,
STDDEV({col_name}) as std_dev
FROM {table_expr}
WHERE {col_name} IS NOT NULL
"""
auth_context = get_auth_context()
stats_result = await connection.execute(stats_sql, auth_context=auth_context)
if stats_result.data and stats_result.data[0]["count"] > 0:
stats = stats_result.data[0]
# Percentiles calculation
percentiles = await self._calculate_percentiles(connection, table_name, col_name, sampling_info)
# Outlier detection
outliers = await self._detect_numeric_outliers(connection, table_name, col_name, percentiles, sampling_info)
# Distribution shape analysis
distribution_shape = await self._analyze_distribution_shape(
connection, table_name, col_name, stats, percentiles, sampling_info
)
numeric_analysis[col_name] = {
"data_type": column["data_type"],
"statistics": {
"count": stats["count"],
"mean": round(float(stats["mean_value"]), 4) if stats["mean_value"] else None,
"std": round(float(stats["std_dev"]), 4) if stats["std_dev"] else None,
"min": float(stats["min_value"]) if stats["min_value"] else None,
"max": float(stats["max_value"]) if stats["max_value"] else None,
**percentiles
},
"distribution_shape": distribution_shape,
"outliers": outliers
}
except Exception as e:
logger.warning(f"Failed to analyze numeric column {col_name}: {str(e)}")
numeric_analysis[col_name] = {"error": str(e)}
return numeric_analysis
async def _calculate_percentiles(self, connection, table_name: str, col_name: str, sampling_info: Dict) -> Dict[str, float]:
"""Calculate percentiles for numeric column"""
try:
table_expr = sampling_info.get("sample_table_expression", table_name)
percentile_sql = f"""
SELECT
PERCENTILE({col_name}, 0.25) as p25,
PERCENTILE({col_name}, 0.50) as p50,
PERCENTILE({col_name}, 0.75) as p75,
PERCENTILE({col_name}, 0.90) as p90,
PERCENTILE({col_name}, 0.95) as p95,
PERCENTILE({col_name}, 0.99) as p99
FROM {table_expr}
WHERE {col_name} IS NOT NULL
"""
auth_context = get_auth_context()
result = await connection.execute(percentile_sql, auth_context=auth_context)
if result.data:
data = result.data[0]
return {
"25%": round(float(data["p25"]), 4) if data["p25"] else None,
"50%": round(float(data["p50"]), 4) if data["p50"] else None,
"75%": round(float(data["p75"]), 4) if data["p75"] else None,
"90%": round(float(data["p90"]), 4) if data["p90"] else None,
"95%": round(float(data["p95"]), 4) if data["p95"] else None,
"99%": round(float(data["p99"]), 4) if data["p99"] else None
}
except Exception as e:
logger.warning(f"Failed to calculate percentiles for {col_name}: {str(e)}")
return {}
async def _detect_numeric_outliers(self, connection, table_name: str, col_name: str, percentiles: Dict, sampling_info: Dict) -> Dict[str, Any]:
"""Detect outliers using IQR method"""
try:
if "25%" not in percentiles or "75%" not in percentiles:
return {"outlier_count": 0, "outlier_rate": 0.0}
q1 = percentiles["25%"]
q3 = percentiles["75%"]
iqr = q3 - q1
lower_bound = q1 - 1.5 * iqr
upper_bound = q3 + 1.5 * iqr
table_expr = sampling_info.get("sample_table_expression", table_name)
outlier_sql = f"""
SELECT
COUNT(*) as total_count,
SUM(CASE WHEN {col_name} < {lower_bound} OR {col_name} > {upper_bound} THEN 1 ELSE 0 END) as outlier_count
FROM {table_expr}
WHERE {col_name} IS NOT NULL
"""
auth_context = get_auth_context()
result = await connection.execute(outlier_sql, auth_context=auth_context)
if result.data:
data = result.data[0]
total_count = data["total_count"]
outlier_count = data["outlier_count"]
outlier_rate = outlier_count / total_count if total_count > 0 else 0
return {
"outlier_count": outlier_count,
"outlier_rate": round(outlier_rate, 4),
"outlier_threshold_lower": round(lower_bound, 4),
"outlier_threshold_upper": round(upper_bound, 4),
"iqr": round(iqr, 4)
}
except Exception as e:
logger.warning(f"Failed to detect outliers for {col_name}: {str(e)}")
return {"outlier_count": 0, "outlier_rate": 0.0}
async def _analyze_distribution_shape(self, connection, table_name: str, col_name: str, stats: Dict, percentiles: Dict, sampling_info: Dict) -> Dict[str, Any]:
"""Analyze the shape of data distribution"""
try:
mean = stats.get("mean_value", 0)
median = percentiles.get("50%", 0)
if mean is None or median is None:
return {"distribution_type": "unknown"}
# Calculate skewness indicator
if abs(mean - median) < 0.01:
skew_indicator = "symmetric"
elif mean > median:
skew_indicator = "right_skewed"
else:
skew_indicator = "left_skewed"
# Estimate kurtosis based on percentile spread
if "25%" in percentiles and "75%" in percentiles:
iqr = percentiles["75%"] - percentiles["25%"]
range_90 = percentiles.get("90%", percentiles["75%"]) - percentiles.get("10%", percentiles["25%"])
if iqr > 0:
kurtosis_indicator = "normal" if 2.5 <= range_90/iqr <= 3.5 else ("heavy_tailed" if range_90/iqr > 3.5 else "light_tailed")
else:
kurtosis_indicator = "unknown"
else:
kurtosis_indicator = "unknown"
return {
"skewness_indicator": skew_indicator,
"kurtosis_indicator": kurtosis_indicator,
"distribution_type": self._classify_distribution_type(skew_indicator, kurtosis_indicator),
"mean_median_ratio": round(mean / median, 4) if median != 0 else None
}
except Exception as e:
logger.warning(f"Failed to analyze distribution shape for {col_name}: {str(e)}")
return {"distribution_type": "unknown"}
def _classify_distribution_type(self, skew: str, kurtosis: str) -> str:
"""Classify distribution type based on skewness and kurtosis"""
if skew == "symmetric" and kurtosis == "normal":
return "approximately_normal"
elif skew == "right_skewed":
return "right_skewed"
elif skew == "left_skewed":
return "left_skewed"
elif kurtosis == "heavy_tailed":
return "heavy_tailed"
else:
return "non_normal"
async def _analyze_categorical_distributions(self, connection, table_name: str, categorical_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
"""Analyze distribution patterns for categorical columns"""
categorical_analysis = {}
for column in categorical_columns:
col_name = column["column_name"]
try:
# Basic cardinality and distribution
cardinality_sql = f"""
SELECT
COUNT(DISTINCT {col_name}) as cardinality,
COUNT({col_name}) as non_null_count
FROM {table_name}
WHERE {col_name} IS NOT NULL
{sampling_info.get('sample_query_suffix', '')}
"""
auth_context = get_auth_context()
cardinality_result = await connection.execute(cardinality_sql, auth_context=auth_context)
if cardinality_result.data:
cardinality_data = cardinality_result.data[0]
cardinality = cardinality_data["cardinality"]
non_null_count = cardinality_data["non_null_count"]
# Value distribution (top values)
value_distribution = await self._get_categorical_value_distribution(
connection, table_name, col_name, sampling_info, non_null_count
)
# Calculate entropy and concentration
entropy = self._calculate_entropy(value_distribution)
concentration_ratio = value_distribution[0]["percentage"] if value_distribution else 0
categorical_analysis[col_name] = {
"data_type": column["data_type"],
"cardinality": cardinality,
"non_null_count": non_null_count,
"value_distribution": value_distribution,
"entropy": round(entropy, 3),
"concentration_ratio": round(concentration_ratio, 4),
"diversity_score": round(cardinality / non_null_count, 4) if non_null_count > 0 else 0
}
except Exception as e:
logger.warning(f"Failed to analyze categorical column {col_name}: {str(e)}")
categorical_analysis[col_name] = {"error": str(e)}
return categorical_analysis
async def _get_categorical_value_distribution(self, connection, table_name: str, col_name: str, sampling_info: Dict, total_count: int) -> List[Dict]:
"""Get value distribution for categorical column"""
try:
# Use sample table expression if sampling is enabled
table_expr = sampling_info.get("sample_table_expression", table_name)
distribution_sql = f"""
SELECT
{col_name} as value,
COUNT(*) as count
FROM {table_expr}
WHERE {col_name} IS NOT NULL
GROUP BY {col_name}
ORDER BY COUNT(*) DESC
LIMIT 20
"""
auth_context = get_auth_context()
result = await connection.execute(distribution_sql, auth_context=auth_context)
if result.data:
distribution = []
for row in result.data:
count = row["count"]
percentage = count / total_count if total_count > 0 else 0
distribution.append({
"value": str(row["value"]),
"count": count,
"percentage": round(percentage, 4)
})
return distribution
except Exception as e:
logger.warning(f"Failed to get value distribution for {col_name}: {str(e)}")
return []
def _calculate_entropy(self, value_distribution: List[Dict]) -> float:
"""Calculate Shannon entropy for categorical distribution"""
if not value_distribution:
return 0.0
entropy = 0.0
for item in value_distribution:
p = item["percentage"]
if p > 0:
entropy -= p * math.log2(p)
return entropy
async def _analyze_temporal_distributions(self, connection, table_name: str, temporal_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
"""Analyze distribution patterns for temporal columns"""
temporal_analysis = {}
for column in temporal_columns:
col_name = column["column_name"]
try:
# Date range analysis
table_expr = sampling_info.get("sample_table_expression", table_name)
range_sql = f"""
SELECT
MIN({col_name}) as earliest,
MAX({col_name}) as latest,
COUNT({col_name}) as non_null_count
FROM {table_expr}
WHERE {col_name} IS NOT NULL
"""
auth_context = get_auth_context()
range_result = await connection.execute(range_sql, auth_context=auth_context)
if range_result.data and range_result.data[0]["non_null_count"] > 0:
range_data = range_result.data[0]
earliest = range_data["earliest"]
latest = range_data["latest"]
# Calculate span
date_span_info = self._calculate_date_span(earliest, latest)
# Temporal patterns analysis
temporal_patterns = await self._analyze_temporal_patterns(
connection, table_name, col_name, sampling_info
)
temporal_analysis[col_name] = {
"data_type": column["data_type"],
"non_null_count": range_data["non_null_count"],
"date_range": {
"earliest": str(earliest),
"latest": str(latest),
**date_span_info
},
"temporal_patterns": temporal_patterns
}
except Exception as e:
logger.warning(f"Failed to analyze temporal column {col_name}: {str(e)}")
temporal_analysis[col_name] = {"error": str(e)}
return temporal_analysis
def _calculate_date_span(self, earliest, latest) -> Dict[str, Any]:
"""Calculate date span information"""
try:
if isinstance(earliest, str):
earliest = datetime.fromisoformat(earliest.replace('Z', '+00:00'))
if isinstance(latest, str):
latest = datetime.fromisoformat(latest.replace('Z', '+00:00'))
span = latest - earliest
span_days = span.days
return {
"span_days": span_days,
"span_years": round(span_days / 365.25, 2),
"span_description": self._describe_time_span(span_days)
}
except Exception as e:
logger.warning(f"Failed to calculate date span: {str(e)}")
return {"span_days": 0}
def _describe_time_span(self, days: int) -> str:
"""Describe time span in human readable format"""
if days < 1:
return "less_than_day"
elif days < 7:
return "days"
elif days < 30:
return "weeks"
elif days < 365:
return "months"
else:
return "years"
async def _analyze_temporal_patterns(self, connection, table_name: str, col_name: str, sampling_info: Dict) -> Dict[str, Any]:
"""Analyze temporal patterns like seasonality and trends"""
try:
table_expr = sampling_info.get("sample_table_expression", table_name)
# Weekly pattern analysis
weekly_pattern_sql = f"""
SELECT
DAYOFWEEK({col_name}) as day_of_week,
COUNT(*) as count
FROM {table_expr}
WHERE {col_name} IS NOT NULL
GROUP BY DAYOFWEEK({col_name})
ORDER BY day_of_week
"""
auth_context = get_auth_context()
weekly_result = await connection.execute(weekly_pattern_sql, auth_context=auth_context)
weekly_pattern = []
if weekly_result.data:
total_records = sum(row["count"] for row in weekly_result.data)
for row in weekly_result.data:
percentage = row["count"] / total_records if total_records > 0 else 0
weekly_pattern.append(round(percentage, 3))
# Monthly trend analysis (simplified)
monthly_trend_sql = f"""
SELECT
YEAR({col_name}) as year,
MONTH({col_name}) as month,
COUNT(*) as count
FROM {table_expr}
WHERE {col_name} IS NOT NULL
GROUP BY YEAR({col_name}), MONTH({col_name})
ORDER BY year, month
LIMIT 12
"""
monthly_result = await connection.execute(monthly_trend_sql, auth_context=auth_context)
monthly_trend = "stable" # Simplified trend analysis
if monthly_result.data and len(monthly_result.data) > 3:
counts = [row["count"] for row in monthly_result.data]
if len(counts) > 1:
trend_direction = "increasing" if counts[-1] > counts[0] else "decreasing"
monthly_trend = trend_direction
return {
"weekly_pattern": weekly_pattern,
"monthly_trend": monthly_trend,
"seasonal_component": self._estimate_seasonality(weekly_pattern)
}
except Exception as e:
logger.warning(f"Failed to analyze temporal patterns for {col_name}: {str(e)}")
return {"weekly_pattern": [], "monthly_trend": "unknown"}
def _estimate_seasonality(self, weekly_pattern: List[float]) -> float:
"""Estimate seasonality strength based on weekly pattern variance"""
if len(weekly_pattern) < 7:
return 0.0
mean_percentage = sum(weekly_pattern) / len(weekly_pattern)
variance = sum((x - mean_percentage) ** 2 for x in weekly_pattern) / len(weekly_pattern)
# Normalize variance to 0-1 scale as seasonality indicator
seasonality = min(variance * 10, 1.0) # Scaling factor
return round(seasonality, 3)
async def _generate_data_quality_insights(self, connection, table_name: str, columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
"""Generate overall data quality insights"""
try:
total_columns = len(columns)
# Calculate null rates across all columns
null_analysis = await self._analyze_overall_null_rates(connection, table_name, columns, sampling_info)
# Identify potential data quality issues
quality_issues = []
# High null rate columns
high_null_columns = [col for col, rate in null_analysis["column_null_rates"].items() if rate > 0.2]
if high_null_columns:
quality_issues.append({
"issue_type": "high_null_rates",
"severity": "medium",
"affected_columns": high_null_columns,
"description": f"{len(high_null_columns)} columns have null rates > 20%"
})
# Calculate overall data quality score
avg_null_rate = sum(null_analysis["column_null_rates"].values()) / len(null_analysis["column_null_rates"]) if null_analysis["column_null_rates"] else 0
data_quality_score = max(0, 1 - avg_null_rate)
return {
"total_columns_analyzed": total_columns,
"null_analysis": null_analysis,
"data_quality_score": round(data_quality_score, 3),
"quality_issues": quality_issues,
"recommendations": self._generate_quality_recommendations(quality_issues, null_analysis)
}
except Exception as e:
logger.warning(f"Failed to generate data quality insights: {str(e)}")
return {"data_quality_score": 0.0, "error": str(e)}
async def _analyze_overall_null_rates(self, connection, table_name: str, columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
"""Analyze null rates across all columns"""
column_null_rates = {}
total_null_count = 0
total_cell_count = 0
for column in columns:
col_name = column["column_name"]
try:
table_expr = sampling_info.get("sample_table_expression", table_name)
null_sql = f"""
SELECT
COUNT(*) as total_count,
COUNT({col_name}) as non_null_count
FROM {table_expr}
"""
auth_context = get_auth_context()
result = await connection.execute(null_sql, auth_context=auth_context)
if result.data:
data = result.data[0]
total_count = data["total_count"]
non_null_count = data["non_null_count"]
null_count = total_count - non_null_count
null_rate = null_count / total_count if total_count > 0 else 0
column_null_rates[col_name] = round(null_rate, 4)
total_null_count += null_count
total_cell_count += total_count
except Exception as e:
logger.warning(f"Failed to analyze null rate for column {col_name}: {str(e)}")
column_null_rates[col_name] = 0.0
overall_null_rate = total_null_count / total_cell_count if total_cell_count > 0 else 0
return {
"column_null_rates": column_null_rates,
"overall_null_rate": round(overall_null_rate, 4),
"columns_with_nulls": len([rate for rate in column_null_rates.values() if rate > 0])
}
def _generate_quality_recommendations(self, quality_issues: List[Dict], null_analysis: Dict) -> List[Dict]:
"""Generate data quality improvement recommendations"""
recommendations = []
# Recommendations based on null analysis
overall_null_rate = null_analysis.get("overall_null_rate", 0)
if overall_null_rate > 0.1:
recommendations.append({
"type": "data_completeness",
"priority": "high" if overall_null_rate > 0.3 else "medium",
"description": f"Overall null rate is {overall_null_rate:.1%}",
"action": "Review data collection and validation processes"
})
# Recommendations based on quality issues
for issue in quality_issues:
if issue["issue_type"] == "high_null_rates":
recommendations.append({
"type": "column_completeness",
"priority": issue["severity"],
"description": issue["description"],
"action": f"Focus on improving data completeness for: {', '.join(issue['affected_columns'][:3])}"
})
return recommendations
def _generate_analysis_summary(self, distribution_analysis: Dict[str, Any]) -> Dict[str, Any]:
"""Generate high-level summary of distribution analysis"""
summary = {
"numeric_columns_count": len(distribution_analysis.get("numeric_columns", {})),
"categorical_columns_count": len(distribution_analysis.get("categorical_columns", {})),
"temporal_columns_count": len(distribution_analysis.get("temporal_columns", {}))
}
# Identify interesting patterns
patterns = []
# Check for highly skewed numeric columns
numeric_cols = distribution_analysis.get("numeric_columns", {})
skewed_cols = [
col for col, info in numeric_cols.items()
if isinstance(info, dict) and
info.get("distribution_shape", {}).get("skewness_indicator") in ["right_skewed", "left_skewed"]
]
if skewed_cols:
patterns.append(f"Found {len(skewed_cols)} skewed numeric columns")
# Check for high cardinality categorical columns
categorical_cols = distribution_analysis.get("categorical_columns", {})
high_cardinality_cols = [
col for col, info in categorical_cols.items()
if isinstance(info, dict) and info.get("cardinality", 0) > 1000
]
if high_cardinality_cols:
patterns.append(f"Found {len(high_cardinality_cols)} high cardinality categorical columns")
summary["notable_patterns"] = patterns
return summary

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -15,77 +15,573 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
""" """
Logging configuration for Doris MCP Server. Enhanced Logging configuration for Doris MCP Server.
Features:
- Log level-based file separation
- Timestamped log entries
- Automatic log rotation
- Comprehensive logging coverage
""" """
import logging import logging
import logging.config import logging.config
import logging.handlers
import sys import sys
import os
import asyncio
import time
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Optional
from datetime import datetime, timedelta
import threading
def setup_logging( class TimestampedFormatter(logging.Formatter):
level: str = "INFO", """Custom formatter with enhanced timestamp and structured format"""
log_file: str | None = None,
log_format: str | None = None, def __init__(self, fmt=None, datefmt=None, style='%'):
) -> None: if fmt is None:
fmt = "%(asctime)s.%(msecs)03d %(level_aligned)s %(name)s:%(lineno)d - %(message)s"
if datefmt is None:
datefmt = "%Y-%m-%d %H:%M:%S"
super().__init__(fmt, datefmt, style)
def format(self, record):
"""Format log record with enhanced information and proper alignment"""
# Add process info if available
if hasattr(record, 'process') and record.process:
record.process_info = f"[PID:{record.process}]"
else:
record.process_info = ""
# Add thread info if available
if hasattr(record, 'thread') and record.thread:
record.thread_info = f"[TID:{record.thread}]"
else:
record.thread_info = ""
# Format with proper alignment after the level name
# Calculate padding needed for alignment
level_name = record.levelname
max_level_length = 8 # Length of "CRITICAL"
padding = max_level_length - len(level_name)
record.level_aligned = f"[{level_name}]{' ' * padding}"
return super().format(record)
class LevelBasedFileHandler(logging.Handler):
"""Custom handler that writes different log levels to different files"""
def __init__(self, log_dir: str, base_name: str = "doris_mcp_server",
max_bytes: int = 10*1024*1024, backup_count: int = 5):
super().__init__()
self.log_dir = Path(log_dir)
self.base_name = base_name
self.max_bytes = max_bytes
self.backup_count = backup_count
# Ensure log directory exists
self.log_dir.mkdir(parents=True, exist_ok=True)
# Create handlers for different log levels
self.handlers = {}
self._setup_level_handlers()
def _setup_level_handlers(self):
"""Setup rotating file handlers for different log levels"""
level_files = {
'DEBUG': 'debug.log',
'INFO': 'info.log',
'WARNING': 'warning.log',
'ERROR': 'error.log',
'CRITICAL': 'critical.log'
}
formatter = TimestampedFormatter()
for level, filename in level_files.items():
file_path = self.log_dir / f"{self.base_name}_{filename}"
handler = logging.handlers.RotatingFileHandler(
file_path,
maxBytes=self.max_bytes,
backupCount=self.backup_count,
encoding='utf-8'
)
handler.setFormatter(formatter)
handler.setLevel(getattr(logging, level))
self.handlers[level] = handler
def emit(self, record):
"""Emit log record to appropriate level-based file"""
level_name = record.levelname
if level_name in self.handlers:
try:
self.handlers[level_name].emit(record)
except Exception:
self.handleError(record)
def close(self):
"""Close all handlers"""
for handler in self.handlers.values():
handler.close()
super().close()
class LogCleanupManager:
"""Log file cleanup manager for automatic maintenance"""
def __init__(self, log_dir: str, max_age_days: int = 30, cleanup_interval_hours: int = 24):
""" """
Setup logging configuration. Initialize log cleanup manager.
Args: Args:
level: Logging level (DEBUG, INFO, WARNING, ERROR) log_dir: Directory containing log files
log_file: Optional log file path max_age_days: Maximum age of log files in days (default: 30 days)
log_format: Optional custom log format cleanup_interval_hours: Cleanup interval in hours (default: 24 hours)
""" """
if log_format is None: self.log_dir = Path(log_dir)
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" self.max_age_days = max_age_days
self.cleanup_interval_hours = cleanup_interval_hours
self.cleanup_thread = None
self.stop_event = threading.Event()
self.logger = None
# Base configuration def start_cleanup_scheduler(self):
config: dict[str, Any] = { """Start the cleanup scheduler in a background thread"""
"version": 1, if self.cleanup_thread and self.cleanup_thread.is_alive():
"disable_existing_loggers": False, return
"formatters": {
"default": {"format": log_format, "datefmt": "%Y-%m-%d %H:%M:%S"} self.stop_event.clear()
}, self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
"handlers": { self.cleanup_thread.start()
"console": {
"class": "logging.StreamHandler", # Get logger for this class
"level": level, if not self.logger:
"formatter": "default", self.logger = logging.getLogger("doris_mcp_server.log_cleanup")
"stream": sys.stdout,
} self.logger.info(f"Log cleanup scheduler started - cleanup every {self.cleanup_interval_hours}h, max age {self.max_age_days} days")
},
"root": {"level": level, "handlers": ["console"]}, def stop_cleanup_scheduler(self):
"loggers": { """Stop the cleanup scheduler"""
"doris_mcp_server": { if self.cleanup_thread and self.cleanup_thread.is_alive():
"level": level, self.stop_event.set()
"handlers": ["console"], self.cleanup_thread.join(timeout=5)
"propagate": False, if self.logger:
} self.logger.info("Log cleanup scheduler stopped")
},
def _cleanup_loop(self):
"""Background loop for periodic cleanup"""
while not self.stop_event.is_set():
try:
self.cleanup_old_logs()
# Sleep for the specified interval, but check stop event every 60 seconds
for _ in range(self.cleanup_interval_hours * 60): # Convert hours to minutes
if self.stop_event.wait(60): # Wait 60 seconds or until stop event
break
except Exception as e:
if self.logger:
self.logger.error(f"Error in log cleanup loop: {e}")
# Sleep for 5 minutes before retrying
self.stop_event.wait(300)
def cleanup_old_logs(self):
"""Clean up old log files based on age"""
if not self.log_dir.exists():
return
current_time = datetime.now()
cutoff_time = current_time - timedelta(days=self.max_age_days)
cleaned_files = []
cleaned_size = 0
# Pattern for log files (including backup files)
log_patterns = [
"doris_mcp_server_*.log",
"doris_mcp_server_*.log.*" # Backup files
]
for pattern in log_patterns:
for log_file in self.log_dir.glob(pattern):
try:
# Get file modification time
file_mtime = datetime.fromtimestamp(log_file.stat().st_mtime)
if file_mtime < cutoff_time:
file_size = log_file.stat().st_size
log_file.unlink() # Delete the file
cleaned_files.append(log_file.name)
cleaned_size += file_size
except Exception as e:
if self.logger:
self.logger.warning(f"Failed to cleanup log file {log_file}: {e}")
if cleaned_files and self.logger:
size_mb = cleaned_size / (1024 * 1024)
self.logger.info(f"Cleaned up {len(cleaned_files)} old log files, freed {size_mb:.2f} MB")
self.logger.debug(f"Cleaned files: {', '.join(cleaned_files)}")
def get_cleanup_stats(self) -> dict:
"""Get statistics about log files and cleanup status"""
if not self.log_dir.exists():
return {"error": "Log directory does not exist"}
stats = {
"log_directory": str(self.log_dir.absolute()),
"max_age_days": self.max_age_days,
"cleanup_interval_hours": self.cleanup_interval_hours,
"scheduler_running": self.cleanup_thread and self.cleanup_thread.is_alive(),
"total_files": 0,
"total_size_mb": 0,
"files_by_age": {"recent": 0, "old": 0},
"oldest_file": None,
"newest_file": None
} }
# Add file handler if log_file is specified current_time = datetime.now()
if log_file: cutoff_time = current_time - timedelta(days=self.max_age_days)
# Ensure log directory exists oldest_time = None
log_path = Path(log_file) newest_time = None
log_path.parent.mkdir(parents=True, exist_ok=True)
config["handlers"]["file"] = { log_patterns = ["doris_mcp_server_*.log", "doris_mcp_server_*.log.*"]
"class": "logging.handlers.RotatingFileHandler",
"level": level,
"formatter": "default",
"filename": log_file,
"maxBytes": 10485760, # 10MB
"backupCount": 5,
}
# Add file handler to root and package loggers for pattern in log_patterns:
config["root"]["handlers"].append("file") for log_file in self.log_dir.glob(pattern):
config["loggers"]["doris_mcp_server"]["handlers"].append("file") try:
file_stat = log_file.stat()
file_mtime = datetime.fromtimestamp(file_stat.st_mtime)
logging.config.dictConfig(config) stats["total_files"] += 1
stats["total_size_mb"] += file_stat.st_size / (1024 * 1024)
if file_mtime < cutoff_time:
stats["files_by_age"]["old"] += 1
else:
stats["files_by_age"]["recent"] += 1
if oldest_time is None or file_mtime < oldest_time:
oldest_time = file_mtime
stats["oldest_file"] = {"name": log_file.name, "age_days": (current_time - file_mtime).days}
if newest_time is None or file_mtime > newest_time:
newest_time = file_mtime
stats["newest_file"] = {"name": log_file.name, "age_days": (current_time - file_mtime).days}
except Exception:
continue
stats["total_size_mb"] = round(stats["total_size_mb"], 2)
return stats
class DorisLoggerManager:
"""Centralized logger manager for Doris MCP Server"""
def __init__(self):
self.is_initialized = False
self.log_dir = None
self.config = None
self.loggers = {}
self.cleanup_manager = None
def setup_logging(self,
level: str = "INFO",
log_dir: str = "logs",
enable_console: bool = True,
enable_file: bool = True,
enable_audit: bool = True,
audit_file: Optional[str] = None,
max_file_size: int = 10*1024*1024,
backup_count: int = 5,
enable_cleanup: bool = True,
max_age_days: int = 30,
cleanup_interval_hours: int = 24) -> None:
"""
Setup comprehensive logging configuration.
Args:
level: Base logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
log_dir: Directory for log files
enable_console: Enable console output
enable_file: Enable file logging
enable_audit: Enable audit logging
audit_file: Custom audit log file path
max_file_size: Maximum size per log file (bytes)
backup_count: Number of backup files to keep
enable_cleanup: Enable automatic log cleanup
max_age_days: Maximum age of log files in days (default: 30)
cleanup_interval_hours: Cleanup interval in hours (default: 24)
"""
if self.is_initialized:
return
self.log_dir = Path(log_dir)
log_dir_writable = True # Initialize the variable
# Try to create log directory, fallback to console-only if fails
try:
self.log_dir.mkdir(parents=True, exist_ok=True)
except (OSError, PermissionError) as e:
# If we can't create log directory (e.g., read-only filesystem in stdio mode),
# fall back to console-only logging
log_dir_writable = False
enable_file = False
enable_audit = False
enable_cleanup = False
# Don't use print() in stdio mode as it interferes with MCP JSON protocol
# Log the warning through the logging system instead, which will be handled after setup
# Clear existing handlers
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
# Set root logger level
root_logger.setLevel(logging.DEBUG) # Allow all levels, handlers will filter
handlers = []
# Console handler
if enable_console:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(getattr(logging, level.upper()))
console_formatter = TimestampedFormatter(
fmt="%(asctime)s.%(msecs)03d %(level_aligned)s %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(console_formatter)
handlers.append(console_handler)
# Level-based file handlers
if enable_file:
level_handler = LevelBasedFileHandler(
log_dir=str(self.log_dir),
base_name="doris_mcp_server",
max_bytes=max_file_size,
backup_count=backup_count
)
level_handler.setLevel(logging.DEBUG) # Accept all levels
handlers.append(level_handler)
# Combined application log (all levels in one file)
if enable_file:
app_log_file = self.log_dir / "doris_mcp_server_all.log"
app_handler = logging.handlers.RotatingFileHandler(
app_log_file,
maxBytes=max_file_size,
backupCount=backup_count,
encoding='utf-8'
)
app_handler.setLevel(getattr(logging, level.upper()))
app_formatter = TimestampedFormatter()
app_handler.setFormatter(app_formatter)
handlers.append(app_handler)
# Audit logger (separate from main logging)
if enable_audit:
audit_file_path = audit_file or str(self.log_dir / "doris_mcp_server_audit.log")
audit_logger = logging.getLogger("audit")
audit_logger.setLevel(logging.INFO)
# Clear existing audit handlers
for handler in audit_logger.handlers[:]:
audit_logger.removeHandler(handler)
audit_handler = logging.handlers.RotatingFileHandler(
audit_file_path,
maxBytes=max_file_size,
backupCount=backup_count,
encoding='utf-8'
)
audit_formatter = TimestampedFormatter(
fmt="%(asctime)s.%(msecs)03d [AUDIT] %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
)
audit_handler.setFormatter(audit_formatter)
audit_logger.addHandler(audit_handler)
audit_logger.propagate = False # Don't propagate to root logger
# Add all handlers to root logger
for handler in handlers:
root_logger.addHandler(handler)
# Setup package-specific loggers
self._setup_package_loggers(level)
# Setup log cleanup manager
if enable_cleanup and enable_file:
self.cleanup_manager = LogCleanupManager(
log_dir=str(self.log_dir),
max_age_days=max_age_days,
cleanup_interval_hours=cleanup_interval_hours
)
self.cleanup_manager.start_cleanup_scheduler()
self.is_initialized = True
# Log initialization message
logger = self.get_logger("doris_mcp_server.logger")
logger.info("=" * 80)
logger.info("Doris MCP Server Logging System Initialized")
logger.info(f"Log Level: {level}")
if log_dir_writable:
logger.info(f"Log Directory: {self.log_dir.absolute()}")
else:
logger.info("Log Directory: Not available (console-only mode)")
logger.info(f"Console Logging: {'Enabled' if enable_console else 'Disabled'}")
logger.info(f"File Logging: {'Enabled' if enable_file else 'Disabled (fallback mode)'}")
logger.info(f"Audit Logging: {'Enabled' if enable_audit else 'Disabled (fallback mode)'}")
logger.info(f"Log Cleanup: {'Enabled' if enable_cleanup and enable_file else 'Disabled (fallback mode)'}")
if enable_cleanup and enable_file:
logger.info(f"Cleanup Settings: Max age {max_age_days} days, interval {cleanup_interval_hours}h")
if not log_dir_writable:
logger.warning("Running in console-only logging mode due to filesystem permissions")
logger.warning(f"Could not create log directory '{log_dir}' - stdio mode fallback enabled")
logger.info("=" * 80)
def _setup_package_loggers(self, level: str):
"""Setup specific loggers for different modules"""
package_loggers = [
"doris_mcp_server",
"doris_mcp_server.main",
"doris_mcp_server.utils",
"doris_mcp_server.tools",
"doris_mcp_client"
]
for logger_name in package_loggers:
logger = logging.getLogger(logger_name)
logger.setLevel(getattr(logging, level.upper()))
# Don't add handlers here - they inherit from root logger
def get_logger(self, name: str) -> logging.Logger:
"""
Get a logger instance with proper configuration.
Args:
name: Logger name (usually __name__)
Returns:
Configured logger instance
"""
if name not in self.loggers:
logger = logging.getLogger(name)
self.loggers[name] = logger
return self.loggers[name]
def get_audit_logger(self) -> logging.Logger:
"""Get the audit logger"""
return logging.getLogger("audit")
def log_system_info(self):
"""Log system information for debugging"""
logger = self.get_logger("doris_mcp_server.system")
logger.info("System Information:")
logger.info(f"Python Version: {sys.version}")
logger.info(f"Platform: {sys.platform}")
logger.info(f"Working Directory: {os.getcwd()}")
logger.info(f"Process ID: {os.getpid()}")
# Log environment variables (filtered)
env_vars = ["LOG_LEVEL", "LOG_FILE_PATH", "ENABLE_AUDIT", "AUDIT_FILE_PATH"]
for var in env_vars:
value = os.getenv(var, "Not Set")
logger.info(f"Environment {var}: {value}")
def get_cleanup_stats(self) -> dict:
"""Get log cleanup statistics"""
if self.cleanup_manager:
return self.cleanup_manager.get_cleanup_stats()
else:
return {"error": "Log cleanup is not enabled"}
def manual_cleanup(self) -> dict:
"""Manually trigger log cleanup and return statistics"""
if self.cleanup_manager:
self.cleanup_manager.cleanup_old_logs()
return self.cleanup_manager.get_cleanup_stats()
else:
return {"error": "Log cleanup is not enabled"}
def shutdown(self):
"""Shutdown logging system"""
if not self.is_initialized:
return
logger = self.get_logger("doris_mcp_server.logger")
logger.info("Shutting down logging system...")
# Stop cleanup manager
if self.cleanup_manager:
self.cleanup_manager.stop_cleanup_scheduler()
# Close all handlers
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
try:
handler.close()
except Exception as e:
print(f"Error closing handler: {e}")
# Close audit logger handlers
audit_logger = logging.getLogger("audit")
for handler in audit_logger.handlers[:]:
try:
handler.close()
except Exception as e:
print(f"Error closing audit handler: {e}")
self.is_initialized = False
# Global logger manager instance
_logger_manager = DorisLoggerManager()
def setup_logging(level: str = "INFO",
log_dir: str = "logs",
enable_console: bool = True,
enable_file: bool = True,
enable_audit: bool = True,
audit_file: Optional[str] = None,
max_file_size: int = 10*1024*1024,
backup_count: int = 5,
enable_cleanup: bool = True,
max_age_days: int = 30,
cleanup_interval_hours: int = 24) -> None:
"""
Setup logging configuration (convenience function).
Args:
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
log_dir: Directory for log files
enable_console: Enable console output
enable_file: Enable file logging
enable_audit: Enable audit logging
audit_file: Custom audit log file path
max_file_size: Maximum size per log file (bytes)
backup_count: Number of backup files to keep
enable_cleanup: Enable automatic log cleanup
max_age_days: Maximum age of log files in days (default: 30)
cleanup_interval_hours: Cleanup interval in hours (default: 24)
"""
_logger_manager.setup_logging(
level=level,
log_dir=log_dir,
enable_console=enable_console,
enable_file=enable_file,
enable_audit=enable_audit,
audit_file=audit_file,
max_file_size=max_file_size,
backup_count=backup_count,
enable_cleanup=enable_cleanup,
max_age_days=max_age_days,
cleanup_interval_hours=cleanup_interval_hours
)
def get_logger(name: str) -> logging.Logger: def get_logger(name: str) -> logging.Logger:
@@ -93,9 +589,60 @@ def get_logger(name: str) -> logging.Logger:
Get a logger instance. Get a logger instance.
Args: Args:
name: Logger name name: Logger name (usually __name__)
Returns: Returns:
Logger instance Configured logger instance
""" """
return logging.getLogger(name) return _logger_manager.get_logger(name)
def get_audit_logger() -> logging.Logger:
"""Get the audit logger"""
return _logger_manager.get_audit_logger()
def log_system_info():
"""Log system information for debugging"""
_logger_manager.log_system_info()
def get_cleanup_stats() -> dict:
"""Get log cleanup statistics"""
return _logger_manager.get_cleanup_stats()
def manual_cleanup() -> dict:
"""Manually trigger log cleanup and return statistics"""
return _logger_manager.manual_cleanup()
def shutdown_logging():
"""Shutdown logging system"""
_logger_manager.shutdown()
# Compatibility function for existing code
def setup_logging_old(level: str = "INFO",
log_file: str | None = None,
log_format: str | None = None) -> None:
"""
Legacy setup function for backward compatibility.
Args:
level: Logging level (DEBUG, INFO, WARNING, ERROR)
log_file: Optional log file path (deprecated - use log_dir instead)
log_format: Optional custom log format (deprecated)
"""
# Extract directory from log_file if provided
log_dir = "logs"
if log_file:
log_dir = str(Path(log_file).parent)
setup_logging(
level=level,
log_dir=log_dir,
enable_console=True,
enable_file=True,
enable_audit=True
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -33,7 +33,11 @@ from datetime import datetime, timedelta, date
from typing import Any, Dict from typing import Any, Dict
from decimal import Decimal from decimal import Decimal
import sqlparse
from .db import DorisConnectionManager, QueryResult from .db import DorisConnectionManager, QueryResult
from .logger import get_logger
from .sql_security_utils import get_auth_context
@dataclass @dataclass
@@ -92,7 +96,7 @@ class QueryCache:
self.max_size = max_size self.max_size = max_size
self.default_ttl = default_ttl self.default_ttl = default_ttl
self.cache: dict[str, CachedQuery] = {} self.cache: dict[str, CachedQuery] = {}
self.logger = logging.getLogger(__name__) self.logger = get_logger(__name__)
def _generate_cache_key( def _generate_cache_key(
self, sql: str, parameters: dict[str, Any] | None = None self, sql: str, parameters: dict[str, Any] | None = None
@@ -194,7 +198,7 @@ class QueryOptimizer:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.logger = logging.getLogger(__name__) self.logger = get_logger(__name__)
self.optimization_rules = self._load_optimization_rules() self.optimization_rules = self._load_optimization_rules()
def _load_optimization_rules(self) -> list[dict[str, Any]]: def _load_optimization_rules(self) -> list[dict[str, Any]]:
@@ -318,7 +322,7 @@ class DorisQueryExecutor:
def __init__(self, connection_manager: DorisConnectionManager, config=None): def __init__(self, connection_manager: DorisConnectionManager, config=None):
self.connection_manager = connection_manager self.connection_manager = connection_manager
self.config = config or self._create_default_config() self.config = config or self._create_default_config()
self.logger = logging.getLogger(__name__) self.logger = get_logger(__name__)
# Initialize components # Initialize components
cache_config = getattr(self.config, 'performance', None) cache_config = getattr(self.config, 'performance', None)
@@ -425,27 +429,27 @@ class DorisQueryExecutor:
self, query_request: QueryRequest, auth_context self, query_request: QueryRequest, auth_context
) -> QueryResult: ) -> QueryResult:
"""Internal query execution""" """Internal query execution"""
# Database configuration should already be handled during authentication
# No need to configure again during query execution
# Optimize query # Optimize query
optimized_sql = await self.query_optimizer.optimize_query( optimized_sql = await self.query_optimizer.optimize_query(
query_request.sql, {"user_roles": getattr(auth_context, 'roles', [])} query_request.sql, {"user_roles": getattr(auth_context, 'roles', [])}
) )
# Execute query # Execute query
connection = await self.connection_manager.get_connection(
query_request.session_id
)
# Set timeout if specified # Set timeout if specified
if query_request.timeout: if query_request.timeout:
try: try:
result = await asyncio.wait_for( result = await asyncio.wait_for(
connection.execute(optimized_sql, query_request.parameters, auth_context), self.connection_manager.execute_query(query_request.session_id, optimized_sql, query_request.parameters, auth_context),
timeout=query_request.timeout timeout=query_request.timeout
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise Exception(f"Query timeout after {query_request.timeout} seconds") raise Exception(f"Query timeout after {query_request.timeout} seconds")
else: else:
result = await connection.execute(optimized_sql, query_request.parameters, auth_context) result = await self.connection_manager.execute_query(query_request.session_id, optimized_sql, query_request.parameters, auth_context)
return result return result
@@ -466,6 +470,51 @@ class DorisQueryExecutor:
f"Slow query detected: {execution_time:.2f}s (threshold: {self.slow_query_threshold}s)" f"Slow query detected: {execution_time:.2f}s (threshold: {self.slow_query_threshold}s)"
) )
async def execute_batch_sqls_for_mcp(
self, sqls: list[str],
timeout: int = 30,
session_id: str = "mcp_session",
user_id: str = "mcp_user",
auth_context=None
) -> dict[str, Any]:
"""Execute multiple sqls in batch"""
if not sqls:
return {
"success": False,
"error": "SQL query is required",
"data": None
}
query_requests = [
QueryRequest(
sql=sql,
session_id=session_id,
user_id=user_id,
timeout=timeout,
cache_enabled=False
)
for sql in sqls
]
query_results = await self.execute_batch_queries(query_requests, auth_context)
# Serialize data for JSON response
results = [
{
"data": [self._serialize_row_data(data) for data in result.data],
"row_count": result.row_count,
"execution_time": result.execution_time,
"metadata": {
"columns": result.metadata.get("columns", []),
"query": result.sql
}
}
for result in query_results
]
return {
"success": True,
"multiple_results": True,
"results": results
}
async def execute_batch_queries( async def execute_batch_queries(
self, query_requests: list[QueryRequest], auth_context=None self, query_requests: list[QueryRequest], auth_context=None
) -> list[QueryResult]: ) -> list[QueryResult]:
@@ -483,20 +532,24 @@ class DorisQueryExecutor:
self.execute_query(request, auth_context) for request in query_requests self.execute_query(request, auth_context) for request in query_requests
] ]
try: query_results = []
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e: for result in results:
self.logger.error(f"Batch query execution failed: {e}") if isinstance(result, Exception):
raise self.logger.error(f"Batch query execution failed: {result}")
raise result
else:
query_results.append(result)
return results return query_results
async def explain_query(self, sql: str, session_id: str) -> dict[str, Any]: async def explain_query(self, sql: str, session_id: str) -> dict[str, Any]:
"""Get query execution plan""" """Get query execution plan"""
explain_sql = f"EXPLAIN {sql}" explain_sql = f"EXPLAIN {sql}"
connection = await self.connection_manager.get_connection(session_id) connection = await self.connection_manager.get_connection(session_id)
result = await connection.execute(explain_sql) auth_context = get_auth_context()
result = await connection.execute(explain_sql, auth_context=auth_context)
return { return {
"query": sql, "query": sql,
@@ -545,9 +598,17 @@ class DorisQueryExecutor:
limit: int = 1000, limit: int = 1000,
timeout: int = 30, timeout: int = 30,
session_id: str = "mcp_session", session_id: str = "mcp_session",
user_id: str = "mcp_user" user_id: str = "mcp_user",
auth_context = None # FIX for Issue #62 Bug 1: Accept auth_context with token
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Execute SQL query for MCP interface - unified method""" """Execute SQL query for MCP interface - unified method
FIX for Issue #62 Bug 1: Now accepts auth_context parameter to support token-bound database configuration
"""
max_retries = 2
retry_count = 0
while retry_count <= max_retries:
try: try:
if not sql: if not sql:
return { return {
@@ -556,69 +617,166 @@ class DorisQueryExecutor:
"data": None "data": None
} }
# Import required security modules
from .security import DorisSecurityManager, AuthContext, SecurityLevel
# FIX: Use provided auth_context if available (contains token for DB config)
# Otherwise create default auth context for backward compatibility
if auth_context is None:
auth_context = AuthContext(
user_id=user_id,
roles=["read_only_user"], # Restrictive role for MCP interface
permissions=["read_data"], # Only read permissions
session_id=session_id,
security_level=SecurityLevel.INTERNAL,
token="" # No token in default context
)
else:
# Use provided auth_context (may contain token for database configuration)
self.logger.debug(f"Using provided auth_context with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
# Perform SQL security validation if enabled
if hasattr(self.connection_manager, 'config') and hasattr(self.connection_manager.config, 'security'):
if self.connection_manager.config.security.enable_security_check:
try:
# 🔧 FIX: Use existing security_manager to avoid creating multiple TokenManager instances
# Creating new DorisSecurityManager each time causes multiple hot reload monitors
security_manager = getattr(self.connection_manager, 'security_manager', None)
if not security_manager:
# Fallback: create new one only if not available (should rarely happen)
self.logger.warning("No existing security_manager, creating new instance")
security_manager = DorisSecurityManager(self.connection_manager.config)
validation_result = await security_manager.validate_sql_security(sql, auth_context)
if not validation_result.is_valid:
self.logger.warning(f"SQL security validation failed for query: {sql[:100]}...")
return {
"success": False,
"error": f"SQL security validation failed: {validation_result.error_message}",
"error_type": "security_violation",
"blocked_operations": validation_result.blocked_operations,
"risk_level": validation_result.risk_level,
"data": None,
"metadata": {
"query": sql,
"validation_details": {
"blocked_operations": validation_result.blocked_operations,
"risk_level": validation_result.risk_level
}
}
}
else:
self.logger.debug(f"SQL security validation passed for query: {sql[:100]}...")
except Exception as security_error:
self.logger.error(f"Security validation error: {str(security_error)}")
# In case of security validation error, fail safe
return {
"success": False,
"error": f"Security validation system error: {str(security_error)}",
"error_type": "security_system_error",
"data": None,
"metadata": {
"query": sql,
"security_error": str(security_error)
}
}
else:
self.logger.info("SQL security check is disabled in configuration")
else:
self.logger.warning("Security configuration not found, proceeding without validation")
# Add LIMIT if not present and it's a SELECT query # Add LIMIT if not present and it's a SELECT query
if sql.upper().startswith("SELECT") and "LIMIT" not in sql.upper(): if sql.upper().startswith("SELECT") and "LIMIT" not in sql.upper():
if sql.endswith(";"): if sql.endswith(";"):
sql = sql[:-1] sql = sql[:-1]
sql = f"{sql} LIMIT {limit}" sql = f"{sql} LIMIT {limit}"
# Create auth context for MCP calls all_statements = [
class MockAuthContext: s.strip()
def __init__(self): for s in sqlparse.split(sql)
self.user_id = user_id if s.strip()
self.roles = ["data_analyst"] ]
self.permissions = ["read_data", "execute_query"] if len(all_statements) > 1:
self.session_id = session_id return await self.execute_batch_sqls_for_mcp(sqls=all_statements, timeout=timeout,
self.security_level = "internal" session_id=session_id, user_id=user_id,
auth_context=auth_context)
auth_context = MockAuthContext()
# Create query request # Create query request
query_request = QueryRequest( query_request = QueryRequest(
sql=sql, sql=sql,
session_id=session_id, session_id=session_id,
user_id=user_id, user_id=user_id,
timeout=timeout, timeout=timeout,
cache_enabled=True cache_enabled=False # Disable cache for MCP calls to ensure fresh data
) )
# Execute query # Execute query with retry logic
result = await self.execute_query(query_request, auth_context) result = await self.execute_query(query_request, auth_context)
# Process results # Serialize data for JSON response
processed_data = [] serialized_data = []
if result.data:
for row in result.data: for row in result.data:
processed_row = self._serialize_row_data(row) serialized_data.append(self._serialize_row_data(row))
processed_data.append(processed_row)
return { return {
"success": True, "success": True,
"data": processed_data, "data": serialized_data,
"metadata": {
"row_count": result.row_count, "row_count": result.row_count,
"execution_time": result.execution_time, "execution_time": result.execution_time,
"metadata": {
"columns": result.metadata.get("columns", []), "columns": result.metadata.get("columns", []),
"query": sql "query": sql
}, }
"error": None
} }
except Exception as e: except Exception as e:
error_msg = str(e) error_msg = str(e)
self.logger.error(f"SQL execution error: {error_msg}") error_str = error_msg.lower()
# Analyze error for better user feedback # Check if it's a connection-related error that we should retry
connection_errors = [
"at_eof", "connection", "closed", "nonetype",
"transport", "reader", "broken pipe", "connection reset"
]
is_connection_error = any(err in error_str for err in connection_errors)
if is_connection_error and retry_count < max_retries:
retry_count += 1
self.logger.warning(f"Connection error detected, retrying ({retry_count}/{max_retries}): {e}")
# Release the problematic connection
try:
await self.connection_manager.release_connection(session_id)
except Exception:
pass # Ignore cleanup errors
# Wait a bit before retry
await asyncio.sleep(0.5 * retry_count)
continue
else:
# If we've exhausted retries or it's not a connection error, return error
error_analysis = self._analyze_error(error_msg) error_analysis = self._analyze_error(error_msg)
return { return {
"success": False, "success": False,
"error": error_analysis.get("user_message", error_msg), "error": error_analysis.get("user_message", error_msg),
"error_type": error_analysis.get("error_type", "execution_error"), "error_type": error_analysis.get("error_type", "general_error"),
"data": None, "data": None,
"metadata": { "metadata": {
"query": sql, "query": sql,
"error_details": error_msg "error_details": error_msg,
"retry_count": retry_count
}
}
# This should never be reached, but just in case
return {
"success": False,
"error": "Maximum retries exceeded",
"data": None,
"metadata": {
"query": sql,
"retry_count": retry_count
} }
} }
@@ -649,7 +807,12 @@ class DorisQueryExecutor:
"""Analyze error message and provide user-friendly feedback""" """Analyze error message and provide user-friendly feedback"""
error_msg_lower = error_message.lower() error_msg_lower = error_message.lower()
if "table" in error_msg_lower and "doesn't exist" in error_msg_lower: if "at_eof" in error_msg_lower or "nonetype" in error_msg_lower and "at_eof" in error_msg_lower:
return {
"error_type": "connection_lost",
"user_message": "Database connection was lost. The query has been automatically retried. If this persists, please restart the server."
}
elif "table" in error_msg_lower and "doesn't exist" in error_msg_lower:
return { return {
"error_type": "table_not_found", "error_type": "table_not_found",
"user_message": "The specified table does not exist. Please check the table name and database." "user_message": "The specified table does not exist. Please check the table name and database."
@@ -674,6 +837,11 @@ class DorisQueryExecutor:
"error_type": "timeout", "error_type": "timeout",
"user_message": "Query execution timed out. Try simplifying your query or adding more specific filters." "user_message": "Query execution timed out. Try simplifying your query or adding more specific filters."
} }
elif "connection" in error_msg_lower and ("closed" in error_msg_lower or "reset" in error_msg_lower):
return {
"error_type": "connection_error",
"user_message": "Database connection was interrupted. The query has been automatically retried."
}
else: else:
return { return {
"error_type": "general_error", "error_type": "general_error",
@@ -701,7 +869,7 @@ class QueryPerformanceMonitor:
def __init__(self, query_executor: DorisQueryExecutor): def __init__(self, query_executor: DorisQueryExecutor):
self.query_executor = query_executor self.query_executor = query_executor
self.logger = logging.getLogger(__name__) self.logger = get_logger(__name__)
self.performance_records = [] self.performance_records = []
async def record_query_performance( async def record_query_performance(
@@ -785,32 +953,51 @@ class QueryPerformanceMonitor:
# Unified convenience function for MCP integration # Unified convenience function for MCP integration
async def execute_sql_query(sql: str, connection_manager: DorisConnectionManager, **kwargs) -> Dict[str, Any]: async def execute_sql_query(sql: str, connection_manager: DorisConnectionManager, **kwargs) -> Dict[str, Any]:
"""Execute SQL query - unified convenience function for MCP tools""" """Execute SQL query - unified convenience function for MCP tools
This function now includes security validation to ensure safe query execution.
All queries are validated against the configured security policies before execution.
FIX for Issue #62 Bug 1: Now supports auth_context parameter for token-bound database configuration
FIX for Issue #58 Problem 2: Removed executor.close() to prevent ClosedResourceError in multi-worker mode
"""
try: try:
# Create query executor # Create query executor with the connection manager's configuration
executor = DorisQueryExecutor(connection_manager) executor = DorisQueryExecutor(connection_manager)
try:
# Extract parameters from kwargs or use defaults # Extract parameters from kwargs or use defaults
limit = kwargs.get("limit", 1000) limit = kwargs.get("limit", 1000)
timeout = kwargs.get("timeout", 30) timeout = kwargs.get("timeout", 30)
session_id = kwargs.get("session_id", "mcp_session") session_id = kwargs.get("session_id", "mcp_session")
user_id = kwargs.get("user_id", "mcp_user") user_id = kwargs.get("user_id", "mcp_user")
auth_context = kwargs.get("auth_context", None) # FIX: Extract auth_context
# The execute_sql_for_mcp method now includes security validation
result = await executor.execute_sql_for_mcp( result = await executor.execute_sql_for_mcp(
sql=sql, sql=sql,
limit=limit, limit=limit,
timeout=timeout, timeout=timeout,
session_id=session_id, session_id=session_id,
user_id=user_id user_id=user_id,
auth_context=auth_context # FIX: Pass auth_context with token
) )
# FIX for Issue #58 Problem 2: Do NOT close executor here
# In multi-worker mode, closing here causes ClosedResourceError
# The executor's resources (cache, background tasks) will be managed
# by the connection_manager lifecycle and Python's garbage collection
# This prevents premature cleanup while MCP session manager is still processing
return result return result
finally:
await executor.close()
except Exception as e: except Exception as e:
return { return {
"success": False, "success": False,
"error": f"Query execution failed: {str(e)}", "error": f"Query execution failed: {str(e)}",
"data": None "error_type": "execution_error",
"data": None,
"metadata": {
"query": sql,
"execution_error": str(e)
}
} }

View File

@@ -31,14 +31,16 @@ from dotenv import load_dotenv
from datetime import datetime, timedelta from datetime import datetime, timedelta
# Import unified logging configuration # Import unified logging configuration
from doris_mcp_server.utils.logger import get_logger from .logger import get_logger
from .sql_security_utils import (
SQLSecurityError,
validate_identifier,
quote_identifier
)
# Configure logging # Configure logging
logger = get_logger(__name__) logger = get_logger(__name__)
# Load environment variables
load_dotenv(override=True)
METADATA_DB_NAME="information_schema" METADATA_DB_NAME="information_schema"
ENABLE_MULTI_DATABASE=os.getenv("ENABLE_MULTI_DATABASE",True) ENABLE_MULTI_DATABASE=os.getenv("ENABLE_MULTI_DATABASE",True)
MULTI_DATABASE_NAMES=os.getenv("MULTI_DATABASE_NAMES","") MULTI_DATABASE_NAMES=os.getenv("MULTI_DATABASE_NAMES","")
@@ -416,7 +418,7 @@ class MetadataExtractor:
return matches return matches
def get_table_schema(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> Dict[str, Any]: async def get_table_schema(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> Dict[str, Any]:
""" """
Get the schema information for a table Get the schema information for a table
@@ -434,12 +436,22 @@ class MetadataExtractor:
logger.warning("Database name not specified") logger.warning("Database name not specified")
return {} return {}
# SECURITY FIX: Validate identifiers to prevent SQL injection
try:
validate_identifier(table_name, "table name")
validate_identifier(db_name, "database name")
if effective_catalog:
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected in get_table_schema: {e}")
return {}
cache_key = f"schema_{effective_catalog or 'default'}_{db_name}_{table_name}" cache_key = f"schema_{effective_catalog or 'default'}_{db_name}_{table_name}"
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl: if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
return self.metadata_cache[cache_key] return self.metadata_cache[cache_key]
try: try:
# Use information_schema.columns table to get table schema # Use information_schema.columns table to get table schema (async)
query = f""" query = f"""
SELECT SELECT
COLUMN_NAME, COLUMN_NAME,
@@ -459,7 +471,7 @@ class MetadataExtractor:
ORDINAL_POSITION ORDINAL_POSITION
""" """
result = self._execute_query_with_catalog(query, db_name, effective_catalog) result = await self._execute_query_with_catalog_async(query, db_name, effective_catalog)
if not result: if not result:
logger.warning(f"Table {effective_catalog or 'default'}.{db_name}.{table_name} does not exist or has no columns") logger.warning(f"Table {effective_catalog or 'default'}.{db_name}.{table_name} does not exist or has no columns")
@@ -468,7 +480,6 @@ class MetadataExtractor:
# Create structured table schema information # Create structured table schema information
columns = [] columns = []
for col in result: for col in result:
# Ensure using actual column values, not column names
column_info = { column_info = {
"name": col.get("COLUMN_NAME", ""), "name": col.get("COLUMN_NAME", ""),
"type": col.get("DATA_TYPE", ""), "type": col.get("DATA_TYPE", ""),
@@ -481,8 +492,8 @@ class MetadataExtractor:
} }
columns.append(column_info) columns.append(column_info)
# Get table comment # Get table comment (async)
table_comment = self.get_table_comment(table_name, db_name, effective_catalog) table_comment = await self.get_table_comment_async(table_name, db_name, effective_catalog)
# Build complete structure # Build complete structure
schema = { schema = {
@@ -493,7 +504,7 @@ class MetadataExtractor:
"create_time": datetime.now().isoformat() "create_time": datetime.now().isoformat()
} }
# Get table type information # Get table type information (async)
try: try:
table_type_query = f""" table_type_query = f"""
SELECT SELECT
@@ -505,7 +516,7 @@ class MetadataExtractor:
TABLE_SCHEMA = '{db_name}' TABLE_SCHEMA = '{db_name}'
AND TABLE_NAME = '{table_name}' AND TABLE_NAME = '{table_name}'
""" """
table_type_result = self._execute_query(table_type_query) table_type_result = await self._execute_query_async(table_type_query)
if table_type_result: if table_type_result:
schema["table_type"] = table_type_result[0].get("TABLE_TYPE", "") schema["table_type"] = table_type_result[0].get("TABLE_TYPE", "")
schema["engine"] = table_type_result[0].get("ENGINE", "") schema["engine"] = table_type_result[0].get("ENGINE", "")
@@ -521,6 +532,7 @@ class MetadataExtractor:
logger.error(f"Error getting table schema: {str(e)}") logger.error(f"Error getting table schema: {str(e)}")
return {} return {}
# Deprecated: sync method (kept for compatibility, will be removed)
def get_table_comment(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> str: def get_table_comment(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> str:
""" """
Get the comment for a table Get the comment for a table
@@ -539,6 +551,16 @@ class MetadataExtractor:
logger.warning("Database name not specified") logger.warning("Database name not specified")
return "" return ""
# SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
validate_identifier(db_name, "database name")
if effective_catalog:
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return ""
cache_key = f"table_comment_{effective_catalog or 'default'}_{db_name}_{table_name}" cache_key = f"table_comment_{effective_catalog or 'default'}_{db_name}_{table_name}"
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl: if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
return self.metadata_cache[cache_key] return self.metadata_cache[cache_key]
@@ -571,6 +593,7 @@ class MetadataExtractor:
logger.error(f"Error getting table comment: {str(e)}") logger.error(f"Error getting table comment: {str(e)}")
return "" return ""
# Deprecated: sync method (kept for compatibility, will be removed)
def get_column_comments(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> Dict[str, str]: def get_column_comments(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> Dict[str, str]:
""" """
Get comments for all columns in a table Get comments for all columns in a table
@@ -589,6 +612,16 @@ class MetadataExtractor:
logger.warning("Database name not specified") logger.warning("Database name not specified")
return {} return {}
# SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
validate_identifier(db_name, "database name")
if effective_catalog:
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return {}
cache_key = f"column_comments_{effective_catalog or 'default'}_{db_name}_{table_name}" cache_key = f"column_comments_{effective_catalog or 'default'}_{db_name}_{table_name}"
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl: if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
return self.metadata_cache[cache_key] return self.metadata_cache[cache_key]
@@ -626,6 +659,7 @@ class MetadataExtractor:
logger.error(f"Error getting column comments: {str(e)}") logger.error(f"Error getting column comments: {str(e)}")
return {} return {}
# Deprecated: sync method (kept for compatibility, will be removed)
def get_table_indexes(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> List[Dict[str, Any]]: def get_table_indexes(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> List[Dict[str, Any]]:
""" """
Get the index information for a table Get the index information for a table
@@ -644,64 +678,62 @@ class MetadataExtractor:
logger.error("Database name not specified") logger.error("Database name not specified")
return [] return []
# SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
validate_identifier(db_name, "database name")
if effective_catalog:
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
cache_key = f"indexes_{effective_catalog or 'default'}_{db_name}_{table_name}" cache_key = f"indexes_{effective_catalog or 'default'}_{db_name}_{table_name}"
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl: if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
return self.metadata_cache[cache_key] return self.metadata_cache[cache_key]
try: try:
# Build query with catalog prefix if specified # Build query with catalog prefix if specified (identifiers already validated)
safe_table = quote_identifier(table_name, "table name")
safe_db = quote_identifier(db_name, "database name")
if effective_catalog: if effective_catalog:
query = f"SHOW INDEX FROM `{effective_catalog}`.`{db_name}`.`{table_name}`" safe_catalog = quote_identifier(effective_catalog, "catalog name")
query = f"SHOW INDEX FROM {safe_catalog}.{safe_db}.{safe_table}"
logger.info(f"Using three-part naming for index query: {query}") logger.info(f"Using three-part naming for index query: {query}")
else: else:
query = f"SHOW INDEX FROM `{db_name}`.`{table_name}`" query = f"SHOW INDEX FROM {safe_db}.{safe_table}"
try: try:
df = self._execute_query(query, return_dataframe=True) # NOTE: Deprecated sync path retained for compatibility; use async variant instead.
# Deprecated sync path removed; return empty indexes on failure
# Process results result = []
indexes = [] indexes = []
current_index = None current_index = None
if result:
if not df.empty: for r in result:
for _, row in df.iterrows():
try: try:
index_name = row['Key_name'] index_name = r.get('Key_name')
column_name = row['Column_name'] column_name = r.get('Column_name')
if current_index is None or current_index.get('name') != index_name:
if current_index is None or current_index['name'] != index_name:
if current_index is not None: if current_index is not None:
indexes.append(current_index) indexes.append(current_index)
current_index = { current_index = {
'name': index_name, 'name': index_name,
'columns': [column_name], 'columns': [column_name] if column_name else [],
'unique': row['Non_unique'] == 0, 'unique': r.get('Non_unique', 1) == 0,
'type': row['Index_type'] 'type': r.get('Index_type', '')
} }
else: else:
if column_name:
current_index['columns'].append(column_name) current_index['columns'].append(column_name)
except Exception as row_error: except Exception as row_error:
logger.warning(f"Failed to process index row data: {row_error}") logger.warning(f"Failed to process index row data: {row_error}")
continue continue
if current_index is not None: if current_index is not None:
indexes.append(current_index) indexes.append(current_index)
except Exception as df_error: except Exception as df_error:
logger.warning(f"DataFrame processing failed, trying regular query: {df_error}") logger.warning(f"Sync index query (deprecated) failed: {df_error}")
# Fall back to regular query
result = self._execute_query(query, return_dataframe=False)
indexes = [] indexes = []
if result:
# Simple processing, no complex index grouping
for row in result:
if isinstance(row, dict):
indexes.append({
'name': row.get('Key_name', ''),
'columns': [row.get('Column_name', '')],
'unique': row.get('Non_unique', 1) == 0,
'type': row.get('Index_type', '')
})
# Update cache # Update cache
self.metadata_cache[cache_key] = indexes self.metadata_cache[cache_key] = indexes
@@ -712,7 +744,7 @@ class MetadataExtractor:
logger.error(f"Error getting index information: {str(e)}") logger.error(f"Error getting index information: {str(e)}")
return [] return []
def get_table_relationships(self) -> List[Dict[str, Any]]: async def get_table_relationships(self) -> List[Dict[str, Any]]:
""" """
Infer table relationships from table comments and naming patterns Infer table relationships from table comments and naming patterns
@@ -725,13 +757,13 @@ class MetadataExtractor:
try: try:
# Get all tables # Get all tables
tables = self.get_database_tables(self.db_name) tables = await self.get_database_tables_async(self.db_name)
relationships = [] relationships = []
# Simple foreign key naming convention detection # Simple foreign key naming convention detection
# Example: If a table has a column named xxx_id and another table named xxx exists, it might be a foreign key relationship # Example: If a table has a column named xxx_id and another table named xxx exists, it might be a foreign key relationship
for table_name in tables: for table_name in tables:
schema = self.get_table_schema(table_name, self.db_name) schema = await self.get_table_schema(table_name, self.db_name)
columns = schema.get("columns", []) columns = schema.get("columns", [])
for column in columns: for column in columns:
@@ -743,7 +775,7 @@ class MetadataExtractor:
# Check if the possible table exists # Check if the possible table exists
if ref_table_name in tables: if ref_table_name in tables:
# Find possible primary key column # Find possible primary key column
ref_schema = self.get_table_schema(ref_table_name, self.db_name) ref_schema = await self.get_table_schema(ref_table_name, self.db_name)
ref_columns = ref_schema.get("columns", []) ref_columns = ref_schema.get("columns", [])
# Assume primary key column name is id # Assume primary key column name is id
@@ -766,6 +798,7 @@ class MetadataExtractor:
logger.error(f"Error inferring table relationships: {str(e)}") logger.error(f"Error inferring table relationships: {str(e)}")
return [] return []
# Deprecated: sync method (kept for compatibility, will be removed)
def get_recent_audit_logs(self, days: int = 7, limit: int = 100) -> pd.DataFrame: def get_recent_audit_logs(self, days: int = 7, limit: int = 100) -> pd.DataFrame:
""" """
Get recent audit logs Get recent audit logs
@@ -792,13 +825,14 @@ class MetadataExtractor:
ORDER BY time DESC ORDER BY time DESC
LIMIT {limit} LIMIT {limit}
""" """
df = self._execute_query(query, return_dataframe=True) # Deprecated sync path removed; this method is deprecated overall
df = pd.DataFrame()
return df return df
except Exception as e: except Exception as e:
logger.error(f"Error getting audit logs: {str(e)}") logger.error(f"Error getting audit logs: {str(e)}")
return pd.DataFrame() return pd.DataFrame()
def get_catalog_list(self) -> List[Dict[str, Any]]: async def get_catalog_list(self) -> List[Dict[str, Any]]:
""" """
Get a list of all catalogs in Doris with detailed information Get a list of all catalogs in Doris with detailed information
@@ -812,7 +846,7 @@ class MetadataExtractor:
try: try:
# Use SHOW CATALOGS command to get catalog list # Use SHOW CATALOGS command to get catalog list
query = "SHOW CATALOGS" query = "SHOW CATALOGS"
result = self._execute_query(query) result = await self._execute_query_async(query)
if not result: if not result:
catalogs = [] catalogs = []
@@ -1101,7 +1135,8 @@ class MetadataExtractor:
AND TABLE_NAME = '{table_name}' AND TABLE_NAME = '{table_name}'
""" """
partitions = self._execute_query(query) # Deprecated sync path removed
partitions = []
if not partitions: if not partitions:
return {} return {}
@@ -1124,31 +1159,25 @@ class MetadataExtractor:
logger.error(f"Error getting partition information for table {db_name}.{table_name}: {str(e)}") logger.error(f"Error getting partition information for table {db_name}.{table_name}: {str(e)}")
return {} return {}
def _execute_query_with_catalog(self, query: str, db_name: str = None, catalog_name: str = None): # Removed sync _execute_query_with_catalog; use async variant instead
async def _execute_query_with_catalog_async(self, query: str, db_name: str = None, catalog_name: str = None):
""" """
Execute query with catalog-aware metadata operations using three-part naming Async version of _execute_query_with_catalog to avoid cross-event-loop issues.
Args: When catalog_name is provided and the SQL targets information_schema, we rewrite
query: SQL query to execute the SQL to use three-part naming: `{catalog}.information_schema` and execute it
db_name: Database name to use via the same running event loop.
catalog_name: Catalog name for three-part naming
Returns:
Query result
""" """
try: try:
# If catalog_name is specified, modify the query to use three-part naming
# for information_schema queries
if catalog_name and 'information_schema' in query.lower(): if catalog_name and 'information_schema' in query.lower():
# Replace 'information_schema' with 'catalog_name.information_schema'
modified_query = query.replace('information_schema', f'{catalog_name}.information_schema') modified_query = query.replace('information_schema', f'{catalog_name}.information_schema')
logger.info(f"Modified query for catalog {catalog_name}: {modified_query}") logger.info(f"Modified query for catalog {catalog_name}: {modified_query}")
return self._execute_query(modified_query, db_name) return await self._execute_query_async(modified_query, db_name)
else: else:
# Execute the original query return await self._execute_query_async(query, db_name)
return self._execute_query(query, db_name)
except Exception as e: except Exception as e:
logger.error(f"Error executing query with catalog: {str(e)}") logger.error(f"Error executing async query with catalog: {str(e)}")
raise raise
async def _execute_query_async(self, query: str, db_name: str = None, return_dataframe: bool = False): async def _execute_query_async(self, query: str, db_name: str = None, return_dataframe: bool = False):
@@ -1165,8 +1194,17 @@ class MetadataExtractor:
""" """
try: try:
if self.connection_manager: if self.connection_manager:
# FIX: Get auth_context from global ContextVar for token-bound database configuration
# This ensures all query methods use the correct user's connection pool
auth_context = None
try:
from .security import mcp_auth_context_var
auth_context = mcp_auth_context_var.get()
except Exception:
pass
# Use the injected connection manager directly (async) # Use the injected connection manager directly (async)
result = await self.connection_manager.execute_query(self._session_id, query, None) result = await self.connection_manager.execute_query(self._session_id, query, None, auth_context)
# Extract data from QueryResult # Extract data from QueryResult
if hasattr(result, 'data'): if hasattr(result, 'data'):
@@ -1200,76 +1238,35 @@ class MetadataExtractor:
else: else:
return [] return []
def _execute_query(self, query: str, db_name: str = None, return_dataframe: bool = False): # Removed sync _execute_query; use async methods exclusively
"""
Execute database query with proper session management (sync wrapper)
Args:
query: SQL query to execute
db_name: Database name to use (optional)
return_dataframe: Whether to return a pandas DataFrame instead of list
Returns:
Query result data (list of dictionaries or pandas DataFrame)
"""
try:
if self.connection_manager:
import asyncio
# Try to run the async query
try:
# Check if there's a running event loop
loop = asyncio.get_running_loop()
# If we're in an async context, we need to run in a separate thread
import concurrent.futures
def run_in_new_loop():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(
self._execute_query_async(query, db_name, return_dataframe)
)
finally:
new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_new_loop)
return future.result(timeout=30)
except RuntimeError:
# No running loop, we can safely create one
return asyncio.run(
self._execute_query_async(query, db_name, return_dataframe)
)
else:
# Fallback: Return empty result
logger.warning("No connection manager provided, returning empty result")
if return_dataframe:
import pandas as pd
return pd.DataFrame()
else:
return []
except Exception as e:
logger.error(f"Error executing query: {str(e)}")
# Return empty result instead of raising exception to prevent cascade failures
if return_dataframe:
import pandas as pd
return pd.DataFrame()
else:
return []
async def get_table_schema_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> List[Dict[str, Any]]: async def get_table_schema_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> List[Dict[str, Any]]:
"""Asynchronously get table schema information""" """Asynchronously get table schema information"""
try: try:
# Use async query method # Use async query method
effective_catalog = catalog_name or self.catalog_name effective_catalog = catalog_name or self.catalog_name
effective_db = db_name or self.db_name
# Build query statement # SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
if effective_db:
validate_identifier(effective_db, "database name")
if effective_catalog and effective_catalog != "internal": if effective_catalog and effective_catalog != "internal":
query = f"DESCRIBE `{effective_catalog}`.`{db_name or self.db_name}`.`{table_name}`" validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
# Build query statement using safe identifiers
safe_table = quote_identifier(table_name, "table name")
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
if effective_catalog and effective_catalog != "internal":
safe_catalog = quote_identifier(effective_catalog, "catalog name")
query = f"DESCRIBE {safe_catalog}.{safe_db}.{safe_table}"
else: else:
query = f"DESCRIBE `{db_name or self.db_name}`.`{table_name}`" query = f"DESCRIBE {safe_db}.{safe_table}"
# Execute async query # Execute async query
result = await self._execute_query_async(query, db_name) result = await self._execute_query_async(query, db_name)
@@ -1302,8 +1299,15 @@ class MetadataExtractor:
try: try:
effective_catalog = catalog_name or self.catalog_name effective_catalog = catalog_name or self.catalog_name
# SECURITY FIX: Validate catalog name if provided
if effective_catalog and effective_catalog != "internal": if effective_catalog and effective_catalog != "internal":
query = f"SHOW DATABASES FROM `{effective_catalog}`" try:
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid catalog name rejected: {e}")
return []
safe_catalog = quote_identifier(effective_catalog, "catalog name")
query = f"SHOW DATABASES FROM {safe_catalog}"
else: else:
query = "SHOW DATABASES" query = "SHOW DATABASES"
@@ -1333,10 +1337,23 @@ class MetadataExtractor:
effective_catalog = catalog_name or self.catalog_name effective_catalog = catalog_name or self.catalog_name
effective_db = db_name or self.db_name effective_db = db_name or self.db_name
# SECURITY FIX: Validate identifiers
try:
if effective_db:
validate_identifier(effective_db, "database name")
if effective_catalog and effective_catalog != "internal": if effective_catalog and effective_catalog != "internal":
query = f"SHOW TABLES FROM `{effective_catalog}`.`{effective_db}`" validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
if effective_catalog and effective_catalog != "internal":
safe_catalog = quote_identifier(effective_catalog, "catalog name")
query = f"SHOW TABLES FROM {safe_catalog}.{safe_db}"
else: else:
query = f"SHOW TABLES FROM `{effective_db}`" query = f"SHOW TABLES FROM {safe_db}"
result = await self._execute_query_async(query, effective_db) result = await self._execute_query_async(query, effective_db)
@@ -1389,6 +1406,162 @@ class MetadataExtractor:
logger.error(f"Failed to get catalog list: {e}") logger.error(f"Failed to get catalog list: {e}")
return [] return []
async def get_table_comment_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> str:
"""Async version: get the comment for a table."""
try:
effective_db = db_name or self.db_name
effective_catalog = catalog_name or self.catalog_name
# SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
if effective_db:
validate_identifier(effective_db, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return ""
query = f"""
SELECT
TABLE_COMMENT
FROM
information_schema.tables
WHERE
TABLE_SCHEMA = '{effective_db}'
AND TABLE_NAME = '{table_name}'
"""
result = await self._execute_query_with_catalog_async(query, effective_db, effective_catalog)
if not result or not result[0]:
return ""
return result[0].get("TABLE_COMMENT", "") or ""
except Exception as e:
logger.error(f"Failed to get table comment asynchronously: {e}")
return ""
async def get_column_comments_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, str]:
"""Async version: get comments for all columns in a table."""
try:
effective_db = db_name or self.db_name
effective_catalog = catalog_name or self.catalog_name
# SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
if effective_db:
validate_identifier(effective_db, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return {}
query = f"""
SELECT
COLUMN_NAME,
COLUMN_COMMENT
FROM
information_schema.columns
WHERE
TABLE_SCHEMA = '{effective_db}'
AND TABLE_NAME = '{table_name}'
ORDER BY
ORDINAL_POSITION
"""
rows = await self._execute_query_with_catalog_async(query, effective_db, effective_catalog)
comments: Dict[str, str] = {}
for col in rows or []:
name = col.get("COLUMN_NAME", "")
if name:
comments[name] = col.get("COLUMN_COMMENT", "") or ""
return comments
except Exception as e:
logger.error(f"Failed to get column comments asynchronously: {e}")
return {}
async def get_table_indexes_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> List[Dict[str, Any]]:
"""Async version: get index information for a table."""
try:
effective_db = db_name or self.db_name
effective_catalog = catalog_name or self.catalog_name
# SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
if effective_db:
validate_identifier(effective_db, "database name")
if effective_catalog:
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
# Build query with catalog prefix if specified (using safe identifiers)
safe_table = quote_identifier(table_name, "table name")
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
if effective_catalog:
safe_catalog = quote_identifier(effective_catalog, "catalog name")
query = f"SHOW INDEX FROM {safe_catalog}.{safe_db}.{safe_table}"
logger.info(f"Using three-part naming for async index query: {query}")
else:
query = f"SHOW INDEX FROM {safe_db}.{safe_table}"
rows = await self._execute_query_async(query, effective_db)
indexes: List[Dict[str, Any]] = []
if rows:
# Group by Key_name
current_index: Dict[str, Any] | None = None
for r in rows:
try:
index_name = r.get('Key_name')
column_name = r.get('Column_name')
if current_index is None or current_index.get('name') != index_name:
if current_index is not None:
indexes.append(current_index)
current_index = {
'name': index_name,
'columns': [column_name] if column_name else [],
'unique': r.get('Non_unique', 1) == 0,
'type': r.get('Index_type', '')
}
else:
if column_name:
current_index['columns'].append(column_name)
except Exception as row_error:
logger.warning(f"Failed to process async index row data: {row_error}")
continue
if current_index is not None:
indexes.append(current_index)
return indexes
except Exception as e:
logger.error(f"Error getting index information asynchronously: {str(e)}")
return []
async def get_recent_audit_logs_async(self, days: int = 7, limit: int = 100):
"""Async version: get recent audit logs and return a pandas DataFrame."""
try:
start_date = (datetime.now() - timedelta(days=days)).strftime('%Y-%m-%d')
query = f"""
SELECT client_ip, user, db, time, stmt_id, stmt, state, error_code
FROM `__internal_schema`.`audit_log`
WHERE `time` >= '{start_date}'
AND state = 'EOF' AND error_code = 0
AND `stmt` NOT LIKE 'SHOW%'
AND `stmt` NOT LIKE 'DESC%'
AND `stmt` NOT LIKE 'EXPLAIN%'
AND `stmt` NOT LIKE 'SELECT 1%'
ORDER BY time DESC
LIMIT {limit}
"""
rows = await self._execute_query_async(query)
import pandas as pd
return pd.DataFrame(rows or [])
except Exception as e:
logger.error(f"Error getting audit logs asynchronously: {str(e)}")
import pandas as pd
return pd.DataFrame()
# ==================== Business layer methods (original metadata_tools.py functionality) ==================== # ==================== Business layer methods (original metadata_tools.py functionality) ====================
def _format_response(self, success: bool, result: Any = None, error: str = None, message: str = "") -> Dict[str, Any]: def _format_response(self, success: bool, result: Any = None, error: str = None, message: str = "") -> Dict[str, Any]:
@@ -1417,6 +1590,9 @@ class MetadataExtractor:
""" """
Execute SQL query and return results, supports catalog federation queries Execute SQL query and return results, supports catalog federation queries
Unified interface for MCP tools Unified interface for MCP tools
FIX for Issue #62 Bug 1: Now retrieves auth_context from context variable to support token-bound database configuration
FIX for Issue #62 Bug 3: Now uses db_name and catalog_name parameters to switch database context
""" """
logger.info(f"Executing SQL query: {sql}, DB: {db_name}, Catalog: {catalog_name}, MaxRows: {max_rows}, Timeout: {timeout}") logger.info(f"Executing SQL query: {sql}, DB: {db_name}, Catalog: {catalog_name}, MaxRows: {max_rows}, Timeout: {timeout}")
@@ -1424,15 +1600,86 @@ class MetadataExtractor:
if not sql: if not sql:
return self._format_response(success=False, error="No SQL statement provided", message="Please provide SQL statement to execute") return self._format_response(success=False, error="No SQL statement provided", message="Please provide SQL statement to execute")
# FIX for Issue #62 Bug 3: Build context switching SQL if db_name or catalog_name is specified
# SECURITY FIX: Validate catalog_name and db_name to prevent SQL injection
final_sql = sql
if catalog_name or db_name:
context_statements = []
# Validate and sanitize catalog_name
if catalog_name:
try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid catalog name rejected: {e}")
return self._format_response(
success=False,
error=f"Invalid catalog name: {catalog_name}",
message="Catalog name contains invalid characters"
)
# Use quote_identifier to safely escape the catalog name
safe_catalog = quote_identifier(catalog_name, "catalog name")
context_statements.append(f"USE CATALOG {safe_catalog}")
logger.debug(f"Switching to catalog: {catalog_name}")
# Validate and sanitize db_name
if db_name:
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid database name rejected: {e}")
return self._format_response(
success=False,
error=f"Invalid database name: {db_name}",
message="Database name contains invalid characters"
)
# Use quote_identifier to safely escape the database name
safe_db = quote_identifier(db_name, "database name")
if catalog_name:
safe_catalog = quote_identifier(catalog_name, "catalog name")
context_statements.append(f"USE {safe_catalog}.{safe_db}")
else:
context_statements.append(f"USE {safe_db}")
logger.debug(f"Switching to database: {db_name}")
# Combine context switching with original SQL
if context_statements:
# Remove trailing semicolon from context statements if present
context_sql = "; ".join(context_statements)
# Ensure original SQL doesn't start with semicolon
sql_clean = sql.lstrip(";").strip()
final_sql = f"{context_sql}; {sql_clean}"
logger.debug(f"Modified SQL with context switching: {final_sql[:200]}...")
# FIX: Try to get auth_context from context variable (set by HTTP middleware)
# This allows token-bound database configuration to work
# CRITICAL: Use the global ContextVar from security.py to ensure same instance is used everywhere
auth_context = None
try:
from .security import mcp_auth_context_var
# Get auth_context from the global context variable
# This will be set by the HTTP request handler in main.py
auth_context = mcp_auth_context_var.get()
if auth_context:
logger.debug(f"Retrieved auth_context from context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
else:
logger.debug("No auth_context found in context variable, using default")
except Exception as ctx_error:
logger.debug(f"Could not retrieve auth_context from context variable: {ctx_error}")
auth_context = None
# Import query executor # Import query executor
from .query_executor import execute_sql_query from .query_executor import execute_sql_query
# Call execute_sql_query to execute query # Call execute_sql_query to execute query with auth_context
exec_result = await execute_sql_query( exec_result = await execute_sql_query(
sql=sql, sql=final_sql, # Use modified SQL with context switching
connection_manager=self.connection_manager, connection_manager=self.connection_manager,
limit=max_rows, limit=max_rows,
timeout=timeout timeout=timeout,
auth_context=auth_context # FIX: Pass auth_context with token
) )
return exec_result return exec_result
@@ -1453,6 +1700,36 @@ class MetadataExtractor:
if not table_name: if not table_name:
return self._format_response(success=False, error="Missing table_name parameter") return self._format_response(success=False, error="Missing table_name parameter")
# SECURITY: Validate identifiers before processing
try:
validate_identifier(table_name, "table name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid table name: {table_name}",
message="Table name contains invalid characters"
)
if db_name:
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid database name: {db_name}",
message="Database name contains invalid characters"
)
if catalog_name and catalog_name != "internal":
try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid catalog name: {catalog_name}",
message="Catalog name contains invalid characters"
)
try: try:
schema = await self.get_table_schema_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name) schema = await self.get_table_schema_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
@@ -1476,6 +1753,27 @@ class MetadataExtractor:
"""Get list of all table names in specified database - MCP interface""" """Get list of all table names in specified database - MCP interface"""
logger.info(f"Getting database table list: DB: {db_name}, Catalog: {catalog_name}") logger.info(f"Getting database table list: DB: {db_name}, Catalog: {catalog_name}")
# SECURITY: Validate identifiers
if db_name:
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid database name: {db_name}",
message="Database name contains invalid characters"
)
if catalog_name and catalog_name != "internal":
try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid catalog name: {catalog_name}",
message="Catalog name contains invalid characters"
)
try: try:
tables = await self.get_database_tables_async(db_name=db_name, catalog_name=catalog_name) tables = await self.get_database_tables_async(db_name=db_name, catalog_name=catalog_name)
return self._format_response(success=True, result=tables) return self._format_response(success=True, result=tables)
@@ -1506,8 +1804,38 @@ class MetadataExtractor:
if not table_name: if not table_name:
return self._format_response(success=False, error="Missing table_name parameter") return self._format_response(success=False, error="Missing table_name parameter")
# SECURITY: Validate identifiers
try: try:
comment = self.get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name) validate_identifier(table_name, "table name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid table name: {table_name}",
message="Table name contains invalid characters"
)
if db_name:
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid database name: {db_name}",
message="Database name contains invalid characters"
)
if catalog_name and catalog_name != "internal":
try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid catalog name: {catalog_name}",
message="Catalog name contains invalid characters"
)
try:
comment = await self.get_table_comment_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
return self._format_response(success=True, result=comment) return self._format_response(success=True, result=comment)
except Exception as e: except Exception as e:
logger.error(f"Failed to get table comment: {str(e)}", exc_info=True) logger.error(f"Failed to get table comment: {str(e)}", exc_info=True)
@@ -1525,8 +1853,38 @@ class MetadataExtractor:
if not table_name: if not table_name:
return self._format_response(success=False, error="Missing table_name parameter") return self._format_response(success=False, error="Missing table_name parameter")
# SECURITY: Validate identifiers
try: try:
comments = self.get_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name) validate_identifier(table_name, "table name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid table name: {table_name}",
message="Table name contains invalid characters"
)
if db_name:
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid database name: {db_name}",
message="Database name contains invalid characters"
)
if catalog_name and catalog_name != "internal":
try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid catalog name: {catalog_name}",
message="Catalog name contains invalid characters"
)
try:
comments = await self.get_column_comments_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
return self._format_response(success=True, result=comments) return self._format_response(success=True, result=comments)
except Exception as e: except Exception as e:
logger.error(f"Failed to get table column comments: {str(e)}", exc_info=True) logger.error(f"Failed to get table column comments: {str(e)}", exc_info=True)
@@ -1544,8 +1902,38 @@ class MetadataExtractor:
if not table_name: if not table_name:
return self._format_response(success=False, error="Missing table_name parameter") return self._format_response(success=False, error="Missing table_name parameter")
# SECURITY: Validate identifiers
try: try:
indexes = self.get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name) validate_identifier(table_name, "table name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid table name: {table_name}",
message="Table name contains invalid characters"
)
if db_name:
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid database name: {db_name}",
message="Database name contains invalid characters"
)
if catalog_name and catalog_name != "internal":
try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid catalog name: {catalog_name}",
message="Catalog name contains invalid characters"
)
try:
indexes = await self.get_table_indexes_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
return self._format_response(success=True, result=indexes) return self._format_response(success=True, result=indexes)
except Exception as e: except Exception as e:
logger.error(f"Failed to get table indexes: {str(e)}", exc_info=True) logger.error(f"Failed to get table indexes: {str(e)}", exc_info=True)
@@ -1569,7 +1957,7 @@ class MetadataExtractor:
logger.info(f"Getting audit logs: Days: {days}, Limit: {limit}") logger.info(f"Getting audit logs: Days: {days}, Limit: {limit}")
try: try:
logs_df = self.get_recent_audit_logs(days=days, limit=limit) logs_df = await self.get_recent_audit_logs_async(days=days, limit=limit)
# Convert DataFrame to JSON format # Convert DataFrame to JSON format
if hasattr(logs_df, 'to_dict'): if hasattr(logs_df, 'to_dict'):

View File

@@ -20,18 +20,25 @@ Doris Security Management Module
Implements enterprise-level authentication, authorization, SQL security validation and data masking functionality Implements enterprise-level authentication, authorization, SQL security validation and data masking functionality
""" """
import hashlib
import logging import logging
import re import re
from dataclasses import dataclass from contextvars import ContextVar
from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any from typing import Any, Optional
import sqlparse import sqlparse
from sqlparse.sql import Statement from sqlparse.sql import Statement
from sqlparse.tokens import Keyword, Name from sqlparse.tokens import Keyword, Name
from .logger import get_logger
from .config import DatabaseConfig
# Global ContextVar for auth_context - must be a single instance shared across all modules
# This allows token-bound database configuration to work correctly in concurrent requests
mcp_auth_context_var: ContextVar['AuthContext'] = ContextVar('mcp_auth_context', default=None)
class SecurityLevel(Enum): class SecurityLevel(Enum):
"""Security level enumeration""" """Security level enumeration"""
@@ -44,15 +51,18 @@ class SecurityLevel(Enum):
@dataclass @dataclass
class AuthContext: class AuthContext:
"""Authentication context""" """Authentication context for audit and session tracking"""
user_id: str token_id: str = "" # Token identifier for audit logging
roles: list[str] user_id: str = "" # User identifier
permissions: list[str] roles: list[str] = field(default_factory=list) # User roles
session_id: str permissions: list[str] = field(default_factory=list) # User permissions
login_time: datetime | None = None security_level: 'SecurityLevel' = field(default_factory=lambda: SecurityLevel.INTERNAL) # Security level
client_ip: str = "unknown" # Client IP address
session_id: str = "" # Session identifier
login_time: datetime = field(default_factory=datetime.utcnow)
last_activity: datetime | None = None last_activity: datetime | None = None
security_level: SecurityLevel = SecurityLevel.INTERNAL token: str = "" # Raw token for token-bound database configuration
@dataclass @dataclass
@@ -85,12 +95,13 @@ class DorisSecurityManager:
Provides complete security control functionality, including authentication, authorization, SQL security validation and data masking Provides complete security control functionality, including authentication, authorization, SQL security validation and data masking
""" """
def __init__(self, config): def __init__(self, config, connection_manager=None):
self.config = config self.config = config
self.logger = logging.getLogger(__name__) self.logger = get_logger(__name__)
self.connection_manager = connection_manager
# Initialize security components # Initialize security components
self.auth_provider = AuthenticationProvider(config) self.auth_provider = AuthenticationProvider(config, self)
self.authz_provider = AuthorizationProvider(config) self.authz_provider = AuthorizationProvider(config)
self.sql_validator = SQLSecurityValidator(config) self.sql_validator = SQLSecurityValidator(config)
self.masking_processor = DataMaskingProcessor(config) self.masking_processor = DataMaskingProcessor(config)
@@ -100,31 +111,55 @@ class DorisSecurityManager:
self.sensitive_tables = self._load_sensitive_tables() self.sensitive_tables = self._load_sensitive_tables()
self.masking_rules = self._load_masking_rules() self.masking_rules = self._load_masking_rules()
# Track initialization state
self._initialized = False
async def initialize(self):
"""Initialize security manager components"""
if self._initialized:
return
try:
# Initialize authentication provider (for JWT setup)
await self.auth_provider.initialize()
self._initialized = True
self.logger.info("DorisSecurityManager initialized successfully")
except Exception as e:
self.logger.error(f"Failed to initialize DorisSecurityManager: {e}")
raise
async def shutdown(self):
"""Shutdown security manager components"""
try:
await self.auth_provider.shutdown()
self._initialized = False
self.logger.info("DorisSecurityManager shutdown completed")
except Exception as e:
self.logger.error(f"Error during DorisSecurityManager shutdown: {e}")
raise
def _load_blocked_keywords(self) -> set[str]: def _load_blocked_keywords(self) -> set[str]:
"""Load blocked SQL keywords""" """Load blocked SQL keywords from configuration"""
default_blocked = { # Load keywords from configuration, unified source of truth
"DROP",
"DELETE",
"TRUNCATE",
"ALTER",
"CREATE",
"INSERT",
"UPDATE",
"GRANT",
"REVOKE",
"EXEC",
"EXECUTE",
"SHUTDOWN",
"KILL",
}
# Load custom rules from configuration file
if hasattr(self.config, 'get'): if hasattr(self.config, 'get'):
custom_blocked = set(self.config.get("blocked_keywords", [])) # Dictionary-style configuration
blocked_keywords = self.config.get("blocked_keywords", [])
elif hasattr(self.config, 'security') and hasattr(self.config.security, 'blocked_keywords'):
# DorisConfig object, get through security.blocked_keywords
blocked_keywords = self.config.security.blocked_keywords
else: else:
custom_blocked = set() # Fallback to default if no configuration available
blocked_keywords = [
"DROP", "CREATE", "ALTER", "TRUNCATE",
"DELETE", "INSERT", "UPDATE",
"GRANT", "REVOKE",
"EXEC", "EXECUTE", "SHUTDOWN", "KILL"
]
return default_blocked.union(custom_blocked) return set(blocked_keywords)
def _load_sensitive_tables(self) -> dict[str, SecurityLevel]: def _load_sensitive_tables(self) -> dict[str, SecurityLevel]:
"""Load sensitive table configuration""" """Load sensitive table configuration"""
@@ -189,8 +224,59 @@ class DorisSecurityManager:
return default_rules return default_rules
async def authenticate_request(self, auth_info: dict[str, Any]) -> AuthContext: async def authenticate_request(self, auth_info: dict[str, Any]) -> AuthContext:
"""Validate request authentication information""" """Validate request authentication information
return await self.auth_provider.authenticate(auth_info)
Tries authentication methods in order: Token -> JWT -> OAuth
Any one method succeeding allows access
If all methods are disabled, returns anonymous context
"""
# Check if any authentication method is enabled
if not (self.config.security.enable_token_auth or
self.config.security.enable_jwt_auth or
self.config.security.enable_oauth_auth):
self.logger.debug("All authentication methods are disabled")
# Return anonymous context when no authentication is enabled
return AuthContext(
token_id="anonymous",
user_id="anonymous",
roles=["anonymous"],
permissions=["read"],
security_level=SecurityLevel.PUBLIC,
client_ip=auth_info.get("client_ip", "unknown"),
session_id="anonymous_session"
)
# Try authentication methods in order of preference
last_error = None
# 1. Try Token authentication first (most common)
if self.config.security.enable_token_auth:
try:
return await self.auth_provider.authenticate_token(auth_info)
except Exception as e:
self.logger.debug(f"Token authentication failed: {e}")
last_error = e
# 2. Try JWT authentication
if self.config.security.enable_jwt_auth:
try:
return await self.auth_provider.authenticate_jwt(auth_info)
except Exception as e:
self.logger.debug(f"JWT authentication failed: {e}")
last_error = e
# 3. Try OAuth authentication
if self.config.security.enable_oauth_auth:
try:
return await self.auth_provider.authenticate_oauth(auth_info)
except Exception as e:
self.logger.debug(f"OAuth authentication failed: {e}")
last_error = e
# All enabled authentication methods failed
error_message = f"Authentication failed: {str(last_error)}" if last_error else "No authentication method succeeded"
self.logger.warning(f"Authentication failed for client {auth_info.get('client_ip', 'unknown')}: {error_message}")
raise ValueError(error_message)
async def authorize_resource_access( async def authorize_resource_access(
self, auth_context: AuthContext, resource_uri: str self, auth_context: AuthContext, resource_uri: str
@@ -212,44 +298,363 @@ class DorisSecurityManager:
"""Apply data masking processing""" """Apply data masking processing"""
return await self.masking_processor.process(data, auth_context) return await self.masking_processor.process(data, auth_context)
# OAuth-specific methods
def get_oauth_authorization_url(self) -> tuple[str, str]:
"""Get OAuth authorization URL
Returns:
Tuple of (authorization_url, state)
"""
if not self.auth_provider.oauth_provider:
raise ValueError("OAuth is not enabled")
return self.auth_provider.oauth_provider.get_authorization_url()
async def handle_oauth_callback(self, code: str, state: str) -> AuthContext:
"""Handle OAuth callback
Args:
code: Authorization code from OAuth provider
state: State parameter for CSRF protection
Returns:
AuthContext for authenticated user
"""
if not self.auth_provider.oauth_provider:
raise ValueError("OAuth is not enabled")
return await self.auth_provider.oauth_provider.handle_callback(code, state)
def get_oauth_provider_info(self) -> dict[str, Any]:
"""Get OAuth provider information
Returns:
OAuth provider information
"""
if not self.auth_provider.oauth_provider:
return {"enabled": False}
return self.auth_provider.oauth_provider.get_provider_info()
# Token management methods
async def create_token(
self,
token_id: str,
expires_hours: Optional[int] = None,
description: str = "",
custom_token: Optional[str] = None,
database_config: Optional[DatabaseConfig] = None
) -> str:
"""Create a new API access token
Args:
token_id: Unique token identifier for audit and management
expires_hours: Token expiration in hours (None for no expiration)
description: Token description for management purposes
custom_token: Custom token string (if None, generates random token)
database_config: Optional database configuration for this token
Returns:
Generated token string
"""
if not self.auth_provider.token_manager:
raise ValueError("Token manager not initialized")
return await self.auth_provider.token_manager.create_token(
token_id=token_id,
expires_hours=expires_hours,
description=description,
custom_token=custom_token,
database_config=database_config
)
async def revoke_token(self, token_id: str) -> bool:
"""Revoke a token by token ID
Args:
token_id: Token ID to revoke
Returns:
True if token was revoked successfully
"""
if not self.auth_provider.token_manager:
raise ValueError("Token manager not initialized")
return await self.auth_provider.token_manager.revoke_token(token_id)
async def list_tokens(self) -> list[dict[str, Any]]:
"""List all tokens (without sensitive data)
Returns:
List of token information
"""
if not self.auth_provider.token_manager:
raise ValueError("Token manager not initialized")
return await self.auth_provider.token_manager.list_tokens()
async def cleanup_expired_tokens(self) -> int:
"""Remove expired tokens and return count
Returns:
Number of expired tokens removed
"""
if not self.auth_provider.token_manager:
return 0
return await self.auth_provider.token_manager.cleanup_expired_tokens()
def get_token_stats(self) -> dict[str, Any]:
"""Get token statistics
Returns:
Token statistics dictionary
"""
if not self.auth_provider.token_manager:
return {"error": "Token manager not initialized"}
return self.auth_provider.token_manager.get_token_stats()
async def _validate_token_database_config(self, token: str, token_info) -> None:
"""Validate database configuration for token immediately during authentication
This ensures database connectivity issues are caught at authentication time,
not during query execution, providing better user experience.
Args:
token: Raw authentication token
token_info: TokenInfo object from token validation
Raises:
ValueError: If database configuration is invalid or connection fails
"""
try:
if not self.connection_manager:
self.logger.warning("Connection manager not available for immediate database validation")
return
# Configure and test database connection for this token
success, config_source = await self.connection_manager.configure_for_token(token)
if success:
self.logger.info(f"Database configuration validated successfully for token {token_info.token_id} (source: {config_source})")
else:
raise ValueError("Database configuration validation failed")
except Exception as e:
error_msg = f"Database configuration validation failed for token {token_info.token_id}: {str(e)}"
self.logger.error(error_msg)
raise ValueError(error_msg)
class AuthenticationProvider: class AuthenticationProvider:
"""Authentication provider""" """Authentication provider"""
def __init__(self, config): def __init__(self, config, security_manager=None):
self.config = config self.config = config
self.logger = logging.getLogger(__name__) self.logger = get_logger(__name__)
self.session_cache = {} self.session_cache = {}
self.jwt_manager = None
self.oauth_provider = None
self.token_manager = None
self.security_manager = security_manager
async def authenticate(self, auth_info: dict[str, Any]) -> AuthContext: # Initialize authentication providers based on individual switches
"""Perform identity authentication""" auth_methods_enabled = []
auth_type = auth_info.get("type", "token")
if auth_type == "token": # Initialize Token manager if enabled
return await self._authenticate_token(auth_info) if config.security.enable_token_auth:
elif auth_type == "basic": self._initialize_token_manager()
return await self._authenticate_basic(auth_info) auth_methods_enabled.append("Token")
# Initialize JWT manager if enabled
if config.security.enable_jwt_auth:
self._initialize_jwt_manager()
auth_methods_enabled.append("JWT")
# Initialize OAuth provider if enabled
if config.security.enable_oauth_auth or (hasattr(config.security, 'oauth_enabled') and config.security.oauth_enabled):
self._initialize_oauth_provider()
auth_methods_enabled.append("OAuth")
if auth_methods_enabled:
self.logger.info(f"Authentication enabled with methods: {', '.join(auth_methods_enabled)}")
else: else:
raise ValueError(f"Unsupported authentication type: {auth_type}") self.logger.info("All authentication methods are disabled - anonymous access allowed")
def _initialize_jwt_manager(self):
"""Initialize JWT manager"""
try:
from ..auth.jwt_manager import JWTManager
self.jwt_manager = JWTManager(self.config)
self.logger.info("JWT manager initialized")
except ImportError as e:
self.logger.error(f"Failed to import JWT manager: {e}")
raise
except Exception as e:
self.logger.error(f"Failed to initialize JWT manager: {e}")
raise
def _initialize_token_manager(self):
"""Initialize Token manager"""
try:
from ..auth.token_manager import TokenManager
self.token_manager = TokenManager(self.config)
self.logger.info("Token manager initialized")
except ImportError as e:
self.logger.error(f"Failed to import Token manager: {e}")
raise
except Exception as e:
self.logger.error(f"Failed to initialize Token manager: {e}")
raise
def _initialize_oauth_provider(self):
"""Initialize OAuth provider"""
try:
from ..auth.oauth_provider import OAuthAuthenticationProvider
self.oauth_provider = OAuthAuthenticationProvider(self.config)
self.logger.info("OAuth provider initialized")
except ImportError as e:
self.logger.error(f"Failed to import OAuth provider: {e}")
raise
except Exception as e:
self.logger.error(f"Failed to initialize OAuth provider: {e}")
raise
async def initialize(self):
"""Initialize authentication provider asynchronously"""
if self.jwt_manager:
success = await self.jwt_manager.initialize()
if not success:
raise RuntimeError("Failed to initialize JWT manager")
self.logger.info("JWT authentication provider initialized successfully")
if self.token_manager:
# Token manager doesn't need async initialization, just log success
self.logger.info("Token authentication provider initialized successfully")
if self.oauth_provider:
success = await self.oauth_provider.initialize()
if not success:
raise RuntimeError("Failed to initialize OAuth provider")
self.logger.info("OAuth authentication provider initialized successfully")
async def shutdown(self):
"""Shutdown authentication provider"""
if self.jwt_manager:
await self.jwt_manager.shutdown()
self.logger.info("JWT authentication provider shutdown completed")
if self.token_manager:
# Token manager doesn't need async shutdown, just log
self.logger.info("Token authentication provider shutdown completed")
if self.oauth_provider:
await self.oauth_provider.shutdown()
self.logger.info("OAuth authentication provider shutdown completed")
async def authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext:
"""Perform token authentication"""
if not self.config.security.enable_token_auth:
raise ValueError("Token authentication is not enabled")
return await self._authenticate_token(auth_info)
async def authenticate_jwt(self, auth_info: dict[str, Any]) -> AuthContext:
"""Perform JWT authentication"""
if not self.config.security.enable_jwt_auth:
raise ValueError("JWT authentication is not enabled")
return await self._authenticate_jwt(auth_info)
async def authenticate_oauth(self, auth_info: dict[str, Any]) -> AuthContext:
"""Perform OAuth authentication"""
if not self.config.security.enable_oauth_auth:
raise ValueError("OAuth authentication is not enabled")
return await self._authenticate_oauth(auth_info)
async def _authenticate_jwt(self, auth_info: dict[str, Any]) -> AuthContext:
"""JWT authentication"""
if not self.jwt_manager:
raise ValueError("JWT manager not initialized")
token = auth_info.get("token")
if not token:
# Try to extract from Authorization header
authorization = auth_info.get("authorization")
if authorization and authorization.startswith('Bearer '):
token = authorization[7:]
if not token:
raise ValueError("Missing JWT token")
try:
# Use JWT middleware for authentication
from ..auth.auth_middleware import AuthMiddleware
middleware = AuthMiddleware(self.jwt_manager)
return await middleware.authenticate_request(auth_info)
except Exception as e:
self.logger.error(f"JWT authentication failed: {e}")
raise ValueError(f"JWT authentication failed: {str(e)}")
async def _authenticate_oauth(self, auth_info: dict[str, Any]) -> AuthContext:
"""OAuth authentication"""
if not self.oauth_provider:
raise ValueError("OAuth provider not initialized")
# Handle different OAuth authentication scenarios
if "access_token" in auth_info:
# Direct OAuth access token authentication
return await self.oauth_provider.authenticate_with_token(auth_info["access_token"])
elif "code" in auth_info and "state" in auth_info:
# OAuth callback authentication
return await self.oauth_provider.handle_callback(auth_info["code"], auth_info["state"])
else:
raise ValueError("OAuth authentication requires either access_token or code+state")
async def _authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext: async def _authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext:
"""Token authentication""" """Token authentication"""
if not self.token_manager:
raise ValueError("Token manager not initialized")
token = auth_info.get("token") token = auth_info.get("token")
if not token:
# Try to extract from Authorization header
authorization = auth_info.get("authorization")
if authorization and authorization.startswith('Bearer '):
token = authorization[7:]
elif authorization and authorization.startswith('Token '):
token = authorization[6:]
if not token: if not token:
raise ValueError("Missing authentication token") raise ValueError("Missing authentication token")
# Validate token (simplified implementation, should validate JWT or query authentication service in practice) try:
user_info = await self._validate_token(token) # Validate token using TokenManager
validation_result = await self.token_manager.validate_token(token)
if not validation_result.is_valid:
raise ValueError(f"Token validation failed: {validation_result.error_message}")
token_info = validation_result.token_info
# Immediately validate database configuration for this token
if self.security_manager:
await self.security_manager._validate_token_database_config(token, token_info)
return AuthContext( return AuthContext(
user_id=user_info["user_id"], token_id=token_info.token_id,
roles=user_info["roles"], user_id=token_info.token_id, # Use token_id as user_id for token auth
permissions=user_info["permissions"], roles=["token_user"], # Default role for token users
session_id=auth_info.get("session_id", "default"), permissions=["read", "write"], # Default permissions for token users
security_level=SecurityLevel.INTERNAL,
client_ip=auth_info.get("client_ip", "unknown"),
session_id=auth_info.get("session_id", f"session_{token_info.token_id}"),
login_time=datetime.utcnow(), login_time=datetime.utcnow(),
security_level=SecurityLevel(user_info.get("security_level", "internal")), last_activity=token_info.last_used,
token=token # Store raw token for token-bound database configuration
) )
except Exception as e:
self.logger.error(f"Token authentication failed: {e}")
raise ValueError(f"Token authentication failed: {str(e)}")
async def _authenticate_basic(self, auth_info: dict[str, Any]) -> AuthContext: async def _authenticate_basic(self, auth_info: dict[str, Any]) -> AuthContext:
"""Basic authentication (username password)""" """Basic authentication (username password)"""
username = auth_info.get("username") username = auth_info.get("username")
@@ -328,7 +733,7 @@ class AuthorizationProvider:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.logger = logging.getLogger(__name__) self.logger = get_logger(__name__)
self.permission_cache = {} self.permission_cache = {}
# Load sensitive tables configuration # Load sensitive tables configuration
@@ -471,42 +876,79 @@ class SQLSecurityValidator:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.logger = logging.getLogger(__name__) self.logger = get_logger(__name__)
# Handle DorisConfig object or dictionary configuration # Handle DorisConfig object or dictionary configuration
if hasattr(config, 'get'): if hasattr(config, 'get'):
# Dictionary configuration # Dictionary configuration
self.blocked_keywords = set(config.get("blocked_keywords", [])) self.blocked_keywords = set(config.get("blocked_keywords", []))
self.max_query_complexity = config.get("max_query_complexity", 100) self.max_query_complexity = config.get("max_query_complexity", 100)
self.enable_security_check = config.get("enable_security_check", True)
elif hasattr(config, 'security'):
# DorisConfig object with security attribute - unified source from config
self.blocked_keywords = set(config.security.blocked_keywords)
self.max_query_complexity = config.security.max_query_complexity
self.enable_security_check = getattr(config.security, 'enable_security_check', True)
else: else:
# DorisConfig object, use default values # Fallback to default if no configuration available
self.blocked_keywords = set(["DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE"]) self.blocked_keywords = set([
"DROP", "CREATE", "ALTER", "TRUNCATE",
"DELETE", "INSERT", "UPDATE",
"GRANT", "REVOKE",
"EXEC", "EXECUTE", "SHUTDOWN", "KILL"
])
self.max_query_complexity = 100 self.max_query_complexity = 100
self.enable_security_check = True
async def validate(self, sql: str, auth_context: AuthContext) -> ValidationResult: async def validate(self, sql: str, auth_context: AuthContext) -> ValidationResult:
"""Validate SQL query security""" """Validate SQL query security"""
# If security check is disabled, always return valid
if not self.enable_security_check:
self.logger.debug("SQL security check is disabled, allowing all queries")
return ValidationResult(is_valid=True)
try: try:
# Parse SQL statement # SECURITY FIX: Parse ALL SQL statements, not just the first one
parsed = sqlparse.parse(sql)[0] # This prevents bypassing security checks by injecting additional statements
all_statements = sqlparse.parse(sql)
if not all_statements:
return ValidationResult(
is_valid=False,
error_message="Empty or invalid SQL statement",
risk_level="medium"
)
# SECURITY FIX: Validate each statement individually
for idx, parsed in enumerate(all_statements):
# Skip empty statements (e.g., from trailing semicolons)
if not parsed.tokens or str(parsed).strip() == '':
continue
self.logger.debug(f"Validating SQL statement {idx + 1}/{len(all_statements)}: {str(parsed)[:100]}...")
# Check blocked operations first (more specific) # Check blocked operations first (more specific)
keyword_result = await self._check_blocked_keywords(parsed) keyword_result = await self._check_blocked_keywords(parsed)
if not keyword_result.is_valid: if not keyword_result.is_valid:
keyword_result.error_message = f"Statement {idx + 1}: {keyword_result.error_message}"
return keyword_result return keyword_result
# Check SQL injection risks # Check SQL injection risks
injection_result = await self._check_sql_injection(sql, parsed) injection_result = await self._check_sql_injection(sql, parsed)
if not injection_result.is_valid: if not injection_result.is_valid:
injection_result.error_message = f"Statement {idx + 1}: {injection_result.error_message}"
return injection_result return injection_result
# Check query complexity # Check query complexity
complexity_result = await self._check_query_complexity(parsed) complexity_result = await self._check_query_complexity(parsed)
if not complexity_result.is_valid: if not complexity_result.is_valid:
complexity_result.error_message = f"Statement {idx + 1}: {complexity_result.error_message}"
return complexity_result return complexity_result
# Check table access permissions # Check table access permissions
table_result = await self._check_table_access(parsed, auth_context) table_result = await self._check_table_access(parsed, auth_context)
if not table_result.is_valid: if not table_result.is_valid:
table_result.error_message = f"Statement {idx + 1}: {table_result.error_message}"
return table_result return table_result
return ValidationResult(is_valid=True) return ValidationResult(is_valid=True)
@@ -522,28 +964,69 @@ class SQLSecurityValidator:
async def _check_sql_injection( async def _check_sql_injection(
self, sql: str, parsed: Statement self, sql: str, parsed: Statement
) -> ValidationResult: ) -> ValidationResult:
"""Check SQL injection risks""" """Check SQL injection risks with improved pattern detection
# Check common SQL injection patterns
FIX for Issue #62 Bug 2: Improved patterns to reduce false positives
Now better distinguishes between legitimate SQL (like BETWEEN...AND) and injection attempts
"""
# Improved injection patterns that are more specific and less prone to false positives
injection_patterns = [ injection_patterns = [
r"(\s|^)(union|select|insert|update|delete|drop|create|alter)\s+.*\s+(union|select|insert|update|delete|drop|create|alter)", # Stacked queries with dangerous operations (true injection risk)
r"(\s|^)(or|and)\s+\d+\s*=\s*\d+", r";\s*(DROP|DELETE|TRUNCATE|ALTER|CREATE|INSERT|UPDATE)\s+",
r"(\s|^)(or|and)\s+['\"].*['\"]",
r";\s*(drop|delete|truncate|alter|create)", # UNION-based injection (but allow legitimate UNION queries)
r"(exec|execute|sp_|xp_)", # Only flag if UNION is followed by suspicious patterns like SELECT with WHERE 1=1
r"(script|javascript|vbscript)", r"UNION\s+(ALL\s+)?SELECT\s+.*\s+(WHERE|AND|OR)\s+\d+\s*=\s*\d+",
r"(char|ascii|substring|concat)\s*\(",
# Boolean-based blind injection with comments (true injection pattern)
r"(WHERE|AND|OR)\s+\d+\s*=\s*\d+\s*(--|#|/\*)",
# Quote-based injection attempts (but not in legitimate strings)
r"(WHERE|AND|OR)\s+(['\"])[^\2]*\2\s*=\s*\2[^\2]*\2",
# Time-based blind injection
r"(SLEEP|WAITFOR|BENCHMARK)\s*\(",
# System stored procedure injection
r"(EXEC|EXECUTE|SP_|XP_)\s*\(",
# Script injection attempts
r"<\s*(SCRIPT|JAVASCRIPT|VBSCRIPT)",
] ]
sql_lower = sql.lower() # FIX: Don't flag legitimate SQL functions and keywords
# These patterns are too broad and cause false positives:
# - REMOVED: r"(char|ascii|substring|concat)\s*\(" - These are legitimate SQL functions
# - REMOVED: r"(\s|^)(or|and)\s+\d+\s*=\s*\d+" - This flags BETWEEN...AND constructs
# - REMOVED: r"(\s|^)(or|and)\s+['\"].*['\"]" - This is too broad
sql_upper = sql.upper()
# Special case: Allow BETWEEN...AND which is legitimate SQL
# This prevents false positives like "WHERE dt BETWEEN '2025-01-01' AND '2025-01-31'"
if "BETWEEN" in sql_upper and "AND" in sql_upper:
# This is likely a BETWEEN clause, not injection
# Check if AND appears in a BETWEEN context
between_pattern = r"BETWEEN\s+[^\s]+\s+AND\s+[^\s]+"
if re.search(between_pattern, sql_upper, re.IGNORECASE):
# Remove BETWEEN clauses before checking other patterns
sql_cleaned = re.sub(between_pattern, "BETWEEN_CLAUSE", sql_upper, flags=re.IGNORECASE)
sql_to_check = sql_cleaned
else:
sql_to_check = sql_upper
else:
sql_to_check = sql_upper
for pattern in injection_patterns: for pattern in injection_patterns:
if re.search(pattern, sql_lower, re.IGNORECASE): if re.search(pattern, sql_to_check, re.IGNORECASE):
self.logger.warning(f"Potential SQL injection pattern detected: {pattern}")
return ValidationResult( return ValidationResult(
is_valid=False, is_valid=False,
error_message="Potential SQL injection risk detected", error_message="Potential SQL injection risk detected",
risk_level="high", risk_level="high",
) )
# Check suspicious quotes and comments # Check suspicious quotes and comments (with improved detection)
if self._has_suspicious_quotes_or_comments(sql): if self._has_suspicious_quotes_or_comments(sql):
return ValidationResult( return ValidationResult(
is_valid=False, is_valid=False,
@@ -554,20 +1037,68 @@ class SQLSecurityValidator:
return ValidationResult(is_valid=True) return ValidationResult(is_valid=True)
def _has_suspicious_quotes_or_comments(self, sql: str) -> bool: def _has_suspicious_quotes_or_comments(self, sql: str) -> bool:
"""Check suspicious quote and comment patterns""" """Check suspicious quote and comment patterns with improved detection
# Check unmatched quotes
single_quotes = sql.count("'")
double_quotes = sql.count('"')
FIX for Issue #62 Bug 2: Improved detection to reduce false positives
Now distinguishes between legitimate comments/strings and injection attempts
"""
try:
# Use sqlparse to parse the SQL and distinguish between code and comments/strings
import sqlparse
from sqlparse.tokens import Comment, String
# Parse the SQL
parsed = sqlparse.parse(sql)
if not parsed:
# If parsing fails, be conservative
return True
statement = parsed[0]
# Check for unmatched quotes ONLY in non-string tokens
# This prevents false positives from legitimate string content
non_string_content = []
has_string_tokens = False
for token in statement.flatten():
if token.ttype in (String.Single, String.Double):
has_string_tokens = True
# Skip string content - quotes inside strings are legitimate
continue
elif token.ttype in (Comment.Single, Comment.Multi):
# Comments are generally OK, but check for suspicious injection patterns
comment_value = str(token).lower()
# Check if comment contains dangerous SQL keywords
dangerous_in_comments = ['drop', 'delete', 'insert', 'update', 'union', 'exec', 'execute']
if any(keyword in comment_value for keyword in dangerous_in_comments):
self.logger.warning(f"Suspicious SQL keyword in comment: {token}")
return True
# Normal comments are OK
continue
else:
# Accumulate non-string, non-comment content
non_string_content.append(str(token))
# Check for unmatched quotes in non-string content
non_string_text = ''.join(non_string_content)
single_quotes = non_string_text.count("'")
double_quotes = non_string_text.count('"')
# Only flag if there are unmatched quotes in actual SQL code (not in strings)
if single_quotes % 2 != 0 or double_quotes % 2 != 0: if single_quotes % 2 != 0 or double_quotes % 2 != 0:
return True return True
# Check SQL comments # FIX: Don't flag legitimate SQL comments
if "--" in sql or "/*" in sql: # Comments are OK as long as they don't contain dangerous patterns (already checked above)
return True
return False return False
except Exception as e:
self.logger.debug(f"SQL parsing error in quote/comment check: {e}")
# On parsing error, fall back to conservative check
# But be more lenient than before
return False # Don't flag on parse errors to reduce false positives
async def _check_blocked_keywords(self, parsed: Statement) -> ValidationResult: async def _check_blocked_keywords(self, parsed: Statement) -> ValidationResult:
"""Check blocked keywords""" """Check blocked keywords"""
blocked_operations = [] blocked_operations = []
@@ -628,6 +1159,10 @@ class SQLSecurityValidator:
self, parsed: Statement, auth_context: AuthContext self, parsed: Statement, auth_context: AuthContext
) -> ValidationResult: ) -> ValidationResult:
"""Check table access permissions""" """Check table access permissions"""
# If no auth_context, skip table access checks (rely on other security checks)
if auth_context is None:
return ValidationResult(is_valid=True)
# Extract table names from query # Extract table names from query
tables = self._extract_table_names(parsed) tables = self._extract_table_names(parsed)
@@ -676,7 +1211,7 @@ class DataMaskingProcessor:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.logger = logging.getLogger(__name__) self.logger = get_logger(__name__)
self.masking_algorithms = self._init_masking_algorithms() self.masking_algorithms = self._init_masking_algorithms()
self.masking_rules = self._load_masking_rules() self.masking_rules = self._load_masking_rules()

View File

@@ -0,0 +1,788 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Security Analytics Tools Module
Provides data access analysis, user behavior monitoring, and security insights
"""
import time
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from collections import Counter, defaultdict
from .db import DorisConnectionManager
from .logger import get_logger
from .sql_security_utils import get_auth_context
logger = get_logger(__name__)
class SecurityAnalyticsTools:
"""Security analytics tools for access pattern analysis and user monitoring"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
logger.info("SecurityAnalyticsTools initialized")
async def analyze_data_access_patterns(
self,
days: int = 7,
include_system_users: bool = False,
min_query_threshold: int = 5
) -> Dict[str, Any]:
"""
Analyze data access patterns for users and roles
Args:
days: Number of days to analyze
include_system_users: Whether to include system/service users
min_query_threshold: Minimum queries for a user to be included in analysis
Returns:
Comprehensive access pattern analysis
"""
try:
start_time = time.time()
# 🚀 PROGRESS: Initialize security analysis
logger.info("=" * 70)
logger.info(f"🔒 Starting Data Access Pattern Analysis")
logger.info(f"📅 Analysis period: {days} days")
logger.info(f"👥 Include system users: {include_system_users}")
logger.info(f"🎯 Min query threshold: {min_query_threshold}")
logger.info("=" * 70)
connection = await self.connection_manager.get_connection("query")
# Define analysis period
end_date = datetime.now()
start_date = end_date - timedelta(days=days)
logger.info(f"📊 Period: {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')}")
# 🚀 PROGRESS: Step 1 - Get audit log data
logger.info("📋 Step 1/5: Retrieving audit log data...")
audit_start = time.time()
audit_data = await self._get_audit_log_data(connection, start_date, end_date, include_system_users)
audit_time = time.time() - audit_start
if not audit_data:
logger.warning("⚠️ No audit data available for the specified period")
return {
"error": "No audit data available for the specified period",
"analysis_period": {
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat(),
"days": days
}
}
logger.info(f"✅ Retrieved {len(audit_data)} audit records in {audit_time:.2f}s")
# 🚀 PROGRESS: Step 2 - Analyze user access patterns
logger.info("👤 Step 2/5: Analyzing user access patterns...")
user_start = time.time()
user_access_analysis = await self._analyze_user_access_patterns(
audit_data, min_query_threshold
)
user_time = time.time() - user_start
logger.info(f"✅ Analyzed {len(user_access_analysis)} users in {user_time:.2f}s")
# 🚀 PROGRESS: Step 3 - Analyze role-based access
logger.info("🎭 Step 3/5: Analyzing role-based access patterns...")
role_start = time.time()
role_access_analysis = await self._analyze_role_access_patterns(
connection, user_access_analysis
)
role_time = time.time() - role_start
logger.info(f"✅ Role analysis completed in {role_time:.2f}s")
# 🚀 PROGRESS: Step 4 - Detect security anomalies
logger.info("🚨 Step 4/5: Detecting security anomalies...")
anomaly_start = time.time()
security_alerts = await self._detect_security_anomalies(
audit_data, user_access_analysis
)
anomaly_time = time.time() - anomaly_start
logger.info(f"✅ Found {len(security_alerts)} security alerts in {anomaly_time:.2f}s")
# Log alert summary
if security_alerts:
high_alerts = sum(1 for alert in security_alerts if alert.get("severity") == "high")
medium_alerts = sum(1 for alert in security_alerts if alert.get("severity") == "medium")
logger.info(f"🚨 Alert breakdown: {high_alerts} high, {medium_alerts} medium")
# 🚀 PROGRESS: Step 5 - Generate access insights
logger.info("💡 Step 5/5: Generating access insights...")
insights_start = time.time()
access_insights = await self._generate_access_insights(
user_access_analysis, role_access_analysis
)
insights_time = time.time() - insights_start
logger.info(f"✅ Access insights generated in {insights_time:.2f}s")
execution_time = time.time() - start_time
return {
"analysis_period": {
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat(),
"days": days
},
"analysis_timestamp": datetime.now().isoformat(),
"execution_time_seconds": round(execution_time, 3),
"user_access_summary": self._generate_user_access_summary(user_access_analysis),
"user_access_details": user_access_analysis,
"role_analysis": role_access_analysis,
"security_alerts": security_alerts,
"access_insights": access_insights,
"recommendations": self._generate_security_recommendations(security_alerts, access_insights)
}
except Exception as e:
logger.error(f"Data access pattern analysis failed: {str(e)}")
return {
"error": str(e),
"analysis_timestamp": datetime.now().isoformat()
}
# ==================== Private Helper Methods ====================
async def _get_audit_log_data(self, connection, start_date: datetime, end_date: datetime, include_system_users: bool) -> List[Dict]:
"""Retrieve audit log data for the specified period"""
try:
# System users filter
system_user_filter = ""
if not include_system_users:
system_users = ['root', 'admin', 'system', 'doris', 'information_schema']
user_list = ','.join([f'"{user}"' for user in system_users])
system_user_filter = f"AND `user` NOT IN ({user_list})"
audit_sql = f"""
SELECT
`user` as user_name,
`client_ip` as host,
`time` as query_time,
`stmt` as sql_statement,
`state` as query_status,
`scan_bytes` as scan_bytes,
`scan_rows` as scan_rows,
`return_rows` as return_rows,
`query_time` as execution_time_ms
FROM internal.__internal_schema.audit_log
WHERE `time` >= '{start_date.strftime('%Y-%m-%d %H:%M:%S')}'
AND `time` <= '{end_date.strftime('%Y-%m-%d %H:%M:%S')}'
AND `stmt` IS NOT NULL
AND `stmt` != ''
{system_user_filter}
ORDER BY `time` DESC
LIMIT 10000
"""
# SECURITY FIX: Pass auth_context to execute
auth_context = get_auth_context()
result = await connection.execute(audit_sql, auth_context=auth_context)
return result.data if result.data else []
except Exception as e:
logger.warning(f"Failed to get audit log data: {str(e)}")
# Try alternative method without detailed metrics
try:
simple_audit_sql = f"""
SELECT
`user` as user_name,
`client_ip` as host,
`time` as query_time,
`stmt` as sql_statement,
`state` as query_status
FROM internal.__internal_schema.audit_log
WHERE `time` >= '{start_date.strftime('%Y-%m-%d %H:%M:%S')}'
AND `time` <= '{end_date.strftime('%Y-%m-%d %H:%M:%S')}'
AND `stmt` IS NOT NULL
{system_user_filter}
ORDER BY `time` DESC
LIMIT 10000
"""
auth_context = get_auth_context()
result = await connection.execute(simple_audit_sql, auth_context=auth_context)
return result.data if result.data else []
except Exception as e2:
logger.error(f"Failed to get simplified audit log data: {str(e2)}")
return []
async def _analyze_user_access_patterns(self, audit_data: List[Dict], min_query_threshold: int) -> List[Dict]:
"""Analyze access patterns for individual users"""
user_stats = defaultdict(lambda: {
"total_queries": 0,
"unique_tables_accessed": set(),
"hosts": set(),
"query_types": Counter(),
"query_times": [],
"failed_queries": 0,
"data_volume_read_bytes": 0,
"data_volume_read_rows": 0,
"hourly_pattern": [0] * 24,
"daily_pattern": [0] * 7,
"query_statements": []
})
# Process audit data
for entry in audit_data:
user_name = entry.get("user_name", "unknown")
query_time = entry.get("query_time")
sql_statement = entry.get("sql_statement", "")
query_status = entry.get("query_status", "")
stats = user_stats[user_name]
stats["total_queries"] += 1
# Extract table names from SQL
tables = self._extract_table_names_from_sql(sql_statement)
stats["unique_tables_accessed"].update(tables)
# Host tracking
if entry.get("host"):
stats["hosts"].add(entry["host"])
# Query type analysis
query_type = self._classify_query_type(sql_statement)
stats["query_types"][query_type] += 1
# Query time patterns
if query_time:
try:
if isinstance(query_time, str):
query_dt = datetime.fromisoformat(query_time.replace('Z', '+00:00'))
else:
query_dt = query_time
stats["query_times"].append(query_dt)
stats["hourly_pattern"][query_dt.hour] += 1
stats["daily_pattern"][query_dt.weekday()] += 1
except Exception:
pass
# Error tracking
if query_status and "error" in query_status.lower():
stats["failed_queries"] += 1
# Data volume tracking
if entry.get("scan_bytes"):
try:
stats["data_volume_read_bytes"] += int(entry["scan_bytes"])
except (ValueError, TypeError):
pass
if entry.get("scan_rows"):
try:
stats["data_volume_read_rows"] += int(entry["scan_rows"])
except (ValueError, TypeError):
pass
# Store sample queries
if len(stats["query_statements"]) < 10:
stats["query_statements"].append({
"sql": sql_statement[:200] + "..." if len(sql_statement) > 200 else sql_statement,
"timestamp": str(query_time),
"type": query_type
})
# Convert to analysis results
user_analysis = []
for user_name, stats in user_stats.items():
if stats["total_queries"] >= min_query_threshold:
# Calculate patterns and insights
access_pattern = self._classify_access_pattern(stats["hourly_pattern"])
table_access_frequency = dict(Counter(
table for entry in audit_data
if entry.get("user_name") == user_name
for table in self._extract_table_names_from_sql(entry.get("sql_statement", ""))
).most_common(10))
user_analysis.append({
"user_name": user_name,
"access_stats": {
"total_queries": stats["total_queries"],
"unique_tables_accessed": len(stats["unique_tables_accessed"]),
"unique_hosts": len(stats["hosts"]),
"data_volume_read_gb": round(stats["data_volume_read_bytes"] / (1024**3), 3),
"data_volume_read_rows": stats["data_volume_read_rows"],
"failed_queries": stats["failed_queries"],
"success_rate": round((stats["total_queries"] - stats["failed_queries"]) / stats["total_queries"], 3) if stats["total_queries"] > 0 else 0,
"peak_access_hour": stats["hourly_pattern"].index(max(stats["hourly_pattern"])) if max(stats["hourly_pattern"]) > 0 else None,
"access_pattern": access_pattern
},
"query_type_distribution": dict(stats["query_types"]),
"table_access_frequency": table_access_frequency,
"hosts_used": list(stats["hosts"]),
"sample_queries": stats["query_statements"],
"temporal_patterns": {
"hourly_distribution": stats["hourly_pattern"],
"daily_distribution": stats["daily_pattern"]
}
})
return sorted(user_analysis, key=lambda x: x["access_stats"]["total_queries"], reverse=True)
def _extract_table_names_from_sql(self, sql: str) -> List[str]:
"""Extract table names from SQL statement (simplified implementation)"""
if not sql:
return []
import re
# Simple regex patterns to match table names
patterns = [
r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
r'\bINTO\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
r'\bUPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
r'\bDELETE\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)'
]
tables = []
for pattern in patterns:
matches = re.findall(pattern, sql, re.IGNORECASE)
tables.extend(matches)
# Clean up table names (remove quotes, aliases, etc.)
cleaned_tables = []
for table in tables:
# Remove backticks, quotes, and get just the table name
clean_table = table.strip('`"\'').split(' ')[0]
if clean_table and not clean_table.upper() in ['SELECT', 'WHERE', 'AND', 'OR']:
cleaned_tables.append(clean_table)
return list(set(cleaned_tables))
def _classify_query_type(self, sql: str) -> str:
"""Classify SQL query type"""
if not sql:
return "unknown"
sql_upper = sql.upper().strip()
if sql_upper.startswith('SELECT'):
return "SELECT"
elif sql_upper.startswith('INSERT'):
return "INSERT"
elif sql_upper.startswith('UPDATE'):
return "UPDATE"
elif sql_upper.startswith('DELETE'):
return "DELETE"
elif sql_upper.startswith('CREATE'):
return "CREATE"
elif sql_upper.startswith('ALTER'):
return "ALTER"
elif sql_upper.startswith('DROP'):
return "DROP"
elif sql_upper.startswith('SHOW'):
return "SHOW"
elif sql_upper.startswith('DESCRIBE') or sql_upper.startswith('DESC'):
return "DESCRIBE"
else:
return "OTHER"
def _classify_access_pattern(self, hourly_pattern: List[int]) -> str:
"""Classify user access pattern based on hourly distribution"""
if not hourly_pattern or max(hourly_pattern) == 0:
return "no_pattern"
# Find peak hours
max_queries = max(hourly_pattern)
peak_hours = [i for i, count in enumerate(hourly_pattern) if count == max_queries]
# Business hours: 9-17
business_hours = set(range(9, 18))
peak_in_business_hours = any(hour in business_hours for hour in peak_hours)
# Night hours: 22-6
night_hours = set(list(range(22, 24)) + list(range(0, 7)))
peak_in_night_hours = any(hour in night_hours for hour in peak_hours)
if peak_in_business_hours and not peak_in_night_hours:
return "regular_business_hours"
elif peak_in_night_hours:
return "night_shift_or_batch"
elif len(peak_hours) > 6: # Distributed throughout day
return "distributed_access"
else:
return "irregular_pattern"
async def _analyze_role_access_patterns(self, connection, user_access_analysis: List[Dict]) -> Dict[str, Any]:
"""Analyze access patterns by role"""
try:
# Get user roles information
user_roles = await self._get_user_roles(connection)
# Group users by roles
role_stats = defaultdict(lambda: {
"user_count": 0,
"total_queries": 0,
"unique_tables": set(),
"query_types": Counter(),
"avg_queries_per_user": 0,
"users": []
})
# Process user access data
for user_data in user_access_analysis:
user_name = user_data["user_name"]
user_stats = user_data["access_stats"]
query_types = user_data["query_type_distribution"]
# Get user roles (default to 'unknown' if not found)
roles = user_roles.get(user_name, ["unknown"])
for role in roles:
stats = role_stats[role]
stats["user_count"] += 1
stats["total_queries"] += user_stats["total_queries"]
stats["users"].append(user_name)
# Aggregate query types
for query_type, count in query_types.items():
stats["query_types"][query_type] += count
# Calculate role analysis
role_analysis = {}
for role, stats in role_stats.items():
if stats["user_count"] > 0:
avg_queries = stats["total_queries"] / stats["user_count"]
# Calculate privilege usage (simplified)
total_role_queries = sum(stats["query_types"].values())
privilege_usage = {}
if total_role_queries > 0:
privilege_usage = {
query_type: round(count / total_role_queries, 3)
for query_type, count in stats["query_types"].items()
}
role_analysis[role] = {
"user_count": stats["user_count"],
"users": stats["users"],
"total_queries": stats["total_queries"],
"avg_queries_per_user": round(avg_queries, 1),
"query_type_distribution": dict(stats["query_types"]),
"privilege_usage": privilege_usage,
"activity_level": self._classify_role_activity_level(avg_queries)
}
return role_analysis
except Exception as e:
logger.warning(f"Failed to analyze role access patterns: {str(e)}")
return {}
async def _get_user_roles(self, connection) -> Dict[str, List[str]]:
"""Get user roles mapping"""
try:
# Try to get user role information
roles_sql = """
SELECT
User as user_name,
COALESCE(Default_role, 'default') as role_name
FROM mysql.user
"""
auth_context = get_auth_context()
result = await connection.execute(roles_sql, auth_context=auth_context)
user_roles = defaultdict(list)
if result.data:
for row in result.data:
user_name = row.get("user_name", "")
role_name = row.get("role_name", "default")
if user_name:
user_roles[user_name].append(role_name)
return dict(user_roles)
except Exception as e:
logger.warning(f"Failed to get user roles: {str(e)}")
return {}
def _classify_role_activity_level(self, avg_queries: float) -> str:
"""Classify role activity level based on average queries"""
if avg_queries > 100:
return "high"
elif avg_queries > 20:
return "medium"
elif avg_queries > 5:
return "low"
else:
return "minimal"
async def _detect_security_anomalies(self, audit_data: List[Dict], user_access_analysis: List[Dict]) -> List[Dict]:
"""Detect potential security anomalies"""
alerts = []
# 1. Detect unusual access times
for user_data in user_access_analysis:
user_name = user_data["user_name"]
hourly_pattern = user_data["temporal_patterns"]["hourly_distribution"]
# Check for significant night-time activity
night_queries = sum(hourly_pattern[22:24]) + sum(hourly_pattern[0:6])
total_queries = sum(hourly_pattern)
if total_queries > 0 and night_queries / total_queries > 0.3: # >30% night activity
alerts.append({
"alert_type": "unusual_access_time",
"severity": "medium",
"user": user_name,
"description": f"User {user_name} has {night_queries/total_queries:.1%} of queries during night hours",
"night_query_percentage": round(night_queries/total_queries, 3),
"timestamp": datetime.now().isoformat()
})
# 2. Detect users with high failure rates
for user_data in user_access_analysis:
user_name = user_data["user_name"]
success_rate = user_data["access_stats"]["success_rate"]
total_queries = user_data["access_stats"]["total_queries"]
if total_queries > 10 and success_rate < 0.8: # <80% success rate
alerts.append({
"alert_type": "high_failure_rate",
"severity": "medium",
"user": user_name,
"description": f"User {user_name} has low query success rate ({success_rate:.1%})",
"success_rate": success_rate,
"total_queries": total_queries,
"timestamp": datetime.now().isoformat()
})
# 3. Detect unusual data volume access
data_volumes = [user["access_stats"]["data_volume_read_gb"] for user in user_access_analysis]
if data_volumes:
avg_volume = sum(data_volumes) / len(data_volumes)
std_dev = (sum((x - avg_volume) ** 2 for x in data_volumes) / len(data_volumes)) ** 0.5
threshold = avg_volume + 2 * std_dev # 2 standard deviations above mean
for user_data in user_access_analysis:
user_name = user_data["user_name"]
volume = user_data["access_stats"]["data_volume_read_gb"]
if volume > threshold and volume > 1.0: # >1GB and above threshold
alerts.append({
"alert_type": "unusual_data_volume",
"severity": "high" if volume > threshold * 2 else "medium",
"user": user_name,
"description": f"User {user_name} read {volume:.2f}GB (threshold: {threshold:.2f}GB)",
"data_volume_gb": volume,
"threshold_gb": round(threshold, 2),
"timestamp": datetime.now().isoformat()
})
# 4. Detect users accessing many different tables
for user_data in user_access_analysis:
user_name = user_data["user_name"]
unique_tables = user_data["access_stats"]["unique_tables_accessed"]
total_queries = user_data["access_stats"]["total_queries"]
# High table diversity might indicate privilege escalation or data mining
if unique_tables > 20 and total_queries > 50:
alerts.append({
"alert_type": "broad_table_access",
"severity": "medium",
"user": user_name,
"description": f"User {user_name} accessed {unique_tables} different tables",
"unique_tables_count": unique_tables,
"total_queries": total_queries,
"timestamp": datetime.now().isoformat()
})
return sorted(alerts, key=lambda x: {"high": 3, "medium": 2, "low": 1}.get(x["severity"], 0), reverse=True)
async def _generate_access_insights(self, user_access_analysis: List[Dict], role_analysis: Dict[str, Any]) -> Dict[str, Any]:
"""Generate access insights and patterns"""
insights = {
"user_behavior_patterns": {},
"role_effectiveness": {},
"security_posture": {}
}
# User behavior patterns
if user_access_analysis:
total_users = len(user_access_analysis)
active_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 10])
power_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 100])
# Access pattern distribution
pattern_distribution = Counter(
user["access_stats"]["access_pattern"] for user in user_access_analysis
)
insights["user_behavior_patterns"] = {
"total_users_analyzed": total_users,
"active_users": active_users,
"power_users": power_users,
"access_pattern_distribution": dict(pattern_distribution),
"avg_queries_per_user": round(
sum(u["access_stats"]["total_queries"] for u in user_access_analysis) / total_users, 1
) if total_users > 0 else 0
}
# Role effectiveness
if role_analysis:
most_active_role = max(role_analysis.items(), key=lambda x: x[1]["total_queries"])
least_active_role = min(role_analysis.items(), key=lambda x: x[1]["total_queries"])
insights["role_effectiveness"] = {
"total_roles": len(role_analysis),
"most_active_role": {
"role": most_active_role[0],
"total_queries": most_active_role[1]["total_queries"],
"user_count": most_active_role[1]["user_count"]
},
"least_active_role": {
"role": least_active_role[0],
"total_queries": least_active_role[1]["total_queries"],
"user_count": least_active_role[1]["user_count"]
},
"avg_users_per_role": round(
sum(role_info["user_count"] for role_info in role_analysis.values()) / len(role_analysis), 1
)
}
# Security posture assessment
if user_access_analysis:
users_with_failures = len([u for u in user_access_analysis if u["access_stats"]["failed_queries"] > 0])
users_night_access = len([
u for u in user_access_analysis
if any(u["temporal_patterns"]["hourly_distribution"][hour] > 0 for hour in list(range(22, 24)) + list(range(0, 6)))
])
insights["security_posture"] = {
"users_with_query_failures": users_with_failures,
"users_with_night_access": users_night_access,
"security_score": self._calculate_security_score(user_access_analysis),
"risk_level": self._assess_overall_risk_level(user_access_analysis)
}
return insights
def _calculate_security_score(self, user_access_analysis: List[Dict]) -> float:
"""Calculate overall security score (0-1, higher is better)"""
if not user_access_analysis:
return 0.0
total_users = len(user_access_analysis)
# Factors that contribute to security score
users_with_high_success_rate = len([u for u in user_access_analysis if u["access_stats"]["success_rate"] > 0.9])
users_with_normal_patterns = len([u for u in user_access_analysis if u["access_stats"]["access_pattern"] == "regular_business_hours"])
success_rate_score = users_with_high_success_rate / total_users
pattern_score = users_with_normal_patterns / total_users
# Combined score
overall_score = (success_rate_score * 0.6 + pattern_score * 0.4)
return round(overall_score, 3)
def _assess_overall_risk_level(self, user_access_analysis: List[Dict]) -> str:
"""Assess overall security risk level"""
security_score = self._calculate_security_score(user_access_analysis)
if security_score > 0.8:
return "low"
elif security_score > 0.6:
return "medium"
else:
return "high"
def _generate_user_access_summary(self, user_access_analysis: List[Dict]) -> Dict[str, Any]:
"""Generate summary statistics for user access"""
if not user_access_analysis:
return {
"total_users": 0,
"active_users": 0,
"high_activity_users": 0,
"dormant_users": 0
}
total_users = len(user_access_analysis)
active_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 10])
high_activity_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 100])
dormant_users = total_users - active_users
return {
"total_users": total_users,
"active_users": active_users,
"high_activity_users": high_activity_users,
"dormant_users": dormant_users,
"activity_distribution": {
"high": high_activity_users,
"medium": active_users - high_activity_users,
"low": dormant_users
}
}
def _generate_security_recommendations(self, security_alerts: List[Dict], access_insights: Dict[str, Any]) -> List[Dict]:
"""Generate security recommendations based on analysis"""
recommendations = []
# Recommendations based on alerts
if security_alerts:
high_severity_alerts = [alert for alert in security_alerts if alert["severity"] == "high"]
if high_severity_alerts:
recommendations.append({
"type": "urgent_security_review",
"priority": "high",
"description": f"Found {len(high_severity_alerts)} high-severity security alerts",
"action": "Immediate review of flagged users and access patterns required",
"affected_users": list(set(alert["user"] for alert in high_severity_alerts if "user" in alert))
})
# Night access recommendations
night_access_alerts = [alert for alert in security_alerts if alert["alert_type"] == "unusual_access_time"]
if night_access_alerts:
recommendations.append({
"type": "access_time_policy",
"priority": "medium",
"description": f"{len(night_access_alerts)} users have significant night-time access",
"action": "Review access time policies and consider time-based restrictions",
"affected_users": [alert["user"] for alert in night_access_alerts]
})
# Recommendations based on insights
security_posture = access_insights.get("security_posture", {})
risk_level = security_posture.get("risk_level", "unknown")
if risk_level == "high":
recommendations.append({
"type": "overall_security_improvement",
"priority": "high",
"description": "Overall security posture indicates high risk",
"action": "Comprehensive security audit and policy review recommended"
})
# Role-based recommendations
role_effectiveness = access_insights.get("role_effectiveness", {})
if role_effectiveness and role_effectiveness.get("total_roles", 0) < 3:
recommendations.append({
"type": "role_management",
"priority": "medium",
"description": "Limited role diversity detected",
"action": "Consider implementing more granular role-based access control"
})
return recommendations

View File

@@ -0,0 +1,301 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
SQL Security Utilities Module
Provides SQL identifier validation, escaping, and safe query building utilities
to prevent SQL injection attacks.
"""
import re
from contextvars import ContextVar
from typing import Optional, Tuple, List, Any
from .logger import get_logger
logger = get_logger(__name__)
# Context variable for auth_context (set by HTTP middleware)
auth_context_var: ContextVar = ContextVar('mcp_auth_context', default=None)
class SQLSecurityError(Exception):
"""Exception raised for SQL security validation failures"""
pass
class SQLSecurityUtils:
"""
SQL Security Utilities for preventing SQL injection attacks.
Provides:
- Identifier validation (database names, table names, column names)
- Safe identifier quoting with backticks
- Safe table reference building
- Auth context retrieval from context variables
"""
# Valid SQL identifier pattern: letters, numbers, underscores
# Must start with letter or underscore, not a number
# Supports Unicode letters for international database/table names
IDENTIFIER_PATTERN = re.compile(r'^[a-zA-Z_\u4e00-\u9fff][a-zA-Z0-9_\u4e00-\u9fff]*$')
# Maximum identifier length (MySQL/Doris standard)
MAX_IDENTIFIER_LENGTH = 64
# SQL reserved keywords that should be quoted
SQL_KEYWORDS = {
'SELECT', 'FROM', 'WHERE', 'INSERT', 'UPDATE', 'DELETE', 'DROP',
'CREATE', 'ALTER', 'TABLE', 'DATABASE', 'INDEX', 'VIEW', 'AND',
'OR', 'NOT', 'NULL', 'TRUE', 'FALSE', 'IN', 'LIKE', 'BETWEEN',
'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'ON', 'AS', 'ORDER',
'BY', 'GROUP', 'HAVING', 'LIMIT', 'OFFSET', 'UNION', 'ALL',
'DISTINCT', 'INTO', 'VALUES', 'SET', 'DEFAULT', 'PRIMARY', 'KEY',
'FOREIGN', 'REFERENCES', 'CHECK', 'UNIQUE', 'CONSTRAINT'
}
@classmethod
def validate_identifier(cls, name: str, identifier_type: str = "identifier") -> str:
"""
Validate a SQL identifier (database name, table name, column name, etc.)
Args:
name: The identifier to validate
identifier_type: Type description for error messages (e.g., "database name", "table name")
Returns:
The validated identifier (unchanged if valid)
Raises:
SQLSecurityError: If the identifier is invalid
"""
if not name:
raise SQLSecurityError(f"Empty {identifier_type} is not allowed")
if not isinstance(name, str):
raise SQLSecurityError(f"Invalid {identifier_type}: must be a string, got {type(name).__name__}")
# Strip whitespace
name = name.strip()
if not name:
raise SQLSecurityError(f"Empty {identifier_type} is not allowed")
# Check length
if len(name) > cls.MAX_IDENTIFIER_LENGTH:
raise SQLSecurityError(
f"Invalid {identifier_type}: '{name[:20]}...' exceeds maximum length of {cls.MAX_IDENTIFIER_LENGTH} characters"
)
# Check for dangerous characters that could be SQL injection
dangerous_chars = ["'", '"', ';', '--', '/*', '*/', '\\', '\x00']
for char in dangerous_chars:
if char in name:
raise SQLSecurityError(
f"Invalid {identifier_type}: '{name}' contains forbidden character '{char}'"
)
# Validate pattern
if not cls.IDENTIFIER_PATTERN.match(name):
raise SQLSecurityError(
f"Invalid {identifier_type}: '{name}' contains invalid characters. "
f"Only letters, numbers, and underscores are allowed, and must start with a letter or underscore."
)
logger.debug(f"Validated {identifier_type}: {name}")
return name
@classmethod
def quote_identifier(cls, name: str, identifier_type: str = "identifier") -> str:
"""
Safely quote a SQL identifier using backticks.
Args:
name: The identifier to quote
identifier_type: Type description for error messages
Returns:
The quoted identifier (e.g., `table_name`)
Raises:
SQLSecurityError: If the identifier is invalid
"""
# First validate the identifier
validated_name = cls.validate_identifier(name, identifier_type)
# Escape any backticks within the name (double them)
escaped_name = validated_name.replace('`', '``')
return f"`{escaped_name}`"
@classmethod
def build_table_reference(
cls,
table_name: str,
db_name: Optional[str] = None,
catalog_name: Optional[str] = None,
quote: bool = True
) -> str:
"""
Build a safe, fully-qualified table reference.
Args:
table_name: The table name (required)
db_name: The database name (optional)
catalog_name: The catalog name (optional)
quote: Whether to quote identifiers with backticks (default: True)
Returns:
A safe table reference string (e.g., `catalog`.`db`.`table`)
Raises:
SQLSecurityError: If any identifier is invalid
"""
parts = []
if catalog_name:
if quote:
parts.append(cls.quote_identifier(catalog_name, "catalog name"))
else:
parts.append(cls.validate_identifier(catalog_name, "catalog name"))
if db_name:
if quote:
parts.append(cls.quote_identifier(db_name, "database name"))
else:
parts.append(cls.validate_identifier(db_name, "database name"))
if quote:
parts.append(cls.quote_identifier(table_name, "table name"))
else:
parts.append(cls.validate_identifier(table_name, "table name"))
return '.'.join(parts)
@classmethod
def build_column_reference(
cls,
column_name: str,
table_name: Optional[str] = None,
quote: bool = True
) -> str:
"""
Build a safe column reference.
Args:
column_name: The column name (required)
table_name: The table name (optional, for qualified references)
quote: Whether to quote identifiers with backticks (default: True)
Returns:
A safe column reference string (e.g., `table`.`column`)
Raises:
SQLSecurityError: If any identifier is invalid
"""
parts = []
if table_name:
if quote:
parts.append(cls.quote_identifier(table_name, "table name"))
else:
parts.append(cls.validate_identifier(table_name, "table name"))
if quote:
parts.append(cls.quote_identifier(column_name, "column name"))
else:
parts.append(cls.validate_identifier(column_name, "column name"))
return '.'.join(parts)
@classmethod
def validate_and_build_where_condition(
cls,
column_name: str,
operator: str = "=",
use_param: bool = True
) -> Tuple[str, bool]:
"""
Build a safe WHERE condition for a column.
Args:
column_name: The column name
operator: The comparison operator (=, !=, <, >, <=, >=, LIKE, IN)
use_param: Whether to use parameterized placeholder (%s)
Returns:
Tuple of (condition_string, needs_param)
e.g., ("`column` = %s", True) or ("`column` = DATABASE()", False)
Raises:
SQLSecurityError: If column name is invalid or operator is not allowed
"""
# Validate column name
quoted_column = cls.quote_identifier(column_name, "column name")
# Validate operator
allowed_operators = {'=', '!=', '<>', '<', '>', '<=', '>=', 'LIKE', 'IN', 'IS'}
if operator.upper() not in allowed_operators:
raise SQLSecurityError(f"Invalid operator: '{operator}'. Allowed: {allowed_operators}")
if use_param:
return f"{quoted_column} {operator} %s", True
else:
return f"{quoted_column} {operator}", False
@staticmethod
def get_auth_context():
"""
Get auth_context from the context variable.
This retrieves the auth_context that was set by the HTTP middleware
during request processing.
Returns:
The auth_context object, or None if not available
"""
try:
auth_context = auth_context_var.get()
if auth_context:
logger.debug(f"Retrieved auth_context from context variable")
return auth_context
except Exception as e:
logger.debug(f"Could not retrieve auth_context: {e}")
return None
@staticmethod
def set_auth_context(auth_context):
"""
Set auth_context in the context variable.
This is typically called by the HTTP middleware during request processing.
Args:
auth_context: The auth_context object to set
"""
auth_context_var.set(auth_context)
logger.debug("Set auth_context in context variable")
# Convenience functions for direct use
validate_identifier = SQLSecurityUtils.validate_identifier
quote_identifier = SQLSecurityUtils.quote_identifier
build_table_reference = SQLSecurityUtils.build_table_reference
build_column_reference = SQLSecurityUtils.build_column_reference
get_auth_context = SQLSecurityUtils.get_auth_context
set_auth_context = SQLSecurityUtils.set_auth_context

147
examples/cursor/README.md Normal file
View File

@@ -0,0 +1,147 @@
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
# Cursor Example: Integrating Doris MCP Server
This guide provides step-by-step instructions on how to integrate the `doris-mcp-server` with the [Cursor](https://cursor.sh/) IDE. This integration allows you to interact with your Apache Doris database using natural language queries directly within Cursor's AI chat.
## Table of Contents
* [Prerequisites](#prerequisites)
* [Step 1: Set Up the Project](#step-1-set-up-the-project)
* [Step 2: Configure the MCP Server in Cursor](#step-2-configure-the-mcp-server-in-cursor)
* [Step 3: Verify the Integration](#step-3-verify-the-integration)
* [Step 4: Query Your Database](#step-4-query-your-database)
* [Example 1: List Tables](#example-1-list-tables)
* [Example 2: Analyze Sales Trends](#example-2-analyze-sales-trends)
---
### Prerequisites
Before you begin, ensure you have the following installed and configured:
* The **Cursor** IDE
* **Git** for cloning the repository
* Access to an **Apache Doris** cluster (FE host, port, username, and password)
* **uv**, a fast Python package installer and runner
You can install `uv` with one of the following commands:
```bash
# For macOS (recommended)
brew install uv
# For other systems using pipx
pipx install uv
```
---
### Step 1: Set Up the Project
First, clone the `doris-mcp-server` repository to your local machine:
```bash
git clone https://github.com/apache/doris-mcp-server.git
cd doris-mcp-server
```
The necessary dependencies are listed in `requirements.txt` and will be managed automatically by `uv` in the next step.
---
### Step 2: Configure the MCP Server in Cursor
1. Open the cloned `doris-mcp-server` directory in Cursor.
2. Click the ⚙️ icon (top-right), then go to **Tools & Integrations**.
![add MCP Server](../images/cursor_add_mcp.png)
3. Click **Add a custom MCP Server**.
4. Paste the following JSON configuration:
```json
{
"mcpServers": {
"doris-mcp": {
"command": "uv",
"args": [
"run",
"--project",
"/path/to/your/doris-mcp-server",
"doris-mcp-server"
],
"env": {
"DORIS_HOST": "your_doris_fe_host",
"DORIS_PORT": "9030",
"DORIS_USER": "your_username",
"DORIS_PASSWORD": "your_password",
"DORIS_DATABASE": "ssb"
}
}
}
}
```
> ⚠️ **Important:**
>
> * Replace `"/path/to/your/doris-mcp-server"` with the **absolute path** to your local project directory.
> * Fill in your actual Doris FE host, username, password, and database name.
---
### Step 3: Verify the Integration
Once saved, go back to the **Settings** panel. If everything is configured correctly, youll see a green status dot next to `doris-mcp-server`, along with available tools like `exec_query`.
![MCP Server](../images/cursor_doris-mcp.png)
---
### Step 4: Query Your Database
You can now chat with Cursor Agent to run SQL queries against your Doris database.
1. Open the chat panel using `Cmd + K` (macOS) or `Ctrl + K` (Windows/Linux), or click the chat icon in the top-right.
2. Switch to **Agent Mode**.
3. Start asking questions using natural language.
![ask](../images/cursor_agent.png)
---
#### Example 1: List Tables
> **Prompt:** What tables are in the `ssb` database?
The agent will call the `get_db_table_list` tool and return the results.
![ask](../images/cursor_ask1.png)
---
#### Example 2: Analyze Sales Trends
> **Prompt:** What has been the sales trend over the past ten years in the `ssb` database, and which year had the fastest growth?
The agent will generate an appropriate SQL query, send it to the MCP server, and interpret the results to give you growth trends and highlights.
![ask](../images/cursor_ask2.png)

156
examples/dify/README.md Normal file
View File

@@ -0,0 +1,156 @@
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
# Dify Example: Integrating Doris MCP Server
This document demonstrates how to integrate and use `doris-mcp-server` in Dify to perform Doris SQL calls via MCP.
## Table of Contents
- [Prerequisites](#prerequisites)
- [Starting the MCP Server](#starting-the-mcp-server)
- [Ngrok Tunnel (Optional)](#ngrok-tunnel-optional)
- [Installing & Configuring the Plugin in Dify](#installing--configuring-the-plugin-in-dify)
- [Creating a Dify App](#creating-a-dify-app)
- [Adding MCP Tools](#adding-mcp-tools)
- [Example Calls](#example-calls)
-----
### Prerequisites
First, install `mcp-doris-server`:
```bash
pip install mcp-doris-server
```
## Starting the MCP Server
Run the startup script:
```bash
# Full configuration with database connection
doris-mcp-server \
--transport http \
--host 0.0.0.0 \
--port 3000 \
--db-host 127.0.0.1 \
--db-port 9030 \
--db-user root \
--db-password your_password
```
If successful, you'll see logs similar to this:
![Server start logs](../images/dify_start_server.png)
-----
## Ngrok Tunnel (Optional)
If your Dify deployment requires a publicly accessible endpoint, you can use the **ngrok** tool. Ngrok is a third-party service that securely exposes local servers to the internet.
-----
## Installing & Configuring the Plugin in Dify
1. In the Dify console, go to **Plugin Marketplace**, search for, and install **MCPSSE / StreamableHTTP**:
![Install plugin](../images/dify_install_plugin.png)
2. After installation, click **Configure** and set the URL to your public or local address. For example, if you're using `ngrok`, this should be the public URL `ngrok` provides, in the format `https://<your-domain>/mcp`. If Dify can directly access your local server, use `http://localhost:3000/mcp`.
```json
{
"doris_mcp_server": {
"transport": "streamable_http",
"url": "https://<your-domain>/mcp"
}
}
```
![Configure plugin](../images/dify_config_mcp.png)
3. Click **Save**. If configured correctly, you'll see a green **Authorized** indicator:
![Authorized](../images/dify_authorized.png)
-----
## Creating a Dify App
1. In the Dify console, click **New App** → **Blank App**.
![Create app](../images/dify_create_app.png)
2. Select **Agent** as the template and set the **App Name** (e.g., `Doris ChatBI`).
![Agent setup](../images/dify_agent_setup.png)
3. Import from DSL,[dify_doris_dsl.yml](dify_doris_dsl.yml)
-----
## Instructions & Tool Configuration
### Instruction Block
Paste the following into the **Instruction** field:
```
<instruction>
Use MCP tools to complete tasks as much as possible. Carefully read the annotations, method names, and parameter descriptions of each tool. Please follow these steps:
1. Analyze the user's question and match the most appropriate tool.
2. Use tool names and parameters exactly as defined; do not invent new ones.
3. Pass parameters in the required JSON format.
4. When calling tools, use:
{"mcp_sse_call_tool": {"tool_name": "<tool_name>", "arguments": "{}"}}
5. Output plain text only—no XML tags.
<input>
User question: user_query
</input>
<output>
Return tool results or a final answer, including analysis.
</output>
</instruction>
```
### Adding MCP Tools
In the **Tools** pane, click **Add** twice to add two entries, both named `mcp_sse` (they will inherit the transport and URL from the plugin):
![Add tools](../images/dify_add_tools.png)
-----
## Example Calls
### List Tables in Database
* **User**: What tables are in the database?
* **Result**: Dify will call the MCP tool to run `SHOW TABLES` and return the list.
![Query tables](../images/dify_query_tabels.png)
### Sales Trend Over Ten Years
* **User**: What has been the sales trend over the past ten years in the ssb database, and which year had the fastest growth?
* **Result**: The tool will execute the SQL, calculate growth rates, and return data.
![Sales trend](../images/dify_sale_trend.png)

View File

@@ -0,0 +1,127 @@
app:
description: ''
icon: 🤖
icon_background: '#FFEAD5'
mode: agent-chat
name: doris
use_icon_as_answer_icon: false
dependencies:
- current_identifier: null
type: marketplace
value:
marketplace_plugin_unique_identifier: langgenius/deepseek:0.0.5@21408d5c48cd9f18d66b08883d0999fe89e6d049c891324c2229dea23b9665d5
- current_identifier: null
type: marketplace
value:
marketplace_plugin_unique_identifier: junjiem/mcp_sse:0.2.1@53cc613667fcf91dd7208dd5f6d2c8df3c7ff0af8b79e8f3c0a430f1b39bda4c
kind: app
model_config:
agent_mode:
enabled: true
max_iteration: 10
prompt: null
strategy: function_call
tools:
- enabled: true
isDeleted: false
notAuthor: false
provider_id: junjiem/mcp_sse/mcp_sse
provider_name: junjiem/mcp_sse/mcp_sse
provider_type: builtin
tool_label: 获取 MCP 工具列表
tool_name: mcp_sse_list_tools
tool_parameters:
prompts_as_tools: 1
resources_as_tools: 1
servers_config: null
- enabled: true
isDeleted: false
notAuthor: false
provider_id: junjiem/mcp_sse/mcp_sse
provider_name: junjiem/mcp_sse/mcp_sse
provider_type: builtin
tool_label: 调用 MCP 工具
tool_name: mcp_sse_call_tool
tool_parameters:
arguments: ''
prompts_as_tools: ''
resources_as_tools: ''
servers_config: ''
tool_name: ''
annotation_reply:
enabled: false
chat_prompt_config: {}
completion_prompt_config: {}
dataset_configs:
datasets:
datasets: []
reranking_enable: true
reranking_mode: reranking_model
reranking_model:
reranking_model_name: ''
reranking_provider_name: ''
retrieval_model: multiple
top_k: 4
dataset_query_variable: ''
external_data_tools: []
file_upload:
allowed_file_extensions:
- .JPG
- .JPEG
- .PNG
- .GIF
- .WEBP
- .SVG
- .MP4
- .MOV
- .MPEG
- .WEBM
allowed_file_types: []
allowed_file_upload_methods:
- remote_url
- local_file
enabled: false
image:
detail: high
enabled: false
number_limits: 3
transfer_methods:
- remote_url
- local_file
number_limits: 3
model:
completion_params:
stop: []
mode: chat
name: deepseek-chat
provider: langgenius/deepseek/deepseek
more_like_this:
enabled: false
opening_statement: ''
pre_prompt: "<instruction>\nUse MCP tools to complete tasks as much as possible.\
\ Carefully read the annotations, method names, and parameter descriptions of\
\ each tool. Please follow these steps:\n1. Analyze the user's question and match\
\ the most appropriate tool.\n2. Use tool names and parameters exactly as defined;\
\ do not invent new ones.\n3. Pass parameters in the required JSON format.\n4.\
\ When calling tools, use:\n {\"mcp_sse_call_tool\": {\"tool_name\": \"<tool_name>\"\
, \"arguments\": \"{}\"}}\n5. Output plain text only—no XML tags.\n<input>\nUser\
\ question: user_query\n</input>\n<output>\nReturn tool results or a final answer,\
\ including analysis.\n</output>\n</instruction>"
prompt_type: simple
retriever_resource:
enabled: true
sensitive_word_avoidance:
configs: []
enabled: false
type: ''
speech_to_text:
enabled: false
suggested_questions: []
suggested_questions_after_answer:
enabled: false
text_to_speech:
enabled: false
language: ''
voice: ''
user_input_form: []
version: 0.3.0

Binary file not shown.

After

Width:  |  Height:  |  Size: 323 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 673 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 118 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 232 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 258 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 127 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 317 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 369 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 272 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 73 KiB

View File

@@ -20,7 +20,7 @@ build-backend = "hatchling.build"
[project] [project]
name = "doris-mcp-server" name = "doris-mcp-server"
version = "0.3.0" version = "0.6.1"
description = "Enterprise-grade Model Context Protocol (MCP) server implementation for Apache Doris" description = "Enterprise-grade Model Context Protocol (MCP) server implementation for Apache Doris"
authors = [ authors = [
{name = "Yijia Su", email = "freeoneplus@apache.org"} {name = "Yijia Su", email = "freeoneplus@apache.org"}
@@ -42,10 +42,14 @@ classifiers = [
dependencies = [ dependencies = [
# Core MCP dependencies # Core MCP dependencies
"mcp>=1.0.0", "mcp>=1.8.0,<2.0.0",
# Database drivers # Database drivers
"aiomysql>=0.2.0", "aiomysql>=0.2.0",
"PyMySQL>=1.1.0", "PyMySQL>=1.1.0",
# ADBC (Arrow Flight SQL) dependencies
"adbc-driver-manager>=0.8.0",
"adbc-driver-flightsql>=0.8.0",
"pyarrow>=14.0.0",
# Async and utility libraries # Async and utility libraries
"asyncio-mqtt>=0.16.0", "asyncio-mqtt>=0.16.0",
"aiofiles>=23.0.0", "aiofiles>=23.0.0",

View File

@@ -1,21 +1,5 @@
# Licensed to the Apache Software Foundation (ASF) under one # Development dependencies - auto-generated from pyproject.toml
# or more contributor license agreements. See the NOTICE file # Installation command: pip install -r requirements-dev.txt
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# 开发依赖 - 从 pyproject.toml 自动生成
# 安装命令: pip install -r requirements-dev.txt
pytest>=7.4.0 pytest>=7.4.0
pytest-asyncio>=0.23.0 pytest-asyncio>=0.23.0

View File

@@ -1,26 +1,13 @@
# Licensed to the Apache Software Foundation (ASF) under one # Main dependencies - auto-generated from pyproject.toml
# or more contributor license agreements. See the NOTICE file # Do not edit this file manually, use 'python generate_requirements.py' to regenerate
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# 主要依赖 - 从 pyproject.toml 自动生成
# 请不要手动编辑此文件,使用 python generate_requirements.py 重新生成
# === 核心依赖 === # === Core Dependencies ===
mcp>=1.0.0 mcp>=1.8.0,<2.0.0
aiomysql>=0.2.0 aiomysql>=0.2.0
PyMySQL>=1.1.0 PyMySQL>=1.1.0
adbc-driver-manager>=0.8.0
adbc-driver-flightsql>=0.8.0
pyarrow>=14.0.0
asyncio-mqtt>=0.16.0 asyncio-mqtt>=0.16.0
aiofiles>=23.0.0 aiofiles>=23.0.0
aiohttp>=3.9.0 aiohttp>=3.9.0
@@ -53,8 +40,11 @@ click>=8.1.0
typer>=0.9.0 typer>=0.9.0
requests>=2.31.0 requests>=2.31.0
tqdm>=4.66.0 tqdm>=4.66.0
pytest>=8.4.0
pytest-asyncio>=1.0.0
pytest-cov>=6.1.1
# === 开发依赖 === # === Development Dependencies ===
pytest>=7.4.0 pytest>=7.4.0
pytest-asyncio>=0.23.0 pytest-asyncio>=0.23.0
pytest-cov>=4.1.0 pytest-cov>=4.1.0

View File

@@ -64,9 +64,11 @@ else
fi fi
# Set HTTP-specific environment variables # Set HTTP-specific environment variables
# FIX for Issue #62 Bug 4: Use SERVER_PORT instead of MCP_PORT for consistency with code
export MCP_TRANSPORT_TYPE="http" export MCP_TRANSPORT_TYPE="http"
export MCP_HOST="${MCP_HOST:-0.0.0.0}" export MCP_HOST="${MCP_HOST:-0.0.0.0}"
export MCP_PORT="${MCP_PORT:-3000}" export SERVER_PORT="${SERVER_PORT:-3000}" # Changed from MCP_PORT to SERVER_PORT
export WORKERS="${WORKERS:-1}"
export ALLOWED_ORIGINS="${ALLOWED_ORIGINS:-*}" export ALLOWED_ORIGINS="${ALLOWED_ORIGINS:-*}"
export LOG_LEVEL="${LOG_LEVEL:-info}" export LOG_LEVEL="${LOG_LEVEL:-info}"
export MCP_ALLOW_CREDENTIALS="${MCP_ALLOW_CREDENTIALS:-false}" export MCP_ALLOW_CREDENTIALS="${MCP_ALLOW_CREDENTIALS:-false}"
@@ -76,14 +78,15 @@ export MCP_DEBUG_ADAPTER="true"
export PYTHONPATH="$(pwd):$PYTHONPATH" export PYTHONPATH="$(pwd):$PYTHONPATH"
echo -e "${GREEN}Starting MCP server (Streamable HTTP mode)...${NC}" echo -e "${GREEN}Starting MCP server (Streamable HTTP mode)...${NC}"
echo -e "${YELLOW}Service will run on http://${MCP_HOST}:${MCP_PORT}/mcp${NC}" echo -e "${YELLOW}Service will run on http://${MCP_HOST}:${SERVER_PORT}/mcp${NC}"
echo -e "${YELLOW}Health Check: http://${MCP_HOST}:${MCP_PORT}/health${NC}" echo -e "${YELLOW}Health Check: http://${MCP_HOST}:${SERVER_PORT}/health${NC}"
echo -e "${YELLOW}MCP Endpoint: http://${MCP_HOST}:${MCP_PORT}/mcp${NC}" echo -e "${YELLOW}MCP Endpoint: http://${MCP_HOST}:${SERVER_PORT}/mcp${NC}"
echo -e "${YELLOW}Local access: http://localhost:${MCP_PORT}/mcp${NC}" echo -e "${YELLOW}Local access: http://localhost:${SERVER_PORT}/mcp${NC}"
echo -e "${YELLOW}Workers: ${WORKERS}${NC}"
echo -e "${YELLOW}Use Ctrl+C to stop the service${NC}" echo -e "${YELLOW}Use Ctrl+C to stop the service${NC}"
# Start the server in HTTP mode (Streamable HTTP) # Start the server in HTTP mode (Streamable HTTP)
python -m doris_mcp_server.main --transport http --host ${MCP_HOST} --port ${MCP_PORT} python -m doris_mcp_server.main --transport http --host ${MCP_HOST} --port ${SERVER_PORT} --workers ${WORKERS}
# Check exit status # Check exit status
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
@@ -95,4 +98,4 @@ fi
echo -e "${YELLOW}Tip: If the page displays abnormally, please clear your browser cache or use incognito mode${NC}" echo -e "${YELLOW}Tip: If the page displays abnormally, please clear your browser cache or use incognito mode${NC}"
echo -e "${YELLOW}Chrome browser clear cache shortcut: Ctrl+Shift+Del (Windows) or Cmd+Shift+Del (Mac)${NC}" echo -e "${YELLOW}Chrome browser clear cache shortcut: Ctrl+Shift+Del (Windows) or Cmd+Shift+Del (Mac)${NC}"
echo -e "${CYAN}For testing HTTP endpoints, you can use:${NC}" echo -e "${CYAN}For testing HTTP endpoints, you can use:${NC}"
echo -e "${CYAN} curl -X POST http://localhost:${MCP_PORT}/mcp -H 'Content-Type: application/json' -d '{\"method\":\"tools/list\"}'${NC}" echo -e "${CYAN} curl -X POST http://localhost:${SERVER_PORT}/mcp -H 'Content-Type: application/json' -d '{\"method\":\"tools/list\"}'${NC}"

View File

@@ -47,22 +47,29 @@ def event_loop():
@pytest.fixture @pytest.fixture
def test_config(): def test_config():
"""Provide test configuration""" """Test configuration fixture"""
return { from doris_mcp_server.utils.config import DorisConfig, DatabaseConfig, SecurityConfig
"doris_host": "localhost",
"doris_port": 9030, config = DorisConfig()
"doris_user": "test_user",
"doris_password": "test_password", # Database configuration
"doris_database": "test_db", config.database.host = "localhost"
"blocked_keywords": ["DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE"], config.database.port = 9030
"sensitive_tables": { config.database.user = "test_user"
"user_info": "confidential", config.database.password = "test_password"
"payment_records": "secret", config.database.database = "test_db"
"employee_data": "confidential", config.database.health_check_interval = 60
"public_reports": "public" config.database.max_connections = 20
}, config.database.connection_timeout = 30
"max_query_complexity": 100 config.database.max_connection_age = 3600
}
# Security configuration
config.security.enable_masking = True
config.security.auth_type = "token"
config.security.token_secret = "test_secret"
config.security.token_expiry = 3600
return config
@pytest.fixture @pytest.fixture

View File

@@ -34,17 +34,9 @@ class TestEndToEndIntegration:
@pytest.fixture @pytest.fixture
def mock_config(self): def mock_config(self):
"""Create mock configuration""" """Create mock configuration"""
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig from doris_mcp_server.utils.config import ADBCConfig, DatabaseConfig, SecurityConfig
config = Mock(spec=DorisConfig) config = Mock(spec=DorisConfig)
config.doris_host = "localhost"
config.doris_port = 9030
config.doris_user = "test_user"
config.doris_password = "test_password"
config.doris_database = "test_db"
config.server_host = "localhost"
config.server_port = 8000
config.enable_security = True
# Add database config # Add database config
config.database = Mock(spec=DatabaseConfig) config.database = Mock(spec=DatabaseConfig)
@@ -54,7 +46,6 @@ class TestEndToEndIntegration:
config.database.password = "test_password" config.database.password = "test_password"
config.database.database = "test_db" config.database.database = "test_db"
config.database.health_check_interval = 60 config.database.health_check_interval = 60
config.database.min_connections = 5
config.database.max_connections = 20 config.database.max_connections = 20
config.database.connection_timeout = 30 config.database.connection_timeout = 30
config.database.max_connection_age = 3600 config.database.max_connection_age = 3600
@@ -65,6 +56,11 @@ class TestEndToEndIntegration:
config.security.auth_type = "token" config.security.auth_type = "token"
config.security.token_secret = "test_secret" config.security.token_secret = "test_secret"
config.security.token_expiry = 3600 config.security.token_expiry = 3600
config.security.blocked_keywords = ["DROP"]
# Add adbc config
config.adbc = Mock(spec=ADBCConfig)
config.adbc.enabled = True
return config return config
@@ -239,7 +235,7 @@ class TestEndToEndIntegration:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_tool_execution_with_security(self, doris_server): async def test_tool_execution_with_security(self, doris_server):
"""Test tool execution with security checks""" """Test tool execution with security checks"""
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute: with patch.object(doris_server.tools_manager.connection_manager, 'execute_query') as mock_execute:
mock_execute.return_value = [{"Database": "test_db"}] mock_execute.return_value = [{"Database": "test_db"}]
# Test tool execution through tools manager # Test tool execution through tools manager
@@ -266,7 +262,7 @@ class TestEndToEndIntegration:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_performance_monitoring_integration(self, doris_server): async def test_performance_monitoring_integration(self, doris_server):
"""Test performance monitoring integration""" """Test performance monitoring integration"""
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute: with patch.object(doris_server.tools_manager.connection_manager, 'execute_query') as mock_execute:
mock_execute.return_value = [ mock_execute.return_value = [
{ {
"query_count": 1500, "query_count": 1500,
@@ -277,10 +273,7 @@ class TestEndToEndIntegration:
] ]
# Test performance stats tool # Test performance stats tool
result = await doris_server.tools_manager.call_tool("performance_stats", { result = await doris_server.tools_manager.call_tool("get_db_list", {})
"metric_type": "queries",
"time_range": "1h"
})
result_data = json.loads(result) result_data = json.loads(result)
# Accept either success result or error (due to mock environment) # Accept either success result or error (due to mock environment)

View File

@@ -0,0 +1,367 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
SQL Security Test Suite for Apache Doris MCP Server
Tests for:
1. SQL injection prevention via identifier validation
2. Multi-statement SQL parsing in security validator
3. auth_context enforcement
"""
import pytest
import asyncio
from unittest.mock import MagicMock, AsyncMock, patch
class TestSQLSecurityUtils:
"""Test cases for sql_security_utils module"""
def test_validate_identifier_accepts_valid_names(self):
"""Test that valid identifiers are accepted"""
from doris_mcp_server.utils.sql_security_utils import validate_identifier
valid_names = [
"users",
"my_table",
"Table123",
"_private_table",
"CamelCaseTable",
"table_with_numbers_123",
]
for name in valid_names:
result = validate_identifier(name, "table")
assert result == name, f"Valid identifier '{name}' should be accepted"
def test_validate_identifier_rejects_sql_injection(self):
"""Test that SQL injection attempts are rejected"""
from doris_mcp_server.utils.sql_security_utils import (
validate_identifier,
SQLSecurityError
)
injection_attempts = [
# Basic SQL injection
"'; DROP TABLE users; --",
"table' OR '1'='1",
"table'; DELETE FROM users; --",
# Union-based injection
"table' UNION SELECT * FROM passwords --",
# Comment injection
"table/**/OR/**/1=1",
"table--comment",
# Special characters
"table`; DROP TABLE users;",
'table"; DROP TABLE users;',
"table\"; DELETE FROM",
# Backtick escape attempt
"analytics`; SELECT * FROM sensitive_table;--",
# Whitespace injection
"table name with spaces",
"table\ttab",
"table\nnewline",
]
for injection in injection_attempts:
with pytest.raises(SQLSecurityError):
validate_identifier(injection, "table")
def test_validate_identifier_rejects_empty(self):
"""Test that empty identifiers are rejected"""
from doris_mcp_server.utils.sql_security_utils import (
validate_identifier,
SQLSecurityError
)
with pytest.raises(SQLSecurityError):
validate_identifier("", "table")
with pytest.raises(SQLSecurityError):
validate_identifier(None, "table")
def test_validate_identifier_rejects_too_long(self):
"""Test that identifiers exceeding max length are rejected"""
from doris_mcp_server.utils.sql_security_utils import (
validate_identifier,
SQLSecurityError
)
# Doris identifier max length is typically 64 characters
long_name = "a" * 100
with pytest.raises(SQLSecurityError):
validate_identifier(long_name, "table")
def test_quote_identifier_adds_backticks(self):
"""Test that quote_identifier properly escapes identifiers"""
from doris_mcp_server.utils.sql_security_utils import quote_identifier
assert quote_identifier("my_table", "table") == "`my_table`"
assert quote_identifier("users", "table") == "`users`"
assert quote_identifier("Table123", "table") == "`Table123`"
def test_quote_identifier_validates_first(self):
"""Test that quote_identifier validates before quoting"""
from doris_mcp_server.utils.sql_security_utils import (
quote_identifier,
SQLSecurityError
)
with pytest.raises(SQLSecurityError):
quote_identifier("'; DROP TABLE users; --", "table")
class TestSQLSecurityValidator:
"""Test cases for SQLSecurityValidator multi-statement parsing"""
@pytest.fixture
def dict_config(self):
"""Create dictionary configuration"""
return {
"blocked_keywords": [
"DROP", "CREATE", "ALTER", "TRUNCATE",
"DELETE", "INSERT", "UPDATE",
"GRANT", "REVOKE", "EXEC", "EXECUTE"
],
"max_query_complexity": 100,
"enable_security_check": True
}
@pytest.fixture
def mock_auth_context(self):
"""Create mock auth context"""
from doris_mcp_server.utils.security import AuthContext, SecurityLevel
return AuthContext(
user_id="test_user",
roles=["user"],
security_level=SecurityLevel.INTERNAL
)
@pytest.mark.asyncio
async def test_validates_all_statements(self, dict_config, mock_auth_context):
"""Test that validator checks ALL SQL statements, not just the first"""
from doris_mcp_server.utils.security import SQLSecurityValidator
validator = SQLSecurityValidator(dict_config)
# Multi-statement with injection in second statement
# This should be BLOCKED
malicious_sql = "SELECT 1; DROP TABLE users; SELECT 2"
result = await validator.validate(malicious_sql, mock_auth_context)
assert not result.is_valid, "Multi-statement injection should be blocked"
# Check for either DROP keyword detection or SQL injection detection
error_upper = result.error_message.upper()
assert ("DROP" in error_upper or
"INJECTION" in error_upper or
"BLOCKED" in error_upper), f"Expected DROP/injection/blocked in: {result.error_message}"
@pytest.mark.asyncio
async def test_blocks_hidden_dangerous_statement(self, dict_config, mock_auth_context):
"""Test that dangerous statements hidden after safe ones are blocked"""
from doris_mcp_server.utils.security import SQLSecurityValidator
validator = SQLSecurityValidator(dict_config)
# Safe statement followed by dangerous one
malicious_sql = """
SELECT * FROM users WHERE id = 1;
DELETE FROM audit_log;
SELECT 1;
"""
result = await validator.validate(malicious_sql, mock_auth_context)
assert not result.is_valid, "Hidden DELETE statement should be blocked"
@pytest.mark.asyncio
async def test_allows_safe_multi_statement(self, dict_config, mock_auth_context):
"""Test that multiple safe SELECT statements are allowed"""
from doris_mcp_server.utils.security import SQLSecurityValidator
validator = SQLSecurityValidator(dict_config)
safe_sql = """
SELECT * FROM users;
SELECT COUNT(*) FROM orders;
SELECT id, name FROM products;
"""
result = await validator.validate(safe_sql, mock_auth_context)
assert result.is_valid, f"Multiple safe SELECT statements should be allowed, got: {result.error_message}"
@pytest.mark.asyncio
async def test_context_switch_injection_blocked(self, dict_config, mock_auth_context):
"""Test that context switch SQL injection is blocked"""
from doris_mcp_server.utils.security import SQLSecurityValidator
validator = SQLSecurityValidator(dict_config)
# Simulating the exec_query_for_mcp attack vector
injected_sql = """
USE `analytics`; SELECT * FROM sensitive_table;-- `;
SELECT * FROM public_table;
"""
result = await validator.validate(injected_sql, mock_auth_context)
# The validator should process all statements
# Even if USE is allowed, subsequent unauthorized access should be caught
# by table access checks (if configured)
class TestExecQueryForMCP:
"""Test cases for exec_query_for_mcp function"""
@pytest.mark.asyncio
async def test_rejects_malicious_db_name(self):
"""Test that malicious db_name is rejected"""
from doris_mcp_server.utils.sql_security_utils import (
validate_identifier,
SQLSecurityError
)
# The attack vector from security report
malicious_db_name = "analytics`; SELECT * FROM sensitive_table;--"
with pytest.raises(SQLSecurityError):
validate_identifier(malicious_db_name, "database name")
@pytest.mark.asyncio
async def test_rejects_malicious_catalog_name(self):
"""Test that malicious catalog_name is rejected"""
from doris_mcp_server.utils.sql_security_utils import (
validate_identifier,
SQLSecurityError
)
malicious_catalog_name = "internal'; DROP DATABASE production;--"
with pytest.raises(SQLSecurityError):
validate_identifier(malicious_catalog_name, "catalog name")
class TestDependencyAnalysisTools:
"""Test cases for dependency_analysis_tools security fixes"""
@pytest.mark.asyncio
async def test_get_tables_metadata_rejects_injection(self):
"""Test that _get_tables_metadata rejects SQL injection"""
from doris_mcp_server.utils.sql_security_utils import (
validate_identifier,
SQLSecurityError
)
# The attack vector from security report
injection_db_name = "test_db' OR '1'='1' --"
with pytest.raises(SQLSecurityError):
validate_identifier(injection_db_name, "database name")
class TestAuthContextEnforcement:
"""Test cases for auth_context enforcement"""
def test_execute_requires_auth_context_for_security(self):
"""Test that security checks require auth_context"""
# This test documents the expected behavior:
# When auth_context is None, security checks are skipped
# When auth_context is provided, security checks are performed
# The fix ensures all execute() calls pass auth_context
pass
@pytest.mark.asyncio
async def test_get_auth_context_returns_context(self):
"""Test that get_auth_context retrieves context from ContextVar"""
from doris_mcp_server.utils.sql_security_utils import get_auth_context
# When no context is set, should return None
result = get_auth_context()
# This is expected - context is set by HTTP middleware
assert result is None or hasattr(result, 'user_id')
class TestIntegrationScenarios:
"""Integration test scenarios for security fixes"""
def test_attack_scenario_1_permission_bypass(self):
"""
Attack Scenario 1: Permission Bypass in Multi-Tenant Environment
Expected: User can only query their own database (db_name="tenant_a_db")
Attack: Inject "tenant_a_db' OR '1'='1' --" to query ALL databases
Result: Should be BLOCKED by validate_identifier()
"""
from doris_mcp_server.utils.sql_security_utils import (
validate_identifier,
SQLSecurityError
)
with pytest.raises(SQLSecurityError):
validate_identifier("tenant_a_db' OR '1'='1' --", "database name")
def test_attack_scenario_2_union_injection(self):
"""
Attack Scenario 2: UNION-based Information Disclosure
Attack: Inject UNION SELECT to extract sensitive data
Result: Should be BLOCKED
"""
from doris_mcp_server.utils.sql_security_utils import (
validate_identifier,
SQLSecurityError
)
with pytest.raises(SQLSecurityError):
validate_identifier(
"test' UNION SELECT password FROM users --",
"database name"
)
def test_attack_scenario_3_backtick_escape(self):
"""
Attack Scenario 3: Backtick Escape Attempt
Attack: Use backticks to break out of quoted identifier
Result: Should be BLOCKED
"""
from doris_mcp_server.utils.sql_security_utils import (
validate_identifier,
SQLSecurityError
)
with pytest.raises(SQLSecurityError):
validate_identifier(
"analytics`; SELECT * FROM sensitive_table;--",
"database name"
)
# Run tests with: pytest tests/test_sql_security.py -v
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])

View File

@@ -0,0 +1,871 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
SQL Injection API Integration Tests
This module tests SQL injection prevention through the MCP HTTP API.
It sends malicious payloads and verifies they are properly blocked.
Prerequisites:
- MCP server running on localhost:3000
- Run with: pytest test/security/test_sql_injection_api.py -v
Usage:
# Start server first
bash start_server.sh
# Run tests
pytest test/security/test_sql_injection_api.py -v --no-cov
"""
import pytest
import httpx
import json
import asyncio
from typing import Optional
# Server configuration
MCP_BASE_URL = "http://localhost:3000"
MCP_ENDPOINT = f"{MCP_BASE_URL}/mcp"
HEALTH_ENDPOINT = f"{MCP_BASE_URL}/health"
TIMEOUT = 30.0
class MCPClient:
"""Simple MCP HTTP client for testing"""
def __init__(self, base_url: str = MCP_BASE_URL):
self.base_url = base_url
self.mcp_endpoint = f"{base_url}/mcp"
self.session_id: Optional[str] = None
self.request_id = 0
self.client = httpx.AsyncClient(timeout=TIMEOUT)
async def close(self):
await self.client.aclose()
def _next_id(self) -> int:
self.request_id += 1
return self.request_id
async def initialize(self) -> dict:
"""Initialize MCP session"""
response = await self.client.post(
self.mcp_endpoint,
headers={
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream"
},
json={
"jsonrpc": "2.0",
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {
"name": "sql-injection-test",
"version": "1.0.0"
}
},
"id": self._next_id()
}
)
# Extract session ID from response header
self.session_id = response.headers.get("mcp-session-id")
return self._parse_response(response.text)
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
"""Call an MCP tool"""
if not self.session_id:
await self.initialize()
response = await self.client.post(
self.mcp_endpoint,
headers={
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
"mcp-session-id": self.session_id
},
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
},
"id": self._next_id()
}
)
return self._parse_response(response.text)
def _parse_response(self, text: str) -> dict:
"""Parse JSON response"""
try:
return json.loads(text)
except json.JSONDecodeError:
# Try SSE format
lines = text.strip().split("\n")
for line in lines:
if line.startswith("data: "):
try:
return json.loads(line[6:])
except json.JSONDecodeError:
continue
return {"raw": text}
def print_result(test_name: str, payload: dict, result: dict):
"""Print test result in a readable format"""
print(f"\n{'='*60}")
print(f"TEST: {test_name}")
print(f"{'='*60}")
print(f"PAYLOAD: {json.dumps(payload, ensure_ascii=False)}")
print(f"{'-'*60}")
# Extract inner result content
if "result" in result and "content" in result.get("result", {}):
for item in result["result"]["content"]:
if item.get("type") == "text":
try:
inner = json.loads(item["text"])
print("RESPONSE:")
print(f" success: {inner.get('success')}")
if inner.get('error'):
print(f" error: {inner.get('error')}")
if inner.get('error_type'):
print(f" error_type: {inner.get('error_type')}")
if inner.get('risk_level'):
print(f" risk_level: {inner.get('risk_level')}")
if inner.get('message'):
print(f" message: {inner.get('message')}")
if inner.get('data') is not None and inner.get('success'):
data_str = json.dumps(inner.get('data'), ensure_ascii=False)
if len(data_str) > 200:
data_str = data_str[:200] + "..."
print(f" data: {data_str}")
except (json.JSONDecodeError, TypeError):
print(f"RESPONSE (raw): {item.get('text', '')[:500]}")
elif "error" in result:
print(f"RESPONSE ERROR: {result['error']}")
else:
print(f"RESPONSE (raw): {json.dumps(result, ensure_ascii=False)[:500]}")
print(f"{'='*60}\n")
class TestSQLInjectionAPI:
"""Test SQL injection prevention through MCP API"""
@pytest.fixture
async def mcp_client(self):
"""Create MCP client instance"""
client = MCPClient()
yield client
await client.close()
@pytest.fixture
def is_server_running(self):
"""Check if MCP server is running"""
import httpx
try:
response = httpx.get(HEALTH_ENDPOINT, timeout=5.0)
return response.status_code == 200
except Exception:
return False
@pytest.mark.asyncio
async def test_server_health(self):
"""Test that MCP server is running and healthy"""
async with httpx.AsyncClient() as client:
response = await client.get(HEALTH_ENDPOINT)
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
@pytest.mark.asyncio
async def test_exec_query_with_drop_injection(self, mcp_client):
"""Test exec_query rejects DROP TABLE injection"""
# Classic SQL injection: append DROP TABLE
payload = {"sql": "SELECT * FROM users; DROP TABLE users; --"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("DROP TABLE Injection", payload, result)
# Should return error, not execute the DROP
assert self._is_blocked_or_error(result), \
f"DROP TABLE injection should be blocked"
@pytest.mark.asyncio
async def test_exec_query_with_union_injection(self, mcp_client):
"""Test exec_query blocks UNION-based injection attempts"""
# UNION injection to extract data from other tables
payload = {"sql": "SELECT id FROM users UNION SELECT password FROM admin_users"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("UNION Injection", payload, result)
@pytest.mark.asyncio
async def test_exec_query_with_delete_injection(self, mcp_client):
"""Test exec_query rejects DELETE injection"""
payload = {"sql": "SELECT 1; DELETE FROM users WHERE 1=1; SELECT 2"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("DELETE Injection", payload, result)
assert self._is_blocked_or_error(result), \
f"DELETE injection should be blocked"
@pytest.mark.asyncio
async def test_exec_query_with_update_injection(self, mcp_client):
"""Test exec_query rejects UPDATE injection"""
payload = {"sql": "SELECT 1; UPDATE users SET role='admin' WHERE id=1; --"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("UPDATE Injection", payload, result)
assert self._is_blocked_or_error(result), \
f"UPDATE injection should be blocked"
@pytest.mark.asyncio
async def test_exec_query_db_name_injection(self, mcp_client):
"""Test exec_query rejects SQL injection via db_name parameter"""
# Attack vector: inject SQL via db_name parameter
payload = {"sql": "SELECT 1", "db_name": "test'; DROP TABLE users; --"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("db_name Parameter Injection", payload, result)
assert self._is_blocked_or_error(result), \
f"db_name injection should be blocked"
@pytest.mark.asyncio
async def test_exec_query_catalog_name_injection(self, mcp_client):
"""Test exec_query rejects SQL injection via catalog_name parameter"""
# Attack vector: inject SQL via catalog_name parameter
payload = {"sql": "SELECT 1", "catalog_name": "internal`; SELECT * FROM mysql.user; --"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("catalog_name Parameter Injection", payload, result)
assert self._is_blocked_or_error(result), \
f"catalog_name injection should be blocked"
@pytest.mark.asyncio
async def test_get_table_schema_injection(self, mcp_client):
"""Test get_table_schema rejects SQL injection via table_name"""
# Attack vector: inject SQL via table_name parameter
payload = {"table_name": "users'; DROP TABLE users; --"}
result = await mcp_client.call_tool("get_table_schema", payload)
print_result("table_name Injection (get_table_schema)", payload, result)
assert self._is_blocked_or_error(result), \
f"table_name injection should be blocked"
@pytest.mark.asyncio
async def test_get_table_schema_db_injection(self, mcp_client):
"""Test get_table_schema rejects SQL injection via db_name"""
payload = {"table_name": "users", "db_name": "test' OR '1'='1"}
result = await mcp_client.call_tool("get_table_schema", payload)
print_result("db_name Injection (get_table_schema)", payload, result)
assert self._is_blocked_or_error(result), \
f"db_name injection in get_table_schema should be blocked"
@pytest.mark.asyncio
async def test_analyze_dependencies_injection(self, mcp_client):
"""Test analyze_dependencies rejects SQL injection"""
# This was the original vulnerability reported
payload = {"table_name": "users", "db_name": "test_db' OR '1'='1' --"}
result = await mcp_client.call_tool("analyze_dependencies", payload)
print_result("analyze_dependencies Injection (Original Report)", payload, result)
assert self._is_blocked_or_error(result), \
f"analyze_dependencies db_name injection should be blocked"
@pytest.mark.asyncio
async def test_stacked_queries_injection(self, mcp_client):
"""Test that stacked queries (multiple statements) are blocked"""
# Multiple statements injection
payload = {"sql": "SELECT * FROM users WHERE id = 1; INSERT INTO audit_log VALUES (NULL, 'hacked', NOW()); SELECT 1;"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("Stacked Queries (INSERT) Injection", payload, result)
assert self._is_blocked_or_error(result), \
f"Stacked queries with INSERT should be blocked"
@pytest.mark.asyncio
async def test_comment_based_injection(self, mcp_client):
"""Test that comment-based injection is blocked"""
# Using comments to bypass filters
payload = {"sql": "SELECT * FROM users WHERE id = 1/**/OR/**/1=1"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("Comment-based Injection", payload, result)
@pytest.mark.asyncio
async def test_hex_encoded_injection(self, mcp_client):
"""Test that hex-encoded injection attempts are handled"""
# Hex-encoded 'DROP' attempt
payload = {"sql": "SELECT 0x44524F50205441424C4520757365727320"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("Hex Encoded Injection", payload, result)
@pytest.mark.asyncio
async def test_backtick_escape_injection(self, mcp_client):
"""Test backtick escape injection is blocked"""
# Attempt to escape backtick quoting
payload = {"sql": "SELECT 1", "db_name": "analytics`; SELECT * FROM sensitive_table;--"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("Backtick Escape Injection", payload, result)
assert self._is_blocked_or_error(result), \
f"Backtick escape injection should be blocked"
@pytest.mark.asyncio
async def test_valid_query_succeeds(self, mcp_client):
"""Test that valid queries still work"""
# Simple valid query should work
payload = {"sql": "SELECT 1 AS test_value"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("Valid Query (should succeed)", payload, result)
@pytest.mark.asyncio
async def test_valid_show_databases(self, mcp_client):
"""Test that SHOW DATABASES works"""
payload = {"sql": "SHOW DATABASES"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("SHOW DATABASES (should succeed)", payload, result)
def _is_blocked_or_error(self, result: dict) -> bool:
"""Check if result indicates blocked or error"""
if not result:
return True
# Check for JSON-RPC error
if "error" in result:
return True
# Check for error in result content
if "result" in result:
result_content = result.get("result", {})
if isinstance(result_content, dict):
# Check for isError flag
if result_content.get("isError"):
return True
# Check content array for error messages
content = result_content.get("content", [])
for item in content:
if isinstance(item, dict):
text = item.get("text", "")
# Parse the JSON text content
try:
text_data = json.loads(text)
# Check for success: false
if text_data.get("success") is False:
return True
# Check for error field
if text_data.get("error"):
return True
except (json.JSONDecodeError, TypeError):
pass
# Check text for security keywords
if any(keyword in text.lower() for keyword in [
"error", "blocked", "invalid", "security",
"injection", "denied", "forbidden", "not allowed",
"security_violation", "risk_level"
]):
return True
# Check raw text response
raw = result.get("raw", "")
if isinstance(raw, str) and any(keyword in raw.lower() for keyword in [
"error", "blocked", "invalid", "security"
]):
return True
return False
class TestIdentifierInjectionAPI:
"""Test identifier-based SQL injection prevention"""
@pytest.fixture
async def mcp_client(self):
"""Create MCP client instance"""
client = MCPClient()
yield client
await client.close()
@pytest.mark.asyncio
async def test_table_name_with_semicolon(self, mcp_client):
"""Test table name containing semicolon is rejected"""
payload = {"table_name": "users; DROP TABLE users"}
result = await mcp_client.call_tool("get_table_schema", payload)
print_result("Table Name with Semicolon", payload, result)
# Should be blocked by identifier validation
assert self._contains_error_indicator(result), \
f"Table name with semicolon should be rejected"
@pytest.mark.asyncio
async def test_table_name_with_quotes(self, mcp_client):
"""Test table name containing quotes is rejected"""
payload = {"table_name": "users' OR '1'='1"}
result = await mcp_client.call_tool("get_table_schema", payload)
print_result("Table Name with Quotes", payload, result)
assert self._contains_error_indicator(result), \
f"Table name with quotes should be rejected"
@pytest.mark.asyncio
async def test_db_name_with_special_chars(self, mcp_client):
"""Test database name with special characters is rejected"""
special_chars = [
"test;db",
"test'db",
"test\"db",
"test`db",
"test--db",
"test/*db*/",
]
for db_name in special_chars:
payload = {"table_name": "users", "db_name": db_name}
result = await mcp_client.call_tool("get_table_schema", payload)
print_result(f"Special Char in db_name: {db_name}", payload, result)
assert self._contains_error_indicator(result), \
f"db_name '{db_name}' should be rejected"
@pytest.mark.asyncio
async def test_valid_identifiers_accepted(self, mcp_client):
"""Test that valid identifiers are accepted"""
valid_names = [
"users",
"my_table",
"Table123",
"_internal_table",
]
for table_name in valid_names:
payload = {"table_name": table_name}
result = await mcp_client.call_tool("get_table_schema", payload)
print_result(f"Valid Identifier: {table_name}", payload, result)
def _contains_error_indicator(self, result: dict) -> bool:
"""Check if result contains error indicators"""
if not result:
return True
# Check for JSON-RPC error
if "error" in result:
return True
# Check result content
result_str = json.dumps(result).lower()
error_keywords = [
"error", "invalid", "illegal", "blocked",
"security", "injection", "denied", "forbidden"
]
return any(keyword in result_str for keyword in error_keywords)
class TestMultiStatementInjectionAPI:
"""Test multi-statement SQL injection prevention"""
@pytest.fixture
async def mcp_client(self):
"""Create MCP client instance"""
client = MCPClient()
yield client
await client.close()
@pytest.mark.asyncio
async def test_hidden_drop_after_select(self, mcp_client):
"""Test DROP hidden after legitimate SELECT is blocked"""
payload = {"sql": "SELECT id, name FROM users WHERE status = 'active'; DROP TABLE audit_log; SELECT 1;"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("Hidden DROP after SELECT", payload, result)
assert self._is_dangerous_blocked(result), \
f"Hidden DROP statement should be blocked"
@pytest.mark.asyncio
async def test_hidden_truncate_after_select(self, mcp_client):
"""Test TRUNCATE hidden after SELECT is blocked"""
payload = {"sql": "SELECT 1; TRUNCATE TABLE users"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("Hidden TRUNCATE after SELECT", payload, result)
assert self._is_dangerous_blocked(result), \
f"Hidden TRUNCATE should be blocked"
@pytest.mark.asyncio
async def test_hidden_grant_after_select(self, mcp_client):
"""Test GRANT hidden after SELECT is blocked"""
payload = {"sql": "SELECT 1; GRANT ALL ON *.* TO 'hacker'@'%'"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("Hidden GRANT after SELECT", payload, result)
assert self._is_dangerous_blocked(result), \
f"Hidden GRANT should be blocked"
@pytest.mark.asyncio
async def test_multiple_safe_selects_allowed(self, mcp_client):
"""Test that multiple SELECT statements may be allowed"""
payload = {"sql": "SELECT 1; SELECT 2; SELECT 3;"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("Multiple Safe SELECTs", payload, result)
def _is_dangerous_blocked(self, result: dict) -> bool:
"""Check if dangerous operation was blocked"""
if not result:
return True
# Check for error
if "error" in result:
return True
# Check result content for blocking indicators
result_str = json.dumps(result).lower()
block_indicators = [
"drop", "truncate", "grant", "revoke",
"blocked", "denied", "forbidden", "not allowed",
"security", "error"
]
return any(indicator in result_str for indicator in block_indicators)
class TestADBCQueryInjectionAPI:
"""Test ADBC query SQL injection prevention"""
@pytest.fixture
async def mcp_client(self):
"""Create MCP client instance"""
client = MCPClient()
yield client
await client.close()
@pytest.mark.asyncio
async def test_exec_adbc_query_drop_injection(self, mcp_client):
"""Test exec_adbc_query rejects DROP TABLE injection"""
payload = {"sql": "SELECT * FROM users; DROP TABLE users; --"}
result = await mcp_client.call_tool("exec_adbc_query", payload)
print_result("ADBC DROP TABLE Injection", payload, result)
assert self._is_blocked_or_error(result), \
"ADBC DROP TABLE injection should be blocked"
@pytest.mark.asyncio
async def test_exec_adbc_query_delete_injection(self, mcp_client):
"""Test exec_adbc_query rejects DELETE injection"""
payload = {"sql": "SELECT 1; DELETE FROM users; --"}
result = await mcp_client.call_tool("exec_adbc_query", payload)
print_result("ADBC DELETE Injection", payload, result)
assert self._is_blocked_or_error(result), \
"ADBC DELETE injection should be blocked"
@pytest.mark.asyncio
async def test_exec_adbc_query_valid(self, mcp_client):
"""Test exec_adbc_query allows valid queries"""
payload = {"sql": "SELECT 1 AS test"}
result = await mcp_client.call_tool("exec_adbc_query", payload)
print_result("ADBC Valid Query", payload, result)
def _is_blocked_or_error(self, result: dict) -> bool:
"""Check if result indicates blocked or error"""
if not result:
return True
if "error" in result:
return True
result_str = json.dumps(result).lower()
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
class TestMetadataToolsInjectionAPI:
"""Test metadata tools SQL injection prevention"""
@pytest.fixture
async def mcp_client(self):
"""Create MCP client instance"""
client = MCPClient()
yield client
await client.close()
@pytest.mark.asyncio
async def test_get_db_table_list_db_injection(self, mcp_client):
"""Test get_db_table_list rejects db_name injection"""
payload = {"db_name": "test'; DROP TABLE users; --"}
result = await mcp_client.call_tool("get_db_table_list", payload)
print_result("get_db_table_list db_name Injection", payload, result)
assert self._is_blocked_or_error(result), \
"db_name injection should be blocked"
@pytest.mark.asyncio
async def test_get_db_table_list_catalog_injection(self, mcp_client):
"""Test get_db_table_list rejects catalog_name injection"""
payload = {"catalog_name": "internal`; SELECT * FROM mysql.user; --"}
result = await mcp_client.call_tool("get_db_table_list", payload)
print_result("get_db_table_list catalog_name Injection", payload, result)
assert self._is_blocked_or_error(result), \
"catalog_name injection should be blocked"
@pytest.mark.asyncio
async def test_get_table_comment_injection(self, mcp_client):
"""Test get_table_comment rejects table_name injection"""
payload = {"table_name": "users'; DROP TABLE users; --"}
result = await mcp_client.call_tool("get_table_comment", payload)
print_result("get_table_comment table_name Injection", payload, result)
assert self._is_blocked_or_error(result), \
"table_name injection should be blocked"
@pytest.mark.asyncio
async def test_get_table_column_comments_injection(self, mcp_client):
"""Test get_table_column_comments rejects injection"""
payload = {"table_name": "users'; DROP TABLE users; --", "db_name": "test"}
result = await mcp_client.call_tool("get_table_column_comments", payload)
print_result("get_table_column_comments Injection", payload, result)
assert self._is_blocked_or_error(result), \
"table_name injection should be blocked"
@pytest.mark.asyncio
async def test_get_table_indexes_injection(self, mcp_client):
"""Test get_table_indexes rejects table_name injection"""
payload = {"table_name": "users; DROP TABLE users", "db_name": "test"}
result = await mcp_client.call_tool("get_table_indexes", payload)
print_result("get_table_indexes Injection", payload, result)
assert self._is_blocked_or_error(result), \
"table_name injection should be blocked"
def _is_blocked_or_error(self, result: dict) -> bool:
"""Check if result indicates blocked or error"""
if not result:
return True
if "error" in result:
return True
result_str = json.dumps(result).lower()
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
class TestAnalyticsToolsInjectionAPI:
"""Test analytics tools SQL injection prevention"""
@pytest.fixture
async def mcp_client(self):
"""Create MCP client instance"""
client = MCPClient()
yield client
await client.close()
@pytest.mark.asyncio
async def test_analyze_columns_table_injection(self, mcp_client):
"""Test analyze_columns rejects table_name injection"""
payload = {"table_name": "users'; DROP TABLE users; --"}
result = await mcp_client.call_tool("analyze_columns", payload)
print_result("analyze_columns table_name Injection", payload, result)
assert self._is_blocked_or_error(result), \
"table_name injection should be blocked"
@pytest.mark.asyncio
async def test_analyze_columns_db_injection(self, mcp_client):
"""Test analyze_columns rejects db_name injection"""
payload = {"table_name": "users", "db_name": "test' OR '1'='1"}
result = await mcp_client.call_tool("analyze_columns", payload)
print_result("analyze_columns db_name Injection", payload, result)
assert self._is_blocked_or_error(result), \
"db_name injection should be blocked"
@pytest.mark.asyncio
async def test_get_table_basic_info_injection(self, mcp_client):
"""Test get_table_basic_info rejects injection"""
payload = {"table_name": "users; DROP TABLE audit_log"}
result = await mcp_client.call_tool("get_table_basic_info", payload)
print_result("get_table_basic_info Injection", payload, result)
assert self._is_blocked_or_error(result), \
"table_name injection should be blocked"
@pytest.mark.asyncio
async def test_analyze_table_storage_injection(self, mcp_client):
"""Test analyze_table_storage rejects injection"""
payload = {"table_name": "users`; SELECT * FROM sensitive; --"}
result = await mcp_client.call_tool("analyze_table_storage", payload)
print_result("analyze_table_storage Injection", payload, result)
assert self._is_blocked_or_error(result), \
"table_name injection should be blocked"
@pytest.mark.asyncio
async def test_get_sql_explain_injection(self, mcp_client):
"""Test get_sql_explain rejects SQL injection"""
payload = {"sql": "SELECT 1; DROP TABLE users; --"}
result = await mcp_client.call_tool("get_sql_explain", payload)
print_result("get_sql_explain SQL Injection", payload, result)
assert self._is_blocked_or_error(result), \
"SQL injection should be blocked"
@pytest.mark.asyncio
async def test_get_sql_profile_injection(self, mcp_client):
"""Test get_sql_profile rejects SQL injection"""
payload = {"sql": "SELECT 1; DELETE FROM audit_log; --"}
result = await mcp_client.call_tool("get_sql_profile", payload)
print_result("get_sql_profile SQL Injection", payload, result)
assert self._is_blocked_or_error(result), \
"SQL injection should be blocked"
def _is_blocked_or_error(self, result: dict) -> bool:
"""Check if result indicates blocked or error"""
if not result:
return True
if "error" in result:
return True
result_str = json.dumps(result).lower()
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
class TestGovernanceToolsInjectionAPI:
"""Test data governance tools SQL injection prevention"""
@pytest.fixture
async def mcp_client(self):
"""Create MCP client instance"""
client = MCPClient()
yield client
await client.close()
@pytest.mark.asyncio
async def test_trace_column_lineage_table_injection(self, mcp_client):
"""Test trace_column_lineage rejects table_name injection"""
payload = {"table_name": "users'; DROP TABLE users; --", "column_name": "id"}
result = await mcp_client.call_tool("trace_column_lineage", payload)
print_result("trace_column_lineage table_name Injection", payload, result)
assert self._is_blocked_or_error(result), \
"table_name injection should be blocked"
@pytest.mark.asyncio
async def test_trace_column_lineage_column_injection(self, mcp_client):
"""Test trace_column_lineage rejects column_name injection"""
payload = {"table_name": "users", "column_name": "id; DROP TABLE users"}
result = await mcp_client.call_tool("trace_column_lineage", payload)
print_result("trace_column_lineage column_name Injection", payload, result)
assert self._is_blocked_or_error(result), \
"column_name injection should be blocked"
@pytest.mark.asyncio
async def test_monitor_data_freshness_injection(self, mcp_client):
"""Test monitor_data_freshness rejects table_name injection"""
payload = {"table_name": "users`; SELECT * FROM passwords; --"}
result = await mcp_client.call_tool("monitor_data_freshness", payload)
print_result("monitor_data_freshness Injection", payload, result)
assert self._is_blocked_or_error(result), \
"table_name injection should be blocked"
@pytest.mark.asyncio
async def test_analyze_data_access_patterns_injection(self, mcp_client):
"""Test analyze_data_access_patterns rejects injection"""
payload = {"table_name": "users' UNION SELECT password FROM admin --"}
result = await mcp_client.call_tool("analyze_data_access_patterns", payload)
print_result("analyze_data_access_patterns Injection", payload, result)
assert self._is_blocked_or_error(result), \
"table_name injection should be blocked"
def _is_blocked_or_error(self, result: dict) -> bool:
"""Check if result indicates blocked or error"""
if not result:
return True
if "error" in result:
return True
result_str = json.dumps(result).lower()
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
class TestPerformanceToolsInjectionAPI:
"""Test performance analytics tools SQL injection prevention"""
@pytest.fixture
async def mcp_client(self):
"""Create MCP client instance"""
client = MCPClient()
yield client
await client.close()
@pytest.mark.asyncio
async def test_analyze_slow_queries_db_injection(self, mcp_client):
"""Test analyze_slow_queries_topn rejects db_name injection"""
payload = {"db_name": "test'; DROP TABLE audit_log; --"}
result = await mcp_client.call_tool("analyze_slow_queries_topn", payload)
print_result("analyze_slow_queries_topn db_name Injection", payload, result)
assert self._is_blocked_or_error(result), \
"db_name injection should be blocked"
@pytest.mark.asyncio
async def test_analyze_resource_growth_db_injection(self, mcp_client):
"""Test analyze_resource_growth_curves rejects db_name injection"""
payload = {"db_name": "test`; GRANT ALL ON *.* TO 'hacker'; --"}
result = await mcp_client.call_tool("analyze_resource_growth_curves", payload)
print_result("analyze_resource_growth_curves db_name Injection", payload, result)
assert self._is_blocked_or_error(result), \
"db_name injection should be blocked"
@pytest.mark.asyncio
async def test_get_table_data_size_injection(self, mcp_client):
"""Test get_table_data_size rejects table_name injection"""
payload = {"table_name": "users; TRUNCATE TABLE logs"}
result = await mcp_client.call_tool("get_table_data_size", payload)
print_result("get_table_data_size Injection", payload, result)
assert self._is_blocked_or_error(result), \
"table_name injection should be blocked"
def _is_blocked_or_error(self, result: dict) -> bool:
"""Check if result indicates blocked or error"""
if not result:
return True
if "error" in result:
return True
result_str = json.dumps(result).lower()
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
# Pytest configuration for async tests
@pytest.fixture(scope="session")
def event_loop():
"""Create event loop for async tests"""
loop = asyncio.new_event_loop()
yield loop
loop.close()
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short", "-x"])

View File

@@ -44,17 +44,31 @@
} }
}, },
"expected_tools": [ "expected_tools": [
"analyze_columns",
"analyze_data_access_patterns",
"analyze_data_flow_dependencies",
"analyze_resource_growth_curves",
"analyze_slow_queries_topn",
"analyze_table_storage",
"exec_adbc_query",
"exec_query", "exec_query",
"get_adbc_connection_info",
"get_catalog_list",
"get_db_list", "get_db_list",
"get_db_table_list", "get_db_table_list",
"get_table_schema", "get_memory_stats",
"get_table_comment", "get_monitoring_metrics",
"get_table_column_comments",
"get_table_indexes",
"column_analysis",
"performance_stats",
"get_recent_audit_logs", "get_recent_audit_logs",
"get_catalog_list" "get_sql_explain",
"get_sql_profile",
"get_table_basic_info",
"get_table_column_comments",
"get_table_comment",
"get_table_data_size",
"get_table_indexes",
"get_table_schema",
"monitor_data_freshness",
"trace_column_lineage"
], ],
"expected_resources": [ "expected_resources": [
"database", "database",

View File

@@ -185,8 +185,9 @@ async def test_server_connectivity(transport: Optional[str] = None) -> bool:
logger.error(f"Connectivity test failed: {e}") logger.error(f"Connectivity test failed: {e}")
return False return False
result = await client.connect_and_run(test_connection) await client.connect_and_run(test_connection)
return result return True
except Exception as e: except Exception as e:
logger.error(f"Failed to test server connectivity: {e}") logger.error(f"Failed to test server connectivity: {e}")
return False return False

View File

@@ -72,8 +72,7 @@ class TestToolsClientServer:
return tools return tools
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert len(result) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_call_tool_exec_query_via_client(self, client, test_config): async def test_call_tool_exec_query_via_client(self, client, test_config):
@@ -91,14 +90,13 @@ class TestToolsClientServer:
assert "success" in result, "Result should contain 'success' field" assert "success" in result, "Result should contain 'success' field"
if result["success"]: if result["success"]:
assert "result" in result, "Successful result should contain 'result' field" assert "data" in result, "Successful result should contain 'data' field"
else: else:
assert "error" in result, "Failed result should contain 'error' field" assert "error" in result, "Failed result should contain 'error' field"
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
# Don't assert success=True as it depends on actual server state
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_call_tool_get_db_list_via_client(self, client, test_config): async def test_call_tool_get_db_list_via_client(self, client, test_config):
@@ -115,8 +113,7 @@ class TestToolsClientServer:
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_call_tool_get_table_schema_via_client(self, client, test_config): async def test_call_tool_get_table_schema_via_client(self, client, test_config):
@@ -133,27 +130,7 @@ class TestToolsClientServer:
assert "success" in result, "Result should contain 'success' field" assert "success" in result, "Result should contain 'success' field"
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio
async def test_call_tool_performance_stats_via_client(self, client, test_config):
"""Test calling performance_stats tool through client"""
if not test_config.is_performance_tests_enabled():
pytest.skip("Performance tests are disabled")
async def test_callback(client_instance):
result = await client_instance.call_tool("performance_stats", {
"metric_type": "queries",
"time_range": "1h"
})
# Verify result structure
assert "success" in result, "Result should contain 'success' field"
return result
result = await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_tool_error_handling_via_client(self, client, test_config): async def test_tool_error_handling_via_client(self, client, test_config):
@@ -168,8 +145,7 @@ class TestToolsClientServer:
assert "success" in result, "Result should contain 'success' field" assert "success" in result, "Result should contain 'success' field"
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_tool_with_auth_token_via_client(self, client, test_config): async def test_tool_with_auth_token_via_client(self, client, test_config):
@@ -188,5 +164,4 @@ class TestToolsClientServer:
assert "success" in result, "Result should contain 'success' field" assert "success" in result, "Result should contain 'success' field"
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert "success" in result

View File

@@ -36,11 +36,6 @@ class TestDorisToolsManager:
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
config = Mock(spec=DorisConfig) config = Mock(spec=DorisConfig)
config.doris_host = "localhost"
config.doris_port = 9030
config.doris_user = "test_user"
config.doris_password = "test_password"
config.doris_database = "test_db"
# Add database config # Add database config
config.database = Mock(spec=DatabaseConfig) config.database = Mock(spec=DatabaseConfig)
@@ -50,7 +45,6 @@ class TestDorisToolsManager:
config.database.password = "test_password" config.database.password = "test_password"
config.database.database = "test_db" config.database.database = "test_db"
config.database.health_check_interval = 60 config.database.health_check_interval = 60
config.database.min_connections = 5
config.database.max_connections = 20 config.database.max_connections = 20
config.database.connection_timeout = 30 config.database.connection_timeout = 30
config.database.max_connection_age = 3600 config.database.max_connection_age = 3600
@@ -235,62 +229,7 @@ class TestDorisToolsManager:
elif "result" in result_data: elif "result" in result_data:
assert len(result_data["result"]) >= 0 # May be empty if no catalogs assert len(result_data["result"]) >= 0 # May be empty if no catalogs
@pytest.mark.asyncio
async def test_column_analysis_tool(self, tools_manager):
"""Test column_analysis tool"""
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
# Mock basic analysis result
mock_execute.return_value = [
{
"total_count": 1000,
"null_count": 10,
"distinct_count": 950,
"min_value": 1,
"max_value": 1000
}
]
arguments = {
"table_name": "users",
"column_name": "id",
"analysis_type": "basic"
}
result = await tools_manager.call_tool("column_analysis", arguments)
result_data = json.loads(result) if isinstance(result, str) else result
# Check if result has analysis field or result field
if "analysis" in result_data:
assert result_data["analysis"]["total_count"] == 1000
elif "result" in result_data:
assert "result" in result_data # Just check result exists
@pytest.mark.asyncio
async def test_performance_stats_tool(self, tools_manager):
"""Test performance_stats tool"""
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
mock_execute.return_value = [
{
"query_count": 1500,
"avg_execution_time": 0.25,
"slow_query_count": 5,
"error_count": 2
}
]
arguments = {
"metric_type": "queries",
"time_range": "1h"
}
result = await tools_manager.call_tool("performance_stats", arguments)
result_data = json.loads(result) if isinstance(result, str) else result
# Check if result has stats field or result field
if "stats" in result_data:
assert result_data["stats"]["query_count"] == 1500
elif "result" in result_data:
assert "result" in result_data # Just check result exists
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_tool_name(self, tools_manager): async def test_invalid_tool_name(self, tools_manager):

78
test/utils/test_db.py Normal file
View File

@@ -0,0 +1,78 @@
from unittest.mock import MagicMock
import pytest
from doris_mcp_server.utils.db import DorisConnection, DorisSessionCache
@pytest.fixture
def session_cache():
"""Provides a DorisSessionCache instance with a mock connection manager."""
connection_manager = MagicMock()
cache = DorisSessionCache(connection_manager=connection_manager)
yield cache, connection_manager
class TestDorisSessionCache:
def test_initialization(self, session_cache):
cache, _ = session_cache
assert cache.cache_system_session is True
assert cache.cache_user_session is False
assert not cache.cached
def test_should_cache(self, session_cache):
cache, _ = session_cache
assert cache._should_cache("query") is True
assert cache._should_cache("system") is True
assert cache._should_cache("user-test-session-id") is False
cache.cache_user_session = True
assert cache._should_cache("user-test-session-id") is True
def test_save_and_get_session(self, session_cache):
cache, _ = session_cache
mock_connection = MagicMock(spec=DorisConnection)
mock_connection.session_id = "query"
cache.save(mock_connection)
retrieved_conn = cache.get("query")
assert retrieved_conn is mock_connection
mock_user_connection = MagicMock(spec=DorisConnection)
mock_user_connection.session_id = "user-test-session-id"
cache.save(mock_user_connection)
assert cache.get("user-test-session-id") is None
cache.cache_user_session = True
cache.save(mock_user_connection)
retrieved_user_conn = cache.get("user-test-session-id")
assert retrieved_user_conn is mock_user_connection
def test_remove_session(self, session_cache):
cache, _ = session_cache
mock_connection = MagicMock(spec=DorisConnection)
mock_connection.session_id = "system"
cache.save(mock_connection)
assert cache.get("system") is not None
cache.remove("system")
assert cache.get("system") is None
def test_clear_cache(self, session_cache):
cache, connection_manager = session_cache
mock_conn1 = MagicMock(spec=DorisConnection)
mock_conn1.session_id = "query"
mock_conn2 = MagicMock(spec=DorisConnection)
mock_conn2.session_id = "system"
cache.save(mock_conn1)
cache.save(mock_conn2)
assert len(cache.cached) == 2
cache.clear()
assert not cache.cached
connection_manager.release_connection.assert_any_call("query", mock_conn1)
connection_manager.release_connection.assert_any_call("system", mock_conn2)
assert connection_manager.release_connection.call_count == 2

View File

@@ -35,11 +35,6 @@ class TestDorisQueryExecutor:
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
config = Mock(spec=DorisConfig) config = Mock(spec=DorisConfig)
config.doris_host = "localhost"
config.doris_port = 9030
config.doris_user = "test_user"
config.doris_password = "test_password"
config.doris_database = "test_db"
# Add database config # Add database config
config.database = Mock(spec=DatabaseConfig) config.database = Mock(spec=DatabaseConfig)
@@ -49,11 +44,17 @@ class TestDorisQueryExecutor:
config.database.password = "test_password" config.database.password = "test_password"
config.database.database = "test_db" config.database.database = "test_db"
config.database.health_check_interval = 60 config.database.health_check_interval = 60
config.database.min_connections = 5
config.database.max_connections = 20 config.database.max_connections = 20
config.database.connection_timeout = 30 config.database.connection_timeout = 30
config.database.max_connection_age = 3600 config.database.max_connection_age = 3600
# Add security config
config.security = Mock(spec=SecurityConfig)
config.security.enable_masking = True
config.security.auth_type = "token"
config.security.token_secret = "test_secret"
config.security.token_expiry = 3600
return config return config
@pytest.fixture @pytest.fixture
@@ -200,3 +201,73 @@ class TestDorisQueryExecutor:
if result["success"]: if result["success"]:
assert "data" in result assert "data" in result
assert "row_count" in result assert "row_count" in result
@pytest.mark.asyncio
async def test_execute_multi_sql_statements(self, query_executor):
"""Test execution of multiple SQL statements"""
from doris_mcp_server.utils.query_executor import QueryResult
# Disable security check for this test
query_executor.connection_manager.config.security.enable_security_check = False
with patch.object(query_executor, 'execute_query') as mock_execute:
# Mock results for three SQL statements
mock_execute.side_effect = [
QueryResult(
data=[{"id": 1, "name": "张三"}],
row_count=1,
execution_time=0.1,
sql="SELECT id, name FROM users WHERE id = 1",
metadata={"columns": ["id", "name"]}
),
QueryResult(
data=[{"id": 2, "name": "李四"}],
row_count=1,
execution_time=0.12,
sql="SELECT id, name FROM users WHERE id = 2",
metadata={"columns": ["id", "name"]}
),
QueryResult(
data=[{"count": 100}],
row_count=1,
execution_time=0.08,
sql="SELECT COUNT(*) as count FROM users",
metadata={"columns": ["count"]}
)
]
# Execute multiple SQL statements separated by semicolons
multi_sql = """
SELECT id, name FROM users WHERE id = 1;
SELECT id, name FROM users WHERE id = 2;
SELECT COUNT(*) as count FROM users;
"""
result = await query_executor.execute_sql_for_mcp(multi_sql)
# Verify the result structure for multiple statements
assert result["success"] is True
assert result["multiple_results"] is True
assert "results" in result
assert len(result["results"]) == 3
# Verify first query result
assert result["results"][0]["data"] == [{"id": 1, "name": "张三"}]
assert result["results"][0]["row_count"] == 1
assert result["results"][0]["metadata"]["columns"] == ["id", "name"]
assert result["results"][0]["metadata"]["query"] == "SELECT id, name FROM users WHERE id = 1"
# Verify second query result
assert result["results"][1]["data"] == [{"id": 2, "name": "李四"}]
assert result["results"][1]["row_count"] == 1
assert result["results"][1]["metadata"]["columns"] == ["id", "name"]
assert result["results"][1]["metadata"]["query"] == "SELECT id, name FROM users WHERE id = 2"
# Verify third query result
assert result["results"][2]["data"] == [{"count": 100}]
assert result["results"][2]["row_count"] == 1
assert result["results"][2]["metadata"]["columns"] == ["count"]
assert result["results"][2]["metadata"]["query"] == "SELECT COUNT(*) as count FROM users"
# Verify execute_query was called three times
assert mock_execute.call_count == 3

View File

@@ -21,8 +21,6 @@ Tests the query execution functionality through actual MCP client-server communi
Assumes the server is already running and configured properly Assumes the server is already running and configured properly
""" """
import asyncio
import json
import pytest import pytest
import os import os
import sys import sys
@@ -66,14 +64,13 @@ class TestQueryExecutorClientServer:
assert "success" in result, "Result should contain 'success' field" assert "success" in result, "Result should contain 'success' field"
if result["success"]: if result["success"]:
assert "result" in result, "Successful result should contain 'result' field" assert "data" in result, "Successful result should contain 'data' field"
else: else:
assert "error" in result, "Failed result should contain 'error' field" assert "error" in result, "Failed result should contain 'error' field"
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_show_databases_query_via_client(self, client, test_config): async def test_show_databases_query_via_client(self, client, test_config):
@@ -87,8 +84,7 @@ class TestQueryExecutorClientServer:
assert "success" in result, "Result should contain 'success' field" assert "success" in result, "Result should contain 'success' field"
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_information_schema_query_via_client(self, client, test_config): async def test_information_schema_query_via_client(self, client, test_config):
@@ -102,8 +98,7 @@ class TestQueryExecutorClientServer:
assert "success" in result, "Result should contain 'success' field" assert "success" in result, "Result should contain 'success' field"
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_with_max_rows_parameter_via_client(self, client, test_config): async def test_query_with_max_rows_parameter_via_client(self, client, test_config):
@@ -118,8 +113,7 @@ class TestQueryExecutorClientServer:
assert "success" in result, "Result should contain 'success' field" assert "success" in result, "Result should contain 'success' field"
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_error_handling_via_client(self, client, test_config): async def test_query_error_handling_via_client(self, client, test_config):
@@ -131,8 +125,7 @@ class TestQueryExecutorClientServer:
assert "success" in result, "Result should contain 'success' field" assert "success" in result, "Result should contain 'success' field"
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_with_auth_token_via_client(self, client, test_config): async def test_query_with_auth_token_via_client(self, client, test_config):
@@ -152,5 +145,4 @@ class TestQueryExecutorClientServer:
assert "success" in result, "Result should contain 'success' field" assert "success" in result, "Result should contain 'success' field"
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert "success" in result

64
tokens.json Normal file
View File

@@ -0,0 +1,64 @@
{
"version": "1.0",
"description": "Doris MCP Server Token configuration file",
"created_at": "2025-09-01T00:00:00Z",
"tokens": [
{
"token_id": "admin-token",
"token": "doris_admin_token_123456",
"description": "Doris admin API access token",
"expires_hours": null,
"is_active": true,
"database_config": {
"host": "127.0.0.1",
"port": 9030,
"user": "root",
"password": "",
"database": "information_schema",
"charset": "UTF8",
"fe_http_port": 8030
}
},
{
"token_id": "analyst-token",
"token": "doris_analyst_token_123456",
"description": "Doris analyst API access token",
"expires_hours": 8760,
"is_active": true,
"database_config": {
"host": "127.0.0.1",
"port": 9030,
"user": "root",
"password": "",
"database": "information_schema",
"charset": "UTF8",
"fe_http_port": 8030
}
},
{
"token_id": "readonly-token",
"token": "doris_readonly_token_123456",
"description": "Doris readonly API access token",
"expires_hours": 4320,
"is_active": true,
"database_config": {
"host": "127.0.0.1",
"port": 9030,
"user": "root",
"password": "",
"database": "information_schema",
"charset": "UTF8",
"fe_http_port": 8030
}
}
],
"notes": [
"The admin_token, analyst_token, readonly_token is default token,Please change the token before using in production!",
"The token_id is the key of the token,Please use the token_id to identify the token",
"The token is the value of the token,Please use the token to identify the token",
"The description is the description of the token,Please use the description to identify the token",
"The expires_hours is the expires hours of the token,Please use the expires_hours to identify the token",
"The is_active is the is active of the token,Please use the is_active to identify the token",
"The token_id, token, description, expires_hours, is_active is the metadata of the token,Please use the metadata to identify the token"
]
}

112
uv.lock generated
View File

@@ -1,19 +1,3 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
version = 1 version = 1
revision = 1 revision = 1
requires-python = ">=3.12" requires-python = ">=3.12"
@@ -22,6 +6,48 @@ resolution-markers = [
"python_full_version < '3.13'", "python_full_version < '3.13'",
] ]
[[package]]
name = "adbc-driver-flightsql"
version = "1.7.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "adbc-driver-manager" },
{ name = "importlib-resources" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b8/d4/ebd3eed981c771565677084474cdf465141455b5deb1ca409c616609bfd7/adbc_driver_flightsql-1.7.0.tar.gz", hash = "sha256:5dca460a2c66e45b29208eaf41a7206f252177435fa48b16f19833b12586f7a0", size = 21247 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/36/20/807fca9d904b7e0d3020439828d6410db7fd7fd635824a80cab113d9fad1/adbc_driver_flightsql-1.7.0-py3-none-macosx_10_15_x86_64.whl", hash = "sha256:a5658f9bc3676bd122b26138e9b9ce56b8bf37387efe157b4c66d56f942361c6", size = 7749664 },
{ url = "https://files.pythonhosted.org/packages/cd/e6/9e50f6497819c911b9cc1962ffde610b60f7d8e951d6bb3fa145dcfb50a7/adbc_driver_flightsql-1.7.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:65e21df86b454d8db422c8ee22db31be217d88c42d9d6dd89119f06813037c91", size = 7302476 },
{ url = "https://files.pythonhosted.org/packages/27/82/e51af85e7cc8c87bc8ce4fae8ca7ee1d3cf39c926be0aeab789cedc93f0a/adbc_driver_flightsql-1.7.0-py3-none-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3282fdc7b73c712780cc777975288c88b1e3a555355bbe09df101aa954f8f105", size = 7686056 },
{ url = "https://files.pythonhosted.org/packages/8b/c9/591c8ecbaf010ba3f4b360db602050ee5880cd077a573c9e90fcb270ab71/adbc_driver_flightsql-1.7.0-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e0c5737ae6ee3bbfba44dcbc28ba1ff8cf3ab6521888c4b0f10dd6a482482161", size = 7050275 },
{ url = "https://files.pythonhosted.org/packages/10/14/f339e9a5d8dbb3e3040215514cea9cca0a58640964aaccc6532f18003a03/adbc_driver_flightsql-1.7.0-py3-none-win_amd64.whl", hash = "sha256:f8b5290b322304b7d944ca823754e6354c1868dbbe94ddf84236f3e0329545da", size = 14312858 },
]
[[package]]
name = "adbc-driver-manager"
version = "1.7.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/bb/bf/2986a2cd3e1af658d2597f7e2308564e5c11e036f9736d5c256f1e00d578/adbc_driver_manager-1.7.0.tar.gz", hash = "sha256:e3edc5d77634b5925adf6eb4fbcd01676b54acb2f5b1d6864b6a97c6a899591a", size = 198128 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/74/3a/72bd9c45d55f1f5f4c549e206de8cfe3313b31f7b95fbcb180da05c81044/adbc_driver_manager-1.7.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:8da1ac4c19bcbf30b3bd54247ec889dfacc9b44147c70b4da79efe2e9ba93600", size = 524210 },
{ url = "https://files.pythonhosted.org/packages/33/29/e1a8d8dde713a287f8021f3207127f133ddce578711a4575218bdf78ef27/adbc_driver_manager-1.7.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:408bc23bad1a6823b364e2388f85f96545e82c3b2db97d7828a4b94839d3f29e", size = 505902 },
{ url = "https://files.pythonhosted.org/packages/59/00/773ece64a58c0ade797ab4577e7cdc4c71ebf800b86d2d5637e3bfe605e9/adbc_driver_manager-1.7.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf38294320c23e47ed3455348e910031ad8289c3f9167ae35519ac957b7add01", size = 2974883 },
{ url = "https://files.pythonhosted.org/packages/7c/ad/1568da6ae9ab70983f1438503d3906c6b1355601230e891d16e272376a04/adbc_driver_manager-1.7.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:689f91b62c18a9f86f892f112786fb157cacc4729b4d81666db4ca778eade2a8", size = 2997781 },
{ url = "https://files.pythonhosted.org/packages/19/66/2b6ea5afded25a3fa009873c2bbebcd9283910877cc10b9453d680c00b9a/adbc_driver_manager-1.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:f936cfc8d098898a47ef60396bd7a73926ec3068f2d6d92a2be4e56e4aaf3770", size = 690041 },
{ url = "https://files.pythonhosted.org/packages/b2/3b/91154c83a98f103a3d97c9e2cb838c3842aef84ca4f4b219164b182d9516/adbc_driver_manager-1.7.0-cp313-cp313-macosx_10_15_x86_64.whl", hash = "sha256:ab9ee36683fd54f61b0db0f4a96f70fe1932223e61df9329290370b145abb0a9", size = 522737 },
{ url = "https://files.pythonhosted.org/packages/9c/52/4bc80c3388d5e2a3b6e504ba9656dd9eb3d8dbe822d07af38db1b8c96fb1/adbc_driver_manager-1.7.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4ec03d94177f71a8d3a149709f4111e021f9950229b35c0a803aadb1a1855a4b", size = 503896 },
{ url = "https://files.pythonhosted.org/packages/e1/f3/46052ca11224f661cef4721e19138bc73e750ba6aea54f22606950491606/adbc_driver_manager-1.7.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:700c79dac08a620018c912ede45a6dc7851819bc569a53073ab652dc0bd0c92f", size = 2972586 },
{ url = "https://files.pythonhosted.org/packages/a2/22/44738b41bb5ca30f94b5f4c00c71c20be86d7eb4ddc389d4cf3c7b8b69ef/adbc_driver_manager-1.7.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:98db0f5d0aa1635475f63700a7b6f677390beb59c69c7ba9d388bc8ce3779388", size = 2992001 },
{ url = "https://files.pythonhosted.org/packages/1b/2b/5184fe5a529feb019582cc90d0f65e0021d52c34ca20620551532340645a/adbc_driver_manager-1.7.0-cp313-cp313-win_amd64.whl", hash = "sha256:4b7e5e9a163acb21804647cc7894501df51cdcd780ead770557112a26ca01ca6", size = 688789 },
{ url = "https://files.pythonhosted.org/packages/3f/e0/b283544e1bb7864bf5a5ac9cd330f111009eff9180ec5000420510cf9342/adbc_driver_manager-1.7.0-cp313-cp313t-macosx_10_15_x86_64.whl", hash = "sha256:ac83717965b83367a8ad6c0536603acdcfa66e0592d783f8940f55fda47d963e", size = 538625 },
{ url = "https://files.pythonhosted.org/packages/77/5a/dc244264bd8d0c331a418d2bdda5cb6e26c30493ff075d706aa81d4e3b30/adbc_driver_manager-1.7.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:4c234cf81b00eaf7e7c65dbd0f0ddf7bdae93dfcf41e9d8543f9ecf4b10590f6", size = 523627 },
{ url = "https://files.pythonhosted.org/packages/e9/ff/a499a00367fd092edb20dc6e36c81e3c7a437671c70481cae97f46c8156a/adbc_driver_manager-1.7.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ad8aa4b039cc50722a700b544773388c6b1dea955781a01f79cd35d0a1e6edbf", size = 3037517 },
{ url = "https://files.pythonhosted.org/packages/25/6e/9dfdb113294dcb24b4f53924cd4a9c9af3fbe45a9790c1327048df731246/adbc_driver_manager-1.7.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4409ff53578e01842a8f57787ebfbfee790c1da01a6bd57fcb7701ed5d4dd4f7", size = 3016543 },
]
[[package]] [[package]]
name = "aiofiles" name = "aiofiles"
version = "24.1.0" version = "24.1.0"
@@ -536,9 +562,11 @@ wheels = [
[[package]] [[package]]
name = "doris-mcp-server" name = "doris-mcp-server"
version = "0.3.0" version = "0.6.1"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "adbc-driver-flightsql" },
{ name = "adbc-driver-manager" },
{ name = "aiofiles" }, { name = "aiofiles" },
{ name = "aiohttp" }, { name = "aiohttp" },
{ name = "aiomysql" }, { name = "aiomysql" },
@@ -555,6 +583,7 @@ dependencies = [
{ name = "pandas" }, { name = "pandas" },
{ name = "passlib", extra = ["bcrypt"] }, { name = "passlib", extra = ["bcrypt"] },
{ name = "prometheus-client" }, { name = "prometheus-client" },
{ name = "pyarrow" },
{ name = "pydantic" }, { name = "pydantic" },
{ name = "pydantic-settings" }, { name = "pydantic-settings" },
{ name = "pyjwt" }, { name = "pyjwt" },
@@ -625,6 +654,8 @@ dev = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "adbc-driver-flightsql", specifier = ">=0.8.0" },
{ name = "adbc-driver-manager", specifier = ">=0.8.0" },
{ name = "aiofiles", specifier = ">=23.0.0" }, { name = "aiofiles", specifier = ">=23.0.0" },
{ name = "aiohttp", specifier = ">=3.9.0" }, { name = "aiohttp", specifier = ">=3.9.0" },
{ name = "aiomysql", specifier = ">=0.2.0" }, { name = "aiomysql", specifier = ">=0.2.0" },
@@ -642,7 +673,7 @@ requires-dist = [
{ name = "httpx", specifier = ">=0.26.0" }, { name = "httpx", specifier = ">=0.26.0" },
{ name = "isort", marker = "extra == 'dev'", specifier = ">=5.13.0" }, { name = "isort", marker = "extra == 'dev'", specifier = ">=5.13.0" },
{ name = "jaeger-client", marker = "extra == 'monitoring'", specifier = ">=4.8.0" }, { name = "jaeger-client", marker = "extra == 'monitoring'", specifier = ">=4.8.0" },
{ name = "mcp", specifier = ">=1.0.0" }, { name = "mcp", specifier = ">=1.8.0,<2.0.0" },
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" },
{ name = "myst-parser", marker = "extra == 'dev'", specifier = ">=2.0.0" }, { name = "myst-parser", marker = "extra == 'dev'", specifier = ">=2.0.0" },
{ name = "myst-parser", marker = "extra == 'docs'", specifier = ">=2.0.0" }, { name = "myst-parser", marker = "extra == 'docs'", specifier = ">=2.0.0" },
@@ -656,6 +687,7 @@ requires-dist = [
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.6.0" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.6.0" },
{ name = "prometheus-client", specifier = ">=0.19.0" }, { name = "prometheus-client", specifier = ">=0.19.0" },
{ name = "prometheus-client", marker = "extra == 'monitoring'", specifier = ">=0.19.0" }, { name = "prometheus-client", marker = "extra == 'monitoring'", specifier = ">=0.19.0" },
{ name = "pyarrow", specifier = ">=14.0.0" },
{ name = "pydantic", specifier = ">=2.5.0" }, { name = "pydantic", specifier = ">=2.5.0" },
{ name = "pydantic-settings", specifier = ">=2.1.0" }, { name = "pydantic-settings", specifier = ">=2.1.0" },
{ name = "pyjwt", specifier = ">=2.8.0" }, { name = "pyjwt", specifier = ">=2.8.0" },
@@ -948,6 +980,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/20/b0/36bd937216ec521246249be3bf9855081de4c5e06a0c9b4219dbeda50373/importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd", size = 27656 }, { url = "https://files.pythonhosted.org/packages/20/b0/36bd937216ec521246249be3bf9855081de4c5e06a0c9b4219dbeda50373/importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd", size = 27656 },
] ]
[[package]]
name = "importlib-resources"
version = "6.5.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/cf/8c/f834fbf984f691b4f7ff60f50b514cc3de5cc08abfc3295564dd89c5e2e7/importlib_resources-6.5.2.tar.gz", hash = "sha256:185f87adef5bcc288449d98fb4fba07cea78bc036455dd44c5fc4a2fe78fed2c", size = 44693 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461 },
]
[[package]] [[package]]
name = "iniconfig" name = "iniconfig"
version = "2.1.0" version = "2.1.0"
@@ -1621,6 +1662,41 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/7b/d7/7831438e6c3ebbfa6e01a927127a6cb42ad3ab844247f3c5b96bea25d73d/psutil-6.1.1-cp37-abi3-win_amd64.whl", hash = "sha256:f35cfccb065fff93529d2afb4a2e89e363fe63ca1e4a5da22b603a85833c2649", size = 254444 }, { url = "https://files.pythonhosted.org/packages/7b/d7/7831438e6c3ebbfa6e01a927127a6cb42ad3ab844247f3c5b96bea25d73d/psutil-6.1.1-cp37-abi3-win_amd64.whl", hash = "sha256:f35cfccb065fff93529d2afb4a2e89e363fe63ca1e4a5da22b603a85833c2649", size = 254444 },
] ]
[[package]]
name = "pyarrow"
version = "20.0.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a2/ee/a7810cb9f3d6e9238e61d312076a9859bf3668fd21c69744de9532383912/pyarrow-20.0.0.tar.gz", hash = "sha256:febc4a913592573c8d5805091a6c2b5064c8bd6e002131f01061797d91c783c1", size = 1125187 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a1/d6/0c10e0d54f6c13eb464ee9b67a68b8c71bcf2f67760ef5b6fbcddd2ab05f/pyarrow-20.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:75a51a5b0eef32727a247707d4755322cb970be7e935172b6a3a9f9ae98404ba", size = 30815067 },
{ url = "https://files.pythonhosted.org/packages/7e/e2/04e9874abe4094a06fd8b0cbb0f1312d8dd7d707f144c2ec1e5e8f452ffa/pyarrow-20.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:211d5e84cecc640c7a3ab900f930aaff5cd2702177e0d562d426fb7c4f737781", size = 32297128 },
{ url = "https://files.pythonhosted.org/packages/31/fd/c565e5dcc906a3b471a83273039cb75cb79aad4a2d4a12f76cc5ae90a4b8/pyarrow-20.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ba3cf4182828be7a896cbd232aa8dd6a31bd1f9e32776cc3796c012855e1199", size = 41334890 },
{ url = "https://files.pythonhosted.org/packages/af/a9/3bdd799e2c9b20c1ea6dc6fa8e83f29480a97711cf806e823f808c2316ac/pyarrow-20.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c3a01f313ffe27ac4126f4c2e5ea0f36a5fc6ab51f8726cf41fee4b256680bd", size = 42421775 },
{ url = "https://files.pythonhosted.org/packages/10/f7/da98ccd86354c332f593218101ae56568d5dcedb460e342000bd89c49cc1/pyarrow-20.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:a2791f69ad72addd33510fec7bb14ee06c2a448e06b649e264c094c5b5f7ce28", size = 40687231 },
{ url = "https://files.pythonhosted.org/packages/bb/1b/2168d6050e52ff1e6cefc61d600723870bf569cbf41d13db939c8cf97a16/pyarrow-20.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4250e28a22302ce8692d3a0e8ec9d9dde54ec00d237cff4dfa9c1fbf79e472a8", size = 42295639 },
{ url = "https://files.pythonhosted.org/packages/b2/66/2d976c0c7158fd25591c8ca55aee026e6d5745a021915a1835578707feb3/pyarrow-20.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:89e030dc58fc760e4010148e6ff164d2f44441490280ef1e97a542375e41058e", size = 42908549 },
{ url = "https://files.pythonhosted.org/packages/31/a9/dfb999c2fc6911201dcbf348247f9cc382a8990f9ab45c12eabfd7243a38/pyarrow-20.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6102b4864d77102dbbb72965618e204e550135a940c2534711d5ffa787df2a5a", size = 44557216 },
{ url = "https://files.pythonhosted.org/packages/a0/8e/9adee63dfa3911be2382fb4d92e4b2e7d82610f9d9f668493bebaa2af50f/pyarrow-20.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:96d6a0a37d9c98be08f5ed6a10831d88d52cac7b13f5287f1e0f625a0de8062b", size = 25660496 },
{ url = "https://files.pythonhosted.org/packages/9b/aa/daa413b81446d20d4dad2944110dcf4cf4f4179ef7f685dd5a6d7570dc8e/pyarrow-20.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:a15532e77b94c61efadde86d10957950392999503b3616b2ffcef7621a002893", size = 30798501 },
{ url = "https://files.pythonhosted.org/packages/ff/75/2303d1caa410925de902d32ac215dc80a7ce7dd8dfe95358c165f2adf107/pyarrow-20.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:dd43f58037443af715f34f1322c782ec463a3c8a94a85fdb2d987ceb5658e061", size = 32277895 },
{ url = "https://files.pythonhosted.org/packages/92/41/fe18c7c0b38b20811b73d1bdd54b1fccba0dab0e51d2048878042d84afa8/pyarrow-20.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aa0d288143a8585806e3cc7c39566407aab646fb9ece164609dac1cfff45f6ae", size = 41327322 },
{ url = "https://files.pythonhosted.org/packages/da/ab/7dbf3d11db67c72dbf36ae63dcbc9f30b866c153b3a22ef728523943eee6/pyarrow-20.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6953f0114f8d6f3d905d98e987d0924dabce59c3cda380bdfaa25a6201563b4", size = 42411441 },
{ url = "https://files.pythonhosted.org/packages/90/c3/0c7da7b6dac863af75b64e2f827e4742161128c350bfe7955b426484e226/pyarrow-20.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:991f85b48a8a5e839b2128590ce07611fae48a904cae6cab1f089c5955b57eb5", size = 40677027 },
{ url = "https://files.pythonhosted.org/packages/be/27/43a47fa0ff9053ab5203bb3faeec435d43c0d8bfa40179bfd076cdbd4e1c/pyarrow-20.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:97c8dc984ed09cb07d618d57d8d4b67a5100a30c3818c2fb0b04599f0da2de7b", size = 42281473 },
{ url = "https://files.pythonhosted.org/packages/bc/0b/d56c63b078876da81bbb9ba695a596eabee9b085555ed12bf6eb3b7cab0e/pyarrow-20.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9b71daf534f4745818f96c214dbc1e6124d7daf059167330b610fc69b6f3d3e3", size = 42893897 },
{ url = "https://files.pythonhosted.org/packages/92/ac/7d4bd020ba9145f354012838692d48300c1b8fe5634bfda886abcada67ed/pyarrow-20.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e8b88758f9303fa5a83d6c90e176714b2fd3852e776fc2d7e42a22dd6c2fb368", size = 44543847 },
{ url = "https://files.pythonhosted.org/packages/9d/07/290f4abf9ca702c5df7b47739c1b2c83588641ddfa2cc75e34a301d42e55/pyarrow-20.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:30b3051b7975801c1e1d387e17c588d8ab05ced9b1e14eec57915f79869b5031", size = 25653219 },
{ url = "https://files.pythonhosted.org/packages/95/df/720bb17704b10bd69dde086e1400b8eefb8f58df3f8ac9cff6c425bf57f1/pyarrow-20.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:ca151afa4f9b7bc45bcc791eb9a89e90a9eb2772767d0b1e5389609c7d03db63", size = 30853957 },
{ url = "https://files.pythonhosted.org/packages/d9/72/0d5f875efc31baef742ba55a00a25213a19ea64d7176e0fe001c5d8b6e9a/pyarrow-20.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:4680f01ecd86e0dd63e39eb5cd59ef9ff24a9d166db328679e36c108dc993d4c", size = 32247972 },
{ url = "https://files.pythonhosted.org/packages/d5/bc/e48b4fa544d2eea72f7844180eb77f83f2030b84c8dad860f199f94307ed/pyarrow-20.0.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f4c8534e2ff059765647aa69b75d6543f9fef59e2cd4c6d18015192565d2b70", size = 41256434 },
{ url = "https://files.pythonhosted.org/packages/c3/01/974043a29874aa2cf4f87fb07fd108828fc7362300265a2a64a94965e35b/pyarrow-20.0.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e1f8a47f4b4ae4c69c4d702cfbdfe4d41e18e5c7ef6f1bb1c50918c1e81c57b", size = 42353648 },
{ url = "https://files.pythonhosted.org/packages/68/95/cc0d3634cde9ca69b0e51cbe830d8915ea32dda2157560dda27ff3b3337b/pyarrow-20.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:a1f60dc14658efaa927f8214734f6a01a806d7690be4b3232ba526836d216122", size = 40619853 },
{ url = "https://files.pythonhosted.org/packages/29/c2/3ad40e07e96a3e74e7ed7cc8285aadfa84eb848a798c98ec0ad009eb6bcc/pyarrow-20.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:204a846dca751428991346976b914d6d2a82ae5b8316a6ed99789ebf976551e6", size = 42241743 },
{ url = "https://files.pythonhosted.org/packages/eb/cb/65fa110b483339add6a9bc7b6373614166b14e20375d4daa73483755f830/pyarrow-20.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:f3b117b922af5e4c6b9a9115825726cac7d8b1421c37c2b5e24fbacc8930612c", size = 42839441 },
{ url = "https://files.pythonhosted.org/packages/98/7b/f30b1954589243207d7a0fbc9997401044bf9a033eec78f6cb50da3f304a/pyarrow-20.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:e724a3fd23ae5b9c010e7be857f4405ed5e679db5c93e66204db1a69f733936a", size = 44503279 },
{ url = "https://files.pythonhosted.org/packages/37/40/ad395740cd641869a13bcf60851296c89624662575621968dcfafabaa7f6/pyarrow-20.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:82f1ee5133bd8f49d31be1299dc07f585136679666b502540db854968576faf9", size = 25944982 },
]
[[package]] [[package]]
name = "pyasn1" name = "pyasn1"
version = "0.6.1" version = "0.6.1"