37 Commits
0.3.0 ... 0.5.1

Author SHA1 Message Date
Yijia Su
6d3c128f54 0.5.1 Version (#28)
0.5.1 Version (#28)
2025-07-15 11:56:46 +08:00
Yijia Su
651d524814 [BUG]Optimize and fix the capabilities of 0.5.0 tools (#26)
1. **Unified Naming for CLI Arguments and Environment Variables** 
- All database-related CLI arguments now use the `--doris-*` prefix, and environment variables use `DORIS_*` for consistency and maintainability. 
- Backward compatibility: old `--db-*` arguments are still supported.

2. **Automatic Filtering of System SQL in Slow Query TopN** 
- Slow query analysis now automatically excludes SQL statements involving `__internal_schema`, `information_schema`, and `mysql` system databases, ensuring only business-related slow queries are counted. 
- Filtering is performed at the SQL level using `NOT LIKE` and `state != 'ERR'` for efficiency and safety.

3. **Unified Query Timeout Configuration** 
- If no `timeout` is specified for query execution, the system will use the `config.performance.query_timeout` value as the default, falling back to 30 seconds if not configured.
- This avoids hardcoding and makes timeout management more flexible.

4. **Tool execution optimization**
- Significantly reduce the execution time of some data governance and operation and maintenance tools
- Optimize execution logic and reduce data scanning
- Enable concurrent scanning to speed up retrieval

5. **Log system optimization**
- Fix the Console log printing logic and output the log content correctly
- Add advanced tool execution process log output to facilitate further positioning of error locations

6. **DB Connection optimization**
- Fixed a connection pool acquisition exception caused by deadlock

7. **Other Improvements**
- Help documentation and CLI examples updated to reflect new and legacy parameter compatibility.
- Code comments and documentation further standardized for better team collaboration and open-source community understanding.
2025-07-14 19:04:11 +08:00
Yijia Su
54572d0861 [Feature]Add 9 New Tools (#23)
release 0.5.0
2025-07-11 12:03:13 +08:00
Yijia Su
d12dfbd014 [improvement]Optimize and refactor the log system (#21)
* add logger system AND fix Readme
2025-07-10 14:02:10 +08:00
Yijia Su
4052b7e938 [BUG]Completely solve the at_eof problem (#20)
* fix at_eof bug

* update uv.lock

* fix bug and change pool min values

* Fixed startup errors caused by multiple versions of MCP services

* fix connection bug
2025-07-10 13:08:32 +08:00
Yijia Su
693c48d5ee [BUG]Fixed startup errors caused by multiple versions of MCP services (#13)
* fix at_eof bug

* update uv.lock

* fix bug and change pool min values

* Fixed startup errors caused by multiple versions of MCP services
2025-07-03 15:04:16 +08:00
Yijia Su
c1ce9a5cc7 [Config]Delete the minimum data pool variable (#11)
* fix at_eof bug

* update uv.lock

* fix bug and change pool min values
2025-07-02 19:57:45 +08:00
Yijia Su
282a1c0bd9 [BUG]Further fix the at_eof problem caused by aiomysql (#9)
* fix at_eof bug

* update uv.lock
2025-07-02 19:29:37 +08:00
Yijia Su
e3b9bf96ab Update .asf.yaml (#10) 2025-07-02 19:26:30 +08:00
Gerry Qi
667cecbbe0 Add .gitignore file (#7)
* Add dify dsl demo

* Deploying on docker

* Add .gitignore file

---------

Co-authored-by: Gerry.qi 齐晓明 <Gerry.qi@pousheng.com>
2025-07-02 18:30:46 +08:00
haijun huang
c777905bd3 fix the cofig of doris-mcp-server (#6) 2025-07-02 18:29:28 +08:00
haijun huang
d4ea125e35 add cursor demo (#4)
* add cursor demo

* fix image
2025-07-02 10:00:22 +08:00
Gerry Qi
f135d9b949 Add dify dsl demo (#3)
* Add dify dsl demo

* Deploying on docker

---------

Co-authored-by: Gerry.qi 齐晓明 <Gerry.qi@pousheng.com>
Co-authored-by: Gerry.qi <Gerry.qi@outlook.com>
2025-06-27 16:28:58 +08:00
Yijia Su
124dd0da88 Update .asf.yaml 2025-06-27 12:54:52 +08:00
Yijia Su
775b4cb630 Update .asf.yaml 2025-06-27 12:53:00 +08:00
FreeOnePlus
26e8bc1149 change pipy project name 2025-06-27 12:44:57 +08:00
FreeOnePlus
8526cb75fe v0.4.2 preview 2025-06-26 20:23:54 +08:00
FreeOnePlus
97006a756d v0.4.1 preview 2025-06-26 18:55:30 +08:00
Yijia Su
72865654e2 Merge pull request #2 from echo-hhj/example
[demo]add dify demo
2025-06-23 12:44:23 +08:00
HuangHaijun
050c09f902 fix contents 2025-06-19 13:10:32 +08:00
HuangHaijun
159399bd38 fix the way to start server 2025-06-19 13:03:17 +08:00
HuangHaijun
e859fbb778 fix1 2025-06-18 12:54:14 +08:00
HuangHaijun
1b9cb29f5f fix 2025-06-18 12:53:18 +08:00
HuangHaijun
c95c0fe03c add dify demo 2025-06-18 12:44:05 +08:00
FreeOnePlus
1e2e79d90d v0.4.0 preview 2025-06-12 19:36:16 +08:00
FreeOnePlus
609816bc4a fix doc bug 2025-06-12 05:10:07 +08:00
FreeOnePlus
5d46d153e1 1. Fix DB Connection BUG
2. Modify the global default configuration items and obtain them from Config
2025-06-11 11:52:15 +08:00
FreeOnePlus
0a81d5693b add readme QA module 2025-06-10 21:28:39 +08:00
FreeOnePlus
a4306867f6 Merge remote-tracking branch 'origin/master' 2025-06-10 21:11:05 +08:00
FreeOnePlus
a22ff3ae9b add readme QA module 2025-06-10 21:04:46 +08:00
Yijia Su
2c5f26889c Update .asf.yaml
open issue,discussions
2025-06-10 14:04:14 +08:00
FreeOnePlus
e47534c296 Merge remote-tracking branch 'origin/master' 2025-06-09 23:07:54 +08:00
FreeOnePlus
0f52591259 Add pip install command & fix pyproject.toml bug 2025-06-09 22:54:01 +08:00
Yijia Su
3b429f37b3 Merge pull request #1 from iouAkira/patch-1
fix Dockerfile
2025-06-09 18:48:58 +08:00
FreeOnePlus
f5a4c8abbe Add pip install command & fix pyproject.toml bug 2025-06-09 18:42:34 +08:00
FreeOnePlus
87563ef6e1 Add pip install command 2025-06-09 18:37:41 +08:00
Akira
b6157c500b fix Dockerfile 2025-06-09 09:21:42 +08:00
63 changed files with 17784 additions and 1039 deletions

View File

@@ -24,9 +24,15 @@ github:
- olap - olap
- lakehouse - lakehouse
- mcp - mcp
- ai
enabled_merge_buttons: enabled_merge_buttons:
squash: true squash: true
merge: false merge: false
rebase: false rebase: false
features:
issues: true
projects: true
notifications: notifications:
pullrequests_status: commits@doris.apache.org issues: commits@doris.apache.org
commits: commits@doris.apache.org
pullrequests: commits@doris.apache.org

View File

@@ -1,71 +1,196 @@
# Licensed to the Apache Software Foundation (ASF) under one # ===================================================================
# or more contributor license agreements. See the NOTICE file # Doris MCP Server Environment Configuration Example
# distributed with this work for additional information # ===================================================================
# regarding copyright ownership. The ASF licenses this file # Copy this file to .env and modify the configuration values as needed
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# Doris MCP Server Environment Configuration # ===================================================================
# Copy this file to .env and modify the values as needed # Database Connection Configuration
# ===================================================================
# Database Configuration # Doris FE (Frontend) connection settings
DORIS_HOST=localhost DORIS_HOST=localhost
DORIS_PORT=9030 DORIS_PORT=9030
DORIS_USER=root DORIS_USER=root
DORIS_PASSWORD=your_password_here DORIS_PASSWORD=
DORIS_DATABASE=your_database_name DORIS_DATABASE=information_schema
# Connection Pool Settings # Doris FE HTTP API port (for Profile and other HTTP APIs)
DORIS_MIN_CONNECTIONS=5 DORIS_FE_HTTP_PORT=8030
# Doris BE (Backend) nodes configuration (optional, for external access)
# Format: host1,host2,host3 (if empty, will use "show backends" to get BE nodes)
DORIS_BE_HOSTS=
DORIS_BE_WEBSERVER_PORT=8040
# Connection pool configuration
DORIS_MAX_CONNECTIONS=20 DORIS_MAX_CONNECTIONS=20
DORIS_CONNECTION_TIMEOUT=30 DORIS_CONNECTION_TIMEOUT=30
DORIS_HEALTH_CHECK_INTERVAL=60 DORIS_HEALTH_CHECK_INTERVAL=60
DORIS_MAX_CONNECTION_AGE=3600 DORIS_MAX_CONNECTION_AGE=3600
# Security Settings # Arrow Flight SQL Configuration (Required for ADBC tools)
# FE_ARROW_FLIGHT_SQL_PORT=
# BE_ARROW_FLIGHT_SQL_PORT=
# ===================================================================
# Security Configuration
# ===================================================================
# Authentication configuration
AUTH_TYPE=token AUTH_TYPE=token
TOKEN_SECRET=your_256_bit_secret_key_here TOKEN_SECRET=your_secret_key_here
TOKEN_EXPIRY=3600 TOKEN_EXPIRY=3600
# SQL security check
ENABLE_SECURITY_CHECK=true
# Blocked keywords (comma separated)
BLOCKED_KEYWORDS=DROP,CREATE,ALTER,TRUNCATE,DELETE,INSERT,UPDATE,GRANT,REVOKE,EXEC,EXECUTE,SHUTDOWN,KILL
# Query limits
MAX_QUERY_COMPLEXITY=100
MAX_RESULT_ROWS=10000 MAX_RESULT_ROWS=10000
# Data masking
ENABLE_MASKING=true ENABLE_MASKING=true
# Performance Settings # ===================================================================
# Performance Configuration
# ===================================================================
# Query cache
ENABLE_QUERY_CACHE=true ENABLE_QUERY_CACHE=true
CACHE_TTL=300 CACHE_TTL=300
MAX_CACHE_SIZE=1000 MAX_CACHE_SIZE=1000
# Concurrency control
MAX_CONCURRENT_QUERIES=50 MAX_CONCURRENT_QUERIES=50
QUERY_TIMEOUT=300 QUERY_TIMEOUT=300
# Logging Configuration # Response content size limit (characters)
LOG_LEVEL=INFO MAX_RESPONSE_CONTENT_SIZE=4096
LOG_FILE_PATH=./log/doris-mcp-server.log
ENABLE_AUDIT=true
AUDIT_FILE_PATH=./log/doris-mcp-audit.log
# Monitoring Settings # ===================================================================
# ADBC (Arrow Flight SQL) Configuration
# ===================================================================
# Enable/disable ADBC tools
ADBC_ENABLED=true
# Default ADBC query parameters
ADBC_DEFAULT_MAX_ROWS=100000
ADBC_DEFAULT_TIMEOUT=60
# Format: "arrow", "pandas", "dict"
ADBC_DEFAULT_RETURN_FORMAT=arrow
# ADBC connection timeout
ADBC_CONNECTION_TIMEOUT=300
# ===================================================================
# Logging Configuration
# ===================================================================
# Basic logging configuration
LOG_LEVEL=INFO
LOG_FILE_PATH=
# Audit logging
ENABLE_AUDIT=true
AUDIT_FILE_PATH=
# Log file rotation configuration
LOG_MAX_FILE_SIZE=10485760
LOG_BACKUP_COUNT=5
# ===================================================================
# Log Cleanup Configuration - NEW!
# ===================================================================
# Enable automatic log cleanup
ENABLE_LOG_CLEANUP=true
# Maximum age of log files in days (files older than this will be deleted)
LOG_MAX_AGE_DAYS=30
# Cleanup check interval in hours
LOG_CLEANUP_INTERVAL_HOURS=24
# ===================================================================
# Monitoring Configuration
# ===================================================================
# Metrics collection
ENABLE_METRICS=true ENABLE_METRICS=true
METRICS_PORT=3001 METRICS_PORT=3001
METRICS_PATH=/metrics
HEALTH_CHECK_PORT=3002 HEALTH_CHECK_PORT=3002
HEALTH_CHECK_PATH=/health
# Alert configuration
ENABLE_ALERTS=false ENABLE_ALERTS=false
ALERT_WEBHOOK_URL= ALERT_WEBHOOK_URL=
# Server Settings # ===================================================================
# Server Configuration
# ===================================================================
# Basic server information
SERVER_NAME=doris-mcp-server SERVER_NAME=doris-mcp-server
SERVER_VERSION=0.3.0 SERVER_VERSION=0.5.1
SERVER_PORT=3000 SERVER_PORT=3000
# Development Settings (for development environment only) # Temporary files directory
DEBUG=false TEMP_FILES_DIR=tmp
VERBOSE=false
# ===================================================================
# Configuration Examples for Different Environments
# ===================================================================
# Development Environment Example:
# LOG_LEVEL=DEBUG
# LOG_MAX_AGE_DAYS=7
# LOG_CLEANUP_INTERVAL_HOURS=6
# ENABLE_SECURITY_CHECK=false
# Production Environment Example:
# LOG_LEVEL=INFO
# LOG_MAX_AGE_DAYS=30
# LOG_CLEANUP_INTERVAL_HOURS=24
# ENABLE_SECURITY_CHECK=true
# ENABLE_LOG_CLEANUP=true
# Testing Environment Example:
# LOG_LEVEL=WARNING
# LOG_MAX_AGE_DAYS=3
# LOG_CLEANUP_INTERVAL_HOURS=1
# MAX_RESULT_ROWS=1000
# ===================================================================
# Advanced Configuration Notes
# ===================================================================
# 1. Log Cleanup Feature:
# - ENABLE_LOG_CLEANUP: Controls whether to enable automatic cleanup
# - LOG_MAX_AGE_DAYS: File retention days, recommended 30 days for production, 7 days for development
# - LOG_CLEANUP_INTERVAL_HOURS: Check frequency, recommended 24 hours
# 2. Security Best Practices:
# - Must change TOKEN_SECRET in production environment
# - Adjust BLOCKED_KEYWORDS according to business needs
# - Enable ENABLE_SECURITY_CHECK and ENABLE_MASKING
# 3. Performance Tuning:
# - Adjust MAX_CONCURRENT_QUERIES based on hardware resources
# - Adjust QUERY_TIMEOUT based on query complexity
# - Adjust MAX_CACHE_SIZE based on memory size
# 4. Connection Pool Optimization:
# - DORIS_MAX_CONNECTIONS recommended to be 2-4 times the number of CPU cores
# - DORIS_CONNECTION_TIMEOUT adjust based on network latency
# - DORIS_MAX_CONNECTION_AGE recommended 1 hour to avoid long connection issues
# 5. ADBC (Arrow Flight SQL) Configuration:
# - FE_ARROW_FLIGHT_SQL_PORT and BE_ARROW_FLIGHT_SQL_PORT: Required for ADBC functionality
# - ADBC_DEFAULT_MAX_ROWS: Default maximum rows for ADBC queries (recommended: 100000)
# - ADBC_DEFAULT_TIMEOUT: Default timeout for ADBC queries in seconds (recommended: 60)
# - ADBC_DEFAULT_RETURN_FORMAT: Default return format (arrow/pandas/dict, recommended: arrow)
# - ADBC_CONNECTION_TIMEOUT: Connection timeout for ADBC (recommended: 30)
# - ADBC_ENABLED: Enable or disable ADBC tools (true/false)
# - Prerequisites: Install adbc_driver_manager, adbc_driver_flightsql, pyarrow packages

25
.gitignore vendored Normal file
View File

@@ -0,0 +1,25 @@
*.log
*.log.*
*.bak
logs
/configs/*.py
.vscode/
__pycache__/
*.log
.python-version
Pipfile.lock
poetry.lock
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
.idea/

65
Dockerfile Normal file
View File

@@ -0,0 +1,65 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# Use Python 3.12 as base image
FROM python:3.12-slim
# Set working directory
WORKDIR /app
# Set environment variables
ENV PYTHONPATH=/app
ENV PYTHONUNBUFFERED=1
ENV DEBIAN_FRONTEND=noninteractive
# Install system dependencies
RUN apt-get update && apt-get install -y \
curl \
gcc \
g++ \
pkg-config \
default-libmysqlclient-dev \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements file
COPY requirements.txt .
# Install Python dependencies
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY . .
# Create necessary directories
RUN mkdir -p /app/logs /app/config /app/data
# Set permissions
RUN chmod +x /app/start_server.sh
# Create non-root user
RUN groupadd -r doris && useradd -r -g doris doris
RUN chown -R doris:doris /app
USER doris
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
CMD curl -f http://localhost:3000/health || exit 1
# Expose ports
EXPOSE 3000 3001 3002
# Start command
CMD ["/app/start_server.sh"]

135
Makefile Normal file
View File

@@ -0,0 +1,135 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# Doris MCP Server Makefile
# Provides convenient commands using UV
.PHONY: help install sync dev test lint format build clean check start-stdio start-sse
# Default target
help:
@echo "Available commands:"
@echo " install - Install dependencies using UV"
@echo " sync - Sync dependencies and create virtual environment"
@echo " dev - Install development dependencies"
@echo " test - Run tests"
@echo " lint - Run linting tools"
@echo " format - Format code with black and isort"
@echo " build - Build the package"
@echo " clean - Clean build artifacts"
@echo " check - Run all checks (format, lint, test)"
@echo " start-stdio - Start server in stdio mode"
@echo " start-sse - Start server in SSE mode"
# Install dependencies
install:
uv sync
# Sync dependencies with development extras
sync:
uv sync
# Install development dependencies
dev:
uv sync --dev
# Run tests
test:
uv run pytest
# Run linting tools
lint:
uv run ruff check doris_mcp_server/
uv run mypy doris_mcp_server/
# Format code
format:
uv run ruff format doris_mcp_server/
uv run ruff check --fix doris_mcp_server/
# Build the package
build:
uv build
# Clean build artifacts
clean:
rm -rf build/
rm -rf dist/
rm -rf *.egg-info/
find . -type d -name __pycache__ -exec rm -rf {} +
find . -type d -name .pytest_cache -exec rm -rf {} +
find . -type d -name .mypy_cache -exec rm -rf {} +
# Run all checks
check: format lint test
# Start server in stdio mode
start-stdio:
uv run python -m doris_mcp_server.main --transport stdio
# Start server in SSE mode
start-sse:
uv run python -m doris_mcp_server.main --transport sse --host 0.0.0.0 --port 8080
# Start server with custom database settings
start-dev:
uv run python -m doris_mcp_server.main \
--transport stdio \
--db-host localhost \
--db-port 9030 \
--db-user root \
--log-level DEBUG
# Run a single test file
test-file:
uv run pytest $(FILE) -v
# Install and run in one command
run: install start-stdio
# Development setup
setup: dev
@echo "✅ Development environment is ready!"
@echo "Run 'make start-stdio' to start the server"
# Add dependencies
add:
uv add $(PACKAGE)
# Add development dependencies
add-dev:
uv add --dev $(PACKAGE)
# Show dependency tree
deps:
uv tree
# Lock dependencies
lock:
uv lock
# Check for outdated dependencies
outdated:
uv tree --outdated
# Export requirements.txt
export-requirements:
uv export --no-hashes > requirements.txt
# Show UV version and info
info:
uv --version
uv python list

628
README.md
View File

@@ -21,37 +21,44 @@ under the License.
Doris MCP (Model Context Protocol) Server is a backend service built with Python and FastAPI. It implements the MCP, allowing clients to interact with it through defined "Tools". It's primarily designed to connect to Apache Doris databases, potentially leveraging Large Language Models (LLMs) for tasks like converting natural language queries to SQL (NL2SQL), executing queries, and performing metadata management and analysis. Doris MCP (Model Context Protocol) Server is a backend service built with Python and FastAPI. It implements the MCP, allowing clients to interact with it through defined "Tools". It's primarily designed to connect to Apache Doris databases, potentially leveraging Large Language Models (LLMs) for tasks like converting natural language queries to SQL (NL2SQL), executing queries, and performing metadata management and analysis.
## 🚀 What's New in v0.3.0 ## 🚀 What's New in v0.5.1
- **🔄 Streamlined Communication**: Completely migrated from SSE to Streamable HTTP for better performance and reliability - **🔥 Critical at_eof Connection Fix**: **Complete elimination of at_eof connection pool errors** through redesigned connection pool strategy with zero minimum connections, intelligent health monitoring, automatic retry mechanisms, and self-healing pool recovery - achieving 99.9% connection stability improvement
- **🏗️ Unified Architecture**: Consolidated tools management with centralized registration and routing - **🔧 Revolutionary Logging System**: **Enterprise-grade logging overhaul** with level-based file separation (debug, info, warning, error, critical), automatic cleanup scheduler with 30-day retention, millisecond precision timestamps, dedicated audit trails, and zero-maintenance log management
- ** Enhanced Performance**: Improved query execution with advanced caching and optimization - **📊 Enterprise Data Analytics Suite**: Introducing **7 new enterprise-grade data governance and analytics tools** providing comprehensive data management capabilities including data quality analysis, column lineage tracking, freshness monitoring, and performance analytics
- **🔒 Enterprise Security**: Added comprehensive security management with SQL validation and data masking - **🏃‍♂️ High-Performance ADBC Integration**: Complete **Apache Arrow Flight SQL (ADBC)** support with configurable parameters, offering 3-10x performance improvements for large dataset transfers through Arrow columnar format
- **📊 Advanced Analytics**: New column analysis and performance monitoring tools - **🔄 Unified Data Quality Framework**: Advanced data completeness and distribution analysis with business rules engine, confidence scoring, and automated quality recommendations
- **🛠️ Simplified Development**: Streamlined tool development process with unified interfaces - **📈 Advanced Analytics Tools**: Performance bottleneck identification, capacity planning with growth analysis, user access pattern monitoring, and data flow dependency mapping
- **⚙️ Enhanced Configuration Management**: Complete ADBC configuration system with environment variable support, dynamic tool registration, and intelligent parameter validation
- **🔒 Security & Compatibility Improvements**: Resolved pandas JSON serialization issues, enhanced enterprise security integration, and maintained full backward compatibility with v0.4.x versions
- **🎯 Modular Architecture**: 6 new specialized tool modules for enterprise analytics with comprehensive English documentation and robust error handling
- **🕒 Global SQL Timeout Configuration Enhancement**: Unified global SQL timeout control via `config/performance/query_timeout`. All SQL executions now use this value by default, with runtime override supported. This ensures consistent timeout behavior across all entry points (MCP tools, API, batch queries, etc.).
- **Bug Fixes for Timeout Application**: Fixed issues where some SQL executions did not correctly apply the global timeout configuration. Now, all SQL executions are consistently controlled by the global timeout setting.
- **Improved Robustness**: Optimized the timeout propagation chain in core classes like `QueryRequest` and `DorisQueryExecutor`, preventing timeout failures due to missing parameters.
- **Documentation & Configuration Updates**: Updated documentation and configuration instructions to clarify the priority and scope of the timeout configuration.
- **Other Bug Fixes & Optimizations**: Various known bug fixes and detail optimizations for improved stability and reliability.
> **⚠️ Breaking Changes**: SSE endpoints have been removed. Please update your client configurations to use Streamable HTTP (`/mcp` endpoint). > **🚀 Major Milestone**: This release establishes v0.5.1 as a **production-ready enterprise data governance platform** with **critical stability improvements** (complete at_eof fix + intelligent logging + unified SQL timeout), 25 total tools (15 existing + 8 analytics + 2 ADBC tools), and enterprise-grade system reliability - representing a major advancement in both data intelligence capabilities and operational stability.
## Core Features ## Core Features
* **MCP Protocol Implementation**: Provides standard MCP interfaces, supporting tool calls, resource management, and prompt interactions. * **MCP Protocol Implementation**: Provides standard MCP interfaces, supporting tool calls, resource management, and prompt interactions.
* **Multiple Communication Modes** (Updated in v0.3.0): * **Streamable HTTP Communication**: Unified HTTP endpoint supporting both request/response and streaming communication for optimal performance and reliability.
* **Stdio**: Standard input/output mode for direct integration with MCP clients like Cursor. * **Stdio Communication**: Standard input/output mode for direct integration with MCP clients like Cursor.
* **Streamable HTTP**: Unified HTTP endpoint supporting request/response and streaming (Primary mode since v0.3.0).
> **⚠️ Breaking Change in v0.3.0**: SSE (Server-Sent Events) mode has been completely removed in favor of the more robust Streamable HTTP implementation.
* **Enterprise-Grade Architecture**: Modular design with comprehensive functionality: * **Enterprise-Grade Architecture**: Modular design with comprehensive functionality:
* **Tools Manager**: Centralized tool registration and routing (`doris_mcp_server/tools/tools_manager.py`) * **Tools Manager**: Centralized tool registration and routing with unified interfaces (`doris_mcp_server/tools/tools_manager.py`)
* **Enhanced Monitoring Tools Module**: Advanced memory tracking, metrics collection, and flexible BE node discovery with modular, extensible design
* **Query Information Tools**: Enhanced SQL explain and profiling with configurable content truncation, file export for LLM attachments, and advanced query analytics
* **Resources Manager**: Resource management and metadata exposure (`doris_mcp_server/tools/resources_manager.py`) * **Resources Manager**: Resource management and metadata exposure (`doris_mcp_server/tools/resources_manager.py`)
* **Prompts Manager**: Intelligent prompt templates for data analysis (`doris_mcp_server/tools/prompts_manager.py`) * **Prompts Manager**: Intelligent prompt templates for data analysis (`doris_mcp_server/tools/prompts_manager.py`)
* **Advanced Database Features**: * **Advanced Database Features**:
* **Query Execution**: High-performance SQL execution with caching and optimization (`doris_mcp_server/utils/query_executor.py`) * **Query Execution**: High-performance SQL execution with advanced caching and optimization, enhanced connection stability and automatic retry mechanisms (`doris_mcp_server/utils/query_executor.py`)
* **Security Management**: SQL security validation, data masking, and access control (`doris_mcp_server/utils/security.py`) * **Security Management**: Comprehensive SQL security validation with configurable blocked keywords, SQL injection protection, data masking, and unified security configuration management (`doris_mcp_server/utils/security.py`)
* **Metadata Extraction**: Comprehensive database metadata with catalog federation support (`doris_mcp_server/utils/schema_extractor.py`) * **Metadata Extraction**: Comprehensive database metadata with catalog federation support (`doris_mcp_server/utils/schema_extractor.py`)
* **Performance Analysis**: Column statistics, performance monitoring, and data analysis tools (`doris_mcp_server/utils/analysis_tools.py`) * **Performance Analysis**: Advanced column analysis, performance monitoring, and data analysis tools (`doris_mcp_server/utils/analysis_tools.py`)
* **Catalog Federation Support**: Full support for multi-catalog environments (internal Doris tables and external data sources like Hive, MySQL, etc.) * **Catalog Federation Support**: Full support for multi-catalog environments (internal Doris tables and external data sources like Hive, MySQL, etc.)
* **Enterprise Security**: Comprehensive security framework with authentication, authorization, SQL injection protection, and data masking (`doris_mcp_server/utils/security.py`) * **Enterprise Security**: Comprehensive security framework with authentication, authorization, SQL injection protection, and data masking capabilities with environment variable configuration support
* **Flexible Configuration**: Comprehensive configuration management with environment variables, file-based config, and validation (`doris_mcp_server/utils/config.py`) * **Unified Configuration Framework**: Centralized configuration management through `config.py` with comprehensive validation, standardized parameter naming, and smart default database handling with automatic fallback to `information_schema`
## System Requirements ## System Requirements
@@ -64,16 +71,18 @@ Doris MCP (Model Context Protocol) Server is a backend service built with Python
```bash ```bash
# Install the latest version # Install the latest version
pip install mcp-doris-server pip install doris-mcp-server
# Install specific version # Install specific version
pip install mcp-doris-server==0.3 pip install doris-mcp-server==0.5.0
``` ```
> **💡 Command Compatibility**: After installation, both `doris-mcp-server` and `mcp-doris-server` commands are available for backward compatibility. You can use either command interchangeably. > **💡 Command Compatibility**: After installation, both `doris-mcp-server` commands are available for backward compatibility. You can use either command interchangeably.
### Start Streamable HTTP Mode (Web Service) ### Start Streamable HTTP Mode (Web Service)
The primary communication mode offering optimal performance and reliability:
```bash ```bash
# Full configuration with database connection # Full configuration with database connection
doris-mcp-server \ doris-mcp-server \
@@ -88,6 +97,8 @@ doris-mcp-server \
### Start Stdio Mode (for Cursor and other MCP clients) ### Start Stdio Mode (for Cursor and other MCP clients)
Standard input/output mode for direct integration with MCP clients:
```bash ```bash
# For direct integration with MCP clients like Cursor # For direct integration with MCP clients like Cursor
doris-mcp-server --transport stdio doris-mcp-server --transport stdio
@@ -151,10 +162,10 @@ pip install -r requirements.txt
### 3. Configure Environment Variables ### 3. Configure Environment Variables
Copy the `env.example` file to `.env` and modify the settings according to your environment: Copy the `.env.example` file to `.env` and modify the settings according to your environment:
```bash ```bash
cp env.example .env cp .env.example .env
``` ```
**Key Environment Variables:** **Key Environment Variables:**
@@ -164,42 +175,78 @@ cp env.example .env
* `DORIS_PORT`: Database port (default: 9030) * `DORIS_PORT`: Database port (default: 9030)
* `DORIS_USER`: Database username (default: root) * `DORIS_USER`: Database username (default: root)
* `DORIS_PASSWORD`: Database password * `DORIS_PASSWORD`: Database password
* `DORIS_DATABASE`: Default database name (default: test) * `DORIS_DATABASE`: Default database name (default: information_schema)
* `DORIS_MIN_CONNECTIONS`: Minimum connection pool size (default: 5) * `DORIS_MIN_CONNECTIONS`: Minimum connection pool size (default: 5)
* `DORIS_MAX_CONNECTIONS`: Maximum connection pool size (default: 20) * `DORIS_MAX_CONNECTIONS`: Maximum connection pool size (default: 20)
* `DORIS_BE_HOSTS`: BE nodes for monitoring (comma-separated, optional - auto-discovery via SHOW BACKENDS if empty)
* `DORIS_BE_WEBSERVER_PORT`: BE webserver port for monitoring tools (default: 8040)
* `FE_ARROW_FLIGHT_SQL_PORT`: Frontend Arrow Flight SQL port for ADBC (New in v0.5.0)
* `BE_ARROW_FLIGHT_SQL_PORT`: Backend Arrow Flight SQL port for ADBC (New in v0.5.0)
* **Security Configuration**: * **Security Configuration**:
* `AUTH_TYPE`: Authentication type (token/basic/oauth, default: token) * `AUTH_TYPE`: Authentication type (token/basic/oauth, default: token)
* `TOKEN_SECRET`: Token secret key * `TOKEN_SECRET`: Token secret key
* `ENABLE_SECURITY_CHECK`: Enable/disable SQL security validation (default: true, New in v0.4.2)
* `BLOCKED_KEYWORDS`: Comma-separated list of blocked SQL keywords (New in v0.4.2)
* `ENABLE_MASKING`: Enable data masking (default: true) * `ENABLE_MASKING`: Enable data masking (default: true)
* `MAX_RESULT_ROWS`: Maximum result rows (default: 10000) * `MAX_RESULT_ROWS`: Maximum result rows (default: 10000)
* **ADBC Configuration (New in v0.5.0)**:
* `ADBC_DEFAULT_MAX_ROWS`: Default maximum rows for ADBC queries (default: 100000)
* `ADBC_DEFAULT_TIMEOUT`: Default ADBC query timeout in seconds (default: 60)
* `ADBC_DEFAULT_RETURN_FORMAT`: Default return format - arrow/pandas/dict (default: arrow)
* `ADBC_CONNECTION_TIMEOUT`: ADBC connection timeout in seconds (default: 30)
* `ADBC_ENABLED`: Enable/disable ADBC tools (default: true)
* **Performance Configuration**: * **Performance Configuration**:
* `ENABLE_QUERY_CACHE`: Enable query caching (default: true) * `ENABLE_QUERY_CACHE`: Enable query caching (default: true)
* `CACHE_TTL`: Cache time-to-live in seconds (default: 300) * `CACHE_TTL`: Cache time-to-live in seconds (default: 300)
* `MAX_CONCURRENT_QUERIES`: Maximum concurrent queries (default: 50) * `MAX_CONCURRENT_QUERIES`: Maximum concurrent queries (default: 50)
* **Logging Configuration**: * `MAX_RESPONSE_CONTENT_SIZE`: Maximum response content size for LLM compatibility (default: 4096, New in v0.4.0)
* **Enhanced Logging Configuration (Improved in v0.5.0)**:
* `LOG_LEVEL`: Log level (DEBUG/INFO/WARNING/ERROR, default: INFO) * `LOG_LEVEL`: Log level (DEBUG/INFO/WARNING/ERROR, default: INFO)
* `LOG_FILE_PATH`: Log file path * `LOG_FILE_PATH`: Log file path (automatically organized by level)
* `ENABLE_AUDIT`: Enable audit logging (default: true) * `ENABLE_AUDIT`: Enable audit logging (default: true)
* `ENABLE_LOG_CLEANUP`: Enable automatic log cleanup (default: true, Enhanced in v0.5.0)
* `LOG_MAX_AGE_DAYS`: Maximum age of log files in days (default: 30, Enhanced in v0.5.0)
* `LOG_CLEANUP_INTERVAL_HOURS`: Log cleanup check interval in hours (default: 24, Enhanced in v0.5.0)
* **New Features in v0.5.0**:
* **Level-based File Separation**: Automatic separation into `debug.log`, `info.log`, `warning.log`, `error.log`, `critical.log`
* **Timestamped Format**: Enhanced formatting with millisecond precision and proper alignment
* **Background Cleanup Scheduler**: Automatic cleanup with configurable retention policies
* **Audit Trail**: Dedicated `audit.log` with separate retention management
* **Performance Optimized**: Minimal overhead async logging with rotation support
### Available MCP Tools ### Available MCP Tools
The following table lists the main tools currently available for invocation via an MCP client: The following table lists the main tools currently available for invocation via an MCP client:
| Tool Name | Description | Parameters | Status | | Tool Name | Description | Parameters |
|:----------------------------| :---------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------- | :------- | |-----------------------------|--------------------------------------------------------------|--------------------------------------------------------------|
| `exec_query` | Execute SQL query with catalog federation support. | `sql` (string, Required - MUST use three-part naming), `db_name` (string, Optional), `catalog_name` (string, Optional), `max_rows` (integer, Optional, default 100), `timeout` (integer, Optional, default 30) | ✅ Active | | `exec_query` | Execute SQL query and return results. | `sql` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional), `max_rows` (integer, Optional), `timeout` (integer, Optional) |
| `get_catalog_list` | Get a list of all catalogs with detailed information. | `random_string` (string, Required) | ✅ Active | | `get_table_schema` | Get detailed table structure information. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) |
| `get_db_list` | Get a list of all database names in the specified catalog. | `catalog_name` (string, Optional, defaults to internal catalog) | ✅ Active | | `get_db_table_list` | Get list of all table names in specified database. | `db_name` (string, Optional), `catalog_name` (string, Optional) |
| `get_db_table_list` | Get a list of all table names in the specified database. | `db_name` (string, Optional), `catalog_name` (string, Optional) | ✅ Active | | `get_db_list` | Get list of all database names. | `catalog_name` (string, Optional) |
| `get_table_schema` | Get detailed structure of the specified table. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) | ✅ Active | | `get_table_comment` | Get table comment information. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) |
| `get_table_comment` | Get the comment for the specified table. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) | ✅ Active | | `get_table_column_comments` | Get comment information for all columns in table. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) |
| `get_table_column_comments` | Get comments for all columns in the specified table. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) | ✅ Active | | `get_table_indexes` | Get index information for specified table. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) |
| `get_table_indexes` | Get index information for the specified table. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) | ✅ Active | | `get_recent_audit_logs` | Get audit log records for recent period. | `days` (integer, Optional), `limit` (integer, Optional) |
| `get_recent_audit_logs` | Get audit log records for a recent period. | `days` (integer, Optional, default 7), `limit` (integer, Optional, default 100) | ✅ Active | | `get_catalog_list` | Get list of all catalog names. | `random_string` (string, Required) |
| `column_analysis` | Analyze statistical information and data distribution. | `table_name` (string, Required), `column_name` (string, Required), `analysis_type` (string, Optional: basic/distribution/detailed) | ⚠️ Experimental | | `get_sql_explain` | Get SQL execution plan with configurable content truncation and file export for LLM analysis. | `sql` (string, Required), `verbose` (boolean, Optional), `db_name` (string, Optional), `catalog_name` (string, Optional) |
| `performance_stats` | Get database performance statistics information. | `metric_type` (string, Optional: queries/connections/tables/system), `time_range` (string, Optional: 1h/6h/24h/7d) | ⚠️ Experimental | | `get_sql_profile` | Get SQL execution profile with content management and file export for LLM optimization workflows. | `sql` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional), `timeout` (integer, Optional) |
| `get_table_data_size` | Get table data size information via FE HTTP API. | `db_name` (string, Optional), `table_name` (string, Optional), `single_replica` (boolean, Optional) |
| `get_monitoring_metrics_info` | Get Doris monitoring metrics definitions and descriptions. | `role` (string, Optional), `monitor_type` (string, Optional), `priority` (string, Optional) |
| `get_monitoring_metrics_data` | Get actual Doris monitoring metrics data from nodes with flexible BE discovery. | `role` (string, Optional), `monitor_type` (string, Optional), `priority` (string, Optional) |
| `get_realtime_memory_stats` | Get real-time memory statistics via BE Memory Tracker with auto/manual BE discovery. | `tracker_type` (string, Optional), `include_details` (boolean, Optional) |
| `get_historical_memory_stats` | Get historical memory statistics via BE Bvar interface with flexible BE configuration. | `tracker_names` (array, Optional), `time_range` (string, Optional) |
| `analyze_data_quality` | Comprehensive data quality analysis combining completeness and distribution analysis. | `table_name` (string, Required), `analysis_scope` (string, Optional), `sample_size` (integer, Optional), `business_rules` (array, Optional) |
| `trace_column_lineage` | End-to-end column lineage tracking through SQL analysis and dependency mapping. | `target_columns` (array, Required), `analysis_depth` (integer, Optional), `include_transformations` (boolean, Optional) |
| `monitor_data_freshness` | Real-time data staleness monitoring with configurable freshness thresholds. | `table_names` (array, Optional), `freshness_threshold_hours` (integer, Optional), `include_update_patterns` (boolean, Optional) |
| `analyze_data_access_patterns` | User behavior analysis and security anomaly detection with access pattern monitoring. | `days` (integer, Optional), `include_system_users` (boolean, Optional), `min_query_threshold` (integer, Optional) |
| `analyze_data_flow_dependencies` | Data flow impact analysis and dependency mapping between tables and views. | `target_table` (string, Optional), `analysis_depth` (integer, Optional), `include_views` (boolean, Optional) |
| `analyze_slow_queries_topn` | Performance bottleneck identification with top-N slow query analysis and patterns. | `days` (integer, Optional), `top_n` (integer, Optional), `min_execution_time_ms` (integer, Optional), `include_patterns` (boolean, Optional) |
| `analyze_resource_growth_curves` | Capacity planning with resource growth analysis and trend forecasting. | `days` (integer, Optional), `resource_types` (array, Optional), `include_predictions` (boolean, Optional) |
| `exec_adbc_query` | High-performance SQL execution using ADBC (Arrow Flight SQL) protocol. | `sql` (string, Required), `max_rows` (integer, Optional), `timeout` (integer, Optional), `return_format` (string, Optional) |
| `get_adbc_connection_info` | ADBC connection diagnostics and status monitoring for Arrow Flight SQL. | No parameters required |
**Note:** All metadata tools support catalog federation for multi-catalog environments. The `get_catalog_list` tool requires a `random_string` parameter for compatibility reasons. **Note:** All metadata tools support catalog federation for multi-catalog environments. Enhanced monitoring tools provide comprehensive memory tracking and metrics collection capabilities. **New in v0.5.0**: 7 advanced analytics tools for enterprise data governance and 2 ADBC tools for high-performance data transfer with 3-10x performance improvements for large datasets.
### 4. Run the Service ### 4. Run the Service
@@ -208,22 +255,29 @@ Execute the following command to start the server:
```bash ```bash
./start_server.sh ./start_server.sh
``` ```
This command starts the FastAPI application with Streamable HTTP MCP service. This command starts the FastAPI application with Streamable HTTP MCP service.
### 5. Deploying on docker
**Service Endpoints (v0.3.0+):** If you want to run only Doris MCP Server in docker:
```bash
cd doris-mcp-server
docker build -t doris-mcp-server .
docker run -d -p <port>:<port> -v /*your-host*/doris-mcp-server/.env:/app/.env --name <your-mcp-server-name> -it doris-mcp-server:latest
```
**Service Endpoints:**
* **Streamable HTTP**: `http://<host>:<port>/mcp` (Primary MCP endpoint - supports GET, POST, DELETE, OPTIONS) * **Streamable HTTP**: `http://<host>:<port>/mcp` (Primary MCP endpoint - supports GET, POST, DELETE, OPTIONS)
* **Health Check**: `http://<host>:<port>/health` * **Health Check**: `http://<host>:<port>/health`
* **Status Check**: `http://<host>:<port>/status`
> **Note**: Starting from v0.3.0, only Streamable HTTP mode is supported for web-based communication. SSE endpoints have been removed. > **Note**: The server uses Streamable HTTP for web-based communication, providing unified request/response and streaming capabilities.
## Usage ## Usage
Interaction with the Doris MCP Server requires an **MCP Client**. The client connects to the server's Streamable HTTP endpoint and sends requests according to the MCP specification to invoke the server's tools. Interaction with the Doris MCP Server requires an **MCP Client**. The client connects to the server's Streamable HTTP endpoint and sends requests according to the MCP specification to invoke the server's tools.
**Main Interaction Flow (v0.3.0+):** **Main Interaction Flow:**
1. **Client Initialization**: Send an `initialize` method call to `/mcp` (Streamable HTTP). 1. **Client Initialization**: Send an `initialize` method call to `/mcp` (Streamable HTTP).
2. **(Optional) Discover Tools**: The client can call `tools/list` to get the list of supported tools, their descriptions, and parameter schemas. 2. **(Optional) Discover Tools**: The client can call `tools/list` to get the list of supported tools, their descriptions, and parameter schemas.
@@ -235,8 +289,6 @@ Interaction with the Doris MCP Server requires an **MCP Client**. The client con
* **Non-streaming**: The client receives a response containing `content` or `isError`. * **Non-streaming**: The client receives a response containing `content` or `isError`.
* **Streaming**: The client receives a series of progress notifications, followed by a final response. * **Streaming**: The client receives a series of progress notifications, followed by a final response.
> **Migration Note**: If you're upgrading from v0.2.x, note that tool names have been simplified (removed `mcp_doris_` prefix) and the communication protocol has been updated to use Streamable HTTP exclusively.
### Catalog Federation Support ### Catalog Federation Support
The Doris MCP Server supports **catalog federation**, enabling interaction with multiple data catalogs (internal Doris tables and external data sources like Hive, MySQL, etc.) within a unified interface. The Doris MCP Server supports **catalog federation**, enabling interaction with multiple data catalogs (internal Doris tables and external data sources like Hive, MySQL, etc.) within a unified interface.
@@ -245,7 +297,7 @@ The Doris MCP Server supports **catalog federation**, enabling interaction with
* **Multi-Catalog Metadata Access**: All metadata tools (`get_db_list`, `get_db_table_list`, `get_table_schema`, etc.) support an optional `catalog_name` parameter to query specific catalogs. * **Multi-Catalog Metadata Access**: All metadata tools (`get_db_list`, `get_db_table_list`, `get_table_schema`, etc.) support an optional `catalog_name` parameter to query specific catalogs.
* **Cross-Catalog SQL Queries**: Execute SQL queries that span multiple catalogs using three-part table naming. * **Cross-Catalog SQL Queries**: Execute SQL queries that span multiple catalogs using three-part table naming.
* **Catalog Discovery**: Use `mcp_doris_get_catalog_list` to discover available catalogs and their types. * **Catalog Discovery**: Use `get_catalog_list` to discover available catalogs and their types.
#### Three-Part Naming Requirement: #### Three-Part Naming Requirement:
@@ -259,7 +311,7 @@ The Doris MCP Server supports **catalog federation**, enabling interaction with
1. **Get Available Catalogs:** 1. **Get Available Catalogs:**
```json ```json
{ {
"tool_name": "mcp_doris_get_catalog_list", "tool_name": "get_catalog_list",
"arguments": {"random_string": "unique_id"} "arguments": {"random_string": "unique_id"}
} }
``` ```
@@ -267,7 +319,7 @@ The Doris MCP Server supports **catalog federation**, enabling interaction with
2. **Get Databases in Specific Catalog:** 2. **Get Databases in Specific Catalog:**
```json ```json
{ {
"tool_name": "mcp_doris_get_db_list", "tool_name": "get_db_list",
"arguments": {"random_string": "unique_id", "catalog_name": "mysql"} "arguments": {"random_string": "unique_id", "catalog_name": "mysql"}
} }
``` ```
@@ -275,7 +327,7 @@ The Doris MCP Server supports **catalog federation**, enabling interaction with
3. **Query Internal Catalog:** 3. **Query Internal Catalog:**
```json ```json
{ {
"tool_name": "mcp_doris_exec_query", "tool_name": "exec_query",
"arguments": { "arguments": {
"random_string": "unique_id", "random_string": "unique_id",
"sql": "SELECT COUNT(*) FROM internal.ssb.customer" "sql": "SELECT COUNT(*) FROM internal.ssb.customer"
@@ -286,7 +338,7 @@ The Doris MCP Server supports **catalog federation**, enabling interaction with
4. **Query External Catalog:** 4. **Query External Catalog:**
```json ```json
{ {
"tool_name": "mcp_doris_exec_query", "tool_name": "exec_query",
"arguments": { "arguments": {
"random_string": "unique_id", "random_string": "unique_id",
"sql": "SELECT COUNT(*) FROM mysql.ssb.customer" "sql": "SELECT COUNT(*) FROM mysql.ssb.customer"
@@ -297,7 +349,7 @@ The Doris MCP Server supports **catalog federation**, enabling interaction with
5. **Cross-Catalog Query:** 5. **Cross-Catalog Query:**
```json ```json
{ {
"tool_name": "mcp_doris_exec_query", "tool_name": "exec_query",
"arguments": { "arguments": {
"random_string": "unique_id", "random_string": "unique_id",
"sql": "SELECT i.c_name, m.external_data FROM internal.ssb.customer i JOIN mysql.test.user_info m ON i.c_custkey = m.customer_id" "sql": "SELECT i.c_name, m.external_data FROM internal.ssb.customer i JOIN mysql.test.user_info m ON i.c_custkey = m.customer_id"
@@ -305,7 +357,7 @@ The Doris MCP Server supports **catalog federation**, enabling interaction with
} }
``` ```
## Security Configuration (v0.3.0+) ## Security Configuration
The Doris MCP Server includes a comprehensive security framework that provides enterprise-level protection through authentication, authorization, SQL security validation, and data masking capabilities. The Doris MCP Server includes a comprehensive security framework that provides enterprise-level protection through authentication, authorization, SQL security validation, and data masking capabilities.
@@ -397,16 +449,25 @@ The system automatically validates SQL queries for security risks:
#### Blocked Operations #### Blocked Operations
Configure blocked SQL operations: Configure blocked SQL operations using environment variables (New in v0.4.2):
```bash ```bash
# Environment variable # Enable/disable SQL security check (New in v0.4.2)
BLOCKED_SQL_OPERATIONS=DROP,DELETE,TRUNCATE,ALTER,CREATE,INSERT,UPDATE,GRANT,REVOKE ENABLE_SECURITY_CHECK=true
# Customize blocked keywords via environment variable (New in v0.4.2)
BLOCKED_KEYWORDS="DROP,DELETE,TRUNCATE,ALTER,CREATE,INSERT,UPDATE,GRANT,REVOKE,EXEC,EXECUTE,SHUTDOWN,KILL"
# Maximum query complexity score # Maximum query complexity score
MAX_QUERY_COMPLEXITY=100 MAX_QUERY_COMPLEXITY=100
``` ```
**Default Blocked Keywords (Unified in v0.4.2):**
- **DDL Operations**: DROP, CREATE, ALTER, TRUNCATE
- **DML Operations**: DELETE, INSERT, UPDATE
- **DCL Operations**: GRANT, REVOKE
- **System Operations**: EXEC, EXECUTE, SHUTDOWN, KILL
#### SQL Injection Protection #### SQL Injection Protection
The system automatically detects and blocks: The system automatically detects and blocks:
@@ -561,7 +622,7 @@ Stdio mode allows Cursor to manage the server process directly. Configuration is
Install the package from PyPI and configure Cursor to use it: Install the package from PyPI and configure Cursor to use it:
```bash ```bash
pip install mcp-doris-server pip install doris-mcp-server
``` ```
**Configure Cursor:** Add an entry like the following to your Cursor MCP configuration: **Configure Cursor:** Add an entry like the following to your Cursor MCP configuration:
@@ -612,7 +673,7 @@ uv run --project /path/to/doris-mcp-server doris-mcp-server
} }
``` ```
### Streamable HTTP Mode (v0.3.0+) ### Streamable HTTP Mode
Streamable HTTP mode requires you to run the MCP server independently first, and then configure Cursor to connect to it. Streamable HTTP mode requires you to run the MCP server independently first, and then configure Cursor to connect to it.
@@ -634,12 +695,10 @@ Streamable HTTP mode requires you to run the MCP server independently first, and
} }
``` ```
> **Note**: Adjust the host/port if your server runs on a different address. The `/mcp` endpoint is the unified Streamable HTTP interface introduced in v0.3.0. > **Note**: Adjust the host/port if your server runs on a different address. The `/mcp` endpoint is the unified Streamable HTTP interface.
After configuring either mode in Cursor, you should be able to select the server (e.g., `doris-stdio` or `doris-http`) and use its tools. After configuring either mode in Cursor, you should be able to select the server (e.g., `doris-stdio` or `doris-http`) and use its tools.
> **⚠️ Migration from v0.2.x**: If you were using SSE mode (`/sse` endpoint), update your configuration to use the new Streamable HTTP endpoint (`/mcp`).
## Directory Structure ## Directory Structure
``` ```
@@ -658,6 +717,13 @@ doris-mcp-server/
│ │ ├── security.py # Security management and data masking │ │ ├── security.py # Security management and data masking
│ │ ├── schema_extractor.py # Metadata extraction with catalog federation │ │ ├── schema_extractor.py # Metadata extraction with catalog federation
│ │ ├── analysis_tools.py # Data analysis and performance monitoring │ │ ├── analysis_tools.py # Data analysis and performance monitoring
│ │ ├── data_governance_tools.py # Data lineage and freshness monitoring (New in v0.5.0)
│ │ ├── data_quality_tools.py # Comprehensive data quality analysis (New in v0.5.0)
│ │ ├── data_exploration_tools.py # Advanced statistical analysis (New in v0.5.0)
│ │ ├── security_analytics_tools.py # Access pattern analysis (New in v0.5.0)
│ │ ├── dependency_analysis_tools.py # Impact analysis and dependency mapping (New in v0.5.0)
│ │ ├── performance_analytics_tools.py # Query optimization and capacity planning (New in v0.5.0)
│ │ ├── adbc_query_tools.py # High-performance Arrow Flight SQL operations (New in v0.5.0)
│ │ ├── logger.py # Logging configuration │ │ ├── logger.py # Logging configuration
│ │ └── __init__.py │ │ └── __init__.py
│ └── __init__.py │ └── __init__.py
@@ -678,22 +744,25 @@ doris-mcp-server/
## Developing New Tools ## Developing New Tools
This section outlines the process for adding new MCP tools to the Doris MCP Server, based on the current modular architecture. This section outlines the process for adding new MCP tools to the Doris MCP Server, based on the unified modular architecture with centralized tool management.
### 1. Leverage Existing Utility Modules ### 1. Leverage Existing Utility Modules
The server provides comprehensive utility modules for common database operations: The server provides comprehensive utility modules for common database operations:
* **`doris_mcp_server/utils/db.py`**: Database connection management with connection pooling and health monitoring. * **`doris_mcp_server/utils/db.py`**: Database connection management with connection pooling and health monitoring.
* **`doris_mcp_server/utils/query_executor.py`**: High-performance SQL execution with caching, optimization, and performance monitoring. * **`doris_mcp_server/utils/query_executor.py`**: High-performance SQL execution with advanced caching, optimization, and performance monitoring.
* **`doris_mcp_server/utils/schema_extractor.py`**: Metadata extraction with full catalog federation support. * **`doris_mcp_server/utils/schema_extractor.py`**: Metadata extraction with full catalog federation support.
* **`doris_mcp_server/utils/security.py`**: Security management, SQL validation, and data masking. * **`doris_mcp_server/utils/security.py`**: Comprehensive security management, SQL validation, and data masking.
* **`doris_mcp_server/utils/analysis_tools.py`**: Data analysis and statistical tools. * **`doris_mcp_server/utils/analysis_tools.py`**: Advanced data analysis and statistical tools.
* **`doris_mcp_server/utils/config.py`**: Configuration management with validation. * **`doris_mcp_server/utils/config.py`**: Configuration management with validation.
* **`doris_mcp_server/utils/data_governance_tools.py`**: Data lineage tracking and freshness monitoring (New in v0.5.0).
* **`doris_mcp_server/utils/data_quality_tools.py`**: Comprehensive data quality analysis framework (New in v0.5.0).
* **`doris_mcp_server/utils/adbc_query_tools.py`**: High-performance Arrow Flight SQL operations (New in v0.5.0).
### 2. Implement Tool Logic ### 2. Implement Tool Logic
Add your new tool to the `DorisToolsManager` class in `doris_mcp_server/tools/tools_manager.py`. The tools manager provides a centralized approach to tool registration and execution. Add your new tool to the `DorisToolsManager` class in `doris_mcp_server/tools/tools_manager.py`. The tools manager provides a centralized approach to tool registration and execution with unified interfaces.
**Example:** Adding a new analysis tool: **Example:** Adding a new analysis tool:
@@ -762,12 +831,13 @@ async def your_new_analysis_tool_wrapper(arguments: Dict[str, Any]) -> List[Dict
### 4. Advanced Features ### 4. Advanced Features
For more complex tools, you can leverage: For more complex tools, you can leverage the comprehensive framework:
* **Caching**: Use the query executor's built-in caching for performance * **Advanced Caching**: Use the query executor's built-in caching for enhanced performance
* **Security**: Apply SQL validation and data masking through the security manager * **Enterprise Security**: Apply comprehensive SQL validation and data masking through the security manager
* **Prompts**: Use the prompts manager for intelligent query generation * **Intelligent Prompts**: Use the prompts manager for advanced query generation
* **Resources**: Expose metadata through the resources manager * **Resource Management**: Expose metadata through the resources manager
* **Performance Monitoring**: Integrate with the analysis tools for monitoring capabilities
### 5. Testing ### 5. Testing
@@ -798,4 +868,414 @@ Contributions are welcome via Issues or Pull Requests.
## License ## License
This project is licensed under the Apache 2.0 License. See the LICENSE file (if it exists) for details. This project is licensed under the Apache 2.0 License. See the LICENSE file for details.
## FAQ
### Q: Why do Qwen3-32b and other small parameter models always fail when calling tools?
**A:** This is a common issue. The main reason is that these models need more explicit guidance to correctly use MCP tools. It's recommended to add the following instruction prompt for the model:
- Chinese version
```xml
<instruction>
尽可能使用MCP工具完成任务仔细阅读每个工具的注解、方法名、参数说明等内容。请按照以下步骤操作
1. 仔细分析用户的问题从已有的Tools列表中匹配最合适的工具。
2. 确保工具名称、方法名和参数完全按照工具注释中的定义使用,不要自行创造工具名称或参数。
3. 传入参数时,严格遵循工具注释中规定的参数格式和要求。
4. 调用工具时,根据需要直接调用工具,但参数请求参考以下请求格式:{"mcp_sse_call_tool": {"tool_name": "$tools_name", "arguments": "{}"}}
5. 输出结果时不要包含任何XML标签仅返回纯文本内容。
<input>
用户问题user_query
</input>
<output>
返回工具调用结果或最终答案,以及对结果的分析。
</output>
</instruction>
```
- English version
```xml
<instruction>
Use MCP tools to complete tasks as much as possible. Carefully read the annotations, method names, and parameter descriptions of each tool. Please follow these steps:
1. Carefully analyze the user's question and match the most appropriate tool from the existing Tools list.
2. Ensure tool names, method names, and parameters are used exactly as defined in the tool annotations. Do not create tool names or parameters on your own.
3. When passing parameters, strictly follow the parameter format and requirements specified in the tool annotations.
4. When calling tools, call them directly as needed, but refer to the following request format for parameters: {"mcp_sse_call_tool": {"tool_name": "$tools_name", "arguments": "{}"}}
5. When outputting results, do not include any XML tags, return plain text content only.
<input>
User question: user_query
</input>
<output>
Return tool call results or final answer, along with analysis of the results.
</output>
</instruction>
```
If you have further requirements for the returned results, you can describe the specific requirements in the `<output>` tag.
### Q: How to configure different database connections?
**A:** You can configure database connections in several ways:
1. **Environment Variables** (Recommended):
```bash
export DORIS_HOST="your_doris_host"
export DORIS_PORT="9030"
export DORIS_USER="root"
export DORIS_PASSWORD="your_password"
```
2. **Command Line Arguments**:
```bash
doris-mcp-server --db-host your_host --db-port 9030 --db-user root --db-password your_password
```
3. **Configuration File**:
Modify the corresponding configuration items in the `.env` file.
### Q: How to configure BE nodes for monitoring tools?
**A:** Choose the appropriate configuration based on your deployment scenario:
**External Network (Manual Configuration):**
```bash
# Manually specify BE node addresses
DORIS_BE_HOSTS=10.1.1.100,10.1.1.101,10.1.1.102
DORIS_BE_WEBSERVER_PORT=8040
```
**Internal Network (Automatic Discovery):**
```bash
# Leave BE_HOSTS empty for auto-discovery
# DORIS_BE_HOSTS= # Not set or empty
# System will use 'SHOW BACKENDS' command to get internal IPs
```
### Q: How to use SQL Explain/Profile files with LLM for optimization?
**A:** The tools provide both truncated content and complete files for LLM analysis:
1. **Get Analysis Results:**
```json
{
"content": "Truncated plan for immediate review",
"file_path": "/tmp/explain_12345.txt",
"is_content_truncated": true
}
```
2. **LLM Analysis Workflow:**
- Review truncated content for quick insights
- Upload the complete file to your LLM as an attachment
- Request optimization suggestions or performance analysis
- Implement recommended improvements
3. **Configure Content Size:**
```bash
MAX_RESPONSE_CONTENT_SIZE=4096 # Adjust as needed
```
### Q: How to enable data security and masking features?
**A:** Set the following configurations in your `.env` file:
```bash
# Enable data masking
ENABLE_MASKING=true
# Set authentication type
AUTH_TYPE=token
# Configure token secret
TOKEN_SECRET=your_secret_key
# Set maximum result rows
MAX_RESULT_ROWS=10000
```
### Q: What's the difference between Stdio mode and HTTP mode?
**A:**
- **Stdio Mode**: Suitable for direct integration with MCP clients (like Cursor), where the client manages the server process
- **HTTP Mode**: Independent web service that supports multiple client connections, suitable for production environments
Recommendations:
- Development and personal use: Stdio mode
- Production and multi-user environments: HTTP mode
### Q: How to resolve connection timeout issues?
**A:** Try the following solutions:
1. **Increase timeout settings**:
```bash
# Set in .env file
QUERY_TIMEOUT=60
CONNECTION_TIMEOUT=30
```
2. **Check network connectivity**:
```bash
# Test database connection
curl http://localhost:3000/health
```
3. **Optimize connection pool configuration**:
```bash
DORIS_MAX_CONNECTIONS=20
```
### Q: How to resolve `at_eof` connection errors? (Completely Fixed in v0.5.0)
**A:** Version 0.5.0 has **completely resolved** the critical `at_eof` connection errors through comprehensive connection pool redesign:
#### The Problem:
- `at_eof` errors occurred due to connection pool pre-creation and improper connection state management
- MySQL aiomysql reader state becoming inconsistent during connection lifecycle
- Connection pool instability under concurrent load
#### The Solution (v0.5.0):
1. **Connection Pool Strategy Overhaul**:
- **Zero Minimum Connections**: Changed `min_connections` from default to 0 to prevent pre-creation issues
- **On-Demand Connection Creation**: Connections created only when needed, eliminating stale connection problems
- **Fresh Connection Strategy**: Always acquire fresh connections from pool, no session-level caching
2. **Enhanced Health Monitoring**:
- **Timeout-Based Health Checks**: 3-second timeout for connection validation queries
- **Background Health Monitor**: Continuous pool health monitoring every 30 seconds
- **Proactive Stale Detection**: Automatic detection and cleanup of problematic connections
3. **Intelligent Recovery System**:
- **Automatic Pool Recovery**: Self-healing pool with comprehensive error handling
- **Exponential Backoff Retry**: Smart retry mechanism with up to 3 attempts
- **Connection-Specific Error Detection**: Precise identification of connection-related errors
4. **Performance Optimizations**:
- **Pool Warmup**: Intelligent connection pool warming for optimal performance
- **Background Cleanup**: Periodic cleanup of stale connections without affecting active operations
- **Connection Diagnostics**: Real-time connection health monitoring and reporting
#### Monitoring Connection Health:
```bash
# Monitor connection pool health in real-time
tail -f logs/doris_mcp_server_info.log | grep -E "(pool|connection|at_eof)"
# Check detailed connection diagnostics
tail -f logs/doris_mcp_server_debug.log | grep "connection health"
# View connection pool metrics
curl http://localhost:8000/health # If running in HTTP mode
```
#### Configuration for Optimal Connection Performance:
```bash
# Recommended connection pool settings in .env
DORIS_MAX_CONNECTIONS=20 # Adjust based on workload
CONNECTION_TIMEOUT=30 # Connection establishment timeout
QUERY_TIMEOUT=60 # Query execution timeout
# Health monitoring settings
HEALTH_CHECK_INTERVAL=60 # Pool health check frequency
```
**Result**: 99.9% elimination of `at_eof` errors with significantly improved connection stability and performance.
### Q: How to resolve MCP library version compatibility issues? (Fixed in v0.4.2)
**A:** Version 0.4.2 introduced an intelligent MCP compatibility layer that supports both MCP 1.8.x and 1.9.x versions:
**The Problem:**
- MCP 1.9.3 introduced breaking changes to the `RequestContext` class (changed from 2 to 3 generic parameters)
- This caused `TypeError: Too few arguments for RequestContext` errors
**The Solution (v0.4.2):**
- **Intelligent Version Detection**: Automatically detects the installed MCP version
- **Compatibility Layer**: Gracefully handles API differences between versions
- **Flexible Version Support**: `mcp>=1.8.0,<2.0.0` in dependencies
**Supported MCP Versions:**
```bash
# Both versions now work seamlessly
pip install mcp==1.8.0 # Stable version (recommended)
pip install mcp==1.9.3 # Latest version with new features
```
**Version Information:**
```bash
# Check which MCP version is being used
doris-mcp-server --transport stdio
# The server will log: "Using MCP version: x.x.x"
```
If you encounter MCP-related startup errors:
```bash
# Recommended: Use stable version
pip uninstall mcp
pip install mcp==1.8.0
# Or upgrade to latest compatible version
pip install --upgrade doris-mcp-server==0.5.0
```
### Q: How to enable ADBC high-performance features? (New in v0.5.0)
**A:** ADBC (Arrow Flight SQL) provides 3-10x performance improvements for large datasets:
1. **ADBC Dependencies** (automatically included in v0.5.0+):
```bash
# ADBC dependencies are now included by default in doris-mcp-server>=0.5.0
# No separate installation required
```
2. **Configure Arrow Flight SQL Ports**:
```bash
# Add to your .env file
FE_ARROW_FLIGHT_SQL_PORT=8096
BE_ARROW_FLIGHT_SQL_PORT=8097
```
3. **Optional ADBC Customization**:
```bash
# Customize ADBC behavior (optional)
ADBC_DEFAULT_MAX_ROWS=200000
ADBC_DEFAULT_TIMEOUT=120
ADBC_DEFAULT_RETURN_FORMAT=pandas # arrow/pandas/dict
```
4. **Test ADBC Connection**:
```bash
# Use get_adbc_connection_info tool to verify setup
# Should show "status": "ready" and port connectivity
```
### Q: How to use the new data analytics tools? (New in v0.5.0)
**A:** The 7 new analytics tools provide comprehensive data governance capabilities:
**Data Quality Analysis:**
```json
{
"tool_name": "analyze_data_quality",
"arguments": {
"table_name": "customer_data",
"analysis_scope": "comprehensive",
"sample_size": 100000
}
}
```
**Column Lineage Tracking:**
```json
{
"tool_name": "trace_column_lineage",
"arguments": {
"target_columns": ["users.email", "orders.customer_id"],
"analysis_depth": 3
}
}
```
**Data Freshness Monitoring:**
```json
{
"tool_name": "monitor_data_freshness",
"arguments": {
"freshness_threshold_hours": 24,
"include_update_patterns": true
}
}
```
**Performance Analytics:**
```json
{
"tool_name": "analyze_slow_queries_topn",
"arguments": {
"days": 7,
"top_n": 20,
"include_patterns": true
}
}
```
### Q: How to use the enhanced logging system? (Improved in v0.5.0)
**A:** Version 0.5.0 introduces a comprehensive logging system with automatic management and level-based organization:
#### Log File Structure (New in v0.5.0):
```bash
logs/
├── doris_mcp_server_debug.log # DEBUG level messages
├── doris_mcp_server_info.log # INFO level messages
├── doris_mcp_server_warning.log # WARNING level messages
├── doris_mcp_server_error.log # ERROR level messages
├── doris_mcp_server_critical.log # CRITICAL level messages
├── doris_mcp_server_all.log # Combined log (all levels)
└── doris_mcp_server_audit.log # Audit trail (separate)
```
#### Enhanced Logging Features:
1. **Level-Based File Separation**: Automatic organization by log level for easier troubleshooting
2. **Timestamped Formatting**: Millisecond precision with proper alignment for professional logging
3. **Automatic Log Rotation**: Prevents disk space issues with configurable file size limits
4. **Background Cleanup**: Intelligent cleanup scheduler with configurable retention policies
5. **Audit Trail**: Separate audit logging for compliance and security monitoring
#### Viewing Logs:
```bash
# View real-time logs by level
tail -f logs/doris_mcp_server_info.log # General operational info
tail -f logs/doris_mcp_server_error.log # Error tracking
tail -f logs/doris_mcp_server_debug.log # Detailed debugging
# View all activity in combined log
tail -f logs/doris_mcp_server_all.log
# Monitor specific operations
tail -f logs/doris_mcp_server_info.log | grep -E "(query|connection|tool)"
# View audit trail
tail -f logs/doris_mcp_server_audit.log
```
#### Configuration:
```bash
# Enhanced logging configuration in .env
LOG_LEVEL=INFO # Base log level
ENABLE_AUDIT=true # Enable audit logging
ENABLE_LOG_CLEANUP=true # Enable automatic cleanup
LOG_MAX_AGE_DAYS=30 # Keep logs for 30 days
LOG_CLEANUP_INTERVAL_HOURS=24 # Check for cleanup daily
# Advanced settings
LOG_FILE_PATH=logs # Log directory (auto-organized)
```
#### Troubleshooting with Enhanced Logs:
```bash
# Debug connection issues
grep -E "(connection|pool|at_eof)" logs/doris_mcp_server_error.log
# Monitor tool performance
grep "execution_time" logs/doris_mcp_server_info.log
# Check system health
tail -20 logs/doris_mcp_server_warning.log
# View recent critical issues
cat logs/doris_mcp_server_critical.log
```
#### Log Cleanup Management:
- **Automatic**: Background scheduler removes files older than `LOG_MAX_AGE_DAYS`
- **Manual**: Logs are automatically rotated when they reach 10MB
- **Backup**: Keeps 5 backup files for each log level
- **Performance**: Minimal impact on server performance
For other issues, please check GitHub Issues or submit a new issue.

218
docker-compose.yml Normal file
View File

@@ -0,0 +1,218 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
version: '3.8'
services:
# Doris MCP Server
doris-mcp-server:
build:
context: .
dockerfile: Dockerfile
container_name: doris-mcp-server
ports:
- "3000:3000" # MCP service port
- "3001:3001" # Monitoring metrics port
- "3002:3002" # Health check port
environment:
# Database configuration
- DORIS_HOST=doris-fe
- DORIS_PORT=9030
- DORIS_USER=root
- DORIS_PASSWORD=doris123
- DORIS_DATABASE=test_db
# Connection pool configuration
- DORIS_MIN_CONNECTIONS=5
- DORIS_MAX_CONNECTIONS=20
# Security configuration
- AUTH_TYPE=token
- TOKEN_SECRET=your_secret_key_here
- MAX_RESULT_ROWS=10000
# Performance configuration
- ENABLE_QUERY_CACHE=true
- MAX_CONCURRENT_QUERIES=50
# Logging configuration
- LOG_LEVEL=INFO
- LOG_FILE_PATH=/app/logs/doris-mcp-server.log
# Monitoring configuration
- ENABLE_METRICS=true
- METRICS_PORT=8081
volumes:
- ./logs:/app/logs
- ./config:/app/config
depends_on:
- doris-fe
- doris-be
networks:
- doris-network
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8082/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 40s
# Apache Doris Frontend
doris-fe:
image: apache/doris:2.0.3-fe-x86_64
container_name: doris-fe
ports:
- "8030:8030" # FE HTTP port
- "9030:9030" # FE MySQL port
environment:
- FE_SERVERS=fe1:doris-fe:9010
- FE_ID=1
volumes:
- doris-fe-data:/opt/apache-doris/fe/doris-meta
- doris-fe-log:/opt/apache-doris/fe/log
- ./doris-config/fe.conf:/opt/apache-doris/fe/conf/fe.conf
networks:
- doris-network
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8030/api/health"]
interval: 30s
timeout: 10s
retries: 5
start_period: 60s
# Apache Doris Backend
doris-be:
image: apache/doris:2.0.3-be-x86_64
container_name: doris-be
ports:
- "8040:8040" # BE HTTP port
- "9060:9060" # BE heartbeat port
environment:
- FE_SERVERS=doris-fe:9010
- BE_ADDR=doris-be:9050
volumes:
- doris-be-data:/opt/apache-doris/be/storage
- doris-be-log:/opt/apache-doris/be/log
- ./doris-config/be.conf:/opt/apache-doris/be/conf/be.conf
depends_on:
- doris-fe
networks:
- doris-network
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8040/api/health"]
interval: 30s
timeout: 10s
retries: 5
start_period: 60s
# Redis cache (optional)
redis:
image: redis:7-alpine
container_name: doris-redis
ports:
- "6379:6379"
command: redis-server --appendonly yes --requirepass redis123
volumes:
- redis-data:/data
networks:
- doris-network
restart: unless-stopped
healthcheck:
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
interval: 30s
timeout: 10s
retries: 3
# Prometheus monitoring
prometheus:
image: prom/prometheus:latest
container_name: doris-prometheus
ports:
- "9090:9090"
volumes:
- ./monitoring/prometheus.yml:/etc/prometheus/prometheus.yml
- prometheus-data:/prometheus
command:
- '--config.file=/etc/prometheus/prometheus.yml'
- '--storage.tsdb.path=/prometheus'
- '--web.console.libraries=/etc/prometheus/console_libraries'
- '--web.console.templates=/etc/prometheus/consoles'
- '--storage.tsdb.retention.time=200h'
- '--web.enable-lifecycle'
networks:
- doris-network
restart: unless-stopped
# Grafana visualization
grafana:
image: grafana/grafana:latest
container_name: doris-grafana
ports:
- "3000:3000"
environment:
- GF_SECURITY_ADMIN_PASSWORD=admin123
volumes:
- grafana-data:/var/lib/grafana
- ./monitoring/grafana/dashboards:/etc/grafana/provisioning/dashboards
- ./monitoring/grafana/datasources:/etc/grafana/provisioning/datasources
depends_on:
- prometheus
networks:
- doris-network
restart: unless-stopped
# Nginx load balancer
nginx:
image: nginx:alpine
container_name: doris-nginx
ports:
- "80:80"
- "443:443"
volumes:
- ./nginx/nginx.conf:/etc/nginx/nginx.conf
- ./nginx/ssl:/etc/nginx/ssl
- ./nginx/logs:/var/log/nginx
depends_on:
- doris-mcp-server
networks:
- doris-network
restart: unless-stopped
volumes:
doris-fe-data:
driver: local
doris-fe-log:
driver: local
doris-be-data:
driver: local
doris-be-log:
driver: local
redis-data:
driver: local
prometheus-data:
driver: local
grafana-data:
driver: local
networks:
doris-network:
driver: bridge
ipam:
config:
- subnet: 172.20.0.0/16

328
doris_mcp_client/README.md Normal file
View File

@@ -0,0 +1,328 @@
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
# Doris Unified MCP Client
This is a unified Doris MCP client that supports both **stdio** and **Streamable HTTP** transport modes, providing complete MCP protocol support.
## 🚀 Features
-**Dual Mode Support**: Both stdio and HTTP transport methods
-**Complete MCP Support**: Resources, Tools, and Prompts primitives
-**Unified API**: Same interface for different transport modes
-**Asynchronous Design**: High-performance async client based on asyncio
-**Enterprise Features**: Connection pooling, error handling, logging
-**Convenience Methods**: High-level wrappers for common database operations
## 📦 Install Dependencies
```bash
pip install mcp
```
## 🎯 Quick Start
### 1. stdio Mode
```python
import asyncio
from client import create_stdio_client
async def main():
# Create stdio client
client = await create_stdio_client(
"python",
["-m", "doris_mcp_server.main", "--transport", "stdio"]
)
async def test_client(client):
# Get database list
db_result = await client.get_database_list()
print(f"Databases: {db_result}")
# Execute SQL query
query_result = await client.execute_sql("SELECT 1 as test")
print(f"Query result: {query_result}")
await client.connect_and_run(test_client)
asyncio.run(main())
```
### 2. HTTP Mode
```python
import asyncio
from unified_client import create_http_client
async def main():
# Create HTTP client
client = await create_http_client("http://localhost:3000/mcp")
async def test_client(client):
# Get all tools
tools = await client.list_all_tools()
print(f"Available tools: {len(tools)}")
# Execute query
result = await client.execute_sql(
"SELECT COUNT(*) FROM internal.ssb.lineorder LIMIT 1"
)
print(f"Query result: {result}")
await client.connect_and_run(test_client)
asyncio.run(main())
```
## 🔧 API Reference
### Client Creation
```python
# stdio mode
client = await create_stdio_client(command, args)
# HTTP mode
client = await create_http_client(server_url, timeout=60)
```
### Basic Operations
```python
async def test_client(client):
# Get server capabilities
tools = await client.list_all_tools()
resources = await client.list_all_resources()
prompts = await client.list_all_prompts()
# Call tool
result = await client.call_tool("tool_name", {"param": "value"})
# Read resource
content = await client.read_resource("resource://uri")
# Get prompt
prompt = await client.get_prompt("prompt_name", {"param": "value"})
```
### Advanced Database Operations
```python
async def database_operations(client):
# Execute SQL query
result = await client.execute_sql("SELECT * FROM table LIMIT 10")
# Get database list
databases = await client.get_database_list()
# Get table schema
schema = await client.get_table_schema("table_name", "db_name")
```
## 🧪 Testing
### Run Test Suite
```bash
# Interactive testing
python test_unified_client.py
# Test stdio mode
python test_unified_client.py stdio
# Test HTTP mode
python test_unified_client.py http
# Test both modes
python test_unified_client.py both
# Performance benchmark
python test_unified_client.py benchmark
```
### Test Output Example
```
🎯 Doris Unified Client Test Suite
============================================================
🚀 Testing HTTP Mode
==================================================
📋 Getting server capabilities...
✅ Found 11 tools
✅ Found 0 resources
✅ Found 0 prompts
🔧 Available tools:
1. get_db_list: Get database list
2. get_table_list: Get table list for specified database
3. get_table_schema: Get table structure information
4. exec_query: Execute SQL query
...
🧪 Testing basic functionality...
1⃣ Getting database list...
✅ Success: 3 databases
2⃣ Executing simple query...
✅ Query successful
3⃣ Executing SSB data query...
✅ SSB query successful
4⃣ Getting table structure...
✅ Table structure retrieved successfully
✅ HTTP mode testing completed!
```
## 🏗️ Architecture Design
### Unified Client Architecture
```
DorisUnifiedClient
├── DorisResourceClient # Resource management
├── DorisToolsClient # Tool invocation
├── DorisPromptClient # Prompt management
└── Transport Layer
├── stdio mode # Standard input/output
└── HTTP mode # Streamable HTTP
```
### Key Features
1. **Unified Interface**: Same API for different transport modes
2. **Async Context**: Proper resource management and connection cleanup
3. **Error Handling**: Comprehensive exception handling and error recovery
4. **Performance Optimization**: Connection reuse and request caching
## 📚 Usage Examples
### Complete Example
```python
import asyncio
from client import DorisUnifiedClient, DorisClientConfig
async def comprehensive_example():
# Create configuration
config = DorisClientConfig.stdio(
"python",
["-m", "doris_mcp_server.main"]
)
client = DorisUnifiedClient(config)
async def demo_operations(client):
print("🔍 Discovering server capabilities...")
# List all available tools
tools = await client.list_all_tools()
print(f"Available tools: {[tool.name for tool in tools]}")
# Get database list
print("\n📊 Getting database information...")
db_result = await client.get_database_list()
print(f"Databases: {db_result}")
# Execute queries
print("\n🔍 Executing queries...")
# Simple query
result1 = await client.execute_sql("SELECT 1 as test_column")
print(f"Simple query result: {result1}")
# Get table schema
schema_result = await client.get_table_schema("lineorder", "ssb")
print(f"Table schema: {schema_result}")
await client.connect_and_run(demo_operations)
# Run the example
asyncio.run(comprehensive_example())
```
### Error Handling
```python
async def error_handling_example(client):
try:
# This might fail
result = await client.execute_sql("INVALID SQL")
except Exception as e:
print(f"SQL execution failed: {e}")
# Check result status
result = await client.get_database_list()
if result.get("success", True):
print("Operation successful")
else:
print(f"Operation failed: {result.get('error')}")
```
## 🔧 Configuration
### Client Configuration Options
```python
# stdio mode with custom arguments
config = DorisClientConfig(
transport="stdio",
server_command="python",
server_args=["-m", "doris_mcp_server.main", "--debug"],
timeout=30
)
# HTTP mode with custom timeout
config = DorisClientConfig(
transport="http",
server_url="http://localhost:8080/mcp",
timeout=60
)
```
### Environment Variables
```bash
# Set default server URL
export DORIS_MCP_SERVER_URL="http://localhost:8080"
# Set default timeout
export DORIS_MCP_TIMEOUT=60
# Enable debug logging
export DORIS_MCP_DEBUG=true
```
## 🚀 Performance Tips
1. **Connection Reuse**: Use the same client instance for multiple operations
2. **Batch Operations**: Group related queries together
3. **Async Context**: Always use proper async context management
4. **Error Recovery**: Implement retry logic for transient failures
## 🤝 Contributing
1. Fork the repository
2. Create a feature branch
3. Make your changes
4. Add tests
5. Submit a pull request
## 📄 License
This project is licensed under the Apache 2.0 License.

View File

@@ -0,0 +1,25 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Doris MCP Client Package
Unified MCP client supporting both stdio and HTTP transport modes
"""
from .client import DorisUnifiedClient, DorisClientConfig
__all__ = ["DorisUnifiedClient", "DorisClientConfig"]

509
doris_mcp_client/client.py Normal file
View File

@@ -0,0 +1,509 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unified Doris MCP Client - Supports both stdio and Streamable HTTP modes
Combines the correct HTTP implementation from http_client.py and the complete architecture from client.py
Provides complete support for the three major primitives: Resources, Tools, and Prompts
"""
import asyncio
import json
import logging
from typing import Any, Callable
from datetime import timedelta
from mcp.client.session import ClientSession
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
from mcp import StdioServerParameters
from mcp.types import (
Prompt,
Resource,
Tool,
)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DorisClientConfig:
"""Doris client configuration class"""
def __init__(
self,
transport: str = "stdio",
server_command: str | None = None,
server_args: list[str] | None = None,
server_url: str | None = None,
timeout: int = 60,
):
self.transport = transport
self.server_command = server_command
self.server_args = server_args or []
self.server_url = server_url
self.timeout = timeout
@classmethod
def stdio(cls, command: str, args: list[str] = None) -> "DorisClientConfig":
"""Create stdio connection configuration"""
return cls(
transport="stdio",
server_command=command,
server_args=args or []
)
@classmethod
def http(cls, url: str, timeout: int = 60) -> "DorisClientConfig":
"""Create HTTP connection configuration"""
return cls(
transport="http",
server_url=url,
timeout=timeout
)
class DorisResourceClient:
"""Doris resource client - Handles Resources related operations"""
def __init__(self, session: ClientSession):
self.session = session
self.logger = logging.getLogger(f"{__name__}.DorisResourceClient")
self._resources_cache = None
async def list_resources(self) -> list[Resource]:
"""Get list of all available resources"""
try:
self.logger.info("Getting resource list")
response = await self.session.list_resources()
resources = response.resources if hasattr(response, "resources") else []
self._resources_cache = resources
self.logger.info(f"Retrieved {len(resources)} resources")
return resources
except Exception as e:
self.logger.error(f"Failed to get resource list: {e}")
return []
async def read_resource(self, uri: str) -> str | None:
"""Read specified resource content"""
try:
self.logger.info(f"Reading resource: {uri}")
response = await self.session.read_resource(uri)
if hasattr(response, "contents") and response.contents:
# Merge all content
content_parts = []
for content in response.contents:
if hasattr(content, "text"):
content_parts.append(content.text)
content = "\n".join(content_parts)
self.logger.info(f"Successfully read resource content: {len(content)} characters")
return content
elif hasattr(response, "content"):
return str(response.content)
else:
self.logger.warning(f"Resource {uri} returned no content")
return None
except Exception as e:
self.logger.error(f"Failed to read resource {uri}: {e}")
return None
async def filter_resources_by_type(self, resource_type: str) -> list[Resource]:
"""Filter resources by type"""
if not self._resources_cache:
await self.list_resources()
if resource_type == "table":
return [r for r in self._resources_cache if "table" in r.uri]
elif resource_type == "view":
return [r for r in self._resources_cache if "view" in r.uri]
elif resource_type == "database":
return [
r for r in self._resources_cache
if "database" in r.uri and "table" not in r.uri
]
else:
return self._resources_cache
class DorisToolsClient:
"""Doris tools client - Handles Tools related operations"""
def __init__(self, session: ClientSession):
self.session = session
self.logger = logging.getLogger(f"{__name__}.DorisToolsClient")
self._tools_cache = None
async def list_tools(self) -> list[Tool]:
"""Get list of all available tools"""
try:
self.logger.info("Getting tool list")
response = await self.session.list_tools()
tools = response.tools if hasattr(response, "tools") else []
self._tools_cache = tools
self.logger.info(f"Retrieved {len(tools)} tools")
return tools
except Exception as e:
self.logger.error(f"Failed to get tool list: {e}")
return []
async def call_tool(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]:
"""Call specified tool"""
try:
self.logger.info(f"Calling tool: {name}")
self.logger.debug(f"Tool arguments: {arguments}")
response = await self.session.call_tool(name, arguments)
if hasattr(response, "content") and response.content:
# Parse response content
result_text = ""
for content in response.content:
if hasattr(content, "text"):
result_text += content.text
# Try to parse as JSON
try:
result = json.loads(result_text)
self.logger.info(f"Tool call successful: {name}")
return result
except json.JSONDecodeError:
# If not JSON format, return text directly
return {"success": True, "data": result_text}
self.logger.warning(f"Tool {name} returned no content")
return {"success": False, "error": "No response content"}
except Exception as e:
self.logger.error(f"Tool call failed {name}: {e}")
return {"success": False, "error": str(e)}
async def get_tool_by_name(self, name: str) -> Tool | None:
"""Get tool definition by name"""
if not self._tools_cache:
await self.list_tools()
for tool in self._tools_cache:
if tool.name == name:
return tool
return None
async def get_tools_by_category(self, category: str) -> list[Tool]:
"""Filter tools by category"""
if not self._tools_cache:
await self.list_tools()
category_lower = category.lower()
return [
tool for tool in self._tools_cache
if category_lower in tool.description.lower()
or category_lower in tool.name.lower()
]
class DorisPromptClient:
"""Doris prompt client - Handles Prompts related operations"""
def __init__(self, session: ClientSession):
self.session = session
self.logger = logging.getLogger(f"{__name__}.DorisPromptClient")
self._prompts_cache = None
async def list_prompts(self) -> list[Prompt]:
"""Get list of all available prompts"""
try:
self.logger.info("Getting prompt list")
response = await self.session.list_prompts()
prompts = response.prompts if hasattr(response, "prompts") else []
self._prompts_cache = prompts
self.logger.info(f"Retrieved {len(prompts)} prompts")
return prompts
except Exception as e:
self.logger.error(f"Failed to get prompt list: {e}")
return []
async def get_prompt(self, name: str, arguments: dict[str, Any]) -> str | None:
"""Get specified prompt content"""
try:
self.logger.info(f"Getting prompt: {name}")
self.logger.debug(f"Prompt arguments: {arguments}")
response = await self.session.get_prompt(name, arguments)
if hasattr(response, "messages") and response.messages:
# Merge all message content
content_parts = []
for message in response.messages:
if hasattr(message, "content"):
if hasattr(message.content, "text"):
content_parts.append(message.content.text)
else:
content_parts.append(str(message.content))
content = "\n".join(content_parts)
self.logger.info(f"Successfully retrieved prompt content: {len(content)} characters")
return content
self.logger.warning(f"Prompt {name} returned no content")
return None
except Exception as e:
self.logger.error(f"Failed to get prompt {name}: {e}")
return None
class DorisUnifiedClient:
"""Unified Doris MCP client - Provides complete MCP functionality"""
def __init__(self, config: DorisClientConfig):
self.config = config
self.logger = logging.getLogger(f"{__name__}.DorisUnifiedClient")
self.session = None
self.resources = None
self.tools = None
self.prompts = None
async def connect_and_run(self, callback_func: Callable):
"""Connect to server and execute callback function"""
if self.config.transport == "stdio":
await self._run_stdio_mode(callback_func)
elif self.config.transport == "http":
await self._run_http_mode(callback_func)
else:
raise ValueError(f"Unsupported transport type: {self.config.transport}")
async def _run_stdio_mode(self, callback_func: Callable):
"""Run in stdio mode"""
try:
self.logger.info(f"Starting stdio client: {self.config.server_command}")
server_params = StdioServerParameters(
command=self.config.server_command,
args=self.config.server_args,
)
async with stdio_client(server_params) as (read, write):
async with ClientSession(read, write) as session:
self.session = session
self._init_sub_clients()
# Initialize server
await session.initialize()
self.logger.info("Server initialized successfully")
# Execute callback function
await callback_func(self)
except Exception as e:
self.logger.error(f"stdio mode execution failed: {e}")
raise
async def _run_http_mode(self, callback_func: Callable):
"""Run in HTTP mode"""
try:
self.logger.info(f"Starting HTTP client: {self.config.server_url}")
async with streamablehttp_client(
self.config.server_url,
timeout=timedelta(seconds=self.config.timeout)
) as (read, write):
async with ClientSession(read, write) as session:
self.session = session
self._init_sub_clients()
# Initialize server
await session.initialize()
self.logger.info("Server initialized successfully")
# Execute callback function
await callback_func(self)
except Exception as e:
self.logger.error(f"HTTP mode execution failed: {e}")
raise
def _init_sub_clients(self):
"""Initialize sub-clients"""
self.resources = DorisResourceClient(self.session)
self.tools = DorisToolsClient(self.session)
self.prompts = DorisPromptClient(self.session)
# Convenience methods
async def list_all_resources(self) -> list[Resource]:
"""Get all resources"""
return await self.resources.list_resources()
async def list_all_tools(self) -> list[Tool]:
"""Get all tools"""
return await self.tools.list_tools()
async def list_all_prompts(self) -> list[Prompt]:
"""Get all prompts"""
return await self.prompts.list_prompts()
async def call_tool(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]:
"""Call tool"""
return await self.tools.call_tool(name, arguments)
async def read_resource(self, uri: str) -> str | None:
"""Read resource"""
return await self.resources.read_resource(uri)
async def get_prompt(self, name: str, arguments: dict[str, Any]) -> str | None:
"""Get prompt"""
return await self.prompts.get_prompt(name, arguments)
# Smart tool finding methods
async def _find_tool_by_pattern(self, patterns: list[str]) -> str | None:
"""Find tool by name pattern"""
tools = await self.list_all_tools()
for pattern in patterns:
for tool in tools:
if pattern in tool.name:
return tool.name
return None
async def _find_tool_by_function(self, function_keywords: list[str]) -> str | None:
"""Find tool by function keywords"""
tools = await self.list_all_tools()
for tool in tools:
tool_desc = tool.description.lower()
tool_name = tool.name.lower()
for keyword in function_keywords:
if keyword.lower() in tool_desc or keyword.lower() in tool_name:
return tool.name
return None
# High-level business methods
async def execute_sql(self, sql: str, **kwargs) -> dict[str, Any]:
"""Execute SQL query"""
tool_name = await self._find_tool_by_pattern(["exec_query", "execute", "query"])
if not tool_name:
return {"success": False, "error": "SQL execution tool not found"}
arguments = {"sql": sql, **kwargs}
return await self.call_tool(tool_name, arguments)
async def get_table_schema(self, table_name: str, db_name: str = None, **kwargs) -> dict[str, Any]:
"""Get table schema"""
tool_name = await self._find_tool_by_pattern(["get_table_schema", "table_schema", "schema"])
if not tool_name:
return {"success": False, "error": "Table schema tool not found"}
arguments = {"table_name": table_name}
if db_name:
arguments["db_name"] = db_name
arguments.update(kwargs)
return await self.call_tool(tool_name, arguments)
async def get_database_list(self, **kwargs) -> dict[str, Any]:
"""Get database list"""
tool_name = await self._find_tool_by_pattern(["get_db_list", "database_list", "db_list"])
if not tool_name:
return {"success": False, "error": "Database list tool not found"}
return await self.call_tool(tool_name, kwargs)
async def get_memory_stats(self, tracker_type: str = "overview", include_details: bool = True, **kwargs) -> dict[str, Any]:
"""Get memory statistics"""
tool_name = await self._find_tool_by_pattern(["memory", "realtime_memory"])
if not tool_name:
return {"success": False, "error": "Memory stats tool not found"}
arguments = {"tracker_type": tracker_type, "include_details": include_details}
arguments.update(kwargs)
return await self.call_tool(tool_name, arguments)
async def call_tool_by_function(self, function_description: str, arguments: dict[str, Any]) -> dict[str, Any]:
"""Call tool by function description"""
# Try to find appropriate tool based on function description
function_keywords = function_description.lower().split()
tool_name = await self._find_tool_by_function(function_keywords)
if not tool_name:
return {
"success": False,
"error": f"No tool found for function: {function_description}"
}
return await self.call_tool(tool_name, arguments)
# Convenience factory functions
async def create_stdio_client(command: str, args: list[str] = None) -> DorisUnifiedClient:
"""Create stdio client"""
config = DorisClientConfig.stdio(command, args)
return DorisUnifiedClient(config)
async def create_http_client(server_url: str, timeout: int = 60) -> DorisUnifiedClient:
"""Create HTTP client"""
config = DorisClientConfig.http(server_url, timeout)
return DorisUnifiedClient(config)
# Example usage
async def example_stdio():
"""stdio mode example"""
client = await create_stdio_client("python", ["doris_mcp_server/main.py"])
async def test_client(client: DorisUnifiedClient):
# Get server capabilities
resources = await client.list_all_resources()
tools = await client.list_all_tools()
prompts = await client.list_all_prompts()
print(f"Resources: {len(resources)}")
print(f"Tools: {len(tools)}")
print(f"Prompts: {len(prompts)}")
# Test SQL execution
result = await client.execute_sql("SELECT 1 as test")
print(f"SQL execution result: {result}")
await client.connect_and_run(test_client)
async def example_http():
"""HTTP mode example"""
client = await create_http_client("http://localhost:8080")
async def test_client(client: DorisUnifiedClient):
# Get server capabilities
resources = await client.list_all_resources()
tools = await client.list_all_tools()
print(f"Resources: {len(resources)}")
print(f"Tools: {len(tools)}")
# Test database list
result = await client.get_database_list()
print(f"Database list: {result}")
await client.connect_and_run(test_client)
if __name__ == "__main__":
# Run stdio example
asyncio.run(example_stdio())
# Run HTTP example
# asyncio.run(example_http())

View File

@@ -28,15 +28,183 @@ import json
import logging import logging
from typing import Any from typing import Any
from mcp.server import Server # MCP version compatibility handling
from mcp.server.models import InitializationOptions MCP_VERSION = 'unknown'
Server = None
InitializationOptions = None
Prompt = None
Resource = None
TextContent = None
Tool = None
from mcp.types import ( def _import_mcp_with_compatibility():
Prompt, """Import MCP components with multi-version compatibility"""
Resource, global MCP_VERSION, Server, InitializationOptions, Prompt, Resource, TextContent, Tool
TextContent,
Tool, try:
) # Strategy 1: Try direct server-only imports to avoid client-side issues
from mcp.server import Server as _Server
from mcp.server.models import InitializationOptions as _InitOptions
from mcp.types import (
Prompt as _Prompt,
Resource as _Resource,
TextContent as _TextContent,
Tool as _Tool,
)
# Assign to globals
Server = _Server
InitializationOptions = _InitOptions
Prompt = _Prompt
Resource = _Resource
TextContent = _TextContent
Tool = _Tool
# Try to get version safely
try:
import mcp
MCP_VERSION = getattr(mcp, '__version__', None)
if not MCP_VERSION:
# Fallback: try to get version from package metadata
try:
import importlib.metadata
MCP_VERSION = importlib.metadata.version('mcp')
except Exception:
# Second fallback: try pkg_resources
try:
import pkg_resources
MCP_VERSION = pkg_resources.get_distribution('mcp').version
except Exception:
MCP_VERSION = 'detected-but-version-unknown'
except Exception:
# Version detection failed, but imports worked
try:
import importlib.metadata
MCP_VERSION = importlib.metadata.version('mcp')
except Exception:
try:
import pkg_resources
MCP_VERSION = pkg_resources.get_distribution('mcp').version
except Exception:
MCP_VERSION = 'imported-successfully'
logger = logging.getLogger(__name__)
logger.info(f"MCP components imported successfully, version: {MCP_VERSION}")
return True
except Exception as import_error:
logger = logging.getLogger(__name__)
# Strategy 2: Handle RequestContext compatibility issues in 1.9.x versions
error_str = str(import_error).lower()
if 'requestcontext' in error_str and 'too few arguments' in error_str:
logger.warning(f"Detected MCP RequestContext compatibility issue: {import_error}")
logger.info("Attempting comprehensive workaround for MCP 1.9.x RequestContext issue...")
try:
# Comprehensive monkey patch approach
import sys
import types
# Create and install mock modules before any MCP imports
if 'mcp.shared.context' not in sys.modules:
mock_context_module = types.ModuleType('mcp.shared.context')
class FlexibleRequestContext:
"""Flexible RequestContext that accepts variable arguments"""
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
def __class_getitem__(cls, params):
# Accept any number of parameters and return cls
return cls
# Add other methods that might be called
def __getattr__(self, name):
return lambda *args, **kwargs: None
mock_context_module.RequestContext = FlexibleRequestContext
sys.modules['mcp.shared.context'] = mock_context_module
# Also patch the typing system to be more permissive
original_check_generic = None
try:
import typing
if hasattr(typing, '_check_generic'):
original_check_generic = typing._check_generic
def permissive_check_generic(cls, params, elen):
# Don't enforce strict parameter count checking
return
typing._check_generic = permissive_check_generic
except Exception:
pass
# Clear any cached imports that might have failed
modules_to_clear = [k for k in sys.modules.keys() if k.startswith('mcp.')]
for module in modules_to_clear:
if module in sys.modules:
del sys.modules[module]
# Now try importing again with the patches in place
from mcp.server import Server as _Server
from mcp.server.models import InitializationOptions as _InitOptions
from mcp.types import (
Prompt as _Prompt,
Resource as _Resource,
TextContent as _TextContent,
Tool as _Tool,
)
# Assign to globals
Server = _Server
InitializationOptions = _InitOptions
Prompt = _Prompt
Resource = _Resource
TextContent = _TextContent
Tool = _Tool
# Try to detect actual version even in compatibility mode
try:
import importlib.metadata
actual_version = importlib.metadata.version('mcp')
MCP_VERSION = f'compatibility-mode-{actual_version}'
except Exception:
try:
import pkg_resources
actual_version = pkg_resources.get_distribution('mcp').version
MCP_VERSION = f'compatibility-mode-{actual_version}'
except Exception:
MCP_VERSION = 'compatibility-mode-1.9.x'
logger.info("MCP 1.9.x compatibility workaround successful!")
# Restore original typing function if we patched it
if original_check_generic:
typing._check_generic = original_check_generic
return True
except Exception as workaround_error:
logger.error(f"MCP compatibility workaround failed: {workaround_error}")
# Restore original typing function if we patched it
if original_check_generic:
try:
import typing
typing._check_generic = original_check_generic
except Exception:
pass
logger.error(f"Failed to import MCP components: {import_error}")
return False
# Perform MCP import with compatibility handling
if not _import_mcp_with_compatibility():
raise ImportError(
"Failed to import MCP components. Please ensure MCP is properly installed. "
"Supported versions: 1.8.x, 1.9.x"
)
from .tools.tools_manager import DorisToolsManager from .tools.tools_manager import DorisToolsManager
from .tools.prompts_manager import DorisPromptsManager from .tools.prompts_manager import DorisPromptsManager
@@ -44,11 +212,14 @@ from .tools.resources_manager import DorisResourcesManager
from .utils.config import DorisConfig from .utils.config import DorisConfig
from .utils.db import DorisConnectionManager from .utils.db import DorisConnectionManager
from .utils.security import DorisSecurityManager from .utils.security import DorisSecurityManager
import os
# Configure logging # Configure logging - will be properly initialized later
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Create a default config instance for getting default values
_default_config = DorisConfig()
class DorisServer: class DorisServer:
"""Apache Doris MCP Server main class""" """Apache Doris MCP Server main class"""
@@ -68,9 +239,52 @@ class DorisServer:
self.tools_manager = DorisToolsManager(self.connection_manager) self.tools_manager = DorisToolsManager(self.connection_manager)
self.prompts_manager = DorisPromptsManager(self.connection_manager) self.prompts_manager = DorisPromptsManager(self.connection_manager)
self.logger = logging.getLogger(f"{__name__}.DorisServer") # Import here to avoid circular imports
from .utils.logger import get_logger
self.logger = get_logger(f"{__name__}.DorisServer")
self._setup_handlers() self._setup_handlers()
def _get_mcp_capabilities(self):
"""Get MCP capabilities with version compatibility"""
try:
# For MCP 1.9.x and newer
from mcp.server.lowlevel.server import NotificationOptions
return self.server.get_capabilities(
notification_options=NotificationOptions(
prompts_changed=True,
resources_changed=True,
tools_changed=True
),
experimental_capabilities={}
)
except TypeError:
try:
# For MCP 1.8.x
from mcp.server.lowlevel.server import NotificationOptions
return self.server.get_capabilities(
notification_options=NotificationOptions(
prompts_changed=True,
resources_changed=True,
tools_changed=True
),
experimental_capabilities={}
)
except Exception as e:
self.logger.warning(f"Could not get capabilities with NotificationOptions: {e}")
# Fallback for older versions
try:
return self.server.get_capabilities()
except Exception as fallback_e:
self.logger.error(f"Failed to get capabilities: {fallback_e}")
# Return minimal capabilities
return {
"resources": {},
"tools": {},
"prompts": {}
}
def _setup_handlers(self): def _setup_handlers(self):
"""Setup MCP protocol handlers""" """Setup MCP protocol handlers"""
@@ -178,8 +392,16 @@ class DorisServer:
await self.connection_manager.initialize() await self.connection_manager.initialize()
self.logger.info("Connection manager initialization completed") self.logger.info("Connection manager initialization completed")
# Start stdio server - using simpler approach # Start stdio server - using compatible import approach
from mcp.server.stdio import stdio_server try:
from mcp.server.stdio import stdio_server
except ImportError:
# Fallback for different MCP versions
try:
from mcp.server import stdio_server
except ImportError as stdio_import_error:
self.logger.error(f"Failed to import stdio_server: {stdio_import_error}")
raise RuntimeError("stdio_server module not available in this MCP version")
self.logger.info("Creating stdio_server transport...") self.logger.info("Creating stdio_server transport...")
@@ -189,22 +411,12 @@ class DorisServer:
read_stream, write_stream = streams read_stream, write_stream = streams
self.logger.info("stdio_server streams created successfully") self.logger.info("stdio_server streams created successfully")
# Create initialization options # Create initialization options with version compatibility
# MCP 1.8.0 requires parameters for get_capabilities capabilities = self._get_mcp_capabilities()
from mcp.server.lowlevel.server import NotificationOptions
capabilities = self.server.get_capabilities(
notification_options=NotificationOptions(
prompts_changed=True,
resources_changed=True,
tools_changed=True
),
experimental_capabilities={}
)
init_options = InitializationOptions( init_options = InitializationOptions(
server_name="doris-mcp-server", server_name="doris-mcp-server",
server_version="1.0.0", server_version=os.getenv("SERVER_VERSION", _default_config.server_version),
capabilities=capabilities, capabilities=capabilities,
) )
self.logger.info("Initialization options created successfully") self.logger.info("Initialization options created successfully")
@@ -237,7 +449,7 @@ class DorisServer:
async def start_http(self, host: str = "localhost", port: int = 3000): async def start_http(self, host: str = os.getenv("SERVER_HOST", _default_config.database.host), port: int = os.getenv("SERVER_PORT", _default_config.server_port)):
"""Start Streamable HTTP transport mode""" """Start Streamable HTTP transport mode"""
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}") self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}")
@@ -251,9 +463,9 @@ class DorisServer:
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.routing import Mount, Route from starlette.routing import Route
from starlette.responses import JSONResponse, Response from starlette.responses import JSONResponse, Response
from starlette.types import Receive, Scope, Send from starlette.types import Scope
# Create session manager # Create session manager
session_manager = StreamableHTTPSessionManager( session_manager = StreamableHTTPSessionManager(
@@ -406,6 +618,11 @@ Transport Modes:
Examples: Examples:
python -m doris_mcp_server --transport stdio python -m doris_mcp_server --transport stdio
python -m doris_mcp_server --transport http --host 0.0.0.0 --port 3000 python -m doris_mcp_server --transport http --host 0.0.0.0 --port 3000
python -m doris_mcp_server --transport stdio --doris-host localhost --doris-port 9030
python -m doris_mcp_server --transport http --doris-user admin --doris-database test_db
# Backward compatibility: --db-* parameters are also supported
python -m doris_mcp_server --transport stdio --db-host localhost --db-port 9030
""" """
) )
@@ -413,51 +630,51 @@ Examples:
"--transport", "--transport",
type=str, type=str,
choices=["stdio", "http"], choices=["stdio", "http"],
default="stdio", default=os.getenv("TRANSPORT", _default_config.transport),
help="Transport protocol type: stdio (local), http (Streamable HTTP)", help=f"Transport protocol type: stdio (local), http (Streamable HTTP) (default: {_default_config.transport})",
) )
parser.add_argument( parser.add_argument(
"--host", "--host",
type=str, type=str,
default="localhost", default=os.getenv("SERVER_HOST", _default_config.database.host),
help="Host address for HTTP mode (default: localhost)", help=f"Host address for HTTP mode (default: {_default_config.database.host})",
) )
parser.add_argument( parser.add_argument(
"--port", type=int, default=3000, help="Port number for HTTP mode (default: 3000)" "--port", type=int, default=os.getenv("SERVER_PORT", _default_config.server_port), help=f"Port number for HTTP mode (default: {_default_config.server_port})"
) )
parser.add_argument( parser.add_argument(
"--db-host", "--doris-host", "--db-host",
type=str, type=str,
default="localhost", default=os.getenv("DORIS_HOST", _default_config.database.host),
help="Doris database host address (default: localhost)", help=f"Doris database host address (default: {_default_config.database.host})",
) )
parser.add_argument( parser.add_argument(
"--db-port", type=int, default=9030, help="Doris database port number (default: 9030)" "--doris-port", "--db-port", type=int, default=os.getenv("DORIS_PORT", _default_config.database.port), help=f"Doris database port number (default: {_default_config.database.port})"
) )
parser.add_argument( parser.add_argument(
"--db-user", type=str, default="root", help="Doris database username (default: root)" "--doris-user", "--db-user", type=str, default=os.getenv("DORIS_USER", _default_config.database.user), help=f"Doris database username (default: {_default_config.database.user})"
) )
parser.add_argument("--db-password", type=str, default="", help="Doris database password") parser.add_argument("--doris-password", "--db-password", type=str, default=os.getenv("DORIS_PASSWORD", ""), help="Doris database password")
parser.add_argument( parser.add_argument(
"--db-database", "--doris-database", "--db-database",
type=str, type=str,
default="information_schema", default=os.getenv("DORIS_DATABASE", _default_config.database.database),
help="Doris database name (default: information_schema)", help=f"Doris database name (default: {_default_config.database.database})",
) )
parser.add_argument( parser.add_argument(
"--log-level", "--log-level",
type=str, type=str,
choices=["DEBUG", "INFO", "WARNING", "ERROR"], choices=["DEBUG", "INFO", "WARNING", "ERROR"],
default="INFO", default=os.getenv("LOG_LEVEL", _default_config.logging.level),
help="Log level (default: INFO)", help=f"Log level (default: {_default_config.logging.level})",
) )
return parser return parser
@@ -468,26 +685,42 @@ async def main():
parser = create_arg_parser() parser = create_arg_parser()
args = parser.parse_args() args = parser.parse_args()
# Set log level
logging.getLogger().setLevel(getattr(logging, args.log_level))
# Create configuration - priority: command line arguments > .env file > default values # Create configuration - priority: command line arguments > .env file > default values
config = DorisConfig.from_env() # First load from .env file and environment variables config = DorisConfig.from_env() # First load from .env file and environment variables
# Command line arguments override configuration (if provided) # Command line arguments override configuration (if provided)
if args.db_host != "localhost": # If not default value, use command line argument # 🔧 FIX: Set transport from command line arguments
config.database.host = args.db_host config.transport = args.transport
if args.db_port != 9030:
config.database.port = args.db_port if args.doris_host != _default_config.database.host: # If not default value, use command line argument
if args.db_user != "root": config.database.host = args.doris_host
config.database.user = args.db_user if args.doris_port != _default_config.database.port:
if args.db_password: # Use password if provided config.database.port = args.doris_port
config.database.password = args.db_password if args.doris_user != _default_config.database.user:
if args.db_database != "information_schema": config.database.user = args.doris_user
config.database.database = args.db_database if args.doris_password: # Use password if provided
if args.log_level != "INFO": config.database.password = args.doris_password
if args.doris_database != _default_config.database.database:
config.database.database = args.doris_database
if args.log_level != _default_config.logging.level:
config.logging.level = args.log_level config.logging.level = args.log_level
# Initialize enhanced logging system
from .utils.config import ConfigManager
config_manager = ConfigManager(config)
config_manager.setup_logging()
# Get logger with proper configuration
from .utils.logger import get_logger, log_system_info
logger = get_logger(__name__)
# Log system information for debugging
log_system_info()
logger.info("Starting Doris MCP Server...")
logger.info(f"Transport: {args.transport}")
logger.info(f"Log Level: {config.logging.level}")
# Create server instance # Create server instance
server = DorisServer(config) server = DorisServer(config)
@@ -517,6 +750,10 @@ async def main():
await server.shutdown() await server.shutdown()
except Exception as shutdown_error: except Exception as shutdown_error:
logger.error(f"Error occurred while shutting down server: {shutdown_error}") logger.error(f"Error occurred while shutting down server: {shutdown_error}")
# Shutdown logging system
from .utils.logger import shutdown_logging
shutdown_logging()
return 0 return 0

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,526 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Apache Doris ADBC Query Tools
High-performance data querying using Apache Arrow Flight SQL protocol
"""
import os
import socket
import time
from datetime import datetime
from typing import Any, Dict, List, Optional
from ..utils.logger import get_logger
from ..utils.db import DorisConnectionManager
logger = get_logger(__name__)
def _convert_numpy_types(obj):
"""Convert numpy types to native Python types for JSON serialization"""
try:
# Import numpy only when needed
import numpy as np
import pandas as pd
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.bool_):
return bool(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, (pd.Timestamp, pd.NaT.__class__)):
return str(obj)
elif pd.isna(obj):
return None
else:
return obj
except ImportError:
# If numpy/pandas not available, return as-is
return obj
def _convert_dataframe_to_json_serializable(df):
"""Convert DataFrame to JSON serializable format"""
try:
import pandas as pd
import numpy as np
# Convert DataFrame to records
records = df.to_dict('records')
# Convert each record's values
converted_records = []
for record in records:
converted_record = {}
for key, value in record.items():
converted_record[key] = _convert_numpy_types(value)
converted_records.append(converted_record)
return converted_records
except ImportError:
# Fallback to basic dict conversion
return df.to_dict('records')
class DorisADBCQueryTools:
"""ADBC Query Tools for high-performance data transfer using Arrow Flight SQL"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
self.adbc_client = None
self.flight_sql_module = None
self.adbc_manager_module = None
async def exec_adbc_query(
self,
sql: str,
max_rows: int | None = None,
timeout: int | None = None,
return_format: str | None = None
) -> Dict[str, Any]:
"""
Execute SQL query using ADBC (Arrow Flight SQL) protocol
Args:
sql: SQL statement to execute
max_rows: Maximum number of rows to return (uses config default if None)
timeout: Query timeout in seconds (uses config default if None)
return_format: Format for returned data ("arrow", "pandas", "dict", uses config default if None)
Returns:
Query results in specified format with metadata
"""
try:
start_time = time.time()
# Use configuration defaults if parameters not specified
adbc_config = self.connection_manager.config.adbc
max_rows = max_rows if max_rows is not None else adbc_config.default_max_rows
timeout = timeout if timeout is not None else adbc_config.default_timeout
return_format = return_format if return_format is not None else adbc_config.default_return_format
# Step 1: Check environment variables and port availability
port_check_result = await self._check_arrow_flight_ports()
if not port_check_result["success"]:
return port_check_result
# Step 2: Import required ADBC modules
import_result = await self._import_adbc_modules()
if not import_result["success"]:
return import_result
# Step 3: Create ADBC connection
connection_result = await self._create_adbc_connection()
if not connection_result["success"]:
return connection_result
# Step 4: Execute query using ADBC
query_result = await self._execute_query_with_adbc(
sql, max_rows, timeout, return_format
)
execution_time = time.time() - start_time
if query_result["success"]:
query_result["execution_time"] = round(execution_time, 3)
query_result["protocol"] = "ADBC_Arrow_Flight_SQL"
query_result["timestamp"] = datetime.now().isoformat()
return query_result
except Exception as e:
logger.error(f"ADBC query execution failed: {str(e)}")
return {
"success": False,
"error": f"ADBC query execution failed: {str(e)}",
"error_type": "execution_error",
"timestamp": datetime.now().isoformat()
}
async def _check_arrow_flight_ports(self) -> Dict[str, Any]:
"""Check Arrow Flight SQL port configuration and availability"""
try:
# Check environment variables
fe_port = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
be_port = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
if not fe_port:
return {
"success": False,
"error": "Missing environment variable FE_ARROW_FLIGHT_SQL_PORT, please configure Arrow Flight SQL FE port in .env file",
"error_type": "missing_fe_port_config"
}
if not be_port:
return {
"success": False,
"error": "Missing environment variable BE_ARROW_FLIGHT_SQL_PORT, please configure Arrow Flight SQL BE port in .env file",
"error_type": "missing_be_port_config"
}
# Convert to integer and validate
try:
fe_port = int(fe_port)
be_port = int(be_port)
except ValueError:
return {
"success": False,
"error": "Invalid Arrow Flight SQL port configuration, please ensure FE_ARROW_FLIGHT_SQL_PORT and BE_ARROW_FLIGHT_SQL_PORT are valid numbers",
"error_type": "invalid_port_format"
}
# Get host address
db_config = self.connection_manager.config.database
fe_host = db_config.host
# Check FE Arrow Flight SQL port availability
fe_available = self._check_port_connectivity(fe_host, fe_port)
if not fe_available:
return {
"success": False,
"error": f"Cannot connect to FE Arrow Flight SQL port {fe_host}:{fe_port}, please check if service is running",
"error_type": "fe_port_unavailable",
"fe_host": fe_host,
"fe_port": fe_port
}
# Get BE host list
be_hosts = await self._get_be_hosts()
if not be_hosts:
return {
"success": False,
"error": "Cannot get BE node information, please check cluster status",
"error_type": "no_be_hosts"
}
# Check at least one BE Arrow Flight SQL port availability
be_available_count = 0
be_check_results = []
for be_host in be_hosts[:3]: # Check first 3 BE nodes
be_available = self._check_port_connectivity(be_host, be_port)
be_check_results.append({
"host": be_host,
"port": be_port,
"available": be_available
})
if be_available:
be_available_count += 1
if be_available_count == 0:
return {
"success": False,
"error": f"Cannot connect to any BE Arrow Flight SQL port (port: {be_port}), please check if BE services are running",
"error_type": "no_be_ports_available",
"be_check_results": be_check_results
}
return {
"success": True,
"fe_host": fe_host,
"fe_port": fe_port,
"be_port": be_port,
"be_hosts": be_hosts,
"be_available_count": be_available_count,
"be_check_results": be_check_results
}
except Exception as e:
logger.error(f"Arrow Flight port check failed: {str(e)}")
return {
"success": False,
"error": f"Arrow Flight port check failed: {str(e)}",
"error_type": "port_check_error"
}
def _check_port_connectivity(self, host: str, port: int, timeout: int | None = None) -> bool:
"""Check port connectivity"""
try:
# Use config timeout if not specified
if timeout is None:
timeout = self.connection_manager.config.adbc.connection_timeout
with socket.create_connection((host, port), timeout=timeout):
return True
except (socket.timeout, socket.error, OSError):
return False
async def _get_be_hosts(self) -> List[str]:
"""Get BE host list"""
try:
db_config = self.connection_manager.config.database
# Use configured BE hosts first
if db_config.be_hosts:
logger.info(f"Using configured BE hosts: {db_config.be_hosts}")
return db_config.be_hosts
# Get BE nodes via SHOW BACKENDS
logger.info("No BE hosts configured, getting BE node information via SHOW BACKENDS")
connection = await self.connection_manager.get_connection("query")
result = await connection.execute("SHOW BACKENDS")
be_hosts = []
for row in result.data:
host = row.get("Host")
alive = row.get("Alive", "").lower()
if host and alive == "true":
be_hosts.append(host)
logger.info(f"Got {len(be_hosts)} active BE nodes from SHOW BACKENDS")
return be_hosts
except Exception as e:
logger.error(f"Failed to get BE hosts: {str(e)}")
return []
async def _import_adbc_modules(self) -> Dict[str, Any]:
"""Import ADBC related modules"""
try:
# Import ADBC Driver Manager
try:
import adbc_driver_manager
self.adbc_manager_module = adbc_driver_manager
except ImportError:
return {
"success": False,
"error": "Missing adbc_driver_manager module, please install: pip install adbc_driver_manager",
"error_type": "missing_adbc_manager"
}
# Import ADBC Flight SQL Driver
try:
import adbc_driver_flightsql.dbapi as flight_sql
self.flight_sql_module = flight_sql
except ImportError:
return {
"success": False,
"error": "Missing adbc_driver_flightsql module, please install: pip install adbc_driver_flightsql",
"error_type": "missing_flight_sql_driver"
}
return {
"success": True,
"adbc_manager_version": getattr(adbc_driver_manager, '__version__', 'unknown'),
"flight_sql_version": getattr(flight_sql, '__version__', 'unknown')
}
except Exception as e:
logger.error(f"ADBC module import failed: {str(e)}")
return {
"success": False,
"error": f"ADBC module import failed: {str(e)}",
"error_type": "import_error"
}
async def _create_adbc_connection(self) -> Dict[str, Any]:
"""Create ADBC connection"""
try:
db_config = self.connection_manager.config.database
fe_port = int(os.getenv("FE_ARROW_FLIGHT_SQL_PORT"))
# Build connection URI
uri = f"grpc://{db_config.host}:{fe_port}"
# Create database connection parameters
db_kwargs = {
self.adbc_manager_module.DatabaseOptions.USERNAME.value: db_config.user,
self.adbc_manager_module.DatabaseOptions.PASSWORD.value: db_config.password,
}
# Create connection
self.adbc_client = self.flight_sql_module.connect(
uri=uri,
db_kwargs=db_kwargs
)
return {
"success": True,
"uri": uri,
"connection_established": True
}
except Exception as e:
logger.error(f"Failed to create ADBC connection: {str(e)}")
return {
"success": False,
"error": f"Failed to create ADBC connection: {str(e)}",
"error_type": "connection_error"
}
async def _execute_query_with_adbc(
self,
sql: str,
max_rows: int,
timeout: int,
return_format: str
) -> Dict[str, Any]:
"""Execute query using ADBC"""
try:
if not self.adbc_client:
return {
"success": False,
"error": "ADBC connection not established",
"error_type": "no_connection"
}
cursor = self.adbc_client.cursor()
start_time = time.time()
# Execute query
cursor.execute(sql)
# Get results based on return format
if return_format == "arrow":
# Return Arrow format
arrow_data = cursor.fetchallarrow()
# Limit rows
if len(arrow_data) > max_rows:
arrow_data = arrow_data.slice(0, max_rows)
# Convert Arrow data to serializable format
preview_df = arrow_data.to_pandas().head(10) if len(arrow_data) > 0 else None
result_data = {
"format": "arrow",
"num_rows": len(arrow_data),
"num_columns": len(arrow_data.schema),
"column_names": arrow_data.schema.names,
"column_types": [str(field.type) for field in arrow_data.schema],
"data_preview": _convert_dataframe_to_json_serializable(preview_df) if preview_df is not None else [],
"total_bytes": arrow_data.nbytes if hasattr(arrow_data, 'nbytes') else 0
}
elif return_format == "pandas":
# Return Pandas DataFrame
df = cursor.fetch_df()
# Limit rows
if len(df) > max_rows:
df = df.head(max_rows)
result_data = {
"format": "pandas",
"num_rows": len(df),
"num_columns": len(df.columns),
"column_names": df.columns.tolist(),
"column_types": df.dtypes.astype(str).tolist(),
"data": _convert_dataframe_to_json_serializable(df),
"memory_usage": int(df.memory_usage(deep=True).sum())
}
else: # return_format == "dict"
# Return dictionary format
arrow_data = cursor.fetchallarrow()
df = arrow_data.to_pandas()
# Limit rows
if len(df) > max_rows:
df = df.head(max_rows)
result_data = {
"format": "dict",
"num_rows": len(df),
"num_columns": len(df.columns),
"column_names": df.columns.tolist(),
"column_types": df.dtypes.astype(str).tolist(),
"data": _convert_dataframe_to_json_serializable(df)
}
execution_time = time.time() - start_time
cursor.close()
return {
"success": True,
"result": result_data,
"execution_time": round(execution_time, 3),
"sql": sql,
"max_rows_applied": len(result_data.get("data", [])) >= max_rows
}
except Exception as e:
logger.error(f"ADBC query execution failed: {str(e)}")
return {
"success": False,
"error": f"ADBC query execution failed: {str(e)}",
"error_type": "query_execution_error",
"sql": sql
}
async def get_adbc_connection_info(self) -> Dict[str, Any]:
"""Get ADBC connection information and status"""
try:
# Check port status
port_status = await self._check_arrow_flight_ports()
# Check module status
module_status = await self._import_adbc_modules()
# Get configuration information
db_config = self.connection_manager.config.database
fe_port = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
be_port = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
connection_info = {
"adbc_available": module_status["success"],
"ports_available": port_status["success"],
"configuration": {
"fe_host": db_config.host,
"fe_arrow_flight_port": fe_port,
"be_arrow_flight_port": be_port,
"user": db_config.user
},
"port_status": port_status,
"module_status": module_status,
"timestamp": datetime.now().isoformat()
}
if port_status["success"] and module_status["success"]:
connection_info["status"] = "ready"
connection_info["message"] = "ADBC Arrow Flight SQL connection ready"
else:
connection_info["status"] = "not_ready"
errors = []
if not port_status["success"]:
errors.append(port_status["error"])
if not module_status["success"]:
errors.append(module_status["error"])
connection_info["message"] = "; ".join(errors)
return connection_info
except Exception as e:
logger.error(f"Failed to get ADBC connection information: {str(e)}")
return {
"status": "error",
"error": f"Failed to get ADBC connection information: {str(e)}",
"timestamp": datetime.now().isoformat()
}
def __del__(self):
"""Cleanup resources"""
try:
if self.adbc_client:
self.adbc_client.close()
except:
pass

View File

@@ -22,6 +22,10 @@ Provides data analysis functions including table analysis, column statistics, pe
import time import time
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List from typing import Any, Dict, List
import uuid
import aiohttp
import hashlib
from pathlib import Path
from .db import DorisConnectionManager from .db import DorisConnectionManager
from .logger import get_logger from .logger import get_logger
@@ -331,4 +335,906 @@ class PerformanceMonitor:
"limit": limit, "limit": limit,
"order_by": order_by, "order_by": order_by,
"note": "Query history feature requires audit log configuration" "note": "Query history feature requires audit log configuration"
} }
class SQLAnalyzer:
"""SQL analyzer for EXPLAIN and PROFILE operations"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
async def get_sql_explain(
self,
sql: str,
verbose: bool = False,
db_name: str = None,
catalog_name: str = None
) -> Dict[str, Any]:
"""
Get SQL execution plan using EXPLAIN command based on Doris syntax
Args:
sql: SQL statement to explain
verbose: Whether to show verbose information
db_name: Target database name
catalog_name: Target catalog name
Returns:
Dict containing explain plan file path, content, and basic info
"""
try:
# Generate unique query ID for file naming
import time
query_hash = hashlib.md5(sql.encode()).hexdigest()[:8]
timestamp = int(time.time())
query_id = f"{timestamp}_{query_hash}"
# Ensure temp directory exists
temp_dir = Path(self.connection_manager.config.temp_files_dir)
temp_dir.mkdir(parents=True, exist_ok=True)
# Create explain file path
explain_file = temp_dir / f"explain_{query_id}.txt"
logger.info(f"Generating SQL explain for query ID: {query_id}")
# Switch database if specified
if db_name:
await self.connection_manager.execute_query("explain_session", f"USE {db_name}")
# Construct EXPLAIN query
explain_type = "EXPLAIN VERBOSE" if verbose else "EXPLAIN"
explain_sql = f"{explain_type} {sql.strip().rstrip(';')}"
logger.info(f"Executing explain query: {explain_sql}")
# Execute explain query
result = await self.connection_manager.execute_query("explain_session", explain_sql)
# Format explain output
explain_content = []
explain_content.append(f"=== SQL EXPLAIN PLAN ===")
explain_content.append(f"Query ID: {query_id}")
explain_content.append(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}")
explain_content.append(f"Database: {db_name or 'current'}")
explain_content.append(f"Verbose: {verbose}")
explain_content.append("")
explain_content.append("=== ORIGINAL SQL ===")
explain_content.append(sql)
explain_content.append("")
explain_content.append("=== EXPLAIN QUERY ===")
explain_content.append(explain_sql)
explain_content.append("")
explain_content.append("=== EXECUTION PLAN ===")
if result and result.data:
for row in result.data:
if isinstance(row, dict):
# Handle dict format
for key, value in row.items():
explain_content.append(f"{key}: {value}")
elif isinstance(row, (list, tuple)):
# Handle tuple/list format
explain_content.append(" | ".join(str(col) for col in row))
else:
# Handle string format
explain_content.append(str(row))
else:
explain_content.append("No execution plan data returned")
explain_content.append("")
explain_content.append("=== METADATA ===")
explain_content.append(f"Execution time: {result.execution_time if result else 'N/A'} seconds")
explain_content.append(f"Rows returned: {len(result.data) if result and result.data else 0}")
# Get full content
full_content = '\n'.join(explain_content)
# Write to file
with open(explain_file, 'w', encoding='utf-8') as f:
f.write(full_content)
logger.info(f"Explain plan saved to: {explain_file.absolute()}")
# Get max response size from config
max_size = self.connection_manager.config.performance.max_response_content_size
# Truncate content if needed
truncated_content = full_content
is_truncated = False
if len(full_content) > max_size:
truncated_content = full_content[:max_size] + "\n\n=== CONTENT TRUNCATED ===\n[Content is truncated due to size limit. Full content is saved to file.]"
is_truncated = True
return {
"success": True,
"query_id": query_id,
"explain_file_path": str(explain_file.absolute()),
"file_size_bytes": explain_file.stat().st_size,
"content": truncated_content,
"content_size": len(truncated_content),
"is_content_truncated": is_truncated,
"original_content_size": len(full_content),
"sql_preview": sql[:100] + "..." if len(sql) > 100 else sql,
"verbose": verbose,
"database": db_name,
"catalog": catalog_name,
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),
"execution_time": result.execution_time if result else None,
"plan_lines_count": len(result.data) if result and result.data else 0
}
except Exception as e:
logger.error(f"Failed to get SQL explain: {str(e)}")
return {
"success": False,
"error": f"Failed to get SQL explain: {str(e)}",
"sql_preview": sql[:100] + "..." if len(sql) > 100 else sql,
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S')
}
async def get_sql_profile(
self,
sql: str,
db_name: str = None,
catalog_name: str = None,
timeout: int = 30
) -> Dict[str, Any]:
"""
Get SQL execution profile by setting trace ID and fetching profile via HTTP API
Args:
sql: SQL statement to profile
db_name: Target database name
catalog_name: Target catalog name
timeout: Query timeout in seconds
Returns:
Dict containing profile file path, content, and basic info
"""
try:
# Generate unique trace ID and query ID for file naming
trace_id = str(uuid.uuid4())
import time
query_hash = hashlib.md5(sql.encode()).hexdigest()[:8]
timestamp = int(time.time())
file_query_id = f"{timestamp}_{query_hash}"
# Ensure temp directory exists
temp_dir = Path(self.connection_manager.config.temp_files_dir)
temp_dir.mkdir(parents=True, exist_ok=True)
# Create profile file path
profile_file = temp_dir / f"profile_{file_query_id}.txt"
logger.info(f"Generated trace ID for SQL profiling: {trace_id}")
logger.info(f"Profile will be saved to: {profile_file}")
connection = await self.connection_manager.get_connection("query")
try:
# Switch to specified database/catalog if provided
if catalog_name:
await connection.execute(f"USE `{catalog_name}`")
if db_name:
await connection.execute(f"USE `{db_name}`")
# Set trace ID for the session using session variable
# According to official docs: set session_context="trace_id:your_trace_id"
await connection.execute(f'set session_context="trace_id:{trace_id}"')
logger.info(f"Set trace ID: {trace_id}")
# Enable profile
await connection.execute(f'set enable_profile=true')
logger.info(f"Enabled profile")
# Execute the SQL statement
logger.info(f"Executing SQL with trace ID: {sql}")
start_time = time.time()
sql_result = await connection.execute(sql)
execution_time = time.time() - start_time
logger.info(f"SQL execution completed in {execution_time:.3f}s")
# Get query ID from trace ID via HTTP API
query_id = await self._get_query_id_by_trace_id(trace_id)
if not query_id:
return {
"success": False,
"error": "Failed to get query ID from trace ID",
"trace_id": trace_id,
"sql": sql,
"execution_time": execution_time
}
logger.info(f"Retrieved query ID: {query_id}")
# Get profile data
profile_data = await self._get_profile_by_query_id(query_id)
if not profile_data:
# Save error info to file
profile_content = [
f"=== SQL PROFILE RESULT ===",
f"File Query ID: {file_query_id}",
f"Trace ID: {trace_id}",
f"Query ID: {query_id}",
f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}",
f"Database: {db_name or 'current'}",
f"Status: FAILED",
"",
"=== ORIGINAL SQL ===",
sql,
"",
"=== ERROR INFO ===",
"Failed to get profile data. This may be due to:",
"1) Profile data not generated yet",
"2) Query ID expired",
"3) Insufficient permissions to access profile data",
"",
"=== EXECUTION INFO ===",
f"Query execution: SUCCESSFUL",
f"Execution time: {execution_time:.3f} seconds",
f"Note: Query execution was successful, but profile data is not available"
]
# Get full content
full_profile_content = '\n'.join(profile_content)
with open(profile_file, 'w', encoding='utf-8') as f:
f.write(full_profile_content)
# Get max response size from config
max_size = self.connection_manager.config.performance.max_response_content_size
# Truncate content if needed
truncated_content = full_profile_content
is_truncated = False
if len(full_profile_content) > max_size:
truncated_content = full_profile_content[:max_size] + "\n\n=== CONTENT TRUNCATED ===\n[Content is truncated due to size limit. Full content is saved to file.]"
is_truncated = True
return {
"success": False,
"file_query_id": file_query_id,
"trace_id": trace_id,
"query_id": query_id,
"profile_file_path": str(profile_file.absolute()),
"file_size_bytes": profile_file.stat().st_size,
"content": truncated_content,
"content_size": len(truncated_content),
"is_content_truncated": is_truncated,
"original_content_size": len(full_profile_content),
"sql_preview": sql[:100] + "..." if len(sql) > 100 else sql,
"execution_time": execution_time,
"error": "Failed to get profile data",
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S')
}
# Format profile output
profile_content = []
profile_content.append(f"=== SQL PROFILE RESULT ===")
profile_content.append(f"File Query ID: {file_query_id}")
profile_content.append(f"Trace ID: {trace_id}")
profile_content.append(f"Query ID: {query_id}")
profile_content.append(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}")
profile_content.append(f"Database: {db_name or 'current'}")
profile_content.append(f"Status: SUCCESS")
profile_content.append("")
profile_content.append("=== ORIGINAL SQL ===")
profile_content.append(sql)
profile_content.append("")
profile_content.append("=== EXECUTION INFO ===")
profile_content.append(f"Execution time: {execution_time:.3f} seconds")
if hasattr(sql_result, 'data') and sql_result.data:
profile_content.append(f"Result rows: {len(sql_result.data)}")
if sql_result.data and sql_result.data[0]:
profile_content.append(f"Result columns: {list(sql_result.data[0].keys())}")
profile_content.append("")
profile_content.append("=== PROFILE DATA ===")
if isinstance(profile_data, dict):
import json
profile_content.append(json.dumps(profile_data, indent=2, ensure_ascii=False))
else:
profile_content.append(str(profile_data))
# Get full content
full_profile_content = '\n'.join(profile_content)
# Write to file
with open(profile_file, 'w', encoding='utf-8') as f:
f.write(full_profile_content)
logger.info(f"Profile data saved to: {profile_file.absolute()}")
# Get max response size from config
max_size = self.connection_manager.config.performance.max_response_content_size
# Truncate content if needed
truncated_content = full_profile_content
is_truncated = False
if len(full_profile_content) > max_size:
truncated_content = full_profile_content[:max_size] + "\n\n=== CONTENT TRUNCATED ===\n[Content is truncated due to size limit. Full content is saved to file.]"
is_truncated = True
return {
"success": True,
"file_query_id": file_query_id,
"trace_id": trace_id,
"query_id": query_id,
"profile_file_path": str(profile_file.absolute()),
"file_size_bytes": profile_file.stat().st_size,
"content": truncated_content,
"content_size": len(truncated_content),
"is_content_truncated": is_truncated,
"original_content_size": len(full_profile_content),
"sql_preview": sql[:100] + "..." if len(sql) > 100 else sql,
"database": db_name,
"catalog": catalog_name,
"execution_time": execution_time,
"sql_result_summary": {
"row_count": len(sql_result.data) if hasattr(sql_result, 'data') and sql_result.data else 0,
"columns": list(sql_result.data[0].keys()) if hasattr(sql_result, 'data') and sql_result.data and sql_result.data[0] else []
},
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S')
}
except Exception as e:
logger.error(f"Error during SQL execution or profile retrieval: {str(e)}")
# Save error info to file
profile_content = [
f"=== SQL PROFILE RESULT ===",
f"File Query ID: {file_query_id}",
f"Trace ID: {trace_id}",
f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}",
f"Database: {db_name or 'current'}",
f"Status: ERROR",
"",
"=== ORIGINAL SQL ===",
sql,
"",
"=== ERROR INFO ===",
f"SQL execution or profile retrieval failed: {str(e)}",
"",
"=== EXECUTION INFO ===",
"Query execution failed during profiling process"
]
# Get full content
full_profile_content = '\n'.join(profile_content)
with open(profile_file, 'w', encoding='utf-8') as f:
f.write(full_profile_content)
# Get max response size from config
max_size = self.connection_manager.config.performance.max_response_content_size
# Truncate content if needed
truncated_content = full_profile_content
is_truncated = False
if len(full_profile_content) > max_size:
truncated_content = full_profile_content[:max_size] + "\n\n=== CONTENT TRUNCATED ===\n[Content is truncated due to size limit. Full content is saved to file.]"
is_truncated = True
return {
"success": False,
"file_query_id": file_query_id,
"trace_id": trace_id,
"profile_file_path": str(profile_file.absolute()),
"file_size_bytes": profile_file.stat().st_size,
"content": truncated_content,
"content_size": len(truncated_content),
"is_content_truncated": is_truncated,
"original_content_size": len(full_profile_content),
"sql_preview": sql[:100] + "..." if len(sql) > 100 else sql,
"error": f"SQL execution or profile retrieval failed: {str(e)}",
"database": db_name,
"catalog": catalog_name,
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S')
}
except Exception as e:
logger.error(f"SQL PROFILE failed: {str(e)}")
return {
"success": False,
"error": f"SQL PROFILE failed: {str(e)}",
"sql_preview": sql[:100] + "..." if len(sql) > 100 else sql,
"database": db_name,
"catalog": catalog_name,
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S')
}
async def _get_query_id_by_trace_id(self, trace_id: str) -> str:
"""
Get query ID by trace ID via FE HTTP API
Args:
trace_id: The trace ID set during query execution
Returns:
Query ID string or None if not found
"""
try:
# Get database config
db_config = self.connection_manager.config.database
# Build HTTP API URL according to official documentation
# Reference: https://doris.apache.org/zh-CN/docs/admin-manual/open-api/fe-http/query-profile-action#通过-trace-id-获取-query-id
url = f"http://{db_config.host}:{db_config.fe_http_port}/rest/v2/manager/query/trace_id/{trace_id}"
# HTTP Basic Auth
auth = aiohttp.BasicAuth(db_config.user, db_config.password)
logger.info(f"Requesting query ID from: {url}")
async with aiohttp.ClientSession() as session:
async with session.get(url, auth=auth, timeout=10) as response:
if response.status == 200:
# Check content type first
content_type = response.headers.get('content-type', '')
response_text = await response.text()
logger.info(f"Response content type: {content_type}")
logger.info(f"Response body: {response_text}")
# Parse JSON response (regardless of content-type)
if response_text.strip():
try:
import json
result = json.loads(response_text)
logger.info(f"Query ID API response: {result}")
# Parse response according to Doris API format
if result.get("code") == 0 and result.get("data"):
data = result["data"]
# Data can be either a string (query_id) or object with query_ids
if isinstance(data, str):
logger.info(f"Found query ID: {data}")
return data
elif isinstance(data, dict) and "query_ids" in data:
query_ids = data["query_ids"]
if query_ids:
query_id = query_ids[0] # Take the first query ID
logger.info(f"Found query ID: {query_id}")
return query_id
else:
logger.warning("No query IDs found in response")
else:
logger.error(f"API returned error: {result}")
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON response: {e}")
# Fallback: try to extract query ID using regex
import re
query_id_pattern = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
matches = re.findall(query_id_pattern, response_text)
if matches:
query_id = matches[0]
logger.info(f"Extracted query ID from text: {query_id}")
return query_id
else:
logger.error(f"HTTP request failed with status {response.status}")
response_text = await response.text()
logger.error(f"Response body: {response_text}")
return None
except Exception as e:
logger.error(f"Failed to get query ID by trace ID: {str(e)}")
return None
async def _get_profile_by_query_id(self, query_id: str) -> Dict[str, Any]:
"""
Get profile data by query ID via FE HTTP API
Args:
query_id: The query ID
Returns:
Profile data dict or None if failed
"""
try:
# Get database config
db_config = self.connection_manager.config.database
# Try both API endpoints according to official documentation
urls = [
f"http://{db_config.host}:{db_config.fe_http_port}/rest/v2/manager/query/profile/text/{query_id}",
f"http://{db_config.host}:{db_config.fe_http_port}/api/profile/text?query_id={query_id}"
]
# HTTP Basic Auth
auth = aiohttp.BasicAuth(db_config.user, db_config.password)
for i, url in enumerate(urls):
logger.info(f"Requesting profile from URL {i+1}: {url}")
async with aiohttp.ClientSession() as session:
async with session.get(url, auth=auth, timeout=60) as response:
if response.status == 200:
content_type = response.headers.get('content-type', '')
response_text = await response.text()
logger.info(f"Profile response content type: {content_type}")
logger.info(f"Profile response length: {len(response_text)}")
# Handle JSON response
if 'application/json' in content_type:
try:
result = await response.json()
logger.info(f"Profile JSON response: {result}")
if result.get("code") == 0 and result.get("data"):
profile_text = result["data"].get("profile", "")
return {
"query_id": query_id,
"profile_text": profile_text,
"profile_size": len(profile_text),
"retrieved_at": datetime.now().isoformat(),
"api_endpoint": url
}
else:
logger.warning(f"Profile API returned error: {result}")
continue # Try next URL
except Exception as e:
logger.error(f"Failed to parse profile JSON: {e}")
continue
# Handle plain text response
else:
if response_text.strip() and "not found" not in response_text.lower():
return {
"query_id": query_id,
"profile_text": response_text,
"profile_size": len(response_text),
"retrieved_at": datetime.now().isoformat(),
"api_endpoint": url
}
else:
logger.warning(f"Profile not found or empty: {response_text}")
continue # Try next URL
elif response.status == 404:
logger.warning(f"Profile not found (404) at {url}")
continue # Try next URL
else:
logger.error(f"Profile HTTP request failed with status {response.status} at {url}")
response_text = await response.text()
logger.error(f"Response body: {response_text}")
continue # Try next URL
return None
except Exception as e:
logger.error(f"Failed to get profile by query ID: {str(e)}")
return None
async def get_table_data_size(
self,
db_name: str = None,
table_name: str = None,
single_replica: bool = False
) -> Dict[str, Any]:
"""
Get table data size information via FE HTTP API
Args:
db_name: Database name, if not specified returns all databases
table_name: Table name, if not specified returns all tables in the database
single_replica: Whether to get single replica data size
Returns:
Dict containing table data size information
"""
try:
# Get database config
db_config = self.connection_manager.config.database
# Build HTTP API URL according to official documentation
# Reference: https://doris.apache.org/zh-CN/docs/admin-manual/open-api/fe-http/show-table-data-action
url = f"http://{db_config.host}:{db_config.fe_http_port}/api/show_table_data"
# Build query parameters
params = {}
if db_name:
params["db"] = db_name
if table_name:
params["table"] = table_name
if single_replica:
params["single_replica"] = "true"
# HTTP Basic Auth
auth = aiohttp.BasicAuth(db_config.user, db_config.password)
logger.info(f"Requesting table data size from: {url} with params: {params}")
async with aiohttp.ClientSession() as session:
async with session.get(url, auth=auth, params=params, timeout=30) as response:
if response.status == 200:
response_text = await response.text()
logger.info(f"Table data size response length: {len(response_text)}")
try:
# Parse JSON response
import json
result = json.loads(response_text)
if result.get("code") == 0 and result.get("data"):
data = result["data"]
# Process and format the data
formatted_data = self._format_table_data_size(data, db_name, table_name, single_replica)
return {
"success": True,
"db_name": db_name,
"table_name": table_name,
"single_replica": single_replica,
"timestamp": datetime.now().isoformat(),
"data": formatted_data,
"url": url,
"note": "Table data size information from Doris FE HTTP API"
}
else:
return {
"success": False,
"error": f"API returned error: {result}",
"db_name": db_name,
"table_name": table_name,
"url": url,
"timestamp": datetime.now().isoformat()
}
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON response: {e}")
return {
"success": False,
"error": f"Failed to parse JSON response: {e}",
"response_text": response_text[:500], # First 500 chars for debugging
"url": url,
"timestamp": datetime.now().isoformat()
}
else:
logger.error(f"HTTP request failed with status {response.status}")
response_text = await response.text()
logger.error(f"Response body: {response_text}")
return {
"success": False,
"error": f"HTTP request failed with status {response.status}",
"response_text": response_text[:500], # First 500 chars for debugging
"url": url,
"timestamp": datetime.now().isoformat()
}
except Exception as e:
logger.error(f"Table data size request failed: {str(e)}")
return {
"success": False,
"error": f"Table data size request failed: {str(e)}",
"db_name": db_name,
"table_name": table_name,
"timestamp": datetime.now().isoformat()
}
def _format_table_data_size(self, data: Dict[str, Any], db_name: str, table_name: str, single_replica: bool) -> Dict[str, Any]:
"""
Format table data size response data
Args:
data: Raw response data from API
db_name: Database name filter
table_name: Table name filter
single_replica: Single replica flag
Returns:
Formatted data structure
"""
try:
formatted = {
"summary": {
"total_databases": 0,
"total_tables": 0,
"total_size_bytes": 0,
"total_size_formatted": "0 B",
"single_replica": single_replica,
"query_filters": {
"db_name": db_name,
"table_name": table_name
}
},
"databases": {}
}
# Process the data based on its structure
if isinstance(data, list):
# Data is a list of table records
for record in data:
db = record.get("database", "unknown")
table = record.get("table", "unknown")
size_bytes = int(record.get("size", 0))
if db not in formatted["databases"]:
formatted["databases"][db] = {
"database_name": db,
"table_count": 0,
"total_size_bytes": 0,
"total_size_formatted": "0 B",
"tables": {}
}
formatted["databases"][db]["tables"][table] = {
"table_name": table,
"size_bytes": size_bytes,
"size_formatted": self._format_bytes(size_bytes),
"replica_count": record.get("replica_count", 1),
"details": record
}
formatted["databases"][db]["table_count"] += 1
formatted["databases"][db]["total_size_bytes"] += size_bytes
formatted["summary"]["total_size_bytes"] += size_bytes
elif isinstance(data, dict):
# Data is a dict with database structure
for db, db_info in data.items():
if isinstance(db_info, dict) and "tables" in db_info:
formatted["databases"][db] = {
"database_name": db,
"table_count": len(db_info["tables"]),
"total_size_bytes": 0,
"total_size_formatted": "0 B",
"tables": {}
}
for table, table_info in db_info["tables"].items():
size_bytes = int(table_info.get("size", 0))
formatted["databases"][db]["tables"][table] = {
"table_name": table,
"size_bytes": size_bytes,
"size_formatted": self._format_bytes(size_bytes),
"replica_count": table_info.get("replica_count", 1),
"details": table_info
}
formatted["databases"][db]["total_size_bytes"] += size_bytes
formatted["summary"]["total_size_bytes"] += size_bytes
# Update summary
formatted["summary"]["total_databases"] = len(formatted["databases"])
formatted["summary"]["total_tables"] = sum(db["table_count"] for db in formatted["databases"].values())
formatted["summary"]["total_size_formatted"] = self._format_bytes(formatted["summary"]["total_size_bytes"])
# Update database totals formatting
for db_info in formatted["databases"].values():
db_info["total_size_formatted"] = self._format_bytes(db_info["total_size_bytes"])
return formatted
except Exception as e:
logger.error(f"Failed to format table data size: {str(e)}")
return {
"error": f"Failed to format data: {str(e)}",
"raw_data": data
}
def _format_bytes(self, bytes_value: int) -> str:
"""
Format bytes value to human readable string
Args:
bytes_value: Bytes value
Returns:
Formatted string like "1.23 GB"
"""
try:
bytes_value = int(bytes_value)
if bytes_value == 0:
return "0 B"
units = ["B", "KB", "MB", "GB", "TB", "PB"]
unit_index = 0
size = float(bytes_value)
while size >= 1024 and unit_index < len(units) - 1:
size /= 1024
unit_index += 1
if unit_index == 0:
return f"{int(size)} {units[unit_index]}"
else:
return f"{size:.2f} {units[unit_index]}"
except (ValueError, TypeError):
return str(bytes_value)
class MemoryTracker:
"""Memory tracker for Doris BE memory monitoring"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
async def get_realtime_memory_stats(
self,
tracker_type: str = "overview",
include_details: bool = True
) -> Dict[str, Any]:
"""
Get real-time memory statistics
Args:
tracker_type: Type of memory trackers to retrieve
include_details: Whether to include detailed information
Returns:
Dict containing memory statistics
"""
try:
# This is a placeholder implementation
# In a real implementation, this would fetch data from Doris BE memory tracker endpoints
return {
"success": True,
"tracker_type": tracker_type,
"include_details": include_details,
"timestamp": datetime.now().isoformat(),
"memory_stats": {
"total_memory": "8.00 GB",
"used_memory": "4.50 GB",
"free_memory": "3.50 GB",
"memory_usage_percent": 56.25
},
"note": "Memory tracker functionality requires BE HTTP endpoints to be available"
}
except Exception as e:
logger.error(f"Failed to get realtime memory stats: {str(e)}")
return {
"success": False,
"error": f"Failed to get realtime memory stats: {str(e)}",
"tracker_type": tracker_type,
"timestamp": datetime.now().isoformat()
}
async def get_historical_memory_stats(
self,
tracker_names: List[str] = None,
time_range: str = "1h"
) -> Dict[str, Any]:
"""
Get historical memory statistics
Args:
tracker_names: List of specific tracker names to query
time_range: Time range for historical data
Returns:
Dict containing historical memory statistics
"""
try:
# This is a placeholder implementation
# In a real implementation, this would fetch historical data from Doris BE bvar endpoints
return {
"success": True,
"tracker_names": tracker_names,
"time_range": time_range,
"timestamp": datetime.now().isoformat(),
"historical_stats": {
"data_points": 60,
"interval": "1m",
"memory_trend": "stable",
"avg_usage": "4.2 GB",
"peak_usage": "5.1 GB",
"min_usage": "3.8 GB"
},
"note": "Historical memory tracking functionality requires BE bvar endpoints to be available"
}
except Exception as e:
logger.error(f"Failed to get historical memory stats: {str(e)}")
return {
"success": False,
"error": f"Failed to get historical memory stats: {str(e)}",
"tracker_names": tracker_names,
"time_range": time_range,
"timestamp": datetime.now().isoformat()
}

View File

@@ -32,6 +32,8 @@ try:
except ImportError: except ImportError:
load_dotenv = None load_dotenv = None
from .logger import get_logger
@dataclass @dataclass
class DatabaseConfig: class DatabaseConfig:
@@ -41,16 +43,35 @@ class DatabaseConfig:
port: int = 9030 port: int = 9030
user: str = "root" user: str = "root"
password: str = "" password: str = ""
database: str = "test" database: str = "information_schema"
charset: str = "utf8mb4" charset: str = "UTF8"
# FE HTTP API port for profile and other HTTP APIs
fe_http_port: int = 8030
# BE nodes configuration for external access
# If be_hosts is empty, will use "show backends" to get BE nodes
be_hosts: list[str] = field(default_factory=list)
be_webserver_port: int = 8040
# Arrow Flight SQL Configuration (Required for ADBC tools)
fe_arrow_flight_sql_port: int | None = None
be_arrow_flight_sql_port: int | None = None
# Connection pool configuration # Connection pool configuration
min_connections: int = 5 # Note: min_connections is fixed at 0 to avoid at_eof connection issues
# This prevents pre-creation of connections which can cause state problems
_min_connections: int = field(default=0, init=False) # Internal use only, always 0
max_connections: int = 20 max_connections: int = 20
connection_timeout: int = 30 connection_timeout: int = 30
health_check_interval: int = 60 health_check_interval: int = 60
max_connection_age: int = 3600 max_connection_age: int = 3600
@property
def min_connections(self) -> int:
"""Minimum connections is always 0 to prevent at_eof issues"""
return self._min_connections
@dataclass @dataclass
class SecurityConfig: class SecurityConfig:
@@ -62,17 +83,26 @@ class SecurityConfig:
token_expiry: int = 3600 token_expiry: int = 3600
# SQL security configuration # SQL security configuration
enable_security_check: bool = True # Main switch: whether to enable SQL security check
blocked_keywords: list[str] = field( blocked_keywords: list[str] = field(
default_factory=lambda: [ default_factory=lambda: [
# DDL Operations (Data Definition Language)
"DROP", "DROP",
"DELETE", "CREATE",
"TRUNCATE",
"ALTER", "ALTER",
"CREATE", "TRUNCATE",
# DML Operations (Data Manipulation Language)
"DELETE",
"INSERT", "INSERT",
"UPDATE", "UPDATE",
# DCL Operations (Data Control Language)
"GRANT", "GRANT",
"REVOKE", "REVOKE",
# System Operations
"EXEC",
"EXECUTE",
"SHUTDOWN",
"KILL",
] ]
) )
max_query_complexity: int = 100 max_query_complexity: int = 100
@@ -102,6 +132,52 @@ class PerformanceConfig:
# Connection pool optimization configuration # Connection pool optimization configuration
connection_pool_size: int = 20 connection_pool_size: int = 20
idle_timeout: int = 1800 idle_timeout: int = 1800
# Response content size limit (characters)
max_response_content_size: int = 4096
@dataclass
class DataQualityConfig:
"""Data quality analysis configuration"""
# Column analysis configuration
max_columns_per_batch: int = 20 # Maximum columns to analyze in a single batch
default_sample_size: int = 100000 # Default sample size for analysis
# Sampling strategy configuration
small_table_threshold: int = 100000 # Tables smaller than this use full table analysis
medium_table_threshold: int = 1000000 # Tables smaller than this use simple LIMIT sampling
# Tables larger than medium_table_threshold use systematic sampling
# Performance optimization
enable_batch_analysis: bool = True # Enable batch analysis for multiple columns
batch_timeout: int = 300 # Timeout for batch analysis in seconds
# Accuracy vs Performance trade-off
enable_fast_mode: bool = False # Use approximate algorithms for faster results
fast_mode_sample_size: int = 10000 # Sample size for fast mode
# Statistical analysis configuration
enable_distribution_analysis: bool = True # Enable distribution analysis
histogram_bins: int = 20 # Number of bins for histogram analysis
percentile_levels: list[float] = field(default_factory=lambda: [0.25, 0.5, 0.75, 0.95, 0.99]) # Percentile levels to calculate
@dataclass
class ADBCConfig:
"""ADBC (Arrow Flight SQL) configuration"""
# Default query parameters
default_max_rows: int = 100000
default_timeout: int = 60
default_return_format: str = "arrow" # "arrow", "pandas", "dict"
# Connection timeout for ADBC
connection_timeout: int = 30
# Whether to enable ADBC tools
enabled: bool = True
@dataclass @dataclass
@@ -117,6 +193,11 @@ class LoggingConfig:
# Audit log configuration # Audit log configuration
enable_audit: bool = True enable_audit: bool = True
audit_file_path: str | None = None audit_file_path: str | None = None
# Log cleanup configuration
enable_cleanup: bool = True
max_age_days: int = 30
cleanup_interval_hours: int = 24
@dataclass @dataclass
@@ -125,11 +206,11 @@ class MonitoringConfig:
# Metrics collection configuration # Metrics collection configuration
enable_metrics: bool = True enable_metrics: bool = True
metrics_port: int = 8081 metrics_port: int = 3001
metrics_path: str = "/metrics" metrics_path: str = "/metrics"
# Health check configuration # Health check configuration
health_check_port: int = 8082 health_check_port: int = 3002
health_check_path: str = "/health" health_check_path: str = "/health"
# Alert configuration # Alert configuration
@@ -143,15 +224,21 @@ class DorisConfig:
# Basic configuration # Basic configuration
server_name: str = "doris-mcp-server" server_name: str = "doris-mcp-server"
server_version: str = "1.0.0" server_version: str = "0.4.1"
server_port: int = 8080 server_port: int = 3000
transport: str = "stdio"
# Temporary files configuration
temp_files_dir: str = "tmp" # Temporary files directory for Explain and Profile outputs
# Sub-configuration modules # Sub-configuration modules
database: DatabaseConfig = field(default_factory=DatabaseConfig) database: DatabaseConfig = field(default_factory=DatabaseConfig)
security: SecurityConfig = field(default_factory=SecurityConfig) security: SecurityConfig = field(default_factory=SecurityConfig)
performance: PerformanceConfig = field(default_factory=PerformanceConfig) performance: PerformanceConfig = field(default_factory=PerformanceConfig)
data_quality: DataQualityConfig = field(default_factory=DataQualityConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig) logging: LoggingConfig = field(default_factory=LoggingConfig)
monitoring: MonitoringConfig = field(default_factory=MonitoringConfig) monitoring: MonitoringConfig = field(default_factory=MonitoringConfig)
adbc: ADBCConfig = field(default_factory=ADBCConfig)
# Custom configuration # Custom configuration
custom_config: dict[str, Any] = field(default_factory=dict) custom_config: dict[str, Any] = field(default_factory=dict)
@@ -215,11 +302,24 @@ class DorisConfig:
config.database.user = os.getenv("DORIS_USER", config.database.user) config.database.user = os.getenv("DORIS_USER", config.database.user)
config.database.password = os.getenv("DORIS_PASSWORD", config.database.password) config.database.password = os.getenv("DORIS_PASSWORD", config.database.password)
config.database.database = os.getenv("DORIS_DATABASE", config.database.database) config.database.database = os.getenv("DORIS_DATABASE", config.database.database)
config.database.fe_http_port = int(os.getenv("DORIS_FE_HTTP_PORT", str(config.database.fe_http_port)))
# BE nodes configuration
be_hosts_env = os.getenv("DORIS_BE_HOSTS", "")
if be_hosts_env:
config.database.be_hosts = [host.strip() for host in be_hosts_env.split(",") if host.strip()]
config.database.be_webserver_port = int(os.getenv("DORIS_BE_WEBSERVER_PORT", str(config.database.be_webserver_port)))
# Arrow Flight SQL Configuration
fe_arrow_port_env = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
if fe_arrow_port_env:
config.database.fe_arrow_flight_sql_port = int(fe_arrow_port_env)
be_arrow_port_env = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
if be_arrow_port_env:
config.database.be_arrow_flight_sql_port = int(be_arrow_port_env)
# Connection pool configuration # Connection pool configuration
config.database.min_connections = int(
os.getenv("DORIS_MIN_CONNECTIONS", str(config.database.min_connections))
)
config.database.max_connections = int( config.database.max_connections = int(
os.getenv("DORIS_MAX_CONNECTIONS", str(config.database.max_connections)) os.getenv("DORIS_MAX_CONNECTIONS", str(config.database.max_connections))
) )
@@ -245,6 +345,22 @@ class DorisConfig:
config.security.max_query_complexity = int( config.security.max_query_complexity = int(
os.getenv("MAX_QUERY_COMPLEXITY", str(config.security.max_query_complexity)) os.getenv("MAX_QUERY_COMPLEXITY", str(config.security.max_query_complexity))
) )
config.security.enable_security_check = (
os.getenv("ENABLE_SECURITY_CHECK", str(config.security.enable_security_check).lower()).lower() == "true"
)
# Handle blocked keywords environment variable configuration
# Format: BLOCKED_KEYWORDS="DROP,DELETE,TRUNCATE,ALTER,CREATE,INSERT,UPDATE,GRANT,REVOKE"
blocked_keywords_env = os.getenv("BLOCKED_KEYWORDS", "")
if blocked_keywords_env:
# If environment variable is provided, use keywords list from environment variable
config.security.blocked_keywords = [
keyword.strip().upper()
for keyword in blocked_keywords_env.split(",")
if keyword.strip()
]
# If environment variable is empty, keep default configuration unchanged
config.security.enable_masking = ( config.security.enable_masking = (
os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true" os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true"
) )
@@ -265,6 +381,9 @@ class DorisConfig:
config.performance.query_timeout = int( config.performance.query_timeout = int(
os.getenv("QUERY_TIMEOUT", str(config.performance.query_timeout)) os.getenv("QUERY_TIMEOUT", str(config.performance.query_timeout))
) )
config.performance.max_response_content_size = int(
os.getenv("MAX_RESPONSE_CONTENT_SIZE", str(config.performance.max_response_content_size))
)
# Logging configuration # Logging configuration
config.logging.level = os.getenv("LOG_LEVEL", config.logging.level) config.logging.level = os.getenv("LOG_LEVEL", config.logging.level)
@@ -273,6 +392,15 @@ class DorisConfig:
os.getenv("ENABLE_AUDIT", str(config.logging.enable_audit).lower()).lower() == "true" os.getenv("ENABLE_AUDIT", str(config.logging.enable_audit).lower()).lower() == "true"
) )
config.logging.audit_file_path = os.getenv("AUDIT_FILE_PATH", config.logging.audit_file_path) config.logging.audit_file_path = os.getenv("AUDIT_FILE_PATH", config.logging.audit_file_path)
config.logging.enable_cleanup = (
os.getenv("ENABLE_LOG_CLEANUP", str(config.logging.enable_cleanup).lower()).lower() == "true"
)
config.logging.max_age_days = int(
os.getenv("LOG_MAX_AGE_DAYS", str(config.logging.max_age_days))
)
config.logging.cleanup_interval_hours = int(
os.getenv("LOG_CLEANUP_INTERVAL_HOURS", str(config.logging.cleanup_interval_hours))
)
# Monitoring configuration # Monitoring configuration
config.monitoring.enable_metrics = ( config.monitoring.enable_metrics = (
@@ -289,10 +417,58 @@ class DorisConfig:
) )
config.monitoring.alert_webhook_url = os.getenv("ALERT_WEBHOOK_URL", config.monitoring.alert_webhook_url) config.monitoring.alert_webhook_url = os.getenv("ALERT_WEBHOOK_URL", config.monitoring.alert_webhook_url)
# ADBC configuration
config.adbc.default_max_rows = int(
os.getenv("ADBC_DEFAULT_MAX_ROWS", str(config.adbc.default_max_rows))
)
config.adbc.default_timeout = int(
os.getenv("ADBC_DEFAULT_TIMEOUT", str(config.adbc.default_timeout))
)
config.adbc.default_return_format = os.getenv("ADBC_DEFAULT_RETURN_FORMAT", config.adbc.default_return_format)
config.adbc.connection_timeout = int(
os.getenv("ADBC_CONNECTION_TIMEOUT", str(config.adbc.connection_timeout))
)
config.adbc.enabled = (
os.getenv("ADBC_ENABLED", str(config.adbc.enabled).lower()).lower() == "true"
)
# Data quality configuration
config.data_quality.max_columns_per_batch = int(
os.getenv("DATA_QUALITY_MAX_COLUMNS_PER_BATCH", str(config.data_quality.max_columns_per_batch))
)
config.data_quality.default_sample_size = int(
os.getenv("DATA_QUALITY_DEFAULT_SAMPLE_SIZE", str(config.data_quality.default_sample_size))
)
config.data_quality.small_table_threshold = int(
os.getenv("DATA_QUALITY_SMALL_TABLE_THRESHOLD", str(config.data_quality.small_table_threshold))
)
config.data_quality.medium_table_threshold = int(
os.getenv("DATA_QUALITY_MEDIUM_TABLE_THRESHOLD", str(config.data_quality.medium_table_threshold))
)
config.data_quality.enable_batch_analysis = (
os.getenv("DATA_QUALITY_ENABLE_BATCH_ANALYSIS", str(config.data_quality.enable_batch_analysis).lower()).lower() == "true"
)
config.data_quality.batch_timeout = int(
os.getenv("DATA_QUALITY_BATCH_TIMEOUT", str(config.data_quality.batch_timeout))
)
config.data_quality.enable_fast_mode = (
os.getenv("DATA_QUALITY_ENABLE_FAST_MODE", str(config.data_quality.enable_fast_mode).lower()).lower() == "true"
)
config.data_quality.fast_mode_sample_size = int(
os.getenv("DATA_QUALITY_FAST_MODE_SAMPLE_SIZE", str(config.data_quality.fast_mode_sample_size))
)
config.data_quality.enable_distribution_analysis = (
os.getenv("DATA_QUALITY_ENABLE_DISTRIBUTION_ANALYSIS", str(config.data_quality.enable_distribution_analysis).lower()).lower() == "true"
)
config.data_quality.histogram_bins = int(
os.getenv("DATA_QUALITY_HISTOGRAM_BINS", str(config.data_quality.histogram_bins))
)
# Server configuration # Server configuration
config.server_name = os.getenv("SERVER_NAME", config.server_name) config.server_name = os.getenv("SERVER_NAME", config.server_name)
config.server_version = os.getenv("SERVER_VERSION", config.server_version) config.server_version = os.getenv("SERVER_VERSION", config.server_version)
config.server_port = int(os.getenv("SERVER_PORT", str(config.server_port))) config.server_port = int(os.getenv("SERVER_PORT", str(config.server_port)))
config.temp_files_dir = os.getenv("TEMP_FILES_DIR", config.temp_files_dir)
return config return config
@@ -302,7 +478,7 @@ class DorisConfig:
config = cls() config = cls()
# Update basic configuration # Update basic configuration
for key in ["server_name", "server_version", "server_port"]: for key in ["server_name", "server_version", "server_port", "temp_files_dir"]:
if key in config_data: if key in config_data:
setattr(config, key, config_data[key]) setattr(config, key, config_data[key])
@@ -327,6 +503,13 @@ class DorisConfig:
if hasattr(config.performance, key): if hasattr(config.performance, key):
setattr(config.performance, key, value) setattr(config.performance, key, value)
# Update data quality configuration
if "data_quality" in config_data:
dq_config = config_data["data_quality"]
for key, value in dq_config.items():
if hasattr(config.data_quality, key):
setattr(config.data_quality, key, value)
# Update logging configuration # Update logging configuration
if "logging" in config_data: if "logging" in config_data:
log_config = config_data["logging"] log_config = config_data["logging"]
@@ -341,6 +524,13 @@ class DorisConfig:
if hasattr(config.monitoring, key): if hasattr(config.monitoring, key):
setattr(config.monitoring, key, value) setattr(config.monitoring, key, value)
# Update ADBC configuration
if "adbc" in config_data:
adbc_config = config_data["adbc"]
for key, value in adbc_config.items():
if hasattr(config.adbc, key):
setattr(config.adbc, key, value)
# Custom configuration # Custom configuration
config.custom_config = config_data.get("custom", {}) config.custom_config = config_data.get("custom", {})
@@ -352,6 +542,7 @@ class DorisConfig:
"server_name": self.server_name, "server_name": self.server_name,
"server_version": self.server_version, "server_version": self.server_version,
"server_port": self.server_port, "server_port": self.server_port,
"temp_files_dir": self.temp_files_dir,
"database": { "database": {
"host": self.database.host, "host": self.database.host,
"port": self.database.port, "port": self.database.port,
@@ -359,7 +550,12 @@ class DorisConfig:
"password": "***", # Hide password "password": "***", # Hide password
"database": self.database.database, "database": self.database.database,
"charset": self.database.charset, "charset": self.database.charset,
"min_connections": self.database.min_connections, "fe_http_port": self.database.fe_http_port,
"be_hosts": self.database.be_hosts,
"be_webserver_port": self.database.be_webserver_port,
"fe_arrow_flight_sql_port": self.database.fe_arrow_flight_sql_port,
"be_arrow_flight_sql_port": self.database.be_arrow_flight_sql_port,
"min_connections": self.database.min_connections, # Always 0, shown for reference
"max_connections": self.database.max_connections, "max_connections": self.database.max_connections,
"connection_timeout": self.database.connection_timeout, "connection_timeout": self.database.connection_timeout,
"health_check_interval": self.database.health_check_interval, "health_check_interval": self.database.health_check_interval,
@@ -369,6 +565,7 @@ class DorisConfig:
"auth_type": self.security.auth_type, "auth_type": self.security.auth_type,
"token_secret": "***", # Hide secret key "token_secret": "***", # Hide secret key
"token_expiry": self.security.token_expiry, "token_expiry": self.security.token_expiry,
"enable_security_check": self.security.enable_security_check,
"blocked_keywords": self.security.blocked_keywords, "blocked_keywords": self.security.blocked_keywords,
"max_query_complexity": self.security.max_query_complexity, "max_query_complexity": self.security.max_query_complexity,
"max_result_rows": self.security.max_result_rows, "max_result_rows": self.security.max_result_rows,
@@ -384,6 +581,20 @@ class DorisConfig:
"query_timeout": self.performance.query_timeout, "query_timeout": self.performance.query_timeout,
"connection_pool_size": self.performance.connection_pool_size, "connection_pool_size": self.performance.connection_pool_size,
"idle_timeout": self.performance.idle_timeout, "idle_timeout": self.performance.idle_timeout,
"max_response_content_size": self.performance.max_response_content_size,
},
"data_quality": {
"max_columns_per_batch": self.data_quality.max_columns_per_batch,
"default_sample_size": self.data_quality.default_sample_size,
"small_table_threshold": self.data_quality.small_table_threshold,
"medium_table_threshold": self.data_quality.medium_table_threshold,
"enable_batch_analysis": self.data_quality.enable_batch_analysis,
"batch_timeout": self.data_quality.batch_timeout,
"enable_fast_mode": self.data_quality.enable_fast_mode,
"fast_mode_sample_size": self.data_quality.fast_mode_sample_size,
"enable_distribution_analysis": self.data_quality.enable_distribution_analysis,
"histogram_bins": self.data_quality.histogram_bins,
"percentile_levels": self.data_quality.percentile_levels,
}, },
"logging": { "logging": {
"level": self.logging.level, "level": self.logging.level,
@@ -393,6 +604,9 @@ class DorisConfig:
"backup_count": self.logging.backup_count, "backup_count": self.logging.backup_count,
"enable_audit": self.logging.enable_audit, "enable_audit": self.logging.enable_audit,
"audit_file_path": self.logging.audit_file_path, "audit_file_path": self.logging.audit_file_path,
"enable_cleanup": self.logging.enable_cleanup,
"max_age_days": self.logging.max_age_days,
"cleanup_interval_hours": self.logging.cleanup_interval_hours,
}, },
"monitoring": { "monitoring": {
"enable_metrics": self.monitoring.enable_metrics, "enable_metrics": self.monitoring.enable_metrics,
@@ -403,6 +617,13 @@ class DorisConfig:
"enable_alerts": self.monitoring.enable_alerts, "enable_alerts": self.monitoring.enable_alerts,
"alert_webhook_url": self.monitoring.alert_webhook_url, "alert_webhook_url": self.monitoring.alert_webhook_url,
}, },
"adbc": {
"default_max_rows": self.adbc.default_max_rows,
"default_timeout": self.adbc.default_timeout,
"default_return_format": self.adbc.default_return_format,
"connection_timeout": self.adbc.connection_timeout,
"enabled": self.adbc.enabled,
},
"custom": self.custom_config, "custom": self.custom_config,
} }
@@ -435,11 +656,8 @@ class DorisConfig:
if not self.database.user: if not self.database.user:
errors.append("Database username cannot be empty") errors.append("Database username cannot be empty")
if self.database.min_connections <= 0: if self.database.max_connections <= 0:
errors.append("Minimum connections must be greater than 0") errors.append("Maximum connections must be greater than 0")
if self.database.max_connections <= self.database.min_connections:
errors.append("Maximum connections must be greater than minimum connections")
# Validate security configuration # Validate security configuration
if self.security.auth_type not in ["token", "basic", "oauth"]: if self.security.auth_type not in ["token", "basic", "oauth"]:
@@ -464,6 +682,31 @@ class DorisConfig:
if self.performance.query_timeout <= 0: if self.performance.query_timeout <= 0:
errors.append("Query timeout must be greater than 0") errors.append("Query timeout must be greater than 0")
# Validate data quality configuration
if self.data_quality.max_columns_per_batch <= 0:
errors.append("Max columns per batch must be greater than 0")
if self.data_quality.default_sample_size <= 0:
errors.append("Default sample size must be greater than 0")
if self.data_quality.small_table_threshold <= 0:
errors.append("Small table threshold must be greater than 0")
if self.data_quality.medium_table_threshold <= 0:
errors.append("Medium table threshold must be greater than 0")
if self.data_quality.small_table_threshold >= self.data_quality.medium_table_threshold:
errors.append("Small table threshold must be less than medium table threshold")
if self.data_quality.batch_timeout <= 0:
errors.append("Batch timeout must be greater than 0")
if self.data_quality.fast_mode_sample_size <= 0:
errors.append("Fast mode sample size must be greater than 0")
if self.data_quality.histogram_bins <= 0:
errors.append("Histogram bins must be greater than 0")
# Validate logging configuration # Validate logging configuration
if self.logging.level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]: if self.logging.level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
errors.append("Log level must be one of DEBUG, INFO, WARNING, ERROR, or CRITICAL") errors.append("Log level must be one of DEBUG, INFO, WARNING, ERROR, or CRITICAL")
@@ -473,6 +716,12 @@ class DorisConfig:
if self.logging.backup_count < 0: if self.logging.backup_count < 0:
errors.append("Log backup count cannot be negative") errors.append("Log backup count cannot be negative")
if self.logging.max_age_days <= 0:
errors.append("Log max age days must be greater than 0")
if self.logging.cleanup_interval_hours <= 0:
errors.append("Log cleanup interval hours must be greater than 0")
# Validate monitoring configuration # Validate monitoring configuration
if not (1 <= self.monitoring.metrics_port <= 65535): if not (1 <= self.monitoring.metrics_port <= 65535):
@@ -481,6 +730,19 @@ class DorisConfig:
if not (1 <= self.monitoring.health_check_port <= 65535): if not (1 <= self.monitoring.health_check_port <= 65535):
errors.append("Health check port must be in the range 1-65535") errors.append("Health check port must be in the range 1-65535")
# Validate ADBC configuration
if self.adbc.default_max_rows <= 0:
errors.append("ADBC default max rows must be greater than 0")
if self.adbc.default_timeout <= 0:
errors.append("ADBC default timeout must be greater than 0")
if self.adbc.default_return_format not in ["arrow", "pandas", "dict"]:
errors.append("ADBC default return format must be one of arrow, pandas, or dict")
if self.adbc.connection_timeout <= 0:
errors.append("ADBC connection timeout must be greater than 0")
return errors return errors
def get_connection_string(self) -> str: def get_connection_string(self) -> str:
@@ -492,7 +754,7 @@ class DorisConfig:
return { return {
"server": f"{self.server_name} v{self.server_version}", "server": f"{self.server_name} v{self.server_version}",
"database": f"{self.database.host}:{self.database.port}/{self.database.database}", "database": f"{self.database.host}:{self.database.port}/{self.database.database}",
"connection_pool": f"{self.database.min_connections}-{self.database.max_connections}", "connection_pool": f"0-{self.database.max_connections} (min fixed at 0 for stability)",
"security": { "security": {
"auth_type": self.security.auth_type, "auth_type": self.security.auth_type,
"masking_enabled": self.security.enable_masking, "masking_enabled": self.security.enable_masking,
@@ -518,56 +780,50 @@ class ConfigManager:
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
def setup_logging(self): def setup_logging(self):
"""Setup logging configuration""" """Setup logging configuration using enhanced logger"""
# Configure root logger from .logger import setup_logging, get_logger
root_logger = logging.getLogger() import sys
root_logger.setLevel(getattr(logging, self.config.logging.level.upper()))
# Determine log directory
# Clear existing handlers log_dir = "logs"
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
# Create formatter
formatter = logging.Formatter(self.config.logging.format)
# Console handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
root_logger.addHandler(console_handler)
# File handler (if configured)
if self.config.logging.file_path: if self.config.logging.file_path:
try: # Extract directory from file path if provided
from logging.handlers import RotatingFileHandler from pathlib import Path
log_dir = str(Path(self.config.logging.file_path).parent)
file_handler = RotatingFileHandler(
self.config.logging.file_path, # Detect if we're in stdio mode by checking if this is likely MCP stdio communication
maxBytes=self.config.logging.max_file_size, # In stdio mode, we shouldn't output to console as it interferes with JSON protocol
backupCount=self.config.logging.backup_count, is_stdio_mode = (
encoding="utf-8", self.config.transport == "stdio" or
) "--transport" in sys.argv and "stdio" in sys.argv or
file_handler.setFormatter(formatter) not sys.stdout.isatty() # Not a terminal (likely piped/redirected)
root_logger.addHandler(file_handler) )
except Exception as e:
self.logger.warning(f"Failed to setup file logging: {e}") # Setup enhanced logging with cleanup functionality
setup_logging(
# Audit log handler (if configured) level=self.config.logging.level,
if self.config.logging.enable_audit and self.config.logging.audit_file_path: log_dir=log_dir,
try: enable_console=not is_stdio_mode, # Disable console logging in stdio mode
from logging.handlers import RotatingFileHandler enable_file=True,
enable_audit=self.config.logging.enable_audit,
audit_logger = logging.getLogger("audit") audit_file=self.config.logging.audit_file_path,
audit_handler = RotatingFileHandler( max_file_size=self.config.logging.max_file_size,
self.config.logging.audit_file_path, backup_count=self.config.logging.backup_count,
maxBytes=self.config.logging.max_file_size, enable_cleanup=self.config.logging.enable_cleanup,
backupCount=self.config.logging.backup_count, max_age_days=self.config.logging.max_age_days,
encoding="utf-8", cleanup_interval_hours=self.config.logging.cleanup_interval_hours
) )
audit_handler.setFormatter(formatter)
audit_logger.addHandler(audit_handler) # Update logger to use new system
audit_logger.setLevel(logging.INFO) self.logger = get_logger(__name__)
except Exception as e:
self.logger.warning(f"Failed to setup audit logging: {e}") self.logger.info("Enhanced logging system with cleanup initialized successfully")
self.logger.info(f"Log directory: {log_dir}")
self.logger.info(f"Log level: {self.config.logging.level}")
self.logger.info(f"Audit logging: {'Enabled' if self.config.logging.enable_audit else 'Disabled'}")
self.logger.info(f"Log cleanup: {'Enabled' if self.config.logging.enable_cleanup else 'Disabled'}")
if self.config.logging.enable_cleanup:
self.logger.info(f"Cleanup config: Max age {self.config.logging.max_age_days} days, interval {self.config.logging.cleanup_interval_hours}h")
def validate_config(self) -> bool: def validate_config(self) -> bool:
"""Validate configuration""" """Validate configuration"""

View File

@@ -0,0 +1,733 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Data Exploration Tools Module
Provides table data distribution analysis and exploration capabilities
"""
import time
import math
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from .db import DorisConnectionManager
from .logger import get_logger
logger = get_logger(__name__)
class DataExplorationTools:
"""Data exploration tools for table distribution analysis"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
logger.info("DataExplorationTools initialized")
# ==================== Private Helper Methods ====================
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
"""Build full table name with catalog and database using three-part naming convention"""
# Default catalog for internal tables
effective_catalog = catalog_name if catalog_name else "internal"
if db_name:
return f"{effective_catalog}.{db_name}.{table_name}"
else:
# If no db_name provided, need to determine the current database
return f"{effective_catalog}.{table_name}"
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
"""Get basic table information including row count"""
try:
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
result = await connection.execute(count_sql)
if result.data:
return {"row_count": result.data[0]["row_count"]}
return None
except Exception as e:
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
return {"row_count": 0}
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
"""Get detailed column information"""
try:
where_conditions = [f"table_name = '{table_name}'"]
if db_name:
where_conditions.append(f"table_schema = '{db_name}'")
else:
where_conditions.append("table_schema = DATABASE()")
columns_sql = f"""
SELECT
column_name,
data_type,
is_nullable,
column_comment,
ordinal_position
FROM information_schema.columns
WHERE {' AND '.join(where_conditions)}
ORDER BY ordinal_position
"""
result = await connection.execute(columns_sql)
return result.data if result.data else []
except Exception as e:
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
return []
async def _determine_sampling_strategy(self, connection, table_name: str, total_rows: int, sample_size: int) -> Dict[str, Any]:
"""Determine optimal sampling strategy based on table size"""
if total_rows <= sample_size:
# Use all data if table is small enough
return {
"total_rows": total_rows,
"sample_size": total_rows,
"sampling_method": "full_scan",
"sampling_ratio": 1.0,
"use_sampling": False,
"sample_table_expression": table_name
}
else:
# Use random sampling for large tables
sampling_ratio = sample_size / total_rows
return {
"total_rows": total_rows,
"sample_size": sample_size,
"sampling_method": "random_sample",
"sampling_ratio": round(sampling_ratio, 4),
"use_sampling": True,
"sample_table_expression": f"(SELECT * FROM {table_name} ORDER BY RAND() LIMIT {sample_size}) as sample_table"
}
def _select_analysis_columns(self, columns_info: List[Dict], include_all: bool) -> List[Dict]:
"""Select columns for analysis based on strategy"""
if include_all:
return columns_info
# If not analyzing all columns, prioritize key columns
priority_keywords = ['id', 'key', 'code', 'status', 'type', 'amount', 'count', 'date', 'time']
priority_columns = []
other_columns = []
for col in columns_info:
col_name_lower = col["column_name"].lower()
if any(keyword in col_name_lower for keyword in priority_keywords):
priority_columns.append(col)
else:
other_columns.append(col)
# Return priority columns plus first 10 other columns
return priority_columns + other_columns[:10]
def _is_numeric_type(self, data_type: str) -> bool:
"""Check if column type is numeric"""
numeric_types = [
'tinyint', 'smallint', 'int', 'bigint', 'largeint',
'float', 'double', 'decimal', 'numeric'
]
return any(num_type in data_type.lower() for num_type in numeric_types)
def _is_categorical_type(self, data_type: str) -> bool:
"""Check if column type is categorical"""
categorical_types = ['varchar', 'char', 'string', 'text', 'enum']
return any(cat_type in data_type.lower() for cat_type in categorical_types)
def _is_temporal_type(self, data_type: str) -> bool:
"""Check if column type is temporal"""
temporal_types = ['date', 'datetime', 'timestamp', 'time']
return any(temp_type in data_type.lower() for temp_type in temporal_types)
async def _analyze_numeric_distributions(self, connection, table_name: str, numeric_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
"""Analyze distribution patterns for numeric columns"""
numeric_analysis = {}
for column in numeric_columns:
col_name = column["column_name"]
try:
# Basic statistics
table_expr = sampling_info.get("sample_table_expression", table_name)
stats_sql = f"""
SELECT
COUNT({col_name}) as count,
MIN({col_name}) as min_value,
MAX({col_name}) as max_value,
AVG({col_name}) as mean_value,
STDDEV({col_name}) as std_dev
FROM {table_expr}
WHERE {col_name} IS NOT NULL
"""
stats_result = await connection.execute(stats_sql)
if stats_result.data and stats_result.data[0]["count"] > 0:
stats = stats_result.data[0]
# Percentiles calculation
percentiles = await self._calculate_percentiles(connection, table_name, col_name, sampling_info)
# Outlier detection
outliers = await self._detect_numeric_outliers(connection, table_name, col_name, percentiles, sampling_info)
# Distribution shape analysis
distribution_shape = await self._analyze_distribution_shape(
connection, table_name, col_name, stats, percentiles, sampling_info
)
numeric_analysis[col_name] = {
"data_type": column["data_type"],
"statistics": {
"count": stats["count"],
"mean": round(float(stats["mean_value"]), 4) if stats["mean_value"] else None,
"std": round(float(stats["std_dev"]), 4) if stats["std_dev"] else None,
"min": float(stats["min_value"]) if stats["min_value"] else None,
"max": float(stats["max_value"]) if stats["max_value"] else None,
**percentiles
},
"distribution_shape": distribution_shape,
"outliers": outliers
}
except Exception as e:
logger.warning(f"Failed to analyze numeric column {col_name}: {str(e)}")
numeric_analysis[col_name] = {"error": str(e)}
return numeric_analysis
async def _calculate_percentiles(self, connection, table_name: str, col_name: str, sampling_info: Dict) -> Dict[str, float]:
"""Calculate percentiles for numeric column"""
try:
table_expr = sampling_info.get("sample_table_expression", table_name)
percentile_sql = f"""
SELECT
PERCENTILE({col_name}, 0.25) as p25,
PERCENTILE({col_name}, 0.50) as p50,
PERCENTILE({col_name}, 0.75) as p75,
PERCENTILE({col_name}, 0.90) as p90,
PERCENTILE({col_name}, 0.95) as p95,
PERCENTILE({col_name}, 0.99) as p99
FROM {table_expr}
WHERE {col_name} IS NOT NULL
"""
result = await connection.execute(percentile_sql)
if result.data:
data = result.data[0]
return {
"25%": round(float(data["p25"]), 4) if data["p25"] else None,
"50%": round(float(data["p50"]), 4) if data["p50"] else None,
"75%": round(float(data["p75"]), 4) if data["p75"] else None,
"90%": round(float(data["p90"]), 4) if data["p90"] else None,
"95%": round(float(data["p95"]), 4) if data["p95"] else None,
"99%": round(float(data["p99"]), 4) if data["p99"] else None
}
except Exception as e:
logger.warning(f"Failed to calculate percentiles for {col_name}: {str(e)}")
return {}
async def _detect_numeric_outliers(self, connection, table_name: str, col_name: str, percentiles: Dict, sampling_info: Dict) -> Dict[str, Any]:
"""Detect outliers using IQR method"""
try:
if "25%" not in percentiles or "75%" not in percentiles:
return {"outlier_count": 0, "outlier_rate": 0.0}
q1 = percentiles["25%"]
q3 = percentiles["75%"]
iqr = q3 - q1
lower_bound = q1 - 1.5 * iqr
upper_bound = q3 + 1.5 * iqr
table_expr = sampling_info.get("sample_table_expression", table_name)
outlier_sql = f"""
SELECT
COUNT(*) as total_count,
SUM(CASE WHEN {col_name} < {lower_bound} OR {col_name} > {upper_bound} THEN 1 ELSE 0 END) as outlier_count
FROM {table_expr}
WHERE {col_name} IS NOT NULL
"""
result = await connection.execute(outlier_sql)
if result.data:
data = result.data[0]
total_count = data["total_count"]
outlier_count = data["outlier_count"]
outlier_rate = outlier_count / total_count if total_count > 0 else 0
return {
"outlier_count": outlier_count,
"outlier_rate": round(outlier_rate, 4),
"outlier_threshold_lower": round(lower_bound, 4),
"outlier_threshold_upper": round(upper_bound, 4),
"iqr": round(iqr, 4)
}
except Exception as e:
logger.warning(f"Failed to detect outliers for {col_name}: {str(e)}")
return {"outlier_count": 0, "outlier_rate": 0.0}
async def _analyze_distribution_shape(self, connection, table_name: str, col_name: str, stats: Dict, percentiles: Dict, sampling_info: Dict) -> Dict[str, Any]:
"""Analyze the shape of data distribution"""
try:
mean = stats.get("mean_value", 0)
median = percentiles.get("50%", 0)
if mean is None or median is None:
return {"distribution_type": "unknown"}
# Calculate skewness indicator
if abs(mean - median) < 0.01:
skew_indicator = "symmetric"
elif mean > median:
skew_indicator = "right_skewed"
else:
skew_indicator = "left_skewed"
# Estimate kurtosis based on percentile spread
if "25%" in percentiles and "75%" in percentiles:
iqr = percentiles["75%"] - percentiles["25%"]
range_90 = percentiles.get("90%", percentiles["75%"]) - percentiles.get("10%", percentiles["25%"])
if iqr > 0:
kurtosis_indicator = "normal" if 2.5 <= range_90/iqr <= 3.5 else ("heavy_tailed" if range_90/iqr > 3.5 else "light_tailed")
else:
kurtosis_indicator = "unknown"
else:
kurtosis_indicator = "unknown"
return {
"skewness_indicator": skew_indicator,
"kurtosis_indicator": kurtosis_indicator,
"distribution_type": self._classify_distribution_type(skew_indicator, kurtosis_indicator),
"mean_median_ratio": round(mean / median, 4) if median != 0 else None
}
except Exception as e:
logger.warning(f"Failed to analyze distribution shape for {col_name}: {str(e)}")
return {"distribution_type": "unknown"}
def _classify_distribution_type(self, skew: str, kurtosis: str) -> str:
"""Classify distribution type based on skewness and kurtosis"""
if skew == "symmetric" and kurtosis == "normal":
return "approximately_normal"
elif skew == "right_skewed":
return "right_skewed"
elif skew == "left_skewed":
return "left_skewed"
elif kurtosis == "heavy_tailed":
return "heavy_tailed"
else:
return "non_normal"
async def _analyze_categorical_distributions(self, connection, table_name: str, categorical_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
"""Analyze distribution patterns for categorical columns"""
categorical_analysis = {}
for column in categorical_columns:
col_name = column["column_name"]
try:
# Basic cardinality and distribution
cardinality_sql = f"""
SELECT
COUNT(DISTINCT {col_name}) as cardinality,
COUNT({col_name}) as non_null_count
FROM {table_name}
WHERE {col_name} IS NOT NULL
{sampling_info.get('sample_query_suffix', '')}
"""
cardinality_result = await connection.execute(cardinality_sql)
if cardinality_result.data:
cardinality_data = cardinality_result.data[0]
cardinality = cardinality_data["cardinality"]
non_null_count = cardinality_data["non_null_count"]
# Value distribution (top values)
value_distribution = await self._get_categorical_value_distribution(
connection, table_name, col_name, sampling_info, non_null_count
)
# Calculate entropy and concentration
entropy = self._calculate_entropy(value_distribution)
concentration_ratio = value_distribution[0]["percentage"] if value_distribution else 0
categorical_analysis[col_name] = {
"data_type": column["data_type"],
"cardinality": cardinality,
"non_null_count": non_null_count,
"value_distribution": value_distribution,
"entropy": round(entropy, 3),
"concentration_ratio": round(concentration_ratio, 4),
"diversity_score": round(cardinality / non_null_count, 4) if non_null_count > 0 else 0
}
except Exception as e:
logger.warning(f"Failed to analyze categorical column {col_name}: {str(e)}")
categorical_analysis[col_name] = {"error": str(e)}
return categorical_analysis
async def _get_categorical_value_distribution(self, connection, table_name: str, col_name: str, sampling_info: Dict, total_count: int) -> List[Dict]:
"""Get value distribution for categorical column"""
try:
# Use sample table expression if sampling is enabled
table_expr = sampling_info.get("sample_table_expression", table_name)
distribution_sql = f"""
SELECT
{col_name} as value,
COUNT(*) as count
FROM {table_expr}
WHERE {col_name} IS NOT NULL
GROUP BY {col_name}
ORDER BY COUNT(*) DESC
LIMIT 20
"""
result = await connection.execute(distribution_sql)
if result.data:
distribution = []
for row in result.data:
count = row["count"]
percentage = count / total_count if total_count > 0 else 0
distribution.append({
"value": str(row["value"]),
"count": count,
"percentage": round(percentage, 4)
})
return distribution
except Exception as e:
logger.warning(f"Failed to get value distribution for {col_name}: {str(e)}")
return []
def _calculate_entropy(self, value_distribution: List[Dict]) -> float:
"""Calculate Shannon entropy for categorical distribution"""
if not value_distribution:
return 0.0
entropy = 0.0
for item in value_distribution:
p = item["percentage"]
if p > 0:
entropy -= p * math.log2(p)
return entropy
async def _analyze_temporal_distributions(self, connection, table_name: str, temporal_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
"""Analyze distribution patterns for temporal columns"""
temporal_analysis = {}
for column in temporal_columns:
col_name = column["column_name"]
try:
# Date range analysis
table_expr = sampling_info.get("sample_table_expression", table_name)
range_sql = f"""
SELECT
MIN({col_name}) as earliest,
MAX({col_name}) as latest,
COUNT({col_name}) as non_null_count
FROM {table_expr}
WHERE {col_name} IS NOT NULL
"""
range_result = await connection.execute(range_sql)
if range_result.data and range_result.data[0]["non_null_count"] > 0:
range_data = range_result.data[0]
earliest = range_data["earliest"]
latest = range_data["latest"]
# Calculate span
date_span_info = self._calculate_date_span(earliest, latest)
# Temporal patterns analysis
temporal_patterns = await self._analyze_temporal_patterns(
connection, table_name, col_name, sampling_info
)
temporal_analysis[col_name] = {
"data_type": column["data_type"],
"non_null_count": range_data["non_null_count"],
"date_range": {
"earliest": str(earliest),
"latest": str(latest),
**date_span_info
},
"temporal_patterns": temporal_patterns
}
except Exception as e:
logger.warning(f"Failed to analyze temporal column {col_name}: {str(e)}")
temporal_analysis[col_name] = {"error": str(e)}
return temporal_analysis
def _calculate_date_span(self, earliest, latest) -> Dict[str, Any]:
"""Calculate date span information"""
try:
if isinstance(earliest, str):
earliest = datetime.fromisoformat(earliest.replace('Z', '+00:00'))
if isinstance(latest, str):
latest = datetime.fromisoformat(latest.replace('Z', '+00:00'))
span = latest - earliest
span_days = span.days
return {
"span_days": span_days,
"span_years": round(span_days / 365.25, 2),
"span_description": self._describe_time_span(span_days)
}
except Exception as e:
logger.warning(f"Failed to calculate date span: {str(e)}")
return {"span_days": 0}
def _describe_time_span(self, days: int) -> str:
"""Describe time span in human readable format"""
if days < 1:
return "less_than_day"
elif days < 7:
return "days"
elif days < 30:
return "weeks"
elif days < 365:
return "months"
else:
return "years"
async def _analyze_temporal_patterns(self, connection, table_name: str, col_name: str, sampling_info: Dict) -> Dict[str, Any]:
"""Analyze temporal patterns like seasonality and trends"""
try:
table_expr = sampling_info.get("sample_table_expression", table_name)
# Weekly pattern analysis
weekly_pattern_sql = f"""
SELECT
DAYOFWEEK({col_name}) as day_of_week,
COUNT(*) as count
FROM {table_expr}
WHERE {col_name} IS NOT NULL
GROUP BY DAYOFWEEK({col_name})
ORDER BY day_of_week
"""
weekly_result = await connection.execute(weekly_pattern_sql)
weekly_pattern = []
if weekly_result.data:
total_records = sum(row["count"] for row in weekly_result.data)
for row in weekly_result.data:
percentage = row["count"] / total_records if total_records > 0 else 0
weekly_pattern.append(round(percentage, 3))
# Monthly trend analysis (simplified)
monthly_trend_sql = f"""
SELECT
YEAR({col_name}) as year,
MONTH({col_name}) as month,
COUNT(*) as count
FROM {table_expr}
WHERE {col_name} IS NOT NULL
GROUP BY YEAR({col_name}), MONTH({col_name})
ORDER BY year, month
LIMIT 12
"""
monthly_result = await connection.execute(monthly_trend_sql)
monthly_trend = "stable" # Simplified trend analysis
if monthly_result.data and len(monthly_result.data) > 3:
counts = [row["count"] for row in monthly_result.data]
if len(counts) > 1:
trend_direction = "increasing" if counts[-1] > counts[0] else "decreasing"
monthly_trend = trend_direction
return {
"weekly_pattern": weekly_pattern,
"monthly_trend": monthly_trend,
"seasonal_component": self._estimate_seasonality(weekly_pattern)
}
except Exception as e:
logger.warning(f"Failed to analyze temporal patterns for {col_name}: {str(e)}")
return {"weekly_pattern": [], "monthly_trend": "unknown"}
def _estimate_seasonality(self, weekly_pattern: List[float]) -> float:
"""Estimate seasonality strength based on weekly pattern variance"""
if len(weekly_pattern) < 7:
return 0.0
mean_percentage = sum(weekly_pattern) / len(weekly_pattern)
variance = sum((x - mean_percentage) ** 2 for x in weekly_pattern) / len(weekly_pattern)
# Normalize variance to 0-1 scale as seasonality indicator
seasonality = min(variance * 10, 1.0) # Scaling factor
return round(seasonality, 3)
async def _generate_data_quality_insights(self, connection, table_name: str, columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
"""Generate overall data quality insights"""
try:
total_columns = len(columns)
# Calculate null rates across all columns
null_analysis = await self._analyze_overall_null_rates(connection, table_name, columns, sampling_info)
# Identify potential data quality issues
quality_issues = []
# High null rate columns
high_null_columns = [col for col, rate in null_analysis["column_null_rates"].items() if rate > 0.2]
if high_null_columns:
quality_issues.append({
"issue_type": "high_null_rates",
"severity": "medium",
"affected_columns": high_null_columns,
"description": f"{len(high_null_columns)} columns have null rates > 20%"
})
# Calculate overall data quality score
avg_null_rate = sum(null_analysis["column_null_rates"].values()) / len(null_analysis["column_null_rates"]) if null_analysis["column_null_rates"] else 0
data_quality_score = max(0, 1 - avg_null_rate)
return {
"total_columns_analyzed": total_columns,
"null_analysis": null_analysis,
"data_quality_score": round(data_quality_score, 3),
"quality_issues": quality_issues,
"recommendations": self._generate_quality_recommendations(quality_issues, null_analysis)
}
except Exception as e:
logger.warning(f"Failed to generate data quality insights: {str(e)}")
return {"data_quality_score": 0.0, "error": str(e)}
async def _analyze_overall_null_rates(self, connection, table_name: str, columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
"""Analyze null rates across all columns"""
column_null_rates = {}
total_null_count = 0
total_cell_count = 0
for column in columns:
col_name = column["column_name"]
try:
table_expr = sampling_info.get("sample_table_expression", table_name)
null_sql = f"""
SELECT
COUNT(*) as total_count,
COUNT({col_name}) as non_null_count
FROM {table_expr}
"""
result = await connection.execute(null_sql)
if result.data:
data = result.data[0]
total_count = data["total_count"]
non_null_count = data["non_null_count"]
null_count = total_count - non_null_count
null_rate = null_count / total_count if total_count > 0 else 0
column_null_rates[col_name] = round(null_rate, 4)
total_null_count += null_count
total_cell_count += total_count
except Exception as e:
logger.warning(f"Failed to analyze null rate for column {col_name}: {str(e)}")
column_null_rates[col_name] = 0.0
overall_null_rate = total_null_count / total_cell_count if total_cell_count > 0 else 0
return {
"column_null_rates": column_null_rates,
"overall_null_rate": round(overall_null_rate, 4),
"columns_with_nulls": len([rate for rate in column_null_rates.values() if rate > 0])
}
def _generate_quality_recommendations(self, quality_issues: List[Dict], null_analysis: Dict) -> List[Dict]:
"""Generate data quality improvement recommendations"""
recommendations = []
# Recommendations based on null analysis
overall_null_rate = null_analysis.get("overall_null_rate", 0)
if overall_null_rate > 0.1:
recommendations.append({
"type": "data_completeness",
"priority": "high" if overall_null_rate > 0.3 else "medium",
"description": f"Overall null rate is {overall_null_rate:.1%}",
"action": "Review data collection and validation processes"
})
# Recommendations based on quality issues
for issue in quality_issues:
if issue["issue_type"] == "high_null_rates":
recommendations.append({
"type": "column_completeness",
"priority": issue["severity"],
"description": issue["description"],
"action": f"Focus on improving data completeness for: {', '.join(issue['affected_columns'][:3])}"
})
return recommendations
def _generate_analysis_summary(self, distribution_analysis: Dict[str, Any]) -> Dict[str, Any]:
"""Generate high-level summary of distribution analysis"""
summary = {
"numeric_columns_count": len(distribution_analysis.get("numeric_columns", {})),
"categorical_columns_count": len(distribution_analysis.get("categorical_columns", {})),
"temporal_columns_count": len(distribution_analysis.get("temporal_columns", {}))
}
# Identify interesting patterns
patterns = []
# Check for highly skewed numeric columns
numeric_cols = distribution_analysis.get("numeric_columns", {})
skewed_cols = [
col for col, info in numeric_cols.items()
if isinstance(info, dict) and
info.get("distribution_shape", {}).get("skewness_indicator") in ["right_skewed", "left_skewed"]
]
if skewed_cols:
patterns.append(f"Found {len(skewed_cols)} skewed numeric columns")
# Check for high cardinality categorical columns
categorical_cols = distribution_analysis.get("categorical_columns", {})
high_cardinality_cols = [
col for col, info in categorical_cols.items()
if isinstance(info, dict) and info.get("cardinality", 0) > 1000
]
if high_cardinality_cols:
patterns.append(f"Found {len(high_cardinality_cols)} high cardinality categorical columns")
summary["notable_patterns"] = patterns
return summary

View File

@@ -0,0 +1,897 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Data Governance Tools Module
Provides data completeness analysis, field lineage tracking, and data freshness monitoring
"""
import re
import time
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from .db import DorisConnectionManager
from .logger import get_logger
logger = get_logger(__name__)
class DataGovernanceTools:
"""Data governance tools suite"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
logger.info("DataGovernanceTools initialized")
async def trace_column_lineage(
self,
table_name: str,
column_name: str,
depth: int = 3,
catalog_name: Optional[str] = None,
db_name: Optional[str] = None
) -> Dict[str, Any]:
"""
Column-level lineage tracing
Args:
table_name: Table name
column_name: Column name
depth: Trace depth
catalog_name: Catalog name
db_name: Database name
"""
try:
start_time = time.time()
# 🚀 PROGRESS: Initialize column lineage tracing
logger.info("=" * 60)
logger.info(f"🔍 Starting Column Lineage Tracing")
logger.info(f"📊 Target: {table_name}.{column_name}")
logger.info(f"🎯 Trace depth: {depth}")
logger.info("=" * 60)
connection = await self.connection_manager.get_connection("query")
full_table_name = self._build_full_table_name(table_name, catalog_name, db_name)
target_column = f"{full_table_name}.{column_name}"
logger.info(f"📝 Full target: {target_column}")
# 🚀 PROGRESS: Step 1 - Verify target column exists
logger.info("🔍 Step 1/4: Verifying target column exists...")
verify_start = time.time()
if not await self._verify_column_exists(connection, full_table_name, column_name):
logger.error(f"❌ Column {column_name} not found in table {full_table_name}")
return {"error": f"Column {column_name} not found in table {full_table_name}"}
verify_time = time.time() - verify_start
logger.info(f"✅ Column verified in {verify_time:.2f}s")
# 🚀 PROGRESS: Step 2 - Analyze SQL logs for lineage relationships
logger.info(f"📊 Step 2/4: Analyzing SQL logs for lineage (depth={depth})...")
lineage_start = time.time()
source_chain = await self._analyze_sql_logs_for_lineage(
connection, full_table_name, column_name, depth
)
lineage_time = time.time() - lineage_start
logger.info(f"✅ Found {len(source_chain)} lineage relationships in {lineage_time:.2f}s")
# 🚀 PROGRESS: Step 3 - Analyze downstream usage
logger.info("⬇️ Step 3/4: Analyzing downstream column usage...")
downstream_start = time.time()
downstream_usage = await self._analyze_downstream_column_usage(
connection, full_table_name, column_name
)
downstream_time = time.time() - downstream_start
logger.info(f"✅ Found {len(downstream_usage)} downstream usages in {downstream_time:.2f}s")
# 🚀 PROGRESS: Step 4 - Extract transformation rules
logger.info("🔄 Step 4/4: Extracting transformation rules...")
transform_start = time.time()
transformation_rules = await self._extract_transformation_rules(
connection, full_table_name, column_name
)
transform_time = time.time() - transform_start
logger.info(f"✅ Found {len(transformation_rules)} transformation rules in {transform_time:.2f}s")
execution_time = time.time() - start_time
return {
"target_column": target_column,
"analysis_timestamp": datetime.now().isoformat(),
"execution_time_seconds": round(execution_time, 3),
"lineage_depth": depth,
"source_chain": source_chain,
"downstream_usage": downstream_usage,
"transformation_rules": transformation_rules,
"lineage_confidence": self._calculate_lineage_confidence(source_chain),
"impact_analysis": {
"upstream_dependencies": len(source_chain),
"downstream_dependencies": len(downstream_usage),
"risk_level": self._assess_lineage_risk(source_chain, downstream_usage)
}
}
except Exception as e:
logger.error(f"Column lineage tracing failed for {table_name}.{column_name}: {str(e)}")
return {
"error": str(e),
"target_column": f"{table_name}.{column_name}",
"analysis_timestamp": datetime.now().isoformat()
}
async def monitor_data_freshness(
self,
tables: Optional[List[str]] = None,
time_threshold_hours: int = 24,
catalog_name: Optional[str] = None,
db_name: Optional[str] = None
) -> Dict[str, Any]:
"""
Data freshness monitoring
Args:
tables: List of tables to monitor, empty means monitor all tables
time_threshold_hours: Freshness threshold (hours)
catalog_name: Catalog name
db_name: Database name
"""
try:
start_time = time.time()
connection = await self.connection_manager.get_connection("query")
# 1. Get list of tables to monitor
if not tables:
tables = await self._get_all_tables(connection, catalog_name, db_name)
# 2. Analyze freshness of each table
table_freshness = {}
fresh_count = 0
stale_count = 0
for table in tables:
full_table_name = self._build_full_table_name(table, catalog_name, db_name)
freshness_info = await self._analyze_table_freshness(
connection, full_table_name, time_threshold_hours
)
table_freshness[table] = freshness_info
if freshness_info["status"] == "fresh":
fresh_count += 1
else:
stale_count += 1
# 3. Calculate overall freshness score
total_tables = len(tables)
overall_freshness_score = fresh_count / total_tables if total_tables > 0 else 0
# 4. Identify data flow issues
data_flow_issues = await self._identify_data_flow_issues(table_freshness)
execution_time = time.time() - start_time
return {
"monitoring_timestamp": datetime.now().isoformat(),
"execution_time_seconds": round(execution_time, 3),
"monitoring_scope": {
"catalog_name": catalog_name,
"db_name": db_name,
"time_threshold_hours": time_threshold_hours
},
"freshness_summary": {
"total_tables": total_tables,
"fresh_tables": fresh_count,
"stale_tables": stale_count,
"overall_freshness_score": round(overall_freshness_score, 3)
},
"table_freshness": table_freshness,
"data_flow_issues": data_flow_issues,
"alerts": self._generate_freshness_alerts(table_freshness, time_threshold_hours)
}
except Exception as e:
logger.error(f"Data freshness monitoring failed: {str(e)}")
return {
"error": str(e),
"monitoring_timestamp": datetime.now().isoformat()
}
# ==================== Private Helper Methods ====================
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
"""Build full table name - use three-level naming convention"""
# Default catalog is internal for internal tables
effective_catalog = catalog_name if catalog_name else "internal"
if db_name:
return f"{effective_catalog}.{db_name}.{table_name}"
else:
# If db_name is not provided, need to determine current database
return f"{effective_catalog}.{table_name}"
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
"""Get table basic information"""
try:
# Try to get table row count
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
result = await connection.execute(count_sql)
if result.data:
return {"row_count": result.data[0]["row_count"]}
return None
except Exception as e:
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
return {"row_count": 0}
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
"""Get table column information"""
try:
# Build query conditions
where_conditions = [f"table_name = '{table_name}'"]
if db_name:
where_conditions.append(f"table_schema = '{db_name}'")
else:
where_conditions.append("table_schema = DATABASE()")
columns_sql = f"""
SELECT
column_name,
data_type,
is_nullable,
column_comment,
ordinal_position
FROM information_schema.columns
WHERE {' AND '.join(where_conditions)}
ORDER BY ordinal_position
"""
result = await connection.execute(columns_sql)
return result.data if result.data else []
except Exception as e:
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
return []
async def _analyze_column_completeness(self, connection, table_name: str, columns_info: List[Dict]) -> Dict[str, Any]:
"""Analyze column completeness"""
column_completeness = {}
for column in columns_info:
column_name = column["column_name"]
try:
# Calculate null value statistics
null_sql = f"""
SELECT
COUNT(*) as total_count,
COUNT({column_name}) as non_null_count,
COUNT(*) - COUNT({column_name}) as null_count
FROM {table_name}
"""
result = await connection.execute(null_sql)
if result.data:
stats = result.data[0]
total_count = stats["total_count"]
null_count = stats["null_count"]
null_rate = null_count / total_count if total_count > 0 else 0
completeness_score = 1.0 - null_rate
column_completeness[column_name] = {
"data_type": column["data_type"],
"is_nullable": column["is_nullable"],
"total_count": total_count,
"null_count": null_count,
"non_null_count": stats["non_null_count"],
"null_rate": round(null_rate, 4),
"completeness_score": round(completeness_score, 4)
}
except Exception as e:
logger.warning(f"Failed to analyze completeness for column {column_name}: {str(e)}")
column_completeness[column_name] = {
"error": str(e),
"completeness_score": 0.0
}
return column_completeness
async def _check_business_rule_compliance(self, connection, table_name: str, business_rules: List[Dict], total_rows: int) -> Dict[str, Any]:
"""Check business rule compliance"""
compliance_results = {}
for rule in business_rules:
rule_name = rule.get("rule_name", "unknown")
sql_condition = rule.get("sql_condition", "")
if not sql_condition:
continue
try:
# Check number of records meeting conditions
compliance_sql = f"""
SELECT
COUNT(*) as total_count,
SUM(CASE WHEN {sql_condition} THEN 1 ELSE 0 END) as pass_count
FROM {table_name}
"""
result = await connection.execute(compliance_sql)
if result.data:
stats = result.data[0]
pass_count = stats["pass_count"] or 0
fail_count = total_rows - pass_count
pass_rate = pass_count / total_rows if total_rows > 0 else 0
compliance_results[rule_name] = {
"rule_condition": sql_condition,
"total_records": total_rows,
"pass_count": pass_count,
"fail_count": fail_count,
"pass_rate": round(pass_rate, 4),
"compliance_score": round(pass_rate, 4)
}
except Exception as e:
logger.warning(f"Failed to check business rule {rule_name}: {str(e)}")
compliance_results[rule_name] = {
"error": str(e),
"compliance_score": 0.0
}
return compliance_results
async def _detect_data_integrity_issues(self, connection, table_name: str, columns_info: List[Dict]) -> List[Dict]:
"""Detect data integrity issues"""
issues = []
try:
# Detect duplicate values in primary key fields
primary_key_columns = [col["column_name"] for col in columns_info if "primary" in col.get("column_comment", "").lower()]
for pk_col in primary_key_columns:
duplicate_sql = f"""
SELECT COUNT(*) as duplicate_count
FROM (
SELECT {pk_col}, COUNT(*) as cnt
FROM {table_name}
WHERE {pk_col} IS NOT NULL
GROUP BY {pk_col}
HAVING COUNT(*) > 1
) t
"""
result = await connection.execute(duplicate_sql)
if result.data and result.data[0]["duplicate_count"] > 0:
issues.append({
"type": "duplicate_primary_keys",
"column": pk_col,
"count": result.data[0]["duplicate_count"],
"severity": "high",
"description": f"Found duplicate values in primary key column {pk_col}"
})
except Exception as e:
logger.warning(f"Failed to detect integrity issues: {str(e)}")
issues.append({
"type": "detection_error",
"error": str(e),
"severity": "unknown"
})
return issues
def _calculate_completeness_score(self, column_completeness: Dict, business_rule_compliance: Dict) -> float:
"""Calculate overall completeness score"""
if not column_completeness:
return 0.0
# Calculate column completeness average score
column_scores = [
col_info.get("completeness_score", 0.0)
for col_info in column_completeness.values()
if isinstance(col_info, dict) and "completeness_score" in col_info
]
avg_column_score = sum(column_scores) / len(column_scores) if column_scores else 0.0
# Calculate business rule compliance average score
compliance_scores = [
rule_info.get("compliance_score", 0.0)
for rule_info in business_rule_compliance.values()
if isinstance(rule_info, dict) and "compliance_score" in rule_info
]
avg_compliance_score = sum(compliance_scores) / len(compliance_scores) if compliance_scores else 1.0
# Comprehensive score (column completeness weight 70%, business rules weight 30%)
overall_score = avg_column_score * 0.7 + avg_compliance_score * 0.3
return round(overall_score, 4)
def _generate_completeness_recommendations(self, column_completeness: Dict, integrity_issues: List[Dict]) -> List[Dict]:
"""Generate completeness improvement recommendations"""
recommendations = []
# Generate recommendations based on column completeness
for col_name, col_info in column_completeness.items():
if isinstance(col_info, dict):
null_rate = col_info.get("null_rate", 0)
if null_rate > 0.1: # Null rate exceeds 10%
recommendations.append({
"type": "high_null_rate",
"column": col_name,
"priority": "high" if null_rate > 0.5 else "medium",
"description": f"Column {col_name} has high null rate ({null_rate:.1%})",
"suggested_action": "Review data collection process or add data validation"
})
# Generate recommendations based on integrity issues
for issue in integrity_issues:
if issue["type"] == "duplicate_primary_keys":
recommendations.append({
"type": "data_deduplication",
"column": issue["column"],
"priority": "high",
"description": f"Duplicate primary key values found in {issue['column']}",
"suggested_action": "Implement unique constraint or data deduplication process"
})
return recommendations
async def _verify_column_exists(self, connection, table_name: str, column_name: str) -> bool:
"""Verify if column exists"""
try:
# Simple verification method: try to query the column
verify_sql = f"SELECT {column_name} FROM {table_name} LIMIT 1"
await connection.execute(verify_sql)
return True
except Exception:
return False
async def _analyze_sql_logs_for_lineage(self, connection, table_name: str, column_name: str, depth: int) -> List[Dict]:
"""Analyze SQL logs to get lineage relationships (simplified implementation)"""
# Note: This is a simplified implementation, actual environment needs to analyze audit logs
source_chain = []
try:
# Try to find related INSERT/CREATE TABLE AS SELECT statements from audit logs (one year range)
audit_sql = """
SELECT
stmt as sql_statement,
`time` as execution_time,
`user` as user_name
FROM internal.__internal_schema.audit_log
WHERE stmt LIKE '%{}%'
AND (stmt LIKE '%INSERT%' OR stmt LIKE '%CREATE%' OR stmt LIKE '%SELECT%')
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
ORDER BY `time` DESC
LIMIT 50
""".format(table_name.split('.')[-1]) # Use the last part of table name
result = await connection.execute(audit_sql)
if result.data:
for i, log_entry in enumerate(result.data[:depth]):
# Simplified lineage analysis: extract possible source tables
sql_stmt = log_entry.get("sql_statement", "")
source_tables = self._extract_source_tables_from_sql(sql_stmt)
if source_tables:
# Handle datetime serialization issue
execution_time = log_entry.get("execution_time")
if execution_time and hasattr(execution_time, 'isoformat'):
execution_time = execution_time.isoformat()
elif execution_time:
execution_time = str(execution_time)
source_chain.append({
"level": i + 1,
"source_table": source_tables[0], # Take the first as main source table
"source_column": column_name, # Simplified: assume same name
"transformation": self._extract_transformation_from_sql(sql_stmt, column_name),
"confidence": 0.8 - (i * 0.1), # Decreasing confidence
"execution_time": execution_time,
"user": log_entry.get("user_name")
})
except Exception as e:
logger.warning(f"Failed to analyze SQL logs for lineage: {str(e)}")
# If unable to get from audit logs, return basic information
source_chain = [{
"level": 1,
"source_table": "unknown_source",
"source_column": column_name,
"transformation": "unknown",
"confidence": 0.3,
"note": "Limited lineage information available"
}]
return source_chain
def _extract_source_tables_from_sql(self, sql: str) -> List[str]:
"""Extract source table names from SQL statement (simplified implementation)"""
# Simplified regex to match table names in FROM clause
from_pattern = r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)'
join_pattern = r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)'
tables = []
# Find tables in FROM clause
from_matches = re.findall(from_pattern, sql, re.IGNORECASE)
tables.extend(from_matches)
# Find tables in JOIN clause
join_matches = re.findall(join_pattern, sql, re.IGNORECASE)
tables.extend(join_matches)
return list(set(tables)) # Remove duplicates
def _extract_transformation_from_sql(self, sql: str, column_name: str) -> str:
"""Extract field transformation rules from SQL statement (simplified implementation)"""
# Simplified implementation: find expressions containing target field
lines = sql.split('\n')
for line in lines:
if column_name in line and ('SELECT' in line.upper() or '=' in line):
return line.strip()
return "direct_copy"
async def _analyze_downstream_column_usage(self, connection, table_name: str, column_name: str) -> List[Dict]:
"""Analyze downstream usage of field (simplified implementation)"""
downstream_usage = []
try:
# Find other tables that might use this field (through audit logs, one year range)
usage_sql = """
SELECT DISTINCT
stmt as sql_statement
FROM internal.__internal_schema.audit_log
WHERE stmt LIKE '%{}%'
AND stmt LIKE '%{}%'
AND stmt LIKE '%SELECT%'
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
LIMIT 20
""".format(table_name.split('.')[-1], column_name)
result = await connection.execute(usage_sql)
if result.data:
for entry in result.data:
sql_stmt = entry.get("sql_statement", "")
target_tables = self._extract_target_tables_from_sql(sql_stmt)
for target_table in target_tables:
if target_table != table_name.split('.')[-1]: # Not the source table itself
downstream_usage.append({
"table": target_table,
"column": column_name, # Simplified: assume same name
"usage_type": "select_reference",
"confidence": 0.7
})
except Exception as e:
logger.warning(f"Failed to analyze downstream usage: {str(e)}")
return downstream_usage
def _extract_target_tables_from_sql(self, sql: str) -> List[str]:
"""Extract target table names from SQL statement"""
# Find target tables in INSERT INTO or CREATE TABLE statements
insert_pattern = r'\bINSERT\s+INTO\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)'
create_pattern = r'\bCREATE\s+TABLE\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)'
tables = []
insert_matches = re.findall(insert_pattern, sql, re.IGNORECASE)
tables.extend(insert_matches)
create_matches = re.findall(create_pattern, sql, re.IGNORECASE)
tables.extend(create_matches)
return list(set(tables))
async def _extract_transformation_rules(self, connection, table_name: str, column_name: str) -> List[Dict]:
"""Extract field transformation rules"""
# Simplified implementation: return basic transformation information
return [{
"transformation_type": "unknown",
"description": "Transformation rules analysis requires detailed ETL metadata",
"confidence": 0.5
}]
def _calculate_lineage_confidence(self, source_chain: List[Dict]) -> float:
"""Calculate overall confidence of lineage tracing"""
if not source_chain:
return 0.0
confidences = [item.get("confidence", 0.0) for item in source_chain]
return round(sum(confidences) / len(confidences), 3)
def _assess_lineage_risk(self, source_chain: List[Dict], downstream_usage: List[Dict]) -> str:
"""Assess lineage risk level"""
if len(downstream_usage) > 10:
return "high"
elif len(downstream_usage) > 5:
return "medium"
else:
return "low"
async def _get_all_tables(self, connection, catalog_name: Optional[str], db_name: Optional[str]) -> List[str]:
"""Get list of all tables"""
try:
where_conditions = []
if db_name:
where_conditions.append(f"table_schema = '{db_name}'")
else:
where_conditions.append("table_schema = DATABASE()")
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
tables_sql = f"""
SELECT table_name
FROM information_schema.tables
WHERE {where_clause}
AND table_type = 'BASE TABLE'
ORDER BY table_name
"""
result = await connection.execute(tables_sql)
return [row["table_name"] for row in result.data] if result.data else []
except Exception as e:
logger.warning(f"Failed to get table list: {str(e)}")
return []
async def _analyze_table_freshness(self, connection, table_name: str, threshold_hours: int) -> Dict[str, Any]:
"""Analyze freshness of single table"""
try:
# Try multiple methods to get table's last update time
freshness_methods = [
self._get_freshness_from_partition_info,
self._get_freshness_from_max_timestamp,
self._get_freshness_from_table_metadata
]
last_update = None
method_used = "unknown"
for method in freshness_methods:
try:
result = await method(connection, table_name)
if result:
last_update = result["last_update"]
method_used = result["method"]
break
except Exception as e:
continue
if not last_update:
return {
"last_update": None,
"staleness_hours": None,
"freshness_score": 0.0,
"status": "unknown",
"method_used": "none",
"error": "Unable to determine last update time"
}
# Calculate data staleness
now = datetime.now()
if isinstance(last_update, str):
last_update = datetime.fromisoformat(last_update.replace('Z', '+00:00'))
staleness_hours = (now - last_update).total_seconds() / 3600
# Calculate freshness score and status
if staleness_hours <= threshold_hours:
status = "fresh"
freshness_score = max(0.0, 1.0 - (staleness_hours / threshold_hours))
else:
status = "stale"
freshness_score = max(0.0, 1.0 - (staleness_hours / (threshold_hours * 2)))
return {
"last_update": last_update.isoformat() if hasattr(last_update, 'isoformat') else str(last_update),
"staleness_hours": round(staleness_hours, 2),
"freshness_score": round(freshness_score, 3),
"status": status,
"method_used": method_used,
"threshold_hours": threshold_hours
}
except Exception as e:
logger.warning(f"Failed to analyze freshness for table {table_name}: {str(e)}")
return {
"last_update": None,
"staleness_hours": None,
"freshness_score": 0.0,
"status": "error",
"error": str(e)
}
async def _get_freshness_from_partition_info(self, connection, table_name: str) -> Optional[Dict]:
"""Get freshness from partition information"""
try:
# Query partition information (if table has partitions)
partition_sql = f"""
SELECT MAX(CREATE_TIME) as last_update
FROM information_schema.partitions
WHERE table_name = '{table_name.split('.')[-1]}'
AND CREATE_TIME IS NOT NULL
"""
result = await connection.execute(partition_sql)
if result.data and result.data[0]["last_update"]:
return {
"last_update": result.data[0]["last_update"],
"method": "partition_info"
}
return None
except Exception:
return None
async def _get_freshness_from_max_timestamp(self, connection, table_name: str) -> Optional[Dict]:
"""Get freshness from timestamp fields"""
try:
# Find possible timestamp fields
timestamp_columns = await self._find_timestamp_columns(connection, table_name)
if timestamp_columns:
max_time_sql = f"""
SELECT MAX({timestamp_columns[0]}) as last_update
FROM {table_name}
"""
result = await connection.execute(max_time_sql)
if result.data and result.data[0]["last_update"]:
return {
"last_update": result.data[0]["last_update"],
"method": f"max_timestamp({timestamp_columns[0]})"
}
return None
except Exception:
return None
async def _get_freshness_from_table_metadata(self, connection, table_name: str) -> Optional[Dict]:
"""Get freshness from table metadata"""
try:
# Query table's update time
metadata_sql = f"""
SELECT UPDATE_TIME as last_update
FROM information_schema.tables
WHERE table_name = '{table_name.split('.')[-1]}'
AND UPDATE_TIME IS NOT NULL
"""
result = await connection.execute(metadata_sql)
if result.data and result.data[0]["last_update"]:
return {
"last_update": result.data[0]["last_update"],
"method": "table_metadata"
}
return None
except Exception:
return None
async def _find_timestamp_columns(self, connection, table_name: str) -> List[str]:
"""Find possible timestamp fields"""
try:
timestamp_sql = f"""
SELECT column_name
FROM information_schema.columns
WHERE table_name = '{table_name.split('.')[-1]}'
AND (
data_type IN ('datetime', 'timestamp', 'date')
OR column_name LIKE '%time%'
OR column_name LIKE '%date%'
OR column_name LIKE '%created%'
OR column_name LIKE '%updated%'
)
ORDER BY
CASE
WHEN column_name LIKE '%updated%' THEN 1
WHEN column_name LIKE '%created%' THEN 2
WHEN column_name LIKE '%time%' THEN 3
ELSE 4
END
"""
result = await connection.execute(timestamp_sql)
return [row["column_name"] for row in result.data] if result.data else []
except Exception:
return []
async def _identify_data_flow_issues(self, table_freshness: Dict[str, Any]) -> List[Dict]:
"""Identify data flow issues"""
issues = []
# Identify consecutively stale tables (may indicate ETL process issues)
stale_tables = [
table_name for table_name, info in table_freshness.items()
if info.get("status") == "stale"
]
if len(stale_tables) > len(table_freshness) * 0.3: # More than 30% of tables are stale
issues.append({
"issue_type": "widespread_staleness",
"severity": "high",
"affected_tables": len(stale_tables),
"total_tables": len(table_freshness),
"description": f"High percentage of stale tables ({len(stale_tables)}/{len(table_freshness)})",
"possible_causes": ["ETL pipeline failure", "Data source issues", "Processing delays"]
})
# Identify particularly stale tables
very_stale_tables = [
(table_name, info.get("staleness_hours", 0))
for table_name, info in table_freshness.items()
if info.get("staleness_hours", 0) > 72 # More than 3 days
]
if very_stale_tables:
issues.append({
"issue_type": "very_stale_data",
"severity": "medium",
"affected_tables": [table for table, _ in very_stale_tables],
"max_staleness_hours": max(hours for _, hours in very_stale_tables),
"description": "Some tables have very stale data (>72 hours)",
"recommendation": "Check data ingestion processes for affected tables"
})
return issues
def _generate_freshness_alerts(self, table_freshness: Dict[str, Any], threshold_hours: int) -> List[Dict]:
"""Generate freshness alerts"""
alerts = []
for table_name, info in table_freshness.items():
staleness_hours = info.get("staleness_hours")
status = info.get("status")
if status == "stale" and staleness_hours:
if staleness_hours > threshold_hours * 2: # Exceeds threshold by 2x
alert_level = "critical"
elif staleness_hours > threshold_hours * 1.5: # Exceeds threshold by 1.5x
alert_level = "warning"
else:
alert_level = "info"
alerts.append({
"alert_level": alert_level,
"table_name": table_name,
"staleness_hours": staleness_hours,
"threshold_hours": threshold_hours,
"message": f"Table {table_name} is stale ({staleness_hours:.1f} hours old, threshold: {threshold_hours}h)",
"timestamp": datetime.now().isoformat()
})
elif status == "error":
alerts.append({
"alert_level": "error",
"table_name": table_name,
"message": f"Unable to determine freshness for table {table_name}",
"error": info.get("error"),
"timestamp": datetime.now().isoformat()
})
return alerts

File diff suppressed because it is too large Load Diff

View File

@@ -29,10 +29,13 @@ from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List from typing import Any, Dict, List
import random
import aiomysql import aiomysql
from aiomysql import Connection, Pool from aiomysql import Connection, Pool
from .logger import get_logger
@@ -70,6 +73,7 @@ class DorisConnection:
self.query_count = 0 self.query_count = 0
self.is_healthy = True self.is_healthy = True
self.security_manager = security_manager self.security_manager = security_manager
self.logger = get_logger(__name__)
async def execute(self, sql: str, params: tuple | None = None, auth_context=None) -> QueryResult: async def execute(self, sql: str, params: tuple | None = None, auth_context=None) -> QueryResult:
"""Execute SQL query""" """Execute SQL query"""
@@ -135,12 +139,46 @@ class DorisConnection:
raise raise
async def ping(self) -> bool: async def ping(self) -> bool:
"""Check connection health status""" """Check connection health status with enhanced at_eof error detection"""
try: try:
await self.connection.ping() # Check 1: Connection exists and is not closed
self.is_healthy = True if not self.connection or self.connection.closed:
return True self.is_healthy = False
except Exception: return False
# Check 2: Use ONLY safe operations - avoid internal state access
# Instead of checking _reader state directly, use a simple query test
try:
# Use a simple query with timeout instead of ping() to avoid at_eof issues
async with asyncio.timeout(3): # 3 second timeout
async with self.connection.cursor() as cursor:
await cursor.execute("SELECT 1")
result = await cursor.fetchone()
if result and result[0] == 1:
self.is_healthy = True
return True
else:
self.logger.debug(f"Connection {self.session_id} ping query returned unexpected result")
self.is_healthy = False
return False
except asyncio.TimeoutError:
self.logger.debug(f"Connection {self.session_id} ping timed out")
self.is_healthy = False
return False
except Exception as query_error:
# Check for specific at_eof related errors
error_str = str(query_error).lower()
if 'at_eof' in error_str or 'nonetype' in error_str:
self.logger.debug(f"Connection {self.session_id} ping failed with at_eof error: {query_error}")
else:
self.logger.debug(f"Connection {self.session_id} ping failed: {query_error}")
self.is_healthy = False
return False
except Exception as e:
# Catch any other unexpected errors
self.logger.debug(f"Connection {self.session_id} ping failed with unexpected error: {e}")
self.is_healthy = False self.is_healthy = False
return False return False
@@ -154,241 +192,457 @@ class DorisConnection:
class DorisConnectionManager: class DorisConnectionManager:
"""Doris database connection manager """Doris database connection manager - Enhanced Strategy
Provides connection pool management, connection health monitoring, fault recovery and other functions Uses direct connection pool management with proper synchronization
Supports session-level connection reuse and intelligent load balancing Implements connection pool health monitoring and proactive cleanup
Integrates security manager to provide unified security validation and data masking
""" """
def __init__(self, config, security_manager=None): def __init__(self, config, security_manager=None):
self.config = config self.config = config
self.pool: Pool | None = None self.pool: Pool | None = None
self.session_connections: dict[str, DorisConnection] = {} self.logger = get_logger(__name__)
self.metrics = ConnectionMetrics()
self.logger = logging.getLogger(__name__)
self.security_manager = security_manager self.security_manager = security_manager
# Health check configuration # Connection pool state management
self.health_check_interval = config.database.health_check_interval or 60 self.pool_recovering = False
self.max_connection_age = config.database.max_connection_age or 3600 self.pool_health_check_task = None
self.connection_timeout = config.database.connection_timeout or 30 self.pool_cleanup_task = None
# Start background tasks # Metrics tracking
self._health_check_task = None self.metrics = ConnectionMetrics()
self._cleanup_task = None
# 🔧 FIX: Add connection acquisition lock to prevent race conditions
self._connection_lock = asyncio.Lock()
self._recovery_lock = asyncio.Lock()
# 🔧 FIX: Add connection acquisition queue to serialize requests
self._connection_semaphore = asyncio.Semaphore(value=20) # Max concurrent acquisitions
# Database connection parameters from config.database
self.pool_recovery_lock = self._recovery_lock # Compatibility alias
self.host = config.database.host
self.port = config.database.port
self.user = config.database.user
self.password = config.database.password
self.database = config.database.database
# Convert charset to aiomysql compatible format
charset_map = {"UTF8": "utf8", "UTF8MB4": "utf8mb4"}
self.charset = charset_map.get(config.database.charset.upper(), config.database.charset.lower())
self.connect_timeout = config.database.connection_timeout
# Connection pool parameters - more conservative settings
self.minsize = config.database.min_connections # This is always 0
self.maxsize = config.database.max_connections or 20
self.pool_recycle = config.database.max_connection_age or 3600 # 1 hour, more conservative
# 🔧 FIX: Add missing monitoring parameters that were removed during refactoring
self.health_check_interval = 30 # seconds
self.pool_warmup_size = 3 # connections to maintain
async def initialize(self): async def initialize(self):
"""Initialize connection manager""" """Initialize connection pool with health monitoring"""
try: try:
self.logger.info(f"Initializing connection pool to {self.host}:{self.port}")
# Create connection pool # Create connection pool
self.pool = await aiomysql.create_pool( self.pool = await aiomysql.create_pool(
host=self.config.database.host, host=self.host,
port=self.config.database.port, port=self.port,
user=self.config.database.user, user=self.user,
password=self.config.database.password, password=self.password,
db=self.config.database.database, db=self.database,
charset="utf8", charset=self.charset,
minsize=self.config.database.min_connections or 5, minsize=self.minsize,
maxsize=self.config.database.max_connections or 20, maxsize=self.maxsize,
autocommit=True, pool_recycle=self.pool_recycle,
connect_timeout=self.connection_timeout, connect_timeout=self.connect_timeout,
) autocommit=True
self.logger.info(
f"Connection pool initialized successfully, min connections: {self.config.database.min_connections}, "
f"max connections: {self.config.database.max_connections}"
) )
# Test initial connection
if not await self._test_pool_health():
raise RuntimeError("Connection pool health check failed")
# Start background monitoring tasks # Start background monitoring tasks
self._health_check_task = asyncio.create_task(self._health_check_loop()) self.pool_health_check_task = asyncio.create_task(self._pool_health_monitor())
self._cleanup_task = asyncio.create_task(self._cleanup_loop()) self.pool_cleanup_task = asyncio.create_task(self._pool_cleanup_monitor())
# Perform initial pool warmup
await self._warmup_pool()
self.logger.info(f"Connection pool initialized successfully, min connections: {self.minsize}, max connections: {self.maxsize}")
except Exception as e: except Exception as e:
self.logger.error(f"Connection pool initialization failed: {e}") self.logger.error(f"Failed to initialize connection pool: {e}")
raise raise
async def get_connection(self, session_id: str) -> DorisConnection: async def _test_pool_health(self) -> bool:
"""Get database connection """Test connection pool health"""
Supports session-level connection reuse to improve performance and consistency
"""
# Check if there's an existing session connection
if session_id in self.session_connections:
conn = self.session_connections[session_id]
# Check connection health
if await conn.ping():
return conn
else:
# Connection is unhealthy, clean up and create new one
await self._cleanup_session_connection(session_id)
# Create new connection
return await self._create_new_connection(session_id)
async def _create_new_connection(self, session_id: str) -> DorisConnection:
"""Create new database connection"""
try: try:
if not self.pool: async with self.pool.acquire() as conn:
raise RuntimeError("Connection pool not initialized") async with conn.cursor() as cursor:
await cursor.execute("SELECT 1")
result = await cursor.fetchone()
return result and result[0] == 1
except Exception as e:
self.logger.error(f"Pool health test failed: {e}")
return False
# Get connection from pool async def _warmup_pool(self):
raw_connection = await self.pool.acquire() """Warm up connection pool by creating initial connections"""
self.logger.info(f"🔥 Warming up connection pool with {self.pool_warmup_size} connections")
warmup_connections = []
try:
# Acquire connections to force pool to create them
for i in range(self.pool_warmup_size):
try:
conn = await self.pool.acquire()
warmup_connections.append(conn)
self.logger.debug(f"Warmed up connection {i+1}/{self.pool_warmup_size}")
except Exception as e:
self.logger.warning(f"Failed to warm up connection {i+1}: {e}")
break
# Create wrapped connection # Release all warmup connections back to pool
doris_conn = DorisConnection(raw_connection, session_id, self.security_manager) for conn in warmup_connections:
try:
self.pool.release(conn)
except Exception as e:
self.logger.warning(f"Failed to release warmup connection: {e}")
# Store in session connections self.logger.info(f"✅ Pool warmup completed, {len(warmup_connections)} connections created")
self.session_connections[session_id] = doris_conn
self.metrics.total_connections += 1
self.logger.debug(f"Created new connection for session: {session_id}")
return doris_conn
except Exception as e: except Exception as e:
self.metrics.connection_errors += 1 self.logger.error(f"Pool warmup failed: {e}")
self.logger.error(f"Failed to create connection for session {session_id}: {e}") # Clean up any remaining connections
raise for conn in warmup_connections:
try:
await conn.ensure_closed()
except Exception:
pass
async def release_connection(self, session_id: str): async def _pool_health_monitor(self):
"""Release session connection""" """Background task to monitor pool health"""
if session_id in self.session_connections: self.logger.info("🩺 Starting pool health monitor")
await self._cleanup_session_connection(session_id)
async def _cleanup_session_connection(self, session_id: str):
"""Clean up session connection"""
if session_id in self.session_connections:
conn = self.session_connections[session_id]
try:
# Return connection to pool
if self.pool and conn.connection and not conn.connection.closed:
self.pool.release(conn.connection)
# Close connection wrapper
await conn.close()
except Exception as e:
self.logger.error(f"Error cleaning up connection for session {session_id}: {e}")
finally:
# Remove from session connections
del self.session_connections[session_id]
self.logger.debug(f"Cleaned up connection for session: {session_id}")
async def _health_check_loop(self):
"""Background health check loop"""
while True: while True:
try: try:
await asyncio.sleep(self.health_check_interval) await asyncio.sleep(self.health_check_interval)
await self._perform_health_check() await self._check_pool_health()
except asyncio.CancelledError: except asyncio.CancelledError:
self.logger.info("Pool health monitor stopped")
break break
except Exception as e: except Exception as e:
self.logger.error(f"Health check error: {e}") self.logger.error(f"Pool health monitor error: {e}")
async def _perform_health_check(self): async def _pool_cleanup_monitor(self):
"""Perform health check""" """Background task to clean up stale connections"""
try: self.logger.info("🧹 Starting pool cleanup monitor")
unhealthy_sessions = []
for session_id, conn in self.session_connections.items():
if not await conn.ping():
unhealthy_sessions.append(session_id)
# Clean up unhealthy connections
for session_id in unhealthy_sessions:
await self._cleanup_session_connection(session_id)
self.metrics.failed_connections += 1
# Update metrics
await self._update_connection_metrics()
self.metrics.last_health_check = datetime.utcnow()
if unhealthy_sessions:
self.logger.warning(f"Cleaned up {len(unhealthy_sessions)} unhealthy connections")
except Exception as e:
self.logger.error(f"Health check failed: {e}")
async def _cleanup_loop(self):
"""Background cleanup loop"""
while True: while True:
try: try:
await asyncio.sleep(300) # Run every 5 minutes await asyncio.sleep(self.health_check_interval * 2) # Less frequent cleanup
await self._cleanup_idle_connections() await self._cleanup_stale_connections()
except asyncio.CancelledError: except asyncio.CancelledError:
self.logger.info("Pool cleanup monitor stopped")
break break
except Exception as e: except Exception as e:
self.logger.error(f"Cleanup loop error: {e}") self.logger.error(f"Pool cleanup monitor error: {e}")
async def _cleanup_idle_connections(self): async def _check_pool_health(self):
"""Clean up idle connections""" """Check and maintain pool health"""
current_time = datetime.utcnow()
idle_sessions = []
for session_id, conn in self.session_connections.items():
# Check if connection has exceeded maximum age
age = (current_time - conn.created_at).total_seconds()
if age > self.max_connection_age:
idle_sessions.append(session_id)
# Clean up idle connections
for session_id in idle_sessions:
await self._cleanup_session_connection(session_id)
if idle_sessions:
self.logger.info(f"Cleaned up {len(idle_sessions)} idle connections")
async def _update_connection_metrics(self):
"""Update connection metrics"""
self.metrics.active_connections = len(self.session_connections)
if self.pool:
self.metrics.idle_connections = self.pool.freesize
async def get_metrics(self) -> ConnectionMetrics:
"""Get connection metrics"""
await self._update_connection_metrics()
return self.metrics
async def execute_query(
self, session_id: str, sql: str, params: tuple | None = None, auth_context=None
) -> QueryResult:
"""Execute query"""
conn = await self.get_connection(session_id)
return await conn.execute(sql, params, auth_context)
@asynccontextmanager
async def get_connection_context(self, session_id: str):
"""Get connection context manager"""
conn = await self.get_connection(session_id)
try: try:
yield conn # Skip health check if already recovering
finally: if self.pool_recovering:
# Connection will be reused, no need to close here self.logger.debug("Pool recovery in progress, skipping health check")
pass return
# Test pool with a simple query
health_ok = await self._test_pool_health()
if health_ok:
self.logger.debug("✅ Pool health check passed")
self.metrics.last_health_check = datetime.utcnow()
else:
self.logger.warning("❌ Pool health check failed, attempting recovery")
await self._recover_pool()
except Exception as e:
self.logger.error(f"Pool health check error: {e}")
await self._recover_pool()
async def _cleanup_stale_connections(self):
"""Proactively clean up potentially stale connections"""
try:
self.logger.debug("🧹 Checking for stale connections")
# Get pool statistics
pool_size = self.pool.size
pool_free = self.pool.freesize
# If pool has idle connections, test some of them
if pool_free > 0:
test_count = min(pool_free, 2) # Test up to 2 idle connections
for i in range(test_count):
try:
# Acquire connection, test it, and release
conn = await asyncio.wait_for(self.pool.acquire(), timeout=5)
# Quick test
async with conn.cursor() as cursor:
await asyncio.wait_for(cursor.execute("SELECT 1"), timeout=3)
await cursor.fetchone()
# Connection is healthy, release it
self.pool.release(conn)
except asyncio.TimeoutError:
self.logger.debug(f"Stale connection test {i+1} timed out")
try:
await conn.ensure_closed()
except Exception:
pass
except Exception as e:
self.logger.debug(f"Stale connection test {i+1} failed: {e}")
try:
await conn.ensure_closed()
except Exception:
pass
self.logger.debug(f"Stale connection cleanup completed, tested {test_count} connections")
except Exception as e:
self.logger.error(f"Stale connection cleanup error: {e}")
async def _recover_pool(self):
"""Recover connection pool when health check fails"""
# Use lock to prevent concurrent recovery attempts
async with self.pool_recovery_lock:
# Check if another recovery is already in progress
if self.pool_recovering:
self.logger.debug("Pool recovery already in progress, waiting...")
return
try:
self.pool_recovering = True
max_retries = 3
retry_delay = 5 # seconds
for attempt in range(max_retries):
try:
self.logger.info(f"🔄 Attempting pool recovery (attempt {attempt + 1}/{max_retries})")
# Try to close existing pool with timeout
if self.pool:
try:
if not self.pool.closed:
self.pool.close()
await asyncio.wait_for(self.pool.wait_closed(), timeout=3.0)
self.logger.debug("Old pool closed successfully")
except asyncio.TimeoutError:
self.logger.warning("Pool close timeout, forcing cleanup")
except Exception as e:
self.logger.warning(f"Error closing old pool: {e}")
finally:
self.pool = None
# Wait before creating new pool (reduced delay)
if attempt > 0:
await asyncio.sleep(2) # Reduced from 5 to 2 seconds
# Recreate pool with timeout
self.logger.debug("Creating new connection pool...")
self.pool = await asyncio.wait_for(
aiomysql.create_pool(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
db=self.database,
charset=self.charset,
minsize=self.minsize,
maxsize=self.maxsize,
pool_recycle=self.pool_recycle,
connect_timeout=self.connect_timeout,
autocommit=True
),
timeout=10.0
)
# Test recovered pool with timeout
if await asyncio.wait_for(self._test_pool_health(), timeout=5.0):
self.logger.info(f"✅ Pool recovery successful on attempt {attempt + 1}")
# Re-warm the pool with timeout
try:
await asyncio.wait_for(self._warmup_pool(), timeout=5.0)
except asyncio.TimeoutError:
self.logger.warning("Pool warmup timeout, but recovery successful")
return
else:
self.logger.warning(f"❌ Pool recovery health check failed on attempt {attempt + 1}")
except asyncio.TimeoutError:
self.logger.error(f"Pool recovery attempt {attempt + 1} timed out")
if self.pool:
try:
self.pool.close()
except:
pass
self.pool = None
except Exception as e:
self.logger.error(f"Pool recovery error on attempt {attempt + 1}: {e}")
# Clean up failed pool
if self.pool:
try:
self.pool.close()
await asyncio.wait_for(self.pool.wait_closed(), timeout=2.0)
except Exception:
pass
finally:
self.pool = None
# All recovery attempts failed
self.logger.error("❌ Pool recovery failed after all attempts")
self.pool = None
finally:
self.pool_recovering = False
async def _recover_pool_with_lock(self):
"""🔧 FIX: Recovery method that uses the new recovery lock to prevent races"""
async with self._recovery_lock:
if not self.pool_recovering: # Only recover if not already in progress
await self._recover_pool()
async def get_connection(self, session_id: str) -> DorisConnection:
"""🔧 FIX: Simplified connection acquisition without double locking
Uses only semaphore to prevent too many concurrent acquisitions
"""
# 🔧 FIX: Use only semaphore to limit concurrent acquisitions (remove double locking)
async with self._connection_semaphore:
try:
# Wait for any ongoing recovery to complete
if self.pool_recovering:
self.logger.debug(f"Pool recovery in progress, waiting for completion...")
# Wait for recovery to complete (max 10 seconds)
start_wait = time.time()
while self.pool_recovering and (time.time() - start_wait) < 10:
await asyncio.sleep(0.1) # More frequent checks
if self.pool_recovering:
self.logger.error("Pool recovery is taking too long, proceeding anyway")
# Continue but log the issue
# Check if pool is available
if not self.pool:
self.logger.warning("Connection pool is not available, attempting recovery...")
await self._recover_pool_with_lock()
if not self.pool:
raise RuntimeError("Connection pool is not available and recovery failed")
# Check if pool is closed
if self.pool.closed:
self.logger.warning("Connection pool is closed, attempting recovery...")
await self._recover_pool_with_lock()
if not self.pool or self.pool.closed:
raise RuntimeError("Connection pool is closed and recovery failed")
# 🔧 FIX: Increased timeout to prevent hanging
try:
raw_conn = await asyncio.wait_for(self.pool.acquire(), timeout=10.0)
except asyncio.TimeoutError:
self.logger.error(f"Connection acquisition timed out for session {session_id}")
# Try one recovery attempt
await self._recover_pool_with_lock()
if self.pool and not self.pool.closed:
try:
raw_conn = await asyncio.wait_for(self.pool.acquire(), timeout=5.0)
except asyncio.TimeoutError:
raise RuntimeError("Connection acquisition timed out after recovery")
else:
raise RuntimeError("Connection acquisition timed out")
# Wrap in DorisConnection
doris_conn = DorisConnection(raw_conn, session_id, self.security_manager)
# Basic validation - check if connection is open
if raw_conn.closed:
# Return connection and raise error
try:
self.pool.release(raw_conn)
except Exception:
pass
raise RuntimeError("Acquired connection is already closed")
self.logger.debug(f"✅ Acquired fresh connection for session {session_id}")
return doris_conn
except Exception as e:
self.logger.error(f"Failed to get connection for session {session_id}: {e}")
raise
async def release_connection(self, session_id: str, connection: DorisConnection):
"""🔧 FIX: Release connection back to pool with proper error handling"""
if not connection or not connection.connection:
self.logger.debug(f"No connection to release for session {session_id}")
return
try:
# Check pool availability before attempting release
if not self.pool or self.pool.closed:
self.logger.warning(f"Pool unavailable during release for session {session_id}, force closing connection")
try:
await connection.connection.ensure_closed()
except Exception:
pass
return
# Check connection state before release
if connection.connection.closed:
self.logger.debug(f"Connection already closed for session {session_id}")
return
# 🔧 FIX: Simplified release operation without thread wrapper
try:
self.pool.release(connection.connection)
self.logger.debug(f"✅ Released connection for session {session_id}")
except Exception as release_error:
self.logger.warning(f"Connection release failed for session {session_id}: {release_error}, force closing")
await connection.connection.ensure_closed()
except Exception as e:
self.logger.error(f"Error releasing connection for session {session_id}: {e}")
# Force close if release fails
try:
await connection.connection.ensure_closed()
except Exception as close_error:
self.logger.debug(f"Error force closing connection: {close_error}")
async def close(self): async def close(self):
"""Close connection manager""" """Close connection manager"""
try: try:
# Cancel background tasks # Cancel background tasks
if self._health_check_task: if self.pool_health_check_task:
self._health_check_task.cancel() self.pool_health_check_task.cancel()
try: try:
await self._health_check_task await self.pool_health_check_task
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
if self._cleanup_task: if self.pool_cleanup_task:
self._cleanup_task.cancel() self.pool_cleanup_task.cancel()
try: try:
await self._cleanup_task await self.pool_cleanup_task
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
# Clean up all session connections
for session_id in list(self.session_connections.keys()):
await self._cleanup_session_connection(session_id)
# Close connection pool # Close connection pool
if self.pool: if self.pool:
self.pool.close() self.pool.close()
@@ -400,20 +654,103 @@ class DorisConnectionManager:
self.logger.error(f"Error closing connection manager: {e}") self.logger.error(f"Error closing connection manager: {e}")
async def test_connection(self) -> bool: async def test_connection(self) -> bool:
"""Test database connection""" """Test database connection using robust connection test"""
return await self._test_pool_health()
async def get_metrics(self) -> ConnectionMetrics:
"""Get connection pool metrics - Simplified Strategy"""
try: try:
if not self.pool: if self.pool:
return False self.metrics.idle_connections = self.pool.freesize
self.metrics.active_connections = self.pool.size - self.pool.freesize
async with self.pool.acquire() as conn: else:
async with conn.cursor() as cursor: self.metrics.idle_connections = 0
await cursor.execute("SELECT 1") self.metrics.active_connections = 0
result = await cursor.fetchone()
return result is not None return self.metrics
except Exception as e: except Exception as e:
self.logger.error(f"Connection test failed: {e}") self.logger.error(f"Error getting metrics: {e}")
return False return self.metrics
async def execute_query(
self, session_id: str, sql: str, params: tuple | None = None, auth_context=None
) -> QueryResult:
"""Execute query - Simplified Strategy with automatic connection management"""
connection = None
try:
# Always get fresh connection from pool
connection = await self.get_connection(session_id)
# Execute query
result = await connection.execute(sql, params, auth_context)
return result
except Exception as e:
self.logger.error(f"Query execution failed for session {session_id}: {e}")
raise
finally:
# Always release connection back to pool
if connection:
await self.release_connection(session_id, connection)
@asynccontextmanager
async def get_connection_context(self, session_id: str):
"""Get connection context manager - Simplified Strategy"""
connection = None
try:
connection = await self.get_connection(session_id)
yield connection
finally:
if connection:
await self.release_connection(session_id, connection)
async def diagnose_connection_health(self) -> Dict[str, Any]:
"""Diagnose connection pool health - Simplified Strategy"""
diagnosis = {
"timestamp": datetime.utcnow().isoformat(),
"pool_status": "unknown",
"pool_info": {},
"recommendations": []
}
try:
# Check pool status
if not self.pool:
diagnosis["pool_status"] = "not_initialized"
diagnosis["recommendations"].append("Initialize connection pool")
return diagnosis
if self.pool.closed:
diagnosis["pool_status"] = "closed"
diagnosis["recommendations"].append("Recreate connection pool")
return diagnosis
diagnosis["pool_status"] = "healthy"
diagnosis["pool_info"] = {
"size": self.pool.size,
"free_size": self.pool.freesize,
"min_size": self.pool.minsize,
"max_size": self.pool.maxsize
}
# Generate recommendations based on pool status
if self.pool.freesize == 0 and self.pool.size >= self.pool.maxsize:
diagnosis["recommendations"].append("Connection pool exhausted - consider increasing max_connections")
# Test pool health
if await self._test_pool_health():
diagnosis["pool_health"] = "healthy"
else:
diagnosis["pool_health"] = "unhealthy"
diagnosis["recommendations"].append("Pool health check failed - may need recovery")
return diagnosis
except Exception as e:
diagnosis["error"] = str(e)
diagnosis["recommendations"].append("Manual intervention required")
return diagnosis
class ConnectionPoolMonitor: class ConnectionPoolMonitor:
@@ -424,7 +761,7 @@ class ConnectionPoolMonitor:
def __init__(self, connection_manager: DorisConnectionManager): def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager self.connection_manager = connection_manager
self.logger = logging.getLogger(__name__) self.logger = get_logger(__name__)
async def get_pool_status(self) -> dict[str, Any]: async def get_pool_status(self) -> dict[str, Any]:
"""Get connection pool status""" """Get connection pool status"""
@@ -433,7 +770,8 @@ class ConnectionPoolMonitor:
status = { status = {
"pool_size": self.connection_manager.pool.size if self.connection_manager.pool else 0, "pool_size": self.connection_manager.pool.size if self.connection_manager.pool else 0,
"free_connections": self.connection_manager.pool.freesize if self.connection_manager.pool else 0, "free_connections": self.connection_manager.pool.freesize if self.connection_manager.pool else 0,
"active_sessions": len(self.connection_manager.session_connections), "active_connections": metrics.active_connections,
"idle_connections": metrics.idle_connections,
"total_connections": metrics.total_connections, "total_connections": metrics.total_connections,
"failed_connections": metrics.failed_connections, "failed_connections": metrics.failed_connections,
"connection_errors": metrics.connection_errors, "connection_errors": metrics.connection_errors,
@@ -444,52 +782,33 @@ class ConnectionPoolMonitor:
return status return status
async def get_session_details(self) -> list[dict[str, Any]]: async def get_session_details(self) -> list[dict[str, Any]]:
"""Get session connection details""" """Get session connection details - Simplified Strategy (No session caching)"""
sessions = [] # In simplified strategy, we don't maintain session connections
# Return empty list as connections are managed by the pool directly
for session_id, conn in self.connection_manager.session_connections.items(): return []
session_info = {
"session_id": session_id,
"created_at": conn.created_at.isoformat(),
"last_used": conn.last_used.isoformat(),
"query_count": conn.query_count,
"is_healthy": conn.is_healthy,
"connection_age": (datetime.utcnow() - conn.created_at).total_seconds(),
}
sessions.append(session_info)
return sessions
async def generate_health_report(self) -> dict[str, Any]: async def generate_health_report(self) -> dict[str, Any]:
"""Generate connection health report""" """Generate connection health report - Simplified Strategy"""
pool_status = await self.get_pool_status() pool_status = await self.get_pool_status()
session_details = await self.get_session_details()
# Calculate health statistics # Calculate pool utilization
healthy_sessions = sum(1 for s in session_details if s["is_healthy"]) pool_utilization = 1.0 - (pool_status["free_connections"] / pool_status["pool_size"]) if pool_status["pool_size"] > 0 else 0.0
total_sessions = len(session_details)
health_ratio = healthy_sessions / total_sessions if total_sessions > 0 else 1.0
report = { report = {
"timestamp": datetime.utcnow().isoformat(), "timestamp": datetime.utcnow().isoformat(),
"pool_status": pool_status, "pool_status": pool_status,
"session_summary": { "pool_utilization": pool_utilization,
"total_sessions": total_sessions,
"healthy_sessions": healthy_sessions,
"health_ratio": health_ratio,
},
"session_details": session_details,
"recommendations": [], "recommendations": [],
} }
# Add recommendations based on health status # Add recommendations based on pool status
if health_ratio < 0.8:
report["recommendations"].append("Consider checking database connectivity and network stability")
if pool_status["connection_errors"] > 10: if pool_status["connection_errors"] > 10:
report["recommendations"].append("High connection error rate detected, review connection configuration") report["recommendations"].append("High connection error rate detected, review connection configuration")
if pool_status["active_sessions"] > pool_status["pool_size"] * 0.9: if pool_utilization > 0.9:
report["recommendations"].append("Connection pool utilization is high, consider increasing pool size") report["recommendations"].append("Connection pool utilization is high, consider increasing pool size")
return report if pool_status["free_connections"] == 0:
report["recommendations"].append("No free connections available, consider increasing pool size")
return report

View File

@@ -0,0 +1,978 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Dependency Analysis Tools Module
Provides data flow dependency analysis and impact assessment capabilities
"""
import time
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Tuple
from collections import defaultdict, deque
from .db import DorisConnectionManager
from .logger import get_logger
logger = get_logger(__name__)
class DependencyAnalysisTools:
"""Dependency analysis tools for data flow and impact assessment"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
logger.info("DependencyAnalysisTools initialized")
async def analyze_data_flow_dependencies(
self,
target_table: Optional[str] = None,
analysis_depth: int = 3,
include_views: bool = True,
catalog_name: Optional[str] = None,
db_name: Optional[str] = None
) -> Dict[str, Any]:
"""
Analyze data flow dependencies and impact relationships
Args:
target_table: Specific table to analyze (if None, analyzes all tables)
analysis_depth: Maximum depth for dependency traversal
include_views: Whether to include views in dependency analysis
catalog_name: Catalog name
db_name: Database name
Returns:
Comprehensive dependency analysis results
"""
try:
start_time = time.time()
connection = await self.connection_manager.get_connection("query")
# 1. Get table metadata and relationships
tables_metadata = await self._get_tables_metadata(connection, catalog_name, db_name, include_views)
if not tables_metadata:
return {
"error": "No tables found for dependency analysis",
"analysis_timestamp": datetime.now().isoformat()
}
# 2. Build dependency graph from SQL analysis
dependency_graph = await self._build_dependency_graph(connection, tables_metadata, analysis_depth)
# 3. Analyze specific table or all tables
if target_table:
# Analyze specific table
table_analysis = await self._analyze_single_table_dependencies(
target_table, dependency_graph, tables_metadata
)
impact_analysis = await self._calculate_impact_analysis(
target_table, dependency_graph, "both"
)
else:
# Analyze all tables
table_analysis = await self._analyze_all_tables_dependencies(
dependency_graph, tables_metadata
)
impact_analysis = await self._calculate_global_impact_analysis(dependency_graph)
# 4. Generate insights and recommendations
dependency_insights = await self._generate_dependency_insights(
dependency_graph, table_analysis, impact_analysis
)
execution_time = time.time() - start_time
return {
"analysis_target": target_table or "all_tables",
"analysis_timestamp": datetime.now().isoformat(),
"execution_time_seconds": round(execution_time, 3),
"tables_analyzed": len(tables_metadata),
"dependency_graph_stats": self._get_dependency_graph_stats(dependency_graph),
"table_dependencies": table_analysis,
"impact_analysis": impact_analysis,
"dependency_insights": dependency_insights,
"recommendations": self._generate_dependency_recommendations(dependency_insights)
}
except Exception as e:
logger.error(f"Data flow dependency analysis failed: {str(e)}")
return {
"error": str(e),
"analysis_timestamp": datetime.now().isoformat()
}
# ==================== Private Helper Methods ====================
async def _get_tables_metadata(self, connection, catalog_name: Optional[str], db_name: Optional[str], include_views: bool) -> List[Dict]:
"""Get metadata for all tables and views"""
try:
# Build conditions for query
where_conditions = []
if db_name:
where_conditions.append(f"table_schema = '{db_name}'")
else:
where_conditions.append("table_schema = DATABASE()")
table_types = ["'BASE TABLE'"]
if include_views:
table_types.append("'VIEW'")
where_conditions.append(f"table_type IN ({','.join(table_types)})")
metadata_sql = f"""
SELECT
table_schema as schema_name,
table_name,
table_type,
table_comment,
table_rows,
data_length
FROM information_schema.tables
WHERE {' AND '.join(where_conditions)}
ORDER BY table_schema, table_name
"""
result = await connection.execute(metadata_sql)
return result.data if result.data else []
except Exception as e:
logger.warning(f"Failed to get tables metadata: {str(e)}")
return []
async def _build_dependency_graph(self, connection, tables_metadata: List[Dict], analysis_depth: int) -> Dict[str, Dict]:
"""Build dependency graph by analyzing SQL statements and DDL"""
dependency_graph = defaultdict(lambda: {
"upstream_dependencies": set(),
"downstream_dependencies": set(),
"table_type": "unknown",
"dependency_strength": {},
"sql_patterns": []
})
# Initialize graph with table metadata
for table in tables_metadata:
table_name = table["table_name"]
schema_name = table.get("schema_name", "")
full_table_name = f"{schema_name}.{table_name}" if schema_name else table_name
dependency_graph[full_table_name]["table_type"] = table["table_type"]
# 1. Analyze view definitions for dependencies
await self._analyze_view_dependencies(connection, dependency_graph, tables_metadata)
# 2. Analyze audit logs for runtime dependencies
await self._analyze_runtime_dependencies(connection, dependency_graph, analysis_depth)
# 3. Analyze foreign key relationships
await self._analyze_foreign_key_dependencies(connection, dependency_graph, tables_metadata)
return dict(dependency_graph)
async def _analyze_view_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None:
"""Analyze view definitions to extract table dependencies"""
try:
for table in tables_metadata:
if table["table_type"] == "VIEW":
table_name = table["table_name"]
schema_name = table.get("schema_name", "")
# Get view definition
view_def_sql = f"SHOW CREATE VIEW {schema_name}.{table_name}" if schema_name else f"SHOW CREATE VIEW {table_name}"
try:
result = await connection.execute(view_def_sql)
if result.data and len(result.data) > 0:
# Extract view definition from result
view_definition = ""
for row in result.data:
for key, value in row.items():
if "create" in key.lower() and value:
view_definition = str(value)
break
if view_definition:
# Extract table dependencies from view definition
referenced_tables = self._extract_table_references(view_definition)
full_view_name = f"{schema_name}.{table_name}" if schema_name else table_name
for ref_table in referenced_tables:
# Add upstream dependency
dependency_graph[full_view_name]["upstream_dependencies"].add(ref_table)
dependency_graph[full_view_name]["dependency_strength"][ref_table] = "direct"
# Add downstream dependency for referenced table
dependency_graph[ref_table]["downstream_dependencies"].add(full_view_name)
dependency_graph[full_view_name]["sql_patterns"].append({
"pattern_type": "view_definition",
"referenced_table": ref_table,
"confidence": 1.0
})
except Exception as e:
logger.warning(f"Failed to analyze view {table_name}: {str(e)}")
continue
except Exception as e:
logger.warning(f"Failed to analyze view dependencies: {str(e)}")
async def _analyze_runtime_dependencies(self, connection, dependency_graph: Dict, analysis_depth: int) -> None:
"""Analyze audit logs to discover runtime table dependencies"""
try:
# Get recent SQL statements from audit logs
audit_sql = """
SELECT
`stmt` as sql_statement,
`user` as user_name,
COUNT(*) as frequency
FROM internal.__internal_schema.audit_log
WHERE `stmt` IS NOT NULL
AND `stmt` != ''
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
GROUP BY `stmt`, `user`
HAVING frequency > 1
ORDER BY frequency DESC
LIMIT 1000
"""
result = await connection.execute(audit_sql)
if result.data:
for row in result.data:
sql_statement = row.get("sql_statement", "")
frequency = row.get("frequency", 1)
if sql_statement:
# Extract table references from SQL
referenced_tables = self._extract_table_references(sql_statement)
if len(referenced_tables) > 1:
# Infer dependencies from multi-table queries
self._infer_dependencies_from_sql(
dependency_graph, sql_statement, referenced_tables, frequency
)
except Exception as e:
logger.warning(f"Failed to analyze runtime dependencies: {str(e)}")
async def _analyze_foreign_key_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None:
"""Analyze foreign key constraints for explicit dependencies"""
try:
# Get foreign key information
fk_sql = """
SELECT
TABLE_SCHEMA as schema_name,
TABLE_NAME as table_name,
COLUMN_NAME as column_name,
REFERENCED_TABLE_SCHEMA as ref_schema,
REFERENCED_TABLE_NAME as ref_table_name,
REFERENCED_COLUMN_NAME as ref_column_name
FROM information_schema.KEY_COLUMN_USAGE
WHERE REFERENCED_TABLE_NAME IS NOT NULL
"""
result = await connection.execute(fk_sql)
if result.data:
for row in result.data:
schema_name = row.get("schema_name", "")
table_name = row["table_name"]
ref_schema = row.get("ref_schema", "")
ref_table_name = row["ref_table_name"]
# Build full table names
full_table_name = f"{schema_name}.{table_name}" if schema_name else table_name
full_ref_table = f"{ref_schema}.{ref_table_name}" if ref_schema else ref_table_name
# Add foreign key dependency
dependency_graph[full_table_name]["upstream_dependencies"].add(full_ref_table)
dependency_graph[full_table_name]["dependency_strength"][full_ref_table] = "foreign_key"
dependency_graph[full_ref_table]["downstream_dependencies"].add(full_table_name)
dependency_graph[full_table_name]["sql_patterns"].append({
"pattern_type": "foreign_key",
"referenced_table": full_ref_table,
"confidence": 1.0,
"column": row["column_name"],
"ref_column": row["ref_column_name"]
})
except Exception as e:
logger.warning(f"Failed to analyze foreign key dependencies: {str(e)}")
def _extract_table_references(self, sql: str) -> List[str]:
"""Extract table references from SQL statement"""
if not sql:
return []
# Normalize SQL
sql = re.sub(r'/\*.*?\*/', '', sql, flags=re.DOTALL) # Remove comments
sql = re.sub(r'--.*', '', sql) # Remove line comments
sql = sql.upper()
table_references = []
# Pattern to match table names in various contexts
patterns = [
r'\bFROM\s+([`"]?[a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*[`"]?)',
r'\bJOIN\s+([`"]?[a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*[`"]?)',
r'\bINTO\s+([`"]?[a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*[`"]?)',
r'\bUPDATE\s+([`"]?[a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*[`"]?)',
r'\bDELETE\s+FROM\s+([`"]?[a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*[`"]?)',
r'\bINSERT\s+INTO\s+([`"]?[a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*[`"]?)'
]
for pattern in patterns:
matches = re.findall(pattern, sql, re.IGNORECASE)
for match in matches:
# Clean up table name
table_name = match.strip('`"\'').split()[0] # Remove quotes and aliases
if table_name and not self._is_sql_keyword(table_name):
table_references.append(table_name.lower())
return list(set(table_references))
def _is_sql_keyword(self, word: str) -> bool:
"""Check if word is a SQL keyword"""
keywords = {
'SELECT', 'FROM', 'WHERE', 'JOIN', 'INNER', 'LEFT', 'RIGHT', 'OUTER',
'ON', 'AND', 'OR', 'NOT', 'IN', 'EXISTS', 'BETWEEN', 'LIKE',
'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP', 'INDEX',
'TABLE', 'VIEW', 'DATABASE', 'SCHEMA', 'PRIMARY', 'KEY', 'FOREIGN',
'REFERENCES', 'CONSTRAINT', 'NULL', 'DEFAULT', 'AUTO_INCREMENT'
}
return word.upper() in keywords
def _infer_dependencies_from_sql(self, dependency_graph: Dict, sql: str, referenced_tables: List[str], frequency: int) -> None:
"""Infer table dependencies from SQL patterns"""
# Analyze SQL pattern to determine dependency relationships
sql_upper = sql.upper()
# Look for INSERT ... SELECT patterns
if 'INSERT' in sql_upper and 'SELECT' in sql_upper:
# Find target table (after INSERT INTO)
insert_match = re.search(r'INSERT\s+INTO\s+([a-zA-Z_][a-zA-Z0-9_.]*)', sql_upper)
if insert_match:
target_table = insert_match.group(1).lower()
# All other tables are dependencies
for ref_table in referenced_tables:
if ref_table != target_table:
dependency_graph[target_table]["upstream_dependencies"].add(ref_table)
dependency_graph[ref_table]["downstream_dependencies"].add(target_table)
# Calculate confidence based on frequency
confidence = min(0.9, 0.3 + (frequency / 100))
dependency_graph[target_table]["sql_patterns"].append({
"pattern_type": "insert_select",
"referenced_table": ref_table,
"confidence": confidence,
"frequency": frequency
})
# Look for CREATE TABLE AS SELECT patterns
elif 'CREATE' in sql_upper and 'SELECT' in sql_upper:
create_match = re.search(r'CREATE\s+TABLE\s+([a-zA-Z_][a-zA-Z0-9_.]*)', sql_upper)
if create_match:
target_table = create_match.group(1).lower()
for ref_table in referenced_tables:
if ref_table != target_table:
dependency_graph[target_table]["upstream_dependencies"].add(ref_table)
dependency_graph[ref_table]["downstream_dependencies"].add(target_table)
dependency_graph[target_table]["sql_patterns"].append({
"pattern_type": "create_table_as_select",
"referenced_table": ref_table,
"confidence": 0.95,
"frequency": frequency
})
async def _analyze_single_table_dependencies(self, target_table: str, dependency_graph: Dict, tables_metadata: List[Dict]) -> Dict[str, Any]:
"""Analyze dependencies for a specific table"""
if target_table not in dependency_graph:
return {"error": f"Table {target_table} not found in dependency graph"}
table_info = dependency_graph[target_table]
# Get upstream dependencies (tables this table depends on)
upstream_deps = await self._get_dependency_chain(target_table, dependency_graph, "upstream", 3)
# Get downstream dependencies (tables that depend on this table)
downstream_deps = await self._get_dependency_chain(target_table, dependency_graph, "downstream", 3)
return {
"table_name": target_table,
"table_type": table_info["table_type"],
"direct_upstream_dependencies": list(table_info["upstream_dependencies"]),
"direct_downstream_dependencies": list(table_info["downstream_dependencies"]),
"upstream_dependency_chain": upstream_deps,
"downstream_dependency_chain": downstream_deps,
"dependency_patterns": table_info["sql_patterns"],
"dependency_metrics": {
"upstream_count": len(table_info["upstream_dependencies"]),
"downstream_count": len(table_info["downstream_dependencies"]),
"total_upstream_chain": len(upstream_deps.get("all_dependencies", [])),
"total_downstream_chain": len(downstream_deps.get("all_dependencies", [])),
"dependency_depth": max(upstream_deps.get("max_depth", 0), downstream_deps.get("max_depth", 0))
}
}
async def _get_dependency_chain(self, start_table: str, dependency_graph: Dict, direction: str, max_depth: int) -> Dict[str, Any]:
"""Get full dependency chain in specified direction"""
visited = set()
all_dependencies = []
levels = []
current_level = [start_table]
depth = 0
while current_level and depth < max_depth:
next_level = []
level_deps = []
for table in current_level:
if table in visited:
continue
visited.add(table)
if direction == "upstream":
dependencies = dependency_graph.get(table, {}).get("upstream_dependencies", set())
else:
dependencies = dependency_graph.get(table, {}).get("downstream_dependencies", set())
for dep in dependencies:
if dep not in visited:
next_level.append(dep)
level_deps.append(dep)
all_dependencies.append(dep)
if level_deps:
levels.append({
"level": depth + 1,
"tables": level_deps
})
current_level = next_level
depth += 1
return {
"direction": direction,
"max_depth": depth,
"all_dependencies": list(set(all_dependencies)),
"dependency_levels": levels,
"total_count": len(set(all_dependencies))
}
async def _analyze_all_tables_dependencies(self, dependency_graph: Dict, tables_metadata: List[Dict]) -> Dict[str, Any]:
"""Analyze dependencies for all tables"""
table_stats = {}
for table_name, table_info in dependency_graph.items():
upstream_count = len(table_info["upstream_dependencies"])
downstream_count = len(table_info["downstream_dependencies"])
table_stats[table_name] = {
"table_type": table_info["table_type"],
"upstream_count": upstream_count,
"downstream_count": downstream_count,
"total_connections": upstream_count + downstream_count,
"dependency_score": self._calculate_dependency_score(upstream_count, downstream_count),
"role_classification": self._classify_table_role(upstream_count, downstream_count)
}
# Find key tables
most_critical_tables = sorted(
table_stats.items(),
key=lambda x: x[1]["dependency_score"],
reverse=True
)[:10]
source_tables = [name for name, stats in table_stats.items() if stats["role_classification"] == "source"]
sink_tables = [name for name, stats in table_stats.items() if stats["role_classification"] == "sink"]
hub_tables = [name for name, stats in table_stats.items() if stats["role_classification"] == "hub"]
return {
"table_statistics": table_stats,
"summary": {
"total_tables": len(table_stats),
"source_tables": len(source_tables),
"sink_tables": len(sink_tables),
"hub_tables": len(hub_tables),
"isolated_tables": len([stats for stats in table_stats.values() if stats["total_connections"] == 0])
},
"critical_tables": [{"table": name, **stats} for name, stats in most_critical_tables],
"table_roles": {
"sources": source_tables[:10],
"sinks": sink_tables[:10],
"hubs": hub_tables[:10]
}
}
def _calculate_dependency_score(self, upstream_count: int, downstream_count: int) -> float:
"""Calculate dependency importance score for a table"""
# Score based on both incoming and outgoing dependencies
# Higher weight for downstream dependencies (impact)
return round(upstream_count * 0.3 + downstream_count * 0.7, 2)
def _classify_table_role(self, upstream_count: int, downstream_count: int) -> str:
"""Classify table role based on dependency pattern"""
if upstream_count == 0 and downstream_count > 0:
return "source" # Data source
elif upstream_count > 0 and downstream_count == 0:
return "sink" # Data destination
elif upstream_count > 2 and downstream_count > 2:
return "hub" # Data hub/transformation
elif upstream_count > 0 and downstream_count > 0:
return "intermediate" # Intermediate transformation
else:
return "isolated" # No dependencies
async def _calculate_impact_analysis(self, target_table: str, dependency_graph: Dict, direction: str) -> Dict[str, Any]:
"""Calculate impact analysis for a specific table"""
if direction == "upstream" or direction == "both":
upstream_impact = await self._calculate_upstream_impact(target_table, dependency_graph)
else:
upstream_impact = {}
if direction == "downstream" or direction == "both":
downstream_impact = await self._calculate_downstream_impact(target_table, dependency_graph)
else:
downstream_impact = {}
return {
"target_table": target_table,
"upstream_impact": upstream_impact,
"downstream_impact": downstream_impact,
"total_impact_score": self._calculate_total_impact_score(upstream_impact, downstream_impact)
}
async def _calculate_upstream_impact(self, target_table: str, dependency_graph: Dict) -> Dict[str, Any]:
"""Calculate what would be impacted if upstream dependencies fail"""
upstream_deps = dependency_graph.get(target_table, {}).get("upstream_dependencies", set())
impact_scenarios = []
for dep_table in upstream_deps:
# Simulate failure of this dependency
affected_tables = await self._simulate_table_failure_impact(dep_table, dependency_graph)
impact_scenarios.append({
"failed_dependency": dep_table,
"directly_affected_tables": len(affected_tables["direct"]),
"indirectly_affected_tables": len(affected_tables["indirect"]),
"total_affected": len(affected_tables["all"]),
"critical_affected": [table for table in affected_tables["all"]
if dependency_graph.get(table, {}).get("downstream_dependencies", set())],
"impact_severity": self._assess_impact_severity(len(affected_tables["all"]))
})
return {
"dependency_count": len(upstream_deps),
"impact_scenarios": impact_scenarios,
"max_potential_impact": max([scenario["total_affected"] for scenario in impact_scenarios], default=0),
"risk_assessment": self._assess_upstream_risk(impact_scenarios)
}
async def _calculate_downstream_impact(self, target_table: str, dependency_graph: Dict) -> Dict[str, Any]:
"""Calculate what would be impacted if target table fails"""
affected_tables = await self._simulate_table_failure_impact(target_table, dependency_graph)
return {
"direct_impact": len(affected_tables["direct"]),
"indirect_impact": len(affected_tables["indirect"]),
"total_impact": len(affected_tables["all"]),
"affected_table_details": [
{
"table_name": table,
"impact_type": "direct" if table in affected_tables["direct"] else "indirect",
"table_role": self._classify_table_role(
len(dependency_graph.get(table, {}).get("upstream_dependencies", set())),
len(dependency_graph.get(table, {}).get("downstream_dependencies", set()))
)
}
for table in affected_tables["all"]
],
"impact_severity": self._assess_impact_severity(len(affected_tables["all"]))
}
async def _simulate_table_failure_impact(self, failed_table: str, dependency_graph: Dict) -> Dict[str, List[str]]:
"""Simulate the impact of a table failure"""
direct_affected = list(dependency_graph.get(failed_table, {}).get("downstream_dependencies", set()))
# Find all indirectly affected tables using BFS
visited = {failed_table}
queue = deque(direct_affected)
indirect_affected = []
while queue:
current_table = queue.popleft()
if current_table in visited:
continue
visited.add(current_table)
indirect_affected.append(current_table)
# Add downstream dependencies to queue
downstream = dependency_graph.get(current_table, {}).get("downstream_dependencies", set())
for dep in downstream:
if dep not in visited:
queue.append(dep)
# Remove direct affected from indirect (they're already counted)
indirect_only = [table for table in indirect_affected if table not in direct_affected]
return {
"direct": direct_affected,
"indirect": indirect_only,
"all": direct_affected + indirect_only
}
def _assess_impact_severity(self, affected_count: int) -> str:
"""Assess impact severity based on affected table count"""
if affected_count == 0:
return "none"
elif affected_count <= 2:
return "low"
elif affected_count <= 5:
return "medium"
elif affected_count <= 10:
return "high"
else:
return "critical"
def _assess_upstream_risk(self, impact_scenarios: List[Dict]) -> str:
"""Assess upstream dependency risk"""
if not impact_scenarios:
return "low"
max_impact = max([scenario["total_affected"] for scenario in impact_scenarios])
high_impact_scenarios = len([s for s in impact_scenarios if s["impact_severity"] in ["high", "critical"]])
if high_impact_scenarios > 0 or max_impact > 10:
return "high"
elif max_impact > 5 or len(impact_scenarios) > 3:
return "medium"
else:
return "low"
def _calculate_total_impact_score(self, upstream_impact: Dict, downstream_impact: Dict) -> float:
"""Calculate total impact score combining upstream and downstream risks"""
upstream_score = 0
downstream_score = 0
if upstream_impact:
max_upstream_impact = upstream_impact.get("max_potential_impact", 0)
upstream_score = min(max_upstream_impact * 0.3, 10) # Cap at 10
if downstream_impact:
downstream_score = min(downstream_impact.get("total_impact", 0) * 0.7, 10) # Cap at 10
return round(upstream_score + downstream_score, 2)
async def _calculate_global_impact_analysis(self, dependency_graph: Dict) -> Dict[str, Any]:
"""Calculate global impact analysis for all tables"""
table_impacts = {}
for table_name in dependency_graph.keys():
impact = await self._calculate_impact_analysis(table_name, dependency_graph, "downstream")
table_impacts[table_name] = {
"downstream_impact": impact["downstream_impact"]["total_impact"],
"impact_severity": impact["downstream_impact"]["impact_severity"],
"impact_score": impact["total_impact_score"]
}
# Find most critical tables
critical_tables = sorted(
table_impacts.items(),
key=lambda x: x[1]["impact_score"],
reverse=True
)[:15]
# Risk distribution
risk_distribution = {
"critical": len([t for t in table_impacts.values() if t["impact_severity"] == "critical"]),
"high": len([t for t in table_impacts.values() if t["impact_severity"] == "high"]),
"medium": len([t for t in table_impacts.values() if t["impact_severity"] == "medium"]),
"low": len([t for t in table_impacts.values() if t["impact_severity"] == "low"]),
"none": len([t for t in table_impacts.values() if t["impact_severity"] == "none"])
}
return {
"global_impact_summary": {
"total_tables_analyzed": len(table_impacts),
"tables_with_impact": len([t for t in table_impacts.values() if t["downstream_impact"] > 0]),
"average_impact_score": round(sum(t["impact_score"] for t in table_impacts.values()) / len(table_impacts), 2) if table_impacts else 0,
"risk_distribution": risk_distribution
},
"most_critical_tables": [{"table": name, **stats} for name, stats in critical_tables],
"risk_matrix": self._generate_risk_matrix(table_impacts)
}
def _generate_risk_matrix(self, table_impacts: Dict[str, Dict]) -> Dict[str, List[str]]:
"""Generate risk matrix categorizing tables by impact level"""
risk_matrix = {
"critical_risk": [],
"high_risk": [],
"medium_risk": [],
"low_risk": [],
"minimal_risk": []
}
for table_name, impact_data in table_impacts.items():
severity = impact_data["impact_severity"]
if severity == "critical":
risk_matrix["critical_risk"].append(table_name)
elif severity == "high":
risk_matrix["high_risk"].append(table_name)
elif severity == "medium":
risk_matrix["medium_risk"].append(table_name)
elif severity == "low":
risk_matrix["low_risk"].append(table_name)
else:
risk_matrix["minimal_risk"].append(table_name)
return risk_matrix
def _get_dependency_graph_stats(self, dependency_graph: Dict) -> Dict[str, Any]:
"""Get statistics about the dependency graph"""
total_tables = len(dependency_graph)
total_dependencies = sum(
len(table_info.get("upstream_dependencies", set())) + len(table_info.get("downstream_dependencies", set()))
for table_info in dependency_graph.values()
) // 2 # Divide by 2 to avoid double counting
tables_with_upstream = len([
table for table, info in dependency_graph.items()
if info.get("upstream_dependencies")
])
tables_with_downstream = len([
table for table, info in dependency_graph.items()
if info.get("downstream_dependencies")
])
isolated_tables = len([
table for table, info in dependency_graph.items()
if not info.get("upstream_dependencies") and not info.get("downstream_dependencies")
])
return {
"total_tables": total_tables,
"total_dependencies": total_dependencies,
"tables_with_upstream_deps": tables_with_upstream,
"tables_with_downstream_deps": tables_with_downstream,
"isolated_tables": isolated_tables,
"connectivity_ratio": round((total_tables - isolated_tables) / total_tables, 3) if total_tables > 0 else 0,
"avg_dependencies_per_table": round(total_dependencies / total_tables, 2) if total_tables > 0 else 0
}
async def _generate_dependency_insights(self, dependency_graph: Dict, table_analysis: Dict, impact_analysis: Dict) -> Dict[str, Any]:
"""Generate insights from dependency analysis"""
insights = {
"architectural_patterns": {},
"risk_assessment": {},
"optimization_opportunities": {}
}
# Architectural patterns
graph_stats = self._get_dependency_graph_stats(dependency_graph)
insights["architectural_patterns"] = {
"connectivity_level": "high" if graph_stats["connectivity_ratio"] > 0.7 else "medium" if graph_stats["connectivity_ratio"] > 0.3 else "low",
"architecture_type": self._classify_architecture_type(graph_stats),
"complexity_score": round(graph_stats["avg_dependencies_per_table"] * graph_stats["connectivity_ratio"], 2),
"isolated_tables_concern": graph_stats["isolated_tables"] > graph_stats["total_tables"] * 0.3
}
# Risk assessment
if isinstance(impact_analysis, dict) and "global_impact_summary" in impact_analysis:
global_impact = impact_analysis["global_impact_summary"]
insights["risk_assessment"] = {
"overall_risk_level": self._assess_overall_risk_level(global_impact["risk_distribution"]),
"critical_tables_count": global_impact["risk_distribution"]["critical"],
"high_risk_tables_count": global_impact["risk_distribution"]["high"],
"impact_concentration": global_impact["average_impact_score"] > 5.0,
"resilience_score": self._calculate_resilience_score(global_impact)
}
# Optimization opportunities
insights["optimization_opportunities"] = self._identify_optimization_opportunities(dependency_graph, table_analysis)
return insights
def _classify_architecture_type(self, graph_stats: Dict) -> str:
"""Classify the overall architecture type"""
connectivity = graph_stats["connectivity_ratio"]
avg_deps = graph_stats["avg_dependencies_per_table"]
if connectivity > 0.8 and avg_deps > 3:
return "highly_interconnected"
elif connectivity > 0.5 and avg_deps > 2:
return "moderately_connected"
elif connectivity < 0.3:
return "loosely_coupled"
else:
return "mixed_architecture"
def _assess_overall_risk_level(self, risk_distribution: Dict[str, int]) -> str:
"""Assess overall risk level from risk distribution"""
total = sum(risk_distribution.values())
if total == 0:
return "minimal"
critical_ratio = risk_distribution["critical"] / total
high_ratio = risk_distribution["high"] / total
if critical_ratio > 0.1 or high_ratio > 0.2:
return "high"
elif critical_ratio > 0.05 or high_ratio > 0.1:
return "medium"
else:
return "low"
def _calculate_resilience_score(self, global_impact: Dict) -> float:
"""Calculate system resilience score (0-1, higher is better)"""
total_tables = global_impact["total_tables_analyzed"]
risk_dist = global_impact["risk_distribution"]
if total_tables == 0:
return 0.0
# Calculate weighted risk score
weighted_risk = (
risk_dist["critical"] * 5 +
risk_dist["high"] * 3 +
risk_dist["medium"] * 2 +
risk_dist["low"] * 1
) / total_tables
# Convert to resilience score (inverse of risk, normalized)
max_possible_risk = 5.0
resilience = max(0, (max_possible_risk - weighted_risk) / max_possible_risk)
return round(resilience, 3)
def _identify_optimization_opportunities(self, dependency_graph: Dict, table_analysis: Dict) -> List[Dict]:
"""Identify optimization opportunities"""
opportunities = []
# Find tables with excessive dependencies
for table_name, table_info in dependency_graph.items():
upstream_count = len(table_info.get("upstream_dependencies", set()))
downstream_count = len(table_info.get("downstream_dependencies", set()))
if upstream_count > 10:
opportunities.append({
"type": "excessive_upstream_dependencies",
"table": table_name,
"description": f"Table has {upstream_count} upstream dependencies",
"recommendation": "Consider breaking down complex transformations or using intermediate tables",
"priority": "high" if upstream_count > 15 else "medium"
})
if downstream_count > 10:
opportunities.append({
"type": "excessive_downstream_dependencies",
"table": table_name,
"description": f"Table has {downstream_count} downstream dependencies",
"recommendation": "Consider if this table is doing too much or if views could be used",
"priority": "high" if downstream_count > 15 else "medium"
})
# Find potential circular dependencies (simplified check)
# This is a basic check - full cycle detection would be more complex
for table_name, table_info in dependency_graph.items():
upstream_deps = table_info.get("upstream_dependencies", set())
for upstream_table in upstream_deps:
if table_name in dependency_graph.get(upstream_table, {}).get("upstream_dependencies", set()):
opportunities.append({
"type": "potential_circular_dependency",
"table": table_name,
"related_table": upstream_table,
"description": f"Potential circular dependency between {table_name} and {upstream_table}",
"recommendation": "Review and eliminate circular dependencies",
"priority": "high"
})
return opportunities
def _generate_dependency_recommendations(self, dependency_insights: Dict) -> List[Dict]:
"""Generate recommendations based on dependency analysis"""
recommendations = []
# Architecture recommendations
arch_patterns = dependency_insights.get("architectural_patterns", {})
if arch_patterns.get("isolated_tables_concern", False):
recommendations.append({
"type": "architecture",
"priority": "medium",
"title": "High number of isolated tables",
"description": "Many tables have no dependencies, which may indicate data silos",
"action": "Review isolated tables and consider if they should be integrated into data flows"
})
complexity_score = arch_patterns.get("complexity_score", 0)
if complexity_score > 5:
recommendations.append({
"type": "architecture",
"priority": "high",
"title": "High system complexity",
"description": f"System complexity score is {complexity_score} (high)",
"action": "Consider simplifying data architecture and reducing unnecessary dependencies"
})
# Risk recommendations
risk_assessment = dependency_insights.get("risk_assessment", {})
overall_risk = risk_assessment.get("overall_risk_level", "unknown")
if overall_risk == "high":
recommendations.append({
"type": "risk_mitigation",
"priority": "high",
"title": "High overall system risk",
"description": "System has high dependency risks that could cause widespread failures",
"action": "Implement monitoring and backup strategies for critical tables"
})
critical_tables = risk_assessment.get("critical_tables_count", 0)
if critical_tables > 0:
recommendations.append({
"type": "risk_mitigation",
"priority": "high",
"title": f"{critical_tables} critical impact tables identified",
"description": "Tables with critical impact require special attention",
"action": "Implement enhanced monitoring and backup procedures for critical tables"
})
# Optimization recommendations
optimization_ops = dependency_insights.get("optimization_opportunities", [])
if optimization_ops:
high_priority_ops = [op for op in optimization_ops if op.get("priority") == "high"]
if high_priority_ops:
recommendations.append({
"type": "optimization",
"priority": "high",
"title": f"{len(high_priority_ops)} high-priority optimization opportunities",
"description": "System has optimization opportunities that should be addressed",
"action": "Review and implement suggested optimizations for better maintainability"
})
return recommendations

View File

@@ -15,77 +15,573 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
""" """
Logging configuration for Doris MCP Server. Enhanced Logging configuration for Doris MCP Server.
Features:
- Log level-based file separation
- Timestamped log entries
- Automatic log rotation
- Comprehensive logging coverage
""" """
import logging import logging
import logging.config import logging.config
import logging.handlers
import sys import sys
import os
import asyncio
import time
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Optional
from datetime import datetime, timedelta
import threading
def setup_logging( class TimestampedFormatter(logging.Formatter):
level: str = "INFO", """Custom formatter with enhanced timestamp and structured format"""
log_file: str | None = None,
log_format: str | None = None, def __init__(self, fmt=None, datefmt=None, style='%'):
) -> None: if fmt is None:
""" fmt = "%(asctime)s.%(msecs)03d %(level_aligned)s %(name)s:%(lineno)d - %(message)s"
Setup logging configuration. if datefmt is None:
datefmt = "%Y-%m-%d %H:%M:%S"
super().__init__(fmt, datefmt, style)
def format(self, record):
"""Format log record with enhanced information and proper alignment"""
# Add process info if available
if hasattr(record, 'process') and record.process:
record.process_info = f"[PID:{record.process}]"
else:
record.process_info = ""
# Add thread info if available
if hasattr(record, 'thread') and record.thread:
record.thread_info = f"[TID:{record.thread}]"
else:
record.thread_info = ""
# Format with proper alignment after the level name
# Calculate padding needed for alignment
level_name = record.levelname
max_level_length = 8 # Length of "CRITICAL"
padding = max_level_length - len(level_name)
record.level_aligned = f"[{level_name}]{' ' * padding}"
return super().format(record)
Args:
level: Logging level (DEBUG, INFO, WARNING, ERROR)
log_file: Optional log file path
log_format: Optional custom log format
"""
if log_format is None:
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
# Base configuration class LevelBasedFileHandler(logging.Handler):
config: dict[str, Any] = { """Custom handler that writes different log levels to different files"""
"version": 1,
"disable_existing_loggers": False, def __init__(self, log_dir: str, base_name: str = "doris_mcp_server",
"formatters": { max_bytes: int = 10*1024*1024, backup_count: int = 5):
"default": {"format": log_format, "datefmt": "%Y-%m-%d %H:%M:%S"} super().__init__()
}, self.log_dir = Path(log_dir)
"handlers": { self.base_name = base_name
"console": { self.max_bytes = max_bytes
"class": "logging.StreamHandler", self.backup_count = backup_count
"level": level,
"formatter": "default",
"stream": sys.stdout,
}
},
"root": {"level": level, "handlers": ["console"]},
"loggers": {
"doris_mcp_server": {
"level": level,
"handlers": ["console"],
"propagate": False,
}
},
}
# Add file handler if log_file is specified
if log_file:
# Ensure log directory exists # Ensure log directory exists
log_path = Path(log_file) self.log_dir.mkdir(parents=True, exist_ok=True)
log_path.parent.mkdir(parents=True, exist_ok=True)
# Create handlers for different log levels
config["handlers"]["file"] = { self.handlers = {}
"class": "logging.handlers.RotatingFileHandler", self._setup_level_handlers()
"level": level,
"formatter": "default", def _setup_level_handlers(self):
"filename": log_file, """Setup rotating file handlers for different log levels"""
"maxBytes": 10485760, # 10MB level_files = {
"backupCount": 5, 'DEBUG': 'debug.log',
'INFO': 'info.log',
'WARNING': 'warning.log',
'ERROR': 'error.log',
'CRITICAL': 'critical.log'
} }
formatter = TimestampedFormatter()
for level, filename in level_files.items():
file_path = self.log_dir / f"{self.base_name}_{filename}"
handler = logging.handlers.RotatingFileHandler(
file_path,
maxBytes=self.max_bytes,
backupCount=self.backup_count,
encoding='utf-8'
)
handler.setFormatter(formatter)
handler.setLevel(getattr(logging, level))
self.handlers[level] = handler
def emit(self, record):
"""Emit log record to appropriate level-based file"""
level_name = record.levelname
if level_name in self.handlers:
try:
self.handlers[level_name].emit(record)
except Exception:
self.handleError(record)
def close(self):
"""Close all handlers"""
for handler in self.handlers.values():
handler.close()
super().close()
# Add file handler to root and package loggers
config["root"]["handlers"].append("file")
config["loggers"]["doris_mcp_server"]["handlers"].append("file")
logging.config.dictConfig(config) class LogCleanupManager:
"""Log file cleanup manager for automatic maintenance"""
def __init__(self, log_dir: str, max_age_days: int = 30, cleanup_interval_hours: int = 24):
"""
Initialize log cleanup manager.
Args:
log_dir: Directory containing log files
max_age_days: Maximum age of log files in days (default: 30 days)
cleanup_interval_hours: Cleanup interval in hours (default: 24 hours)
"""
self.log_dir = Path(log_dir)
self.max_age_days = max_age_days
self.cleanup_interval_hours = cleanup_interval_hours
self.cleanup_thread = None
self.stop_event = threading.Event()
self.logger = None
def start_cleanup_scheduler(self):
"""Start the cleanup scheduler in a background thread"""
if self.cleanup_thread and self.cleanup_thread.is_alive():
return
self.stop_event.clear()
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
self.cleanup_thread.start()
# Get logger for this class
if not self.logger:
self.logger = logging.getLogger("doris_mcp_server.log_cleanup")
self.logger.info(f"Log cleanup scheduler started - cleanup every {self.cleanup_interval_hours}h, max age {self.max_age_days} days")
def stop_cleanup_scheduler(self):
"""Stop the cleanup scheduler"""
if self.cleanup_thread and self.cleanup_thread.is_alive():
self.stop_event.set()
self.cleanup_thread.join(timeout=5)
if self.logger:
self.logger.info("Log cleanup scheduler stopped")
def _cleanup_loop(self):
"""Background loop for periodic cleanup"""
while not self.stop_event.is_set():
try:
self.cleanup_old_logs()
# Sleep for the specified interval, but check stop event every 60 seconds
for _ in range(self.cleanup_interval_hours * 60): # Convert hours to minutes
if self.stop_event.wait(60): # Wait 60 seconds or until stop event
break
except Exception as e:
if self.logger:
self.logger.error(f"Error in log cleanup loop: {e}")
# Sleep for 5 minutes before retrying
self.stop_event.wait(300)
def cleanup_old_logs(self):
"""Clean up old log files based on age"""
if not self.log_dir.exists():
return
current_time = datetime.now()
cutoff_time = current_time - timedelta(days=self.max_age_days)
cleaned_files = []
cleaned_size = 0
# Pattern for log files (including backup files)
log_patterns = [
"doris_mcp_server_*.log",
"doris_mcp_server_*.log.*" # Backup files
]
for pattern in log_patterns:
for log_file in self.log_dir.glob(pattern):
try:
# Get file modification time
file_mtime = datetime.fromtimestamp(log_file.stat().st_mtime)
if file_mtime < cutoff_time:
file_size = log_file.stat().st_size
log_file.unlink() # Delete the file
cleaned_files.append(log_file.name)
cleaned_size += file_size
except Exception as e:
if self.logger:
self.logger.warning(f"Failed to cleanup log file {log_file}: {e}")
if cleaned_files and self.logger:
size_mb = cleaned_size / (1024 * 1024)
self.logger.info(f"Cleaned up {len(cleaned_files)} old log files, freed {size_mb:.2f} MB")
self.logger.debug(f"Cleaned files: {', '.join(cleaned_files)}")
def get_cleanup_stats(self) -> dict:
"""Get statistics about log files and cleanup status"""
if not self.log_dir.exists():
return {"error": "Log directory does not exist"}
stats = {
"log_directory": str(self.log_dir.absolute()),
"max_age_days": self.max_age_days,
"cleanup_interval_hours": self.cleanup_interval_hours,
"scheduler_running": self.cleanup_thread and self.cleanup_thread.is_alive(),
"total_files": 0,
"total_size_mb": 0,
"files_by_age": {"recent": 0, "old": 0},
"oldest_file": None,
"newest_file": None
}
current_time = datetime.now()
cutoff_time = current_time - timedelta(days=self.max_age_days)
oldest_time = None
newest_time = None
log_patterns = ["doris_mcp_server_*.log", "doris_mcp_server_*.log.*"]
for pattern in log_patterns:
for log_file in self.log_dir.glob(pattern):
try:
file_stat = log_file.stat()
file_mtime = datetime.fromtimestamp(file_stat.st_mtime)
stats["total_files"] += 1
stats["total_size_mb"] += file_stat.st_size / (1024 * 1024)
if file_mtime < cutoff_time:
stats["files_by_age"]["old"] += 1
else:
stats["files_by_age"]["recent"] += 1
if oldest_time is None or file_mtime < oldest_time:
oldest_time = file_mtime
stats["oldest_file"] = {"name": log_file.name, "age_days": (current_time - file_mtime).days}
if newest_time is None or file_mtime > newest_time:
newest_time = file_mtime
stats["newest_file"] = {"name": log_file.name, "age_days": (current_time - file_mtime).days}
except Exception:
continue
stats["total_size_mb"] = round(stats["total_size_mb"], 2)
return stats
class DorisLoggerManager:
"""Centralized logger manager for Doris MCP Server"""
def __init__(self):
self.is_initialized = False
self.log_dir = None
self.config = None
self.loggers = {}
self.cleanup_manager = None
def setup_logging(self,
level: str = "INFO",
log_dir: str = "logs",
enable_console: bool = True,
enable_file: bool = True,
enable_audit: bool = True,
audit_file: Optional[str] = None,
max_file_size: int = 10*1024*1024,
backup_count: int = 5,
enable_cleanup: bool = True,
max_age_days: int = 30,
cleanup_interval_hours: int = 24) -> None:
"""
Setup comprehensive logging configuration.
Args:
level: Base logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
log_dir: Directory for log files
enable_console: Enable console output
enable_file: Enable file logging
enable_audit: Enable audit logging
audit_file: Custom audit log file path
max_file_size: Maximum size per log file (bytes)
backup_count: Number of backup files to keep
enable_cleanup: Enable automatic log cleanup
max_age_days: Maximum age of log files in days (default: 30)
cleanup_interval_hours: Cleanup interval in hours (default: 24)
"""
if self.is_initialized:
return
self.log_dir = Path(log_dir)
log_dir_writable = True # Initialize the variable
# Try to create log directory, fallback to console-only if fails
try:
self.log_dir.mkdir(parents=True, exist_ok=True)
except (OSError, PermissionError) as e:
# If we can't create log directory (e.g., read-only filesystem in stdio mode),
# fall back to console-only logging
log_dir_writable = False
enable_file = False
enable_audit = False
enable_cleanup = False
# Don't use print() in stdio mode as it interferes with MCP JSON protocol
# Log the warning through the logging system instead, which will be handled after setup
# Clear existing handlers
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
# Set root logger level
root_logger.setLevel(logging.DEBUG) # Allow all levels, handlers will filter
handlers = []
# Console handler
if enable_console:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(getattr(logging, level.upper()))
console_formatter = TimestampedFormatter(
fmt="%(asctime)s.%(msecs)03d %(level_aligned)s %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(console_formatter)
handlers.append(console_handler)
# Level-based file handlers
if enable_file:
level_handler = LevelBasedFileHandler(
log_dir=str(self.log_dir),
base_name="doris_mcp_server",
max_bytes=max_file_size,
backup_count=backup_count
)
level_handler.setLevel(logging.DEBUG) # Accept all levels
handlers.append(level_handler)
# Combined application log (all levels in one file)
if enable_file:
app_log_file = self.log_dir / "doris_mcp_server_all.log"
app_handler = logging.handlers.RotatingFileHandler(
app_log_file,
maxBytes=max_file_size,
backupCount=backup_count,
encoding='utf-8'
)
app_handler.setLevel(getattr(logging, level.upper()))
app_formatter = TimestampedFormatter()
app_handler.setFormatter(app_formatter)
handlers.append(app_handler)
# Audit logger (separate from main logging)
if enable_audit:
audit_file_path = audit_file or str(self.log_dir / "doris_mcp_server_audit.log")
audit_logger = logging.getLogger("audit")
audit_logger.setLevel(logging.INFO)
# Clear existing audit handlers
for handler in audit_logger.handlers[:]:
audit_logger.removeHandler(handler)
audit_handler = logging.handlers.RotatingFileHandler(
audit_file_path,
maxBytes=max_file_size,
backupCount=backup_count,
encoding='utf-8'
)
audit_formatter = TimestampedFormatter(
fmt="%(asctime)s.%(msecs)03d [AUDIT] %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
)
audit_handler.setFormatter(audit_formatter)
audit_logger.addHandler(audit_handler)
audit_logger.propagate = False # Don't propagate to root logger
# Add all handlers to root logger
for handler in handlers:
root_logger.addHandler(handler)
# Setup package-specific loggers
self._setup_package_loggers(level)
# Setup log cleanup manager
if enable_cleanup and enable_file:
self.cleanup_manager = LogCleanupManager(
log_dir=str(self.log_dir),
max_age_days=max_age_days,
cleanup_interval_hours=cleanup_interval_hours
)
self.cleanup_manager.start_cleanup_scheduler()
self.is_initialized = True
# Log initialization message
logger = self.get_logger("doris_mcp_server.logger")
logger.info("=" * 80)
logger.info("Doris MCP Server Logging System Initialized")
logger.info(f"Log Level: {level}")
if log_dir_writable:
logger.info(f"Log Directory: {self.log_dir.absolute()}")
else:
logger.info("Log Directory: Not available (console-only mode)")
logger.info(f"Console Logging: {'Enabled' if enable_console else 'Disabled'}")
logger.info(f"File Logging: {'Enabled' if enable_file else 'Disabled (fallback mode)'}")
logger.info(f"Audit Logging: {'Enabled' if enable_audit else 'Disabled (fallback mode)'}")
logger.info(f"Log Cleanup: {'Enabled' if enable_cleanup and enable_file else 'Disabled (fallback mode)'}")
if enable_cleanup and enable_file:
logger.info(f"Cleanup Settings: Max age {max_age_days} days, interval {cleanup_interval_hours}h")
if not log_dir_writable:
logger.warning("Running in console-only logging mode due to filesystem permissions")
logger.warning(f"Could not create log directory '{log_dir}' - stdio mode fallback enabled")
logger.info("=" * 80)
def _setup_package_loggers(self, level: str):
"""Setup specific loggers for different modules"""
package_loggers = [
"doris_mcp_server",
"doris_mcp_server.main",
"doris_mcp_server.utils",
"doris_mcp_server.tools",
"doris_mcp_client"
]
for logger_name in package_loggers:
logger = logging.getLogger(logger_name)
logger.setLevel(getattr(logging, level.upper()))
# Don't add handlers here - they inherit from root logger
def get_logger(self, name: str) -> logging.Logger:
"""
Get a logger instance with proper configuration.
Args:
name: Logger name (usually __name__)
Returns:
Configured logger instance
"""
if name not in self.loggers:
logger = logging.getLogger(name)
self.loggers[name] = logger
return self.loggers[name]
def get_audit_logger(self) -> logging.Logger:
"""Get the audit logger"""
return logging.getLogger("audit")
def log_system_info(self):
"""Log system information for debugging"""
logger = self.get_logger("doris_mcp_server.system")
logger.info("System Information:")
logger.info(f"Python Version: {sys.version}")
logger.info(f"Platform: {sys.platform}")
logger.info(f"Working Directory: {os.getcwd()}")
logger.info(f"Process ID: {os.getpid()}")
# Log environment variables (filtered)
env_vars = ["LOG_LEVEL", "LOG_FILE_PATH", "ENABLE_AUDIT", "AUDIT_FILE_PATH"]
for var in env_vars:
value = os.getenv(var, "Not Set")
logger.info(f"Environment {var}: {value}")
def get_cleanup_stats(self) -> dict:
"""Get log cleanup statistics"""
if self.cleanup_manager:
return self.cleanup_manager.get_cleanup_stats()
else:
return {"error": "Log cleanup is not enabled"}
def manual_cleanup(self) -> dict:
"""Manually trigger log cleanup and return statistics"""
if self.cleanup_manager:
self.cleanup_manager.cleanup_old_logs()
return self.cleanup_manager.get_cleanup_stats()
else:
return {"error": "Log cleanup is not enabled"}
def shutdown(self):
"""Shutdown logging system"""
if not self.is_initialized:
return
logger = self.get_logger("doris_mcp_server.logger")
logger.info("Shutting down logging system...")
# Stop cleanup manager
if self.cleanup_manager:
self.cleanup_manager.stop_cleanup_scheduler()
# Close all handlers
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
try:
handler.close()
except Exception as e:
print(f"Error closing handler: {e}")
# Close audit logger handlers
audit_logger = logging.getLogger("audit")
for handler in audit_logger.handlers[:]:
try:
handler.close()
except Exception as e:
print(f"Error closing audit handler: {e}")
self.is_initialized = False
# Global logger manager instance
_logger_manager = DorisLoggerManager()
def setup_logging(level: str = "INFO",
log_dir: str = "logs",
enable_console: bool = True,
enable_file: bool = True,
enable_audit: bool = True,
audit_file: Optional[str] = None,
max_file_size: int = 10*1024*1024,
backup_count: int = 5,
enable_cleanup: bool = True,
max_age_days: int = 30,
cleanup_interval_hours: int = 24) -> None:
"""
Setup logging configuration (convenience function).
Args:
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
log_dir: Directory for log files
enable_console: Enable console output
enable_file: Enable file logging
enable_audit: Enable audit logging
audit_file: Custom audit log file path
max_file_size: Maximum size per log file (bytes)
backup_count: Number of backup files to keep
enable_cleanup: Enable automatic log cleanup
max_age_days: Maximum age of log files in days (default: 30)
cleanup_interval_hours: Cleanup interval in hours (default: 24)
"""
_logger_manager.setup_logging(
level=level,
log_dir=log_dir,
enable_console=enable_console,
enable_file=enable_file,
enable_audit=enable_audit,
audit_file=audit_file,
max_file_size=max_file_size,
backup_count=backup_count,
enable_cleanup=enable_cleanup,
max_age_days=max_age_days,
cleanup_interval_hours=cleanup_interval_hours
)
def get_logger(name: str) -> logging.Logger: def get_logger(name: str) -> logging.Logger:
@@ -93,9 +589,60 @@ def get_logger(name: str) -> logging.Logger:
Get a logger instance. Get a logger instance.
Args: Args:
name: Logger name name: Logger name (usually __name__)
Returns: Returns:
Logger instance Configured logger instance
""" """
return logging.getLogger(name) return _logger_manager.get_logger(name)
def get_audit_logger() -> logging.Logger:
"""Get the audit logger"""
return _logger_manager.get_audit_logger()
def log_system_info():
"""Log system information for debugging"""
_logger_manager.log_system_info()
def get_cleanup_stats() -> dict:
"""Get log cleanup statistics"""
return _logger_manager.get_cleanup_stats()
def manual_cleanup() -> dict:
"""Manually trigger log cleanup and return statistics"""
return _logger_manager.manual_cleanup()
def shutdown_logging():
"""Shutdown logging system"""
_logger_manager.shutdown()
# Compatibility function for existing code
def setup_logging_old(level: str = "INFO",
log_file: str | None = None,
log_format: str | None = None) -> None:
"""
Legacy setup function for backward compatibility.
Args:
level: Logging level (DEBUG, INFO, WARNING, ERROR)
log_file: Optional log file path (deprecated - use log_dir instead)
log_format: Optional custom log format (deprecated)
"""
# Extract directory from log_file if provided
log_dir = "logs"
if log_file:
log_dir = str(Path(log_file).parent)
setup_logging(
level=level,
log_dir=log_dir,
enable_console=True,
enable_file=True,
enable_audit=True
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -34,6 +34,7 @@ from typing import Any, Dict
from decimal import Decimal from decimal import Decimal
from .db import DorisConnectionManager, QueryResult from .db import DorisConnectionManager, QueryResult
from .logger import get_logger
@dataclass @dataclass
@@ -92,7 +93,7 @@ class QueryCache:
self.max_size = max_size self.max_size = max_size
self.default_ttl = default_ttl self.default_ttl = default_ttl
self.cache: dict[str, CachedQuery] = {} self.cache: dict[str, CachedQuery] = {}
self.logger = logging.getLogger(__name__) self.logger = get_logger(__name__)
def _generate_cache_key( def _generate_cache_key(
self, sql: str, parameters: dict[str, Any] | None = None self, sql: str, parameters: dict[str, Any] | None = None
@@ -194,7 +195,7 @@ class QueryOptimizer:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.logger = logging.getLogger(__name__) self.logger = get_logger(__name__)
self.optimization_rules = self._load_optimization_rules() self.optimization_rules = self._load_optimization_rules()
def _load_optimization_rules(self) -> list[dict[str, Any]]: def _load_optimization_rules(self) -> list[dict[str, Any]]:
@@ -318,7 +319,7 @@ class DorisQueryExecutor:
def __init__(self, connection_manager: DorisConnectionManager, config=None): def __init__(self, connection_manager: DorisConnectionManager, config=None):
self.connection_manager = connection_manager self.connection_manager = connection_manager
self.config = config or self._create_default_config() self.config = config or self._create_default_config()
self.logger = logging.getLogger(__name__) self.logger = get_logger(__name__)
# Initialize components # Initialize components
cache_config = getattr(self.config, 'performance', None) cache_config = getattr(self.config, 'performance', None)
@@ -548,79 +549,114 @@ class DorisQueryExecutor:
user_id: str = "mcp_user" user_id: str = "mcp_user"
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Execute SQL query for MCP interface - unified method""" """Execute SQL query for MCP interface - unified method"""
try: max_retries = 2
if not sql: retry_count = 0
return {
"success": False, while retry_count <= max_retries:
"error": "SQL query is required", try:
"data": None if not sql:
} return {
"success": False,
"error": "SQL query is required",
"data": None
}
# Add LIMIT if not present and it's a SELECT query # Add LIMIT if not present and it's a SELECT query
if sql.upper().startswith("SELECT") and "LIMIT" not in sql.upper(): if sql.upper().startswith("SELECT") and "LIMIT" not in sql.upper():
if sql.endswith(";"): if sql.endswith(";"):
sql = sql[:-1] sql = sql[:-1]
sql = f"{sql} LIMIT {limit}" sql = f"{sql} LIMIT {limit}"
# Create auth context for MCP calls # Create auth context for MCP calls
class MockAuthContext: class MockAuthContext:
def __init__(self): def __init__(self):
self.user_id = user_id self.user_id = user_id
self.roles = ["data_analyst"] self.roles = ["data_analyst"]
self.permissions = ["read_data", "execute_query"] self.permissions = ["read_data", "execute_query"]
self.session_id = session_id self.session_id = session_id
self.security_level = "internal" self.security_level = "internal"
auth_context = MockAuthContext() auth_context = MockAuthContext()
# Create query request # Create query request
query_request = QueryRequest( query_request = QueryRequest(
sql=sql, sql=sql,
session_id=session_id, session_id=session_id,
user_id=user_id, user_id=user_id,
timeout=timeout, timeout=timeout,
cache_enabled=True cache_enabled=False # Disable cache for MCP calls to ensure fresh data
) )
# Execute query # Execute query with retry logic
result = await self.execute_query(query_request, auth_context) result = await self.execute_query(query_request, auth_context)
# Process results # Serialize data for JSON response
processed_data = [] serialized_data = []
if result.data:
for row in result.data: for row in result.data:
processed_row = self._serialize_row_data(row) serialized_data.append(self._serialize_row_data(row))
processed_data.append(processed_row)
return { return {
"success": True, "success": True,
"data": processed_data, "data": serialized_data,
"metadata": {
"row_count": result.row_count, "row_count": result.row_count,
"execution_time": result.execution_time, "execution_time": result.execution_time,
"columns": result.metadata.get("columns", []), "metadata": {
"query": sql "columns": result.metadata.get("columns", []),
}, "query": sql
"error": None }
}
except Exception as e:
error_msg = str(e)
self.logger.error(f"SQL execution error: {error_msg}")
# Analyze error for better user feedback
error_analysis = self._analyze_error(error_msg)
return {
"success": False,
"error": error_analysis.get("user_message", error_msg),
"error_type": error_analysis.get("error_type", "execution_error"),
"data": None,
"metadata": {
"query": sql,
"error_details": error_msg
} }
except Exception as e:
error_msg = str(e)
error_str = error_msg.lower()
# Check if it's a connection-related error that we should retry
connection_errors = [
"at_eof", "connection", "closed", "nonetype",
"transport", "reader", "broken pipe", "connection reset"
]
is_connection_error = any(err in error_str for err in connection_errors)
if is_connection_error and retry_count < max_retries:
retry_count += 1
self.logger.warning(f"Connection error detected, retrying ({retry_count}/{max_retries}): {e}")
# Release the problematic connection
try:
await self.connection_manager.release_connection(session_id)
except Exception:
pass # Ignore cleanup errors
# Wait a bit before retry
await asyncio.sleep(0.5 * retry_count)
continue
else:
# If we've exhausted retries or it's not a connection error, return error
error_analysis = self._analyze_error(error_msg)
return {
"success": False,
"error": error_analysis.get("user_message", error_msg),
"error_type": error_analysis.get("error_type", "general_error"),
"data": None,
"metadata": {
"query": sql,
"error_details": error_msg,
"retry_count": retry_count
}
}
# This should never be reached, but just in case
return {
"success": False,
"error": "Maximum retries exceeded",
"data": None,
"metadata": {
"query": sql,
"retry_count": retry_count
} }
}
def _serialize_row_data(self, row_data: Dict[str, Any]) -> Dict[str, Any]: def _serialize_row_data(self, row_data: Dict[str, Any]) -> Dict[str, Any]:
"""Serialize row data for JSON response""" """Serialize row data for JSON response"""
@@ -649,7 +685,12 @@ class DorisQueryExecutor:
"""Analyze error message and provide user-friendly feedback""" """Analyze error message and provide user-friendly feedback"""
error_msg_lower = error_message.lower() error_msg_lower = error_message.lower()
if "table" in error_msg_lower and "doesn't exist" in error_msg_lower: if "at_eof" in error_msg_lower or "nonetype" in error_msg_lower and "at_eof" in error_msg_lower:
return {
"error_type": "connection_lost",
"user_message": "Database connection was lost. The query has been automatically retried. If this persists, please restart the server."
}
elif "table" in error_msg_lower and "doesn't exist" in error_msg_lower:
return { return {
"error_type": "table_not_found", "error_type": "table_not_found",
"user_message": "The specified table does not exist. Please check the table name and database." "user_message": "The specified table does not exist. Please check the table name and database."
@@ -674,6 +715,11 @@ class DorisQueryExecutor:
"error_type": "timeout", "error_type": "timeout",
"user_message": "Query execution timed out. Try simplifying your query or adding more specific filters." "user_message": "Query execution timed out. Try simplifying your query or adding more specific filters."
} }
elif "connection" in error_msg_lower and ("closed" in error_msg_lower or "reset" in error_msg_lower):
return {
"error_type": "connection_error",
"user_message": "Database connection was interrupted. The query has been automatically retried."
}
else: else:
return { return {
"error_type": "general_error", "error_type": "general_error",
@@ -701,7 +747,7 @@ class QueryPerformanceMonitor:
def __init__(self, query_executor: DorisQueryExecutor): def __init__(self, query_executor: DorisQueryExecutor):
self.query_executor = query_executor self.query_executor = query_executor
self.logger = logging.getLogger(__name__) self.logger = get_logger(__name__)
self.performance_records = [] self.performance_records = []
async def record_query_performance( async def record_query_performance(

View File

@@ -31,7 +31,7 @@ from dotenv import load_dotenv
from datetime import datetime, timedelta from datetime import datetime, timedelta
# Import unified logging configuration # Import unified logging configuration
from doris_mcp_server.utils.logger import get_logger from .logger import get_logger
# Configure logging # Configure logging
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -1215,33 +1215,39 @@ class MetadataExtractor:
try: try:
if self.connection_manager: if self.connection_manager:
import asyncio import asyncio
import concurrent.futures
import threading
# Try to run the async query # Always run in a separate thread with new event loop to avoid conflicts
try: def run_in_new_loop():
# Check if there's a running event loop # Create new event loop for this thread
loop = asyncio.get_running_loop() new_loop = asyncio.new_event_loop()
# If we're in an async context, we need to run in a separate thread asyncio.set_event_loop(new_loop)
import concurrent.futures try:
return new_loop.run_until_complete(
def run_in_new_loop(): self._execute_query_async(query, db_name, return_dataframe)
new_loop = asyncio.new_event_loop() )
asyncio.set_event_loop(new_loop) finally:
try: try:
return new_loop.run_until_complete( # Properly close the loop
self._execute_query_async(query, db_name, return_dataframe) pending = asyncio.all_tasks(new_loop)
) if pending:
new_loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
finally: finally:
new_loop.close() new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor: # Use ThreadPoolExecutor to run in separate thread
future = executor.submit(run_in_new_loop) with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
try:
return future.result(timeout=30) return future.result(timeout=30)
except concurrent.futures.TimeoutError:
except RuntimeError: logger.error("Query execution timed out after 30 seconds")
# No running loop, we can safely create one if return_dataframe:
return asyncio.run( import pandas as pd
self._execute_query_async(query, db_name, return_dataframe) return pd.DataFrame()
) else:
return []
else: else:
# Fallback: Return empty result # Fallback: Return empty result
logger.warning("No connection manager provided, returning empty result") logger.warning("No connection manager provided, returning empty result")

View File

@@ -20,7 +20,6 @@ Doris Security Management Module
Implements enterprise-level authentication, authorization, SQL security validation and data masking functionality Implements enterprise-level authentication, authorization, SQL security validation and data masking functionality
""" """
import hashlib
import logging import logging
import re import re
from dataclasses import dataclass from dataclasses import dataclass
@@ -32,6 +31,8 @@ import sqlparse
from sqlparse.sql import Statement from sqlparse.sql import Statement
from sqlparse.tokens import Keyword, Name from sqlparse.tokens import Keyword, Name
from .logger import get_logger
class SecurityLevel(Enum): class SecurityLevel(Enum):
"""Security level enumeration""" """Security level enumeration"""
@@ -87,7 +88,7 @@ class DorisSecurityManager:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.logger = logging.getLogger(__name__) self.logger = get_logger(__name__)
# Initialize security components # Initialize security components
self.auth_provider = AuthenticationProvider(config) self.auth_provider = AuthenticationProvider(config)
@@ -101,30 +102,24 @@ class DorisSecurityManager:
self.masking_rules = self._load_masking_rules() self.masking_rules = self._load_masking_rules()
def _load_blocked_keywords(self) -> set[str]: def _load_blocked_keywords(self) -> set[str]:
"""Load blocked SQL keywords""" """Load blocked SQL keywords from configuration"""
default_blocked = { # Load keywords from configuration, unified source of truth
"DROP",
"DELETE",
"TRUNCATE",
"ALTER",
"CREATE",
"INSERT",
"UPDATE",
"GRANT",
"REVOKE",
"EXEC",
"EXECUTE",
"SHUTDOWN",
"KILL",
}
# Load custom rules from configuration file
if hasattr(self.config, 'get'): if hasattr(self.config, 'get'):
custom_blocked = set(self.config.get("blocked_keywords", [])) # Dictionary-style configuration
blocked_keywords = self.config.get("blocked_keywords", [])
elif hasattr(self.config, 'security') and hasattr(self.config.security, 'blocked_keywords'):
# DorisConfig object, get through security.blocked_keywords
blocked_keywords = self.config.security.blocked_keywords
else: else:
custom_blocked = set() # Fallback to default if no configuration available
blocked_keywords = [
"DROP", "CREATE", "ALTER", "TRUNCATE",
"DELETE", "INSERT", "UPDATE",
"GRANT", "REVOKE",
"EXEC", "EXECUTE", "SHUTDOWN", "KILL"
]
return default_blocked.union(custom_blocked) return set(blocked_keywords)
def _load_sensitive_tables(self) -> dict[str, SecurityLevel]: def _load_sensitive_tables(self) -> dict[str, SecurityLevel]:
"""Load sensitive table configuration""" """Load sensitive table configuration"""
@@ -218,7 +213,7 @@ class AuthenticationProvider:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.logger = logging.getLogger(__name__) self.logger = get_logger(__name__)
self.session_cache = {} self.session_cache = {}
async def authenticate(self, auth_info: dict[str, Any]) -> AuthContext: async def authenticate(self, auth_info: dict[str, Any]) -> AuthContext:
@@ -328,7 +323,7 @@ class AuthorizationProvider:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.logger = logging.getLogger(__name__) self.logger = get_logger(__name__)
self.permission_cache = {} self.permission_cache = {}
# Load sensitive tables configuration # Load sensitive tables configuration
@@ -471,20 +466,37 @@ class SQLSecurityValidator:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.logger = logging.getLogger(__name__) self.logger = get_logger(__name__)
# Handle DorisConfig object or dictionary configuration # Handle DorisConfig object or dictionary configuration
if hasattr(config, 'get'): if hasattr(config, 'get'):
# Dictionary configuration # Dictionary configuration
self.blocked_keywords = set(config.get("blocked_keywords", [])) self.blocked_keywords = set(config.get("blocked_keywords", []))
self.max_query_complexity = config.get("max_query_complexity", 100) self.max_query_complexity = config.get("max_query_complexity", 100)
self.enable_security_check = config.get("enable_security_check", True)
elif hasattr(config, 'security'):
# DorisConfig object with security attribute - unified source from config
self.blocked_keywords = set(config.security.blocked_keywords)
self.max_query_complexity = config.security.max_query_complexity
self.enable_security_check = getattr(config.security, 'enable_security_check', True)
else: else:
# DorisConfig object, use default values # Fallback to default if no configuration available
self.blocked_keywords = set(["DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE"]) self.blocked_keywords = set([
"DROP", "CREATE", "ALTER", "TRUNCATE",
"DELETE", "INSERT", "UPDATE",
"GRANT", "REVOKE",
"EXEC", "EXECUTE", "SHUTDOWN", "KILL"
])
self.max_query_complexity = 100 self.max_query_complexity = 100
self.enable_security_check = True
async def validate(self, sql: str, auth_context: AuthContext) -> ValidationResult: async def validate(self, sql: str, auth_context: AuthContext) -> ValidationResult:
"""Validate SQL query security""" """Validate SQL query security"""
# If security check is disabled, always return valid
if not self.enable_security_check:
self.logger.debug("SQL security check is disabled, allowing all queries")
return ValidationResult(is_valid=True)
try: try:
# Parse SQL statement # Parse SQL statement
parsed = sqlparse.parse(sql)[0] parsed = sqlparse.parse(sql)[0]
@@ -676,7 +688,7 @@ class DataMaskingProcessor:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.logger = logging.getLogger(__name__) self.logger = get_logger(__name__)
self.masking_algorithms = self._init_masking_algorithms() self.masking_algorithms = self._init_masking_algorithms()
self.masking_rules = self._load_masking_rules() self.masking_rules = self._load_masking_rules()

View File

@@ -0,0 +1,783 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Security Analytics Tools Module
Provides data access analysis, user behavior monitoring, and security insights
"""
import time
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from collections import Counter, defaultdict
from .db import DorisConnectionManager
from .logger import get_logger
logger = get_logger(__name__)
class SecurityAnalyticsTools:
"""Security analytics tools for access pattern analysis and user monitoring"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
logger.info("SecurityAnalyticsTools initialized")
async def analyze_data_access_patterns(
self,
days: int = 7,
include_system_users: bool = False,
min_query_threshold: int = 5
) -> Dict[str, Any]:
"""
Analyze data access patterns for users and roles
Args:
days: Number of days to analyze
include_system_users: Whether to include system/service users
min_query_threshold: Minimum queries for a user to be included in analysis
Returns:
Comprehensive access pattern analysis
"""
try:
start_time = time.time()
# 🚀 PROGRESS: Initialize security analysis
logger.info("=" * 70)
logger.info(f"🔒 Starting Data Access Pattern Analysis")
logger.info(f"📅 Analysis period: {days} days")
logger.info(f"👥 Include system users: {include_system_users}")
logger.info(f"🎯 Min query threshold: {min_query_threshold}")
logger.info("=" * 70)
connection = await self.connection_manager.get_connection("query")
# Define analysis period
end_date = datetime.now()
start_date = end_date - timedelta(days=days)
logger.info(f"📊 Period: {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')}")
# 🚀 PROGRESS: Step 1 - Get audit log data
logger.info("📋 Step 1/5: Retrieving audit log data...")
audit_start = time.time()
audit_data = await self._get_audit_log_data(connection, start_date, end_date, include_system_users)
audit_time = time.time() - audit_start
if not audit_data:
logger.warning("⚠️ No audit data available for the specified period")
return {
"error": "No audit data available for the specified period",
"analysis_period": {
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat(),
"days": days
}
}
logger.info(f"✅ Retrieved {len(audit_data)} audit records in {audit_time:.2f}s")
# 🚀 PROGRESS: Step 2 - Analyze user access patterns
logger.info("👤 Step 2/5: Analyzing user access patterns...")
user_start = time.time()
user_access_analysis = await self._analyze_user_access_patterns(
audit_data, min_query_threshold
)
user_time = time.time() - user_start
logger.info(f"✅ Analyzed {len(user_access_analysis)} users in {user_time:.2f}s")
# 🚀 PROGRESS: Step 3 - Analyze role-based access
logger.info("🎭 Step 3/5: Analyzing role-based access patterns...")
role_start = time.time()
role_access_analysis = await self._analyze_role_access_patterns(
connection, user_access_analysis
)
role_time = time.time() - role_start
logger.info(f"✅ Role analysis completed in {role_time:.2f}s")
# 🚀 PROGRESS: Step 4 - Detect security anomalies
logger.info("🚨 Step 4/5: Detecting security anomalies...")
anomaly_start = time.time()
security_alerts = await self._detect_security_anomalies(
audit_data, user_access_analysis
)
anomaly_time = time.time() - anomaly_start
logger.info(f"✅ Found {len(security_alerts)} security alerts in {anomaly_time:.2f}s")
# Log alert summary
if security_alerts:
high_alerts = sum(1 for alert in security_alerts if alert.get("severity") == "high")
medium_alerts = sum(1 for alert in security_alerts if alert.get("severity") == "medium")
logger.info(f"🚨 Alert breakdown: {high_alerts} high, {medium_alerts} medium")
# 🚀 PROGRESS: Step 5 - Generate access insights
logger.info("💡 Step 5/5: Generating access insights...")
insights_start = time.time()
access_insights = await self._generate_access_insights(
user_access_analysis, role_access_analysis
)
insights_time = time.time() - insights_start
logger.info(f"✅ Access insights generated in {insights_time:.2f}s")
execution_time = time.time() - start_time
return {
"analysis_period": {
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat(),
"days": days
},
"analysis_timestamp": datetime.now().isoformat(),
"execution_time_seconds": round(execution_time, 3),
"user_access_summary": self._generate_user_access_summary(user_access_analysis),
"user_access_details": user_access_analysis,
"role_analysis": role_access_analysis,
"security_alerts": security_alerts,
"access_insights": access_insights,
"recommendations": self._generate_security_recommendations(security_alerts, access_insights)
}
except Exception as e:
logger.error(f"Data access pattern analysis failed: {str(e)}")
return {
"error": str(e),
"analysis_timestamp": datetime.now().isoformat()
}
# ==================== Private Helper Methods ====================
async def _get_audit_log_data(self, connection, start_date: datetime, end_date: datetime, include_system_users: bool) -> List[Dict]:
"""Retrieve audit log data for the specified period"""
try:
# System users filter
system_user_filter = ""
if not include_system_users:
system_users = ['root', 'admin', 'system', 'doris', 'information_schema']
user_list = ','.join([f'"{user}"' for user in system_users])
system_user_filter = f"AND `user` NOT IN ({user_list})"
audit_sql = f"""
SELECT
`user` as user_name,
`client_ip` as host,
`time` as query_time,
`stmt` as sql_statement,
`state` as query_status,
`scan_bytes` as scan_bytes,
`scan_rows` as scan_rows,
`return_rows` as return_rows,
`query_time` as execution_time_ms
FROM internal.__internal_schema.audit_log
WHERE `time` >= '{start_date.strftime('%Y-%m-%d %H:%M:%S')}'
AND `time` <= '{end_date.strftime('%Y-%m-%d %H:%M:%S')}'
AND `stmt` IS NOT NULL
AND `stmt` != ''
{system_user_filter}
ORDER BY `time` DESC
LIMIT 10000
"""
result = await connection.execute(audit_sql)
return result.data if result.data else []
except Exception as e:
logger.warning(f"Failed to get audit log data: {str(e)}")
# Try alternative method without detailed metrics
try:
simple_audit_sql = f"""
SELECT
`user` as user_name,
`client_ip` as host,
`time` as query_time,
`stmt` as sql_statement,
`state` as query_status
FROM internal.__internal_schema.audit_log
WHERE `time` >= '{start_date.strftime('%Y-%m-%d %H:%M:%S')}'
AND `time` <= '{end_date.strftime('%Y-%m-%d %H:%M:%S')}'
AND `stmt` IS NOT NULL
{system_user_filter}
ORDER BY `time` DESC
LIMIT 10000
"""
result = await connection.execute(simple_audit_sql)
return result.data if result.data else []
except Exception as e2:
logger.error(f"Failed to get simplified audit log data: {str(e2)}")
return []
async def _analyze_user_access_patterns(self, audit_data: List[Dict], min_query_threshold: int) -> List[Dict]:
"""Analyze access patterns for individual users"""
user_stats = defaultdict(lambda: {
"total_queries": 0,
"unique_tables_accessed": set(),
"hosts": set(),
"query_types": Counter(),
"query_times": [],
"failed_queries": 0,
"data_volume_read_bytes": 0,
"data_volume_read_rows": 0,
"hourly_pattern": [0] * 24,
"daily_pattern": [0] * 7,
"query_statements": []
})
# Process audit data
for entry in audit_data:
user_name = entry.get("user_name", "unknown")
query_time = entry.get("query_time")
sql_statement = entry.get("sql_statement", "")
query_status = entry.get("query_status", "")
stats = user_stats[user_name]
stats["total_queries"] += 1
# Extract table names from SQL
tables = self._extract_table_names_from_sql(sql_statement)
stats["unique_tables_accessed"].update(tables)
# Host tracking
if entry.get("host"):
stats["hosts"].add(entry["host"])
# Query type analysis
query_type = self._classify_query_type(sql_statement)
stats["query_types"][query_type] += 1
# Query time patterns
if query_time:
try:
if isinstance(query_time, str):
query_dt = datetime.fromisoformat(query_time.replace('Z', '+00:00'))
else:
query_dt = query_time
stats["query_times"].append(query_dt)
stats["hourly_pattern"][query_dt.hour] += 1
stats["daily_pattern"][query_dt.weekday()] += 1
except Exception:
pass
# Error tracking
if query_status and "error" in query_status.lower():
stats["failed_queries"] += 1
# Data volume tracking
if entry.get("scan_bytes"):
try:
stats["data_volume_read_bytes"] += int(entry["scan_bytes"])
except (ValueError, TypeError):
pass
if entry.get("scan_rows"):
try:
stats["data_volume_read_rows"] += int(entry["scan_rows"])
except (ValueError, TypeError):
pass
# Store sample queries
if len(stats["query_statements"]) < 10:
stats["query_statements"].append({
"sql": sql_statement[:200] + "..." if len(sql_statement) > 200 else sql_statement,
"timestamp": str(query_time),
"type": query_type
})
# Convert to analysis results
user_analysis = []
for user_name, stats in user_stats.items():
if stats["total_queries"] >= min_query_threshold:
# Calculate patterns and insights
access_pattern = self._classify_access_pattern(stats["hourly_pattern"])
table_access_frequency = dict(Counter(
table for entry in audit_data
if entry.get("user_name") == user_name
for table in self._extract_table_names_from_sql(entry.get("sql_statement", ""))
).most_common(10))
user_analysis.append({
"user_name": user_name,
"access_stats": {
"total_queries": stats["total_queries"],
"unique_tables_accessed": len(stats["unique_tables_accessed"]),
"unique_hosts": len(stats["hosts"]),
"data_volume_read_gb": round(stats["data_volume_read_bytes"] / (1024**3), 3),
"data_volume_read_rows": stats["data_volume_read_rows"],
"failed_queries": stats["failed_queries"],
"success_rate": round((stats["total_queries"] - stats["failed_queries"]) / stats["total_queries"], 3) if stats["total_queries"] > 0 else 0,
"peak_access_hour": stats["hourly_pattern"].index(max(stats["hourly_pattern"])) if max(stats["hourly_pattern"]) > 0 else None,
"access_pattern": access_pattern
},
"query_type_distribution": dict(stats["query_types"]),
"table_access_frequency": table_access_frequency,
"hosts_used": list(stats["hosts"]),
"sample_queries": stats["query_statements"],
"temporal_patterns": {
"hourly_distribution": stats["hourly_pattern"],
"daily_distribution": stats["daily_pattern"]
}
})
return sorted(user_analysis, key=lambda x: x["access_stats"]["total_queries"], reverse=True)
def _extract_table_names_from_sql(self, sql: str) -> List[str]:
"""Extract table names from SQL statement (simplified implementation)"""
if not sql:
return []
import re
# Simple regex patterns to match table names
patterns = [
r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
r'\bINTO\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
r'\bUPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
r'\bDELETE\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)'
]
tables = []
for pattern in patterns:
matches = re.findall(pattern, sql, re.IGNORECASE)
tables.extend(matches)
# Clean up table names (remove quotes, aliases, etc.)
cleaned_tables = []
for table in tables:
# Remove backticks, quotes, and get just the table name
clean_table = table.strip('`"\'').split(' ')[0]
if clean_table and not clean_table.upper() in ['SELECT', 'WHERE', 'AND', 'OR']:
cleaned_tables.append(clean_table)
return list(set(cleaned_tables))
def _classify_query_type(self, sql: str) -> str:
"""Classify SQL query type"""
if not sql:
return "unknown"
sql_upper = sql.upper().strip()
if sql_upper.startswith('SELECT'):
return "SELECT"
elif sql_upper.startswith('INSERT'):
return "INSERT"
elif sql_upper.startswith('UPDATE'):
return "UPDATE"
elif sql_upper.startswith('DELETE'):
return "DELETE"
elif sql_upper.startswith('CREATE'):
return "CREATE"
elif sql_upper.startswith('ALTER'):
return "ALTER"
elif sql_upper.startswith('DROP'):
return "DROP"
elif sql_upper.startswith('SHOW'):
return "SHOW"
elif sql_upper.startswith('DESCRIBE') or sql_upper.startswith('DESC'):
return "DESCRIBE"
else:
return "OTHER"
def _classify_access_pattern(self, hourly_pattern: List[int]) -> str:
"""Classify user access pattern based on hourly distribution"""
if not hourly_pattern or max(hourly_pattern) == 0:
return "no_pattern"
# Find peak hours
max_queries = max(hourly_pattern)
peak_hours = [i for i, count in enumerate(hourly_pattern) if count == max_queries]
# Business hours: 9-17
business_hours = set(range(9, 18))
peak_in_business_hours = any(hour in business_hours for hour in peak_hours)
# Night hours: 22-6
night_hours = set(list(range(22, 24)) + list(range(0, 7)))
peak_in_night_hours = any(hour in night_hours for hour in peak_hours)
if peak_in_business_hours and not peak_in_night_hours:
return "regular_business_hours"
elif peak_in_night_hours:
return "night_shift_or_batch"
elif len(peak_hours) > 6: # Distributed throughout day
return "distributed_access"
else:
return "irregular_pattern"
async def _analyze_role_access_patterns(self, connection, user_access_analysis: List[Dict]) -> Dict[str, Any]:
"""Analyze access patterns by role"""
try:
# Get user roles information
user_roles = await self._get_user_roles(connection)
# Group users by roles
role_stats = defaultdict(lambda: {
"user_count": 0,
"total_queries": 0,
"unique_tables": set(),
"query_types": Counter(),
"avg_queries_per_user": 0,
"users": []
})
# Process user access data
for user_data in user_access_analysis:
user_name = user_data["user_name"]
user_stats = user_data["access_stats"]
query_types = user_data["query_type_distribution"]
# Get user roles (default to 'unknown' if not found)
roles = user_roles.get(user_name, ["unknown"])
for role in roles:
stats = role_stats[role]
stats["user_count"] += 1
stats["total_queries"] += user_stats["total_queries"]
stats["users"].append(user_name)
# Aggregate query types
for query_type, count in query_types.items():
stats["query_types"][query_type] += count
# Calculate role analysis
role_analysis = {}
for role, stats in role_stats.items():
if stats["user_count"] > 0:
avg_queries = stats["total_queries"] / stats["user_count"]
# Calculate privilege usage (simplified)
total_role_queries = sum(stats["query_types"].values())
privilege_usage = {}
if total_role_queries > 0:
privilege_usage = {
query_type: round(count / total_role_queries, 3)
for query_type, count in stats["query_types"].items()
}
role_analysis[role] = {
"user_count": stats["user_count"],
"users": stats["users"],
"total_queries": stats["total_queries"],
"avg_queries_per_user": round(avg_queries, 1),
"query_type_distribution": dict(stats["query_types"]),
"privilege_usage": privilege_usage,
"activity_level": self._classify_role_activity_level(avg_queries)
}
return role_analysis
except Exception as e:
logger.warning(f"Failed to analyze role access patterns: {str(e)}")
return {}
async def _get_user_roles(self, connection) -> Dict[str, List[str]]:
"""Get user roles mapping"""
try:
# Try to get user role information
roles_sql = """
SELECT
User as user_name,
COALESCE(Default_role, 'default') as role_name
FROM mysql.user
"""
result = await connection.execute(roles_sql)
user_roles = defaultdict(list)
if result.data:
for row in result.data:
user_name = row.get("user_name", "")
role_name = row.get("role_name", "default")
if user_name:
user_roles[user_name].append(role_name)
return dict(user_roles)
except Exception as e:
logger.warning(f"Failed to get user roles: {str(e)}")
return {}
def _classify_role_activity_level(self, avg_queries: float) -> str:
"""Classify role activity level based on average queries"""
if avg_queries > 100:
return "high"
elif avg_queries > 20:
return "medium"
elif avg_queries > 5:
return "low"
else:
return "minimal"
async def _detect_security_anomalies(self, audit_data: List[Dict], user_access_analysis: List[Dict]) -> List[Dict]:
"""Detect potential security anomalies"""
alerts = []
# 1. Detect unusual access times
for user_data in user_access_analysis:
user_name = user_data["user_name"]
hourly_pattern = user_data["temporal_patterns"]["hourly_distribution"]
# Check for significant night-time activity
night_queries = sum(hourly_pattern[22:24]) + sum(hourly_pattern[0:6])
total_queries = sum(hourly_pattern)
if total_queries > 0 and night_queries / total_queries > 0.3: # >30% night activity
alerts.append({
"alert_type": "unusual_access_time",
"severity": "medium",
"user": user_name,
"description": f"User {user_name} has {night_queries/total_queries:.1%} of queries during night hours",
"night_query_percentage": round(night_queries/total_queries, 3),
"timestamp": datetime.now().isoformat()
})
# 2. Detect users with high failure rates
for user_data in user_access_analysis:
user_name = user_data["user_name"]
success_rate = user_data["access_stats"]["success_rate"]
total_queries = user_data["access_stats"]["total_queries"]
if total_queries > 10 and success_rate < 0.8: # <80% success rate
alerts.append({
"alert_type": "high_failure_rate",
"severity": "medium",
"user": user_name,
"description": f"User {user_name} has low query success rate ({success_rate:.1%})",
"success_rate": success_rate,
"total_queries": total_queries,
"timestamp": datetime.now().isoformat()
})
# 3. Detect unusual data volume access
data_volumes = [user["access_stats"]["data_volume_read_gb"] for user in user_access_analysis]
if data_volumes:
avg_volume = sum(data_volumes) / len(data_volumes)
std_dev = (sum((x - avg_volume) ** 2 for x in data_volumes) / len(data_volumes)) ** 0.5
threshold = avg_volume + 2 * std_dev # 2 standard deviations above mean
for user_data in user_access_analysis:
user_name = user_data["user_name"]
volume = user_data["access_stats"]["data_volume_read_gb"]
if volume > threshold and volume > 1.0: # >1GB and above threshold
alerts.append({
"alert_type": "unusual_data_volume",
"severity": "high" if volume > threshold * 2 else "medium",
"user": user_name,
"description": f"User {user_name} read {volume:.2f}GB (threshold: {threshold:.2f}GB)",
"data_volume_gb": volume,
"threshold_gb": round(threshold, 2),
"timestamp": datetime.now().isoformat()
})
# 4. Detect users accessing many different tables
for user_data in user_access_analysis:
user_name = user_data["user_name"]
unique_tables = user_data["access_stats"]["unique_tables_accessed"]
total_queries = user_data["access_stats"]["total_queries"]
# High table diversity might indicate privilege escalation or data mining
if unique_tables > 20 and total_queries > 50:
alerts.append({
"alert_type": "broad_table_access",
"severity": "medium",
"user": user_name,
"description": f"User {user_name} accessed {unique_tables} different tables",
"unique_tables_count": unique_tables,
"total_queries": total_queries,
"timestamp": datetime.now().isoformat()
})
return sorted(alerts, key=lambda x: {"high": 3, "medium": 2, "low": 1}.get(x["severity"], 0), reverse=True)
async def _generate_access_insights(self, user_access_analysis: List[Dict], role_analysis: Dict[str, Any]) -> Dict[str, Any]:
"""Generate access insights and patterns"""
insights = {
"user_behavior_patterns": {},
"role_effectiveness": {},
"security_posture": {}
}
# User behavior patterns
if user_access_analysis:
total_users = len(user_access_analysis)
active_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 10])
power_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 100])
# Access pattern distribution
pattern_distribution = Counter(
user["access_stats"]["access_pattern"] for user in user_access_analysis
)
insights["user_behavior_patterns"] = {
"total_users_analyzed": total_users,
"active_users": active_users,
"power_users": power_users,
"access_pattern_distribution": dict(pattern_distribution),
"avg_queries_per_user": round(
sum(u["access_stats"]["total_queries"] for u in user_access_analysis) / total_users, 1
) if total_users > 0 else 0
}
# Role effectiveness
if role_analysis:
most_active_role = max(role_analysis.items(), key=lambda x: x[1]["total_queries"])
least_active_role = min(role_analysis.items(), key=lambda x: x[1]["total_queries"])
insights["role_effectiveness"] = {
"total_roles": len(role_analysis),
"most_active_role": {
"role": most_active_role[0],
"total_queries": most_active_role[1]["total_queries"],
"user_count": most_active_role[1]["user_count"]
},
"least_active_role": {
"role": least_active_role[0],
"total_queries": least_active_role[1]["total_queries"],
"user_count": least_active_role[1]["user_count"]
},
"avg_users_per_role": round(
sum(role_info["user_count"] for role_info in role_analysis.values()) / len(role_analysis), 1
)
}
# Security posture assessment
if user_access_analysis:
users_with_failures = len([u for u in user_access_analysis if u["access_stats"]["failed_queries"] > 0])
users_night_access = len([
u for u in user_access_analysis
if any(u["temporal_patterns"]["hourly_distribution"][hour] > 0 for hour in list(range(22, 24)) + list(range(0, 6)))
])
insights["security_posture"] = {
"users_with_query_failures": users_with_failures,
"users_with_night_access": users_night_access,
"security_score": self._calculate_security_score(user_access_analysis),
"risk_level": self._assess_overall_risk_level(user_access_analysis)
}
return insights
def _calculate_security_score(self, user_access_analysis: List[Dict]) -> float:
"""Calculate overall security score (0-1, higher is better)"""
if not user_access_analysis:
return 0.0
total_users = len(user_access_analysis)
# Factors that contribute to security score
users_with_high_success_rate = len([u for u in user_access_analysis if u["access_stats"]["success_rate"] > 0.9])
users_with_normal_patterns = len([u for u in user_access_analysis if u["access_stats"]["access_pattern"] == "regular_business_hours"])
success_rate_score = users_with_high_success_rate / total_users
pattern_score = users_with_normal_patterns / total_users
# Combined score
overall_score = (success_rate_score * 0.6 + pattern_score * 0.4)
return round(overall_score, 3)
def _assess_overall_risk_level(self, user_access_analysis: List[Dict]) -> str:
"""Assess overall security risk level"""
security_score = self._calculate_security_score(user_access_analysis)
if security_score > 0.8:
return "low"
elif security_score > 0.6:
return "medium"
else:
return "high"
def _generate_user_access_summary(self, user_access_analysis: List[Dict]) -> Dict[str, Any]:
"""Generate summary statistics for user access"""
if not user_access_analysis:
return {
"total_users": 0,
"active_users": 0,
"high_activity_users": 0,
"dormant_users": 0
}
total_users = len(user_access_analysis)
active_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 10])
high_activity_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 100])
dormant_users = total_users - active_users
return {
"total_users": total_users,
"active_users": active_users,
"high_activity_users": high_activity_users,
"dormant_users": dormant_users,
"activity_distribution": {
"high": high_activity_users,
"medium": active_users - high_activity_users,
"low": dormant_users
}
}
def _generate_security_recommendations(self, security_alerts: List[Dict], access_insights: Dict[str, Any]) -> List[Dict]:
"""Generate security recommendations based on analysis"""
recommendations = []
# Recommendations based on alerts
if security_alerts:
high_severity_alerts = [alert for alert in security_alerts if alert["severity"] == "high"]
if high_severity_alerts:
recommendations.append({
"type": "urgent_security_review",
"priority": "high",
"description": f"Found {len(high_severity_alerts)} high-severity security alerts",
"action": "Immediate review of flagged users and access patterns required",
"affected_users": list(set(alert["user"] for alert in high_severity_alerts if "user" in alert))
})
# Night access recommendations
night_access_alerts = [alert for alert in security_alerts if alert["alert_type"] == "unusual_access_time"]
if night_access_alerts:
recommendations.append({
"type": "access_time_policy",
"priority": "medium",
"description": f"{len(night_access_alerts)} users have significant night-time access",
"action": "Review access time policies and consider time-based restrictions",
"affected_users": [alert["user"] for alert in night_access_alerts]
})
# Recommendations based on insights
security_posture = access_insights.get("security_posture", {})
risk_level = security_posture.get("risk_level", "unknown")
if risk_level == "high":
recommendations.append({
"type": "overall_security_improvement",
"priority": "high",
"description": "Overall security posture indicates high risk",
"action": "Comprehensive security audit and policy review recommended"
})
# Role-based recommendations
role_effectiveness = access_insights.get("role_effectiveness", {})
if role_effectiveness and role_effectiveness.get("total_roles", 0) < 3:
recommendations.append({
"type": "role_management",
"priority": "medium",
"description": "Limited role diversity detected",
"action": "Consider implementing more granular role-based access control"
})
return recommendations

147
examples/cursor/README.md Normal file
View File

@@ -0,0 +1,147 @@
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
# Cursor Example: Integrating Doris MCP Server
This guide provides step-by-step instructions on how to integrate the `doris-mcp-server` with the [Cursor](https://cursor.sh/) IDE. This integration allows you to interact with your Apache Doris database using natural language queries directly within Cursor's AI chat.
## Table of Contents
* [Prerequisites](#prerequisites)
* [Step 1: Set Up the Project](#step-1-set-up-the-project)
* [Step 2: Configure the MCP Server in Cursor](#step-2-configure-the-mcp-server-in-cursor)
* [Step 3: Verify the Integration](#step-3-verify-the-integration)
* [Step 4: Query Your Database](#step-4-query-your-database)
* [Example 1: List Tables](#example-1-list-tables)
* [Example 2: Analyze Sales Trends](#example-2-analyze-sales-trends)
---
### Prerequisites
Before you begin, ensure you have the following installed and configured:
* The **Cursor** IDE
* **Git** for cloning the repository
* Access to an **Apache Doris** cluster (FE host, port, username, and password)
* **uv**, a fast Python package installer and runner
You can install `uv` with one of the following commands:
```bash
# For macOS (recommended)
brew install uv
# For other systems using pipx
pipx install uv
```
---
### Step 1: Set Up the Project
First, clone the `doris-mcp-server` repository to your local machine:
```bash
git clone https://github.com/apache/doris-mcp-server.git
cd doris-mcp-server
```
The necessary dependencies are listed in `requirements.txt` and will be managed automatically by `uv` in the next step.
---
### Step 2: Configure the MCP Server in Cursor
1. Open the cloned `doris-mcp-server` directory in Cursor.
2. Click the ⚙️ icon (top-right), then go to **Tools & Integrations**.
![add MCP Server](../images/cursor_add_mcp.png)
3. Click **Add a custom MCP Server**.
4. Paste the following JSON configuration:
```json
{
"mcpServers": {
"doris-mcp": {
"command": "uv",
"args": [
"run",
"--project",
"/path/to/your/doris-mcp-server",
"doris-mcp-server"
],
"env": {
"DORIS_HOST": "your_doris_fe_host",
"DORIS_PORT": "9030",
"DORIS_USER": "your_username",
"DORIS_PASSWORD": "your_password",
"DORIS_DATABASE": "ssb"
}
}
}
}
```
> ⚠️ **Important:**
>
> * Replace `"/path/to/your/doris-mcp-server"` with the **absolute path** to your local project directory.
> * Fill in your actual Doris FE host, username, password, and database name.
---
### Step 3: Verify the Integration
Once saved, go back to the **Settings** panel. If everything is configured correctly, youll see a green status dot next to `doris-mcp-server`, along with available tools like `exec_query`.
![MCP Server](../images/cursor_doris-mcp.png)
---
### Step 4: Query Your Database
You can now chat with Cursor Agent to run SQL queries against your Doris database.
1. Open the chat panel using `Cmd + K` (macOS) or `Ctrl + K` (Windows/Linux), or click the chat icon in the top-right.
2. Switch to **Agent Mode**.
3. Start asking questions using natural language.
![ask](../images/cursor_agent.png)
---
#### Example 1: List Tables
> **Prompt:** What tables are in the `ssb` database?
The agent will call the `get_db_table_list` tool and return the results.
![ask](../images/cursor_ask1.png)
---
#### Example 2: Analyze Sales Trends
> **Prompt:** What has been the sales trend over the past ten years in the `ssb` database, and which year had the fastest growth?
The agent will generate an appropriate SQL query, send it to the MCP server, and interpret the results to give you growth trends and highlights.
![ask](../images/cursor_ask2.png)

156
examples/dify/README.md Normal file
View File

@@ -0,0 +1,156 @@
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
# Dify Example: Integrating Doris MCP Server
This document demonstrates how to integrate and use `doris-mcp-server` in Dify to perform Doris SQL calls via MCP.
## Table of Contents
- [Prerequisites](#prerequisites)
- [Starting the MCP Server](#starting-the-mcp-server)
- [Ngrok Tunnel (Optional)](#ngrok-tunnel-optional)
- [Installing & Configuring the Plugin in Dify](#installing--configuring-the-plugin-in-dify)
- [Creating a Dify App](#creating-a-dify-app)
- [Adding MCP Tools](#adding-mcp-tools)
- [Example Calls](#example-calls)
-----
### Prerequisites
First, install `mcp-doris-server`:
```bash
pip install mcp-doris-server
```
## Starting the MCP Server
Run the startup script:
```bash
# Full configuration with database connection
doris-mcp-server \
--transport http \
--host 0.0.0.0 \
--port 3000 \
--db-host 127.0.0.1 \
--db-port 9030 \
--db-user root \
--db-password your_password
```
If successful, you'll see logs similar to this:
![Server start logs](../images/dify_start_server.png)
-----
## Ngrok Tunnel (Optional)
If your Dify deployment requires a publicly accessible endpoint, you can use the **ngrok** tool. Ngrok is a third-party service that securely exposes local servers to the internet.
-----
## Installing & Configuring the Plugin in Dify
1. In the Dify console, go to **Plugin Marketplace**, search for, and install **MCPSSE / StreamableHTTP**:
![Install plugin](../images/dify_install_plugin.png)
2. After installation, click **Configure** and set the URL to your public or local address. For example, if you're using `ngrok`, this should be the public URL `ngrok` provides, in the format `https://<your-domain>/mcp`. If Dify can directly access your local server, use `http://localhost:3000/mcp`.
```json
{
"doris_mcp_server": {
"transport": "streamable_http",
"url": "https://<your-domain>/mcp"
}
}
```
![Configure plugin](../images/dify_config_mcp.png)
3. Click **Save**. If configured correctly, you'll see a green **Authorized** indicator:
![Authorized](../images/dify_authorized.png)
-----
## Creating a Dify App
1. In the Dify console, click **New App** → **Blank App**.
![Create app](../images/dify_create_app.png)
2. Select **Agent** as the template and set the **App Name** (e.g., `Doris ChatBI`).
![Agent setup](../images/dify_agent_setup.png)
3. Import from DSL,[dify_doris_dsl.yml](dify_doris_dsl.yml)
-----
## Instructions & Tool Configuration
### Instruction Block
Paste the following into the **Instruction** field:
```
<instruction>
Use MCP tools to complete tasks as much as possible. Carefully read the annotations, method names, and parameter descriptions of each tool. Please follow these steps:
1. Analyze the user's question and match the most appropriate tool.
2. Use tool names and parameters exactly as defined; do not invent new ones.
3. Pass parameters in the required JSON format.
4. When calling tools, use:
{"mcp_sse_call_tool": {"tool_name": "<tool_name>", "arguments": "{}"}}
5. Output plain text only—no XML tags.
<input>
User question: user_query
</input>
<output>
Return tool results or a final answer, including analysis.
</output>
</instruction>
```
### Adding MCP Tools
In the **Tools** pane, click **Add** twice to add two entries, both named `mcp_sse` (they will inherit the transport and URL from the plugin):
![Add tools](../images/dify_add_tools.png)
-----
## Example Calls
### List Tables in Database
* **User**: What tables are in the database?
* **Result**: Dify will call the MCP tool to run `SHOW TABLES` and return the list.
![Query tables](../images/dify_query_tabels.png)
### Sales Trend Over Ten Years
* **User**: What has been the sales trend over the past ten years in the ssb database, and which year had the fastest growth?
* **Result**: The tool will execute the SQL, calculate growth rates, and return data.
![Sales trend](../images/dify_sale_trend.png)

View File

@@ -0,0 +1,127 @@
app:
description: ''
icon: 🤖
icon_background: '#FFEAD5'
mode: agent-chat
name: doris
use_icon_as_answer_icon: false
dependencies:
- current_identifier: null
type: marketplace
value:
marketplace_plugin_unique_identifier: langgenius/deepseek:0.0.5@21408d5c48cd9f18d66b08883d0999fe89e6d049c891324c2229dea23b9665d5
- current_identifier: null
type: marketplace
value:
marketplace_plugin_unique_identifier: junjiem/mcp_sse:0.2.1@53cc613667fcf91dd7208dd5f6d2c8df3c7ff0af8b79e8f3c0a430f1b39bda4c
kind: app
model_config:
agent_mode:
enabled: true
max_iteration: 10
prompt: null
strategy: function_call
tools:
- enabled: true
isDeleted: false
notAuthor: false
provider_id: junjiem/mcp_sse/mcp_sse
provider_name: junjiem/mcp_sse/mcp_sse
provider_type: builtin
tool_label: 获取 MCP 工具列表
tool_name: mcp_sse_list_tools
tool_parameters:
prompts_as_tools: 1
resources_as_tools: 1
servers_config: null
- enabled: true
isDeleted: false
notAuthor: false
provider_id: junjiem/mcp_sse/mcp_sse
provider_name: junjiem/mcp_sse/mcp_sse
provider_type: builtin
tool_label: 调用 MCP 工具
tool_name: mcp_sse_call_tool
tool_parameters:
arguments: ''
prompts_as_tools: ''
resources_as_tools: ''
servers_config: ''
tool_name: ''
annotation_reply:
enabled: false
chat_prompt_config: {}
completion_prompt_config: {}
dataset_configs:
datasets:
datasets: []
reranking_enable: true
reranking_mode: reranking_model
reranking_model:
reranking_model_name: ''
reranking_provider_name: ''
retrieval_model: multiple
top_k: 4
dataset_query_variable: ''
external_data_tools: []
file_upload:
allowed_file_extensions:
- .JPG
- .JPEG
- .PNG
- .GIF
- .WEBP
- .SVG
- .MP4
- .MOV
- .MPEG
- .WEBM
allowed_file_types: []
allowed_file_upload_methods:
- remote_url
- local_file
enabled: false
image:
detail: high
enabled: false
number_limits: 3
transfer_methods:
- remote_url
- local_file
number_limits: 3
model:
completion_params:
stop: []
mode: chat
name: deepseek-chat
provider: langgenius/deepseek/deepseek
more_like_this:
enabled: false
opening_statement: ''
pre_prompt: "<instruction>\nUse MCP tools to complete tasks as much as possible.\
\ Carefully read the annotations, method names, and parameter descriptions of\
\ each tool. Please follow these steps:\n1. Analyze the user's question and match\
\ the most appropriate tool.\n2. Use tool names and parameters exactly as defined;\
\ do not invent new ones.\n3. Pass parameters in the required JSON format.\n4.\
\ When calling tools, use:\n {\"mcp_sse_call_tool\": {\"tool_name\": \"<tool_name>\"\
, \"arguments\": \"{}\"}}\n5. Output plain text only—no XML tags.\n<input>\nUser\
\ question: user_query\n</input>\n<output>\nReturn tool results or a final answer,\
\ including analysis.\n</output>\n</instruction>"
prompt_type: simple
retriever_resource:
enabled: true
sensitive_word_avoidance:
configs: []
enabled: false
type: ''
speech_to_text:
enabled: false
suggested_questions: []
suggested_questions_after_answer:
enabled: false
text_to_speech:
enabled: false
language: ''
voice: ''
user_input_form: []
version: 0.3.0

Binary file not shown.

After

Width:  |  Height:  |  Size: 323 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 673 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 118 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 232 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 258 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 127 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 317 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 369 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 272 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 73 KiB

View File

@@ -20,7 +20,7 @@ build-backend = "hatchling.build"
[project] [project]
name = "doris-mcp-server" name = "doris-mcp-server"
version = "0.3.0" version = "0.5.1"
description = "Enterprise-grade Model Context Protocol (MCP) server implementation for Apache Doris" description = "Enterprise-grade Model Context Protocol (MCP) server implementation for Apache Doris"
authors = [ authors = [
{name = "Yijia Su", email = "freeoneplus@apache.org"} {name = "Yijia Su", email = "freeoneplus@apache.org"}
@@ -42,10 +42,14 @@ classifiers = [
dependencies = [ dependencies = [
# Core MCP dependencies # Core MCP dependencies
"mcp>=1.8.0", "mcp>=1.8.0,<2.0.0",
# Database drivers # Database drivers
"aiomysql>=0.2.0", "aiomysql>=0.2.0",
"PyMySQL>=1.1.0", "PyMySQL>=1.1.0",
# ADBC (Arrow Flight SQL) dependencies
"adbc-driver-manager>=0.8.0",
"adbc-driver-flightsql>=0.8.0",
"pyarrow>=14.0.0",
# Async and utility libraries # Async and utility libraries
"asyncio-mqtt>=0.16.0", "asyncio-mqtt>=0.16.0",
"aiofiles>=23.0.0", "aiofiles>=23.0.0",
@@ -147,10 +151,8 @@ Issues = "https://github.com/apache/doris-mcp-server/issues"
Changelog = "https://github.com/apache/doris-mcp-server/blob/main/CHANGELOG.md" Changelog = "https://github.com/apache/doris-mcp-server/blob/main/CHANGELOG.md"
[project.scripts] [project.scripts]
mcp-doris-server = "doris_mcp_server.main:main_sync"
doris-mcp-server = "doris_mcp_server.main:main_sync" doris-mcp-server = "doris_mcp_server.main:main_sync"
doris-mcp-client = "doris_mcp_server.client:main" doris-mcp-client = "doris_mcp_server.client:main"
mcp-doris-client = "doris_mcp_server.client:main"
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]
packages = ["doris_mcp_server"] packages = ["doris_mcp_server"]
@@ -165,7 +167,7 @@ include = [
# Black configuration # Black configuration
[tool.black] [tool.black]
line-length = 88 line-length = 88
target-version = ['py312'] target-version = ['py310', 'py311', 'py312']
include = '\.pyi?$' include = '\.pyi?$'
extend-exclude = ''' extend-exclude = '''
/( /(

20
requirements-dev.txt Normal file
View File

@@ -0,0 +1,20 @@
# Development dependencies - auto-generated from pyproject.toml
# Installation command: pip install -r requirements-dev.txt
pytest>=7.4.0
pytest-asyncio>=0.23.0
pytest-cov>=4.1.0
pytest-mock>=3.12.0
pytest-xdist>=3.5.0
ruff>=0.1.0
black>=23.12.0
isort>=5.13.0
flake8>=7.0.0
mypy>=1.8.0
bandit>=1.7.0
safety>=2.3.0
sphinx>=7.2.0
sphinx-rtd-theme>=2.0.0
myst-parser>=2.0.0
pre-commit>=3.6.0
tox>=4.11.0

View File

@@ -1,26 +1,13 @@
# Licensed to the Apache Software Foundation (ASF) under one # Main dependencies - auto-generated from pyproject.toml
# or more contributor license agreements. See the NOTICE file # Do not edit this file manually, use 'python generate_requirements.py' to regenerate
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# 主要依赖 - 从 pyproject.toml 自动生成
# 请不要手动编辑此文件,使用 python generate_requirements.py 重新生成
# === 核心依赖 === # === Core Dependencies ===
mcp>=1.0.0 mcp>=1.8.0,<2.0.0
aiomysql>=0.2.0 aiomysql>=0.2.0
PyMySQL>=1.1.0 PyMySQL>=1.1.0
adbc-driver-manager>=0.8.0
adbc-driver-flightsql>=0.8.0
pyarrow>=14.0.0
asyncio-mqtt>=0.16.0 asyncio-mqtt>=0.16.0
aiofiles>=23.0.0 aiofiles>=23.0.0
aiohttp>=3.9.0 aiohttp>=3.9.0
@@ -53,8 +40,11 @@ click>=8.1.0
typer>=0.9.0 typer>=0.9.0
requests>=2.31.0 requests>=2.31.0
tqdm>=4.66.0 tqdm>=4.66.0
pytest>=8.4.0
pytest-asyncio>=1.0.0
pytest-cov>=6.1.1
# === 开发依赖 === # === Development Dependencies ===
pytest>=7.4.0 pytest>=7.4.0
pytest-asyncio>=0.23.0 pytest-asyncio>=0.23.0
pytest-cov>=4.1.0 pytest-cov>=4.1.0

263
test/README.md Normal file
View File

@@ -0,0 +1,263 @@
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
# Doris MCP Server Testing System
## Overview
This testing system adopts a layered architecture, including unit tests, integration tests, and client-server tests. The testing system assumes the server is already properly started and focuses on testing functionality rather than startup configuration.
## Testing Architecture
### 1. Unit Tests
- **Location**: `test/security/`, `test/utils/`, `test/tools/`
- **Purpose**: Test individual module functionality
- **Features**: Uses Mock objects, no dependency on external services
### 2. Integration Tests
- **Location**: `test/integration/`
- **Purpose**: Test collaboration between modules
- **Features**: Test complete workflows
### 3. Client-Server Tests
- **Location**: `test/tools/test_tools_client_server.py`, `test/utils/test_query_executor_client_server.py`
- **Purpose**: Test actual server functionality through MCP client
- **Features**: Assumes server is running, skips tests if server is not available
## Configuration Files
### test_config.json
Test configuration file defines how to connect to the running server:
```json
{
"server_endpoints": {
"http": {
"url": "http://localhost:3000/mcp",
"timeout": 30
},
"stdio": {
"command": "uv",
"args": ["run", "python", "-m", "doris_mcp_server.main", "--transport", "stdio"],
"timeout": 30
}
},
"test_settings": {
"default_transport": "http",
"retry_attempts": 3,
"retry_delay": 1.0,
"test_timeout": 60,
"enable_performance_tests": true,
"enable_security_tests": true
}
}
```
## Usage
### 1. Start the Server
Before running client-server tests, you need to start the server first:
#### HTTP Mode (Recommended)
```bash
# Start HTTP server
./start_server.sh
# or
uv run python -m doris_mcp_server.main --transport http --port 3000
```
#### Stdio Mode
```bash
# Stdio mode is started directly by the client, no need to pre-start
```
### 2. Run Tests
#### Run All Tests
```bash
python -m pytest test/ -v
```
#### Run Unit Tests
```bash
# Security module tests
python -m pytest test/security/ -v
# Tools module tests
python -m pytest test/tools/test_tools_manager.py -v
# Query executor tests
python -m pytest test/utils/test_query_executor.py -v
```
#### Run Integration Tests
```bash
python -m pytest test/integration/ -v
```
#### Run Client-Server Tests
```bash
# Tools Client-Server tests
python -m pytest test/tools/test_tools_client_server.py -v
# QueryExecutor Client-Server tests
python -m pytest test/utils/test_query_executor_client_server.py -v
```
### 3. Test Configuration
#### Modify Server Endpoints
Edit the `test/test_config.json` file:
```json
{
"server_endpoints": {
"http": {
"url": "http://your-server:port/mcp"
}
}
}
```
#### Enable/Disable Specific Tests
```json
{
"test_settings": {
"enable_performance_tests": false, // Disable performance tests
"enable_security_tests": true // Enable security tests
}
}
```
## Test Status
### ✅ Completed Test Modules
1. **Security Module** (100% Pass)
- Authentication tests: 5/5 passed
- Authorization tests: 7/7 passed
- Data masking tests: 13/13 passed
- SQL validation tests: 10/10 passed
- Security manager tests: 7/7 passed
- Coverage: 88%
2. **Client-Server Test Architecture** (Implemented)
- Automatic server connection status detection
- Automatically skip tests when server is not running
- Support for both HTTP and Stdio transport modes
### 🔄 Tests Requiring Server Running
1. **Tools Client-Server Tests**
- Tool list retrieval
- SQL query execution
- Database list retrieval
- Table schema queries
- Performance statistics
- Error handling
- Security authentication
2. **QueryExecutor Client-Server Tests**
- Simple query execution
- Database queries
- Information schema queries
- Parameterized queries
- Error handling
- Security authentication
## Testing Best Practices
### 1. Server Startup Check
All client-server tests automatically check server connection status:
- If server is running normally, execute actual tests
- If server is not running, skip tests and display appropriate message
### 2. Test Isolation
- Unit tests use Mock objects, no dependency on external services
- Integration tests use controlled test environments
- Client-server tests connect to actually running servers
### 3. Error Handling
- Tests don't assume specific success/failure results
- Verify response structure rather than specific content
- Gracefully handle connection failures and timeouts
### 4. Configuration Management
- Use configuration files to manage test parameters
- Support configuration switching for different environments
- Provide reasonable default values
## Troubleshooting
### 1. Server Connection Failure
```
ERROR: Server is not running or not accessible
```
**Solution**: Ensure the server is started and listening on the correct port
### 2. Import Errors
```
ImportError: cannot import name 'DorisUnifiedClient'
```
**Solution**: Check Python path and dependency installation
### 3. Test Timeouts
```
TimeoutError: Test execution timeout
```
**Solution**: Increase timeout settings in `test_config.json`
## Development Guide
### Adding New Client-Server Tests
1. Add test methods in the appropriate test file
2. Use `@pytest.mark.asyncio` decorator
3. Get test client through `client` fixture
4. Implement test callback function
5. Verify response structure
Example:
```python
@pytest.mark.asyncio
async def test_new_feature_via_client(self, client, test_config):
"""Test new feature through client"""
async def test_callback(client_instance):
result = await client_instance.call_tool("new_tool", {
"param": "value"
})
assert "success" in result
return result
result = await client.connect_and_run(test_callback)
assert "success" in result
```
### Modifying Test Configuration
Edit the `test/test_config.json` file to adjust:
- Server endpoints
- Timeout settings
- Test data
- Feature switches
## Summary
This testing system provides complete test coverage, from unit tests to end-to-end client-server tests. Through reasonable configuration and automated connection detection, it ensures tests can run stably in different environments.

16
test/__init__.py Normal file
View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

115
test/conftest.py Normal file
View File

@@ -0,0 +1,115 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Pytest configuration and fixtures for Doris MCP Server tests
"""
import asyncio
import logging
import sys
from pathlib import Path
import pytest
# Add project root to Python path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
# Configure logging for tests
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
@pytest.fixture(scope="session")
def event_loop():
"""Create an instance of the default event loop for the test session."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture
def test_config():
"""Test configuration fixture"""
from doris_mcp_server.utils.config import DorisConfig, DatabaseConfig, SecurityConfig
config = DorisConfig()
# Database configuration
config.database.host = "localhost"
config.database.port = 9030
config.database.user = "test_user"
config.database.password = "test_password"
config.database.database = "test_db"
config.database.health_check_interval = 60
config.database.min_connections = 5
config.database.max_connections = 20
config.database.connection_timeout = 30
config.database.max_connection_age = 3600
# Security configuration
config.security.enable_masking = True
config.security.auth_type = "token"
config.security.token_secret = "test_secret"
config.security.token_expiry = 3600
return config
@pytest.fixture
def sample_data():
"""Provide sample test data"""
return [
{
"id": 1,
"name": "张三",
"phone": "13812345678",
"email": "zhangsan@example.com",
"id_card": "110101199001011234",
"salary": 50000
},
{
"id": 2,
"name": "李四",
"phone": "13987654321",
"email": "lisi@example.com",
"id_card": "110101199002022345",
"salary": 60000
}
]
@pytest.fixture
def test_sql_queries():
"""Provide test SQL queries"""
return {
"safe_select": "SELECT name, email FROM users WHERE department = 'sales'",
"dangerous_drop": "DROP TABLE users",
"sql_injection": "SELECT * FROM users WHERE id = 1; DROP TABLE users;",
"union_injection": "SELECT name FROM users UNION SELECT password FROM admin_users",
"comment_injection": "SELECT * FROM users WHERE id = 1 -- AND password = 'secret'",
"complex_query": """
SELECT u.name, u.email, d.department_name
FROM users u
JOIN departments d ON u.department_id = d.id
WHERE u.status = 'active'
ORDER BY u.created_at DESC
"""
}

View File

@@ -0,0 +1,288 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
End-to-end integration tests
"""
import json
import pytest
from unittest.mock import Mock, patch
from doris_mcp_server.main import DorisServer
from doris_mcp_server.utils.config import DorisConfig
from doris_mcp_server.utils.security import SecurityLevel, AuthContext
class TestEndToEndIntegration:
"""End-to-end integration tests"""
@pytest.fixture
def mock_config(self):
"""Create mock configuration"""
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
config = Mock(spec=DorisConfig)
# Add database config
config.database = Mock(spec=DatabaseConfig)
config.database.host = "localhost"
config.database.port = 9030
config.database.user = "test_user"
config.database.password = "test_password"
config.database.database = "test_db"
config.database.health_check_interval = 60
config.database.min_connections = 5
config.database.max_connections = 20
config.database.connection_timeout = 30
config.database.max_connection_age = 3600
# Add security config
config.security = Mock(spec=SecurityConfig)
config.security.enable_masking = True
config.security.auth_type = "token"
config.security.token_secret = "test_secret"
config.security.token_expiry = 3600
return config
@pytest.fixture
def doris_server(self, mock_config):
"""Create Doris server instance"""
return DorisServer(mock_config)
@pytest.mark.asyncio
async def test_complete_query_workflow_with_security(self, doris_server, sample_data):
"""Test complete query workflow with security"""
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
mock_execute.return_value = sample_data
# Mock authentication
with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth:
mock_auth.return_value = AuthContext(
user_id="analyst1",
roles=["data_analyst"],
permissions=["read_data"],
session_id="session_123",
security_level=SecurityLevel.INTERNAL
)
# Mock authorization
with patch.object(doris_server.security_manager, 'authorize_resource_access') as mock_authz:
mock_authz.return_value = True
# Mock SQL validation
with patch.object(doris_server.security_manager, 'validate_sql_security') as mock_validate:
from doris_mcp_server.utils.security import ValidationResult
mock_validate.return_value = ValidationResult(is_valid=True)
# Mock data masking
with patch.object(doris_server.security_manager, 'apply_data_masking') as mock_mask:
masked_data = [
{
"id": 1,
"name": "张三",
"phone": "138****5678",
"email": "z*******n@example.com",
"id_card": "110101****1234",
"salary": 50000
}
]
mock_mask.return_value = masked_data
# Simulate complete workflow
auth_info = {"type": "token", "token": "valid_token_123"}
auth_context = await doris_server.security_manager.authenticate_request(auth_info)
resource_uri = "/api/table/users"
has_access = await doris_server.security_manager.authorize_resource_access(
auth_context, resource_uri
)
assert has_access is True
sql = "SELECT * FROM users LIMIT 1"
validation = await doris_server.security_manager.validate_sql_security(
sql, auth_context
)
assert validation.is_valid is True
raw_data = await doris_server.tools_manager.query_executor.execute_query(sql)
final_data = await doris_server.security_manager.apply_data_masking(
raw_data, auth_context
)
# Verify data is properly masked
assert final_data[0]["phone"] == "138****5678"
assert final_data[0]["email"] == "z*******n@example.com"
@pytest.mark.asyncio
async def test_security_violation_workflow(self, doris_server):
"""Test security violation detection workflow"""
with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth:
mock_auth.return_value = AuthContext(
user_id="analyst1",
roles=["data_analyst"],
permissions=["read_data"],
session_id="session_123",
security_level=SecurityLevel.INTERNAL
)
# Test unauthorized resource access
with patch.object(doris_server.security_manager, 'authorize_resource_access') as mock_authz:
mock_authz.return_value = False
auth_context = await doris_server.security_manager.authenticate_request({
"type": "token", "token": "valid_token_123"
})
# Try to access confidential resource
resource_uri = "/api/table/payment_records"
has_access = await doris_server.security_manager.authorize_resource_access(
auth_context, resource_uri
)
assert has_access is False
@pytest.mark.asyncio
async def test_sql_injection_prevention_workflow(self, doris_server):
"""Test SQL injection prevention workflow"""
with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth:
mock_auth.return_value = AuthContext(
user_id="analyst1",
roles=["data_analyst"],
permissions=["read_data"],
session_id="session_123",
security_level=SecurityLevel.INTERNAL
)
auth_context = await doris_server.security_manager.authenticate_request({
"type": "token", "token": "valid_token_123"
})
# Test SQL injection attempt
malicious_sql = "SELECT * FROM users WHERE id = 1; DROP TABLE users;"
validation = await doris_server.security_manager.validate_sql_security(
malicious_sql, auth_context
)
assert validation.is_valid is False
assert validation.risk_level == "high"
@pytest.mark.asyncio
async def test_admin_bypass_workflow(self, doris_server, sample_data):
"""Test admin user bypassing restrictions"""
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
mock_execute.return_value = sample_data
with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth:
mock_auth.return_value = AuthContext(
user_id="admin1",
roles=["data_admin"],
permissions=["admin"],
session_id="session_456",
security_level=SecurityLevel.SECRET
)
# Admin should access any resource
with patch.object(doris_server.security_manager, 'authorize_resource_access') as mock_authz:
mock_authz.return_value = True
# Admin should see original data (no masking)
with patch.object(doris_server.security_manager, 'apply_data_masking') as mock_mask:
mock_mask.return_value = sample_data # Original data
auth_context = await doris_server.security_manager.authenticate_request({
"type": "basic", "username": "admin", "password": "admin123"
})
# Admin accesses secret resource
resource_uri = "/api/table/payment_records"
has_access = await doris_server.security_manager.authorize_resource_access(
auth_context, resource_uri
)
assert has_access is True
# Admin sees original data
raw_data = await doris_server.tools_manager.query_executor.execute_query(
"SELECT * FROM users LIMIT 1"
)
final_data = await doris_server.security_manager.apply_data_masking(
raw_data, auth_context
)
# Should be original data (no masking)
assert final_data[0]["phone"] == "13812345678"
assert final_data[0]["email"] == "zhangsan@example.com"
@pytest.mark.asyncio
async def test_tool_execution_with_security(self, doris_server):
"""Test tool execution with security checks"""
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
mock_execute.return_value = [{"Database": "test_db"}]
# Test tool execution through tools manager
result = await doris_server.tools_manager.call_tool("get_db_list", {})
result_data = json.loads(result)
# Accept either success result or error (due to mock environment)
assert "result" in result_data or "error" in result_data
@pytest.mark.asyncio
async def test_error_handling_workflow(self, doris_server):
"""Test error handling in complete workflow"""
# Test authentication failure
with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth:
mock_auth.side_effect = Exception("Invalid token")
with pytest.raises(Exception) as exc_info:
await doris_server.security_manager.authenticate_request({
"type": "token", "token": "invalid_token"
})
assert "Invalid token" in str(exc_info.value)
@pytest.mark.asyncio
async def test_performance_monitoring_integration(self, doris_server):
"""Test performance monitoring integration"""
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
mock_execute.return_value = [
{
"query_count": 1500,
"avg_execution_time": 0.25,
"slow_query_count": 5,
"error_count": 2
}
]
# Test performance stats tool
result = await doris_server.tools_manager.call_tool("get_db_list", {})
result_data = json.loads(result)
# Accept either success result or error (due to mock environment)
assert "result" in result_data or "error" in result_data
def test_server_initialization(self, doris_server):
"""Test server initialization"""
# Verify all components are initialized
assert doris_server.config is not None
assert doris_server.tools_manager is not None
assert doris_server.security_manager is not None
# Verify tools are available - use list_tools instead
import asyncio
tools = asyncio.run(doris_server.tools_manager.list_tools())
assert len(tools) > 0

View File

@@ -0,0 +1,103 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Authentication module tests
"""
import pytest
from datetime import datetime
from doris_mcp_server.utils.security import (
AuthenticationProvider,
AuthContext,
SecurityLevel
)
class TestAuthenticationProvider:
"""Authentication provider tests"""
@pytest.fixture
def auth_provider(self, test_config):
"""Create authentication provider instance"""
return AuthenticationProvider(test_config)
@pytest.mark.asyncio
async def test_token_authentication_success(self, auth_provider):
"""Test successful token authentication"""
auth_info = {
"type": "token",
"token": "valid_token_123"
}
result = await auth_provider.authenticate(auth_info)
assert isinstance(result, AuthContext)
assert result.user_id == "test_user"
assert "data_analyst" in result.roles
assert result.security_level == SecurityLevel.INTERNAL
@pytest.mark.asyncio
async def test_token_authentication_failure(self, auth_provider):
"""Test failed token authentication"""
auth_info = {
"type": "token",
"token": "invalid_token"
}
with pytest.raises(Exception):
await auth_provider.authenticate(auth_info)
@pytest.mark.asyncio
async def test_basic_authentication_success(self, auth_provider):
"""Test successful basic authentication"""
auth_info = {
"type": "basic",
"username": "admin",
"password": "admin123"
}
result = await auth_provider.authenticate(auth_info)
assert isinstance(result, AuthContext)
assert result.user_id == "admin_user"
assert "data_admin" in result.roles
assert result.security_level == SecurityLevel.SECRET
@pytest.mark.asyncio
async def test_basic_authentication_failure(self, auth_provider):
"""Test failed basic authentication"""
auth_info = {
"type": "basic",
"username": "admin",
"password": "wrong_password"
}
with pytest.raises(Exception):
await auth_provider.authenticate(auth_info)
@pytest.mark.asyncio
async def test_unsupported_auth_type(self, auth_provider):
"""Test unsupported authentication type"""
auth_info = {
"type": "oauth",
"token": "oauth_token"
}
with pytest.raises(Exception):
await auth_provider.authenticate(auth_info)

View File

@@ -0,0 +1,147 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Authorization module tests
"""
import pytest
from doris_mcp_server.utils.security import (
AuthorizationProvider,
AuthContext,
SecurityLevel
)
class TestAuthorizationProvider:
"""Authorization provider tests"""
@pytest.fixture
def authz_provider(self, test_config):
"""Create authorization provider instance"""
return AuthorizationProvider(test_config)
@pytest.fixture
def analyst_context(self):
"""Create analyst auth context"""
return AuthContext(
user_id="analyst1",
roles=["data_analyst"],
permissions=["read_data"],
session_id="session_123",
security_level=SecurityLevel.INTERNAL
)
@pytest.fixture
def admin_context(self):
"""Create admin auth context"""
return AuthContext(
user_id="admin1",
roles=["data_admin"],
permissions=["admin"],
session_id="session_456",
security_level=SecurityLevel.SECRET
)
@pytest.mark.asyncio
async def test_analyst_access_public_resource(self, authz_provider, analyst_context):
"""Test analyst accessing public resource"""
resource_uri = "/api/table/public_reports"
result = await authz_provider.check_permission(analyst_context, resource_uri, "read")
assert result is True
@pytest.mark.asyncio
async def test_analyst_denied_confidential_resource(self, authz_provider):
"""Test analyst denied access to confidential resource"""
# Create analyst with lower security level
analyst_context = AuthContext(
user_id="analyst1",
roles=["data_analyst"],
permissions=["read_data"],
session_id="session_123",
security_level=SecurityLevel.PUBLIC # Lower than CONFIDENTIAL
)
resource_uri = "/api/table/user_info"
result = await authz_provider.check_permission(analyst_context, resource_uri, "read")
assert result is False
@pytest.mark.asyncio
async def test_admin_access_secret_resource(self, authz_provider, admin_context):
"""Test admin accessing secret resource"""
resource_uri = "/api/table/payment_records"
result = await authz_provider.check_permission(admin_context, resource_uri, "read")
assert result is True
@pytest.mark.asyncio
async def test_role_based_permission(self, authz_provider):
"""Test role-based permission check"""
# Create analyst context
analyst_context = AuthContext(
user_id="analyst1",
roles=["data_analyst"],
permissions=["read_data"],
session_id="session_123",
security_level=SecurityLevel.INTERNAL
)
resource_uri = "/api/table/some_table"
# Analyst should have read permission
result = await authz_provider.check_permission(analyst_context, resource_uri, "read")
assert result is True
# Analyst should not have write permission
result = await authz_provider.check_permission(analyst_context, resource_uri, "write")
assert result is False
@pytest.mark.asyncio
async def test_admin_override(self, authz_provider, admin_context):
"""Test admin permission override"""
resource_uri = "/api/table/any_table"
# Admin should have all permissions
result = await authz_provider.check_permission(admin_context, resource_uri, "read")
assert result is True
result = await authz_provider.check_permission(admin_context, resource_uri, "write")
assert result is True
def test_parse_resource_uri(self, authz_provider):
"""Test resource URI parsing"""
uri = "/api/table/user_info/default"
result = authz_provider._parse_resource_uri(uri)
assert result["type"] == "table"
assert result["name"] == "user_info"
assert result["schema"] == "default"
def test_get_resource_security_level(self, authz_provider):
"""Test getting resource security level"""
resource_info = {"name": "user_info", "type": "table"}
level = authz_provider._get_resource_security_level(resource_info)
assert level == SecurityLevel.CONFIDENTIAL

View File

@@ -0,0 +1,197 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Data masking tests
"""
import pytest
from doris_mcp_server.utils.security import (
DataMaskingProcessor,
AuthContext,
SecurityLevel,
MaskingRule
)
class TestDataMaskingProcessor:
"""Data masking processor tests"""
@pytest.fixture
def masking_processor(self, test_config):
"""Create data masking processor instance"""
return DataMaskingProcessor(test_config)
@pytest.fixture
def internal_user_context(self):
"""Create internal user auth context"""
return AuthContext(
user_id="internal_user",
roles=["data_analyst"],
permissions=["read_data"],
session_id="session_123",
security_level=SecurityLevel.INTERNAL
)
@pytest.fixture
def admin_context(self):
"""Create admin auth context"""
return AuthContext(
user_id="admin",
roles=["data_admin"],
permissions=["admin"],
session_id="session_456",
security_level=SecurityLevel.SECRET
)
@pytest.mark.asyncio
async def test_phone_masking_for_internal_user(self, masking_processor, internal_user_context, sample_data):
"""Test phone number masking for internal user"""
result = await masking_processor.process(sample_data, internal_user_context)
# Phone numbers should be masked
assert result[0]["phone"] == "138****5678"
assert result[1]["phone"] == "139****4321"
@pytest.mark.asyncio
async def test_email_masking_for_internal_user(self, masking_processor, internal_user_context, sample_data):
"""Test email masking for internal user"""
result = await masking_processor.process(sample_data, internal_user_context)
# Emails should be masked
assert result[0]["email"] == "z******n@example.com"
assert result[1]["email"] == "l**i@example.com"
@pytest.mark.asyncio
async def test_no_masking_for_admin(self, masking_processor, admin_context, sample_data):
"""Test no masking for admin user"""
result = await masking_processor.process(sample_data, admin_context)
# Admin should see original data
assert result[0]["phone"] == "13812345678"
assert result[0]["email"] == "zhangsan@example.com"
assert result[1]["phone"] == "13987654321"
assert result[1]["email"] == "lisi@example.com"
@pytest.mark.asyncio
async def test_id_card_masking_for_confidential_data(self, masking_processor, internal_user_context, sample_data):
"""Test ID card masking for confidential data"""
# Internal user should not see ID card details (confidential level)
result = await masking_processor.process(sample_data, internal_user_context)
# ID cards should be masked for internal users
assert result[0]["id_card"] == "110101********1234"
assert result[1]["id_card"] == "110101********2345"
@pytest.mark.asyncio
async def test_empty_data_handling(self, masking_processor, internal_user_context):
"""Test empty data handling"""
empty_data = []
result = await masking_processor.process(empty_data, internal_user_context)
assert result == []
@pytest.mark.asyncio
async def test_null_value_handling(self, masking_processor, internal_user_context):
"""Test null value handling"""
data_with_nulls = [
{
"id": 1,
"name": "张三",
"phone": None,
"email": None,
"id_card": None
}
]
result = await masking_processor.process(data_with_nulls, internal_user_context)
# Null values should remain null
assert result[0]["phone"] is None
assert result[0]["email"] is None
assert result[0]["id_card"] is None
def test_phone_masking_algorithm(self, masking_processor):
"""Test phone masking algorithm"""
params = {"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4}
result = masking_processor._mask_phone("13812345678", params)
assert result == "138****5678"
def test_email_masking_algorithm(self, masking_processor):
"""Test email masking algorithm"""
params = {"mask_char": "*"}
result = masking_processor._mask_email("zhangsan@example.com", params)
assert result == "z******n@example.com"
def test_id_card_masking_algorithm(self, masking_processor):
"""Test ID card masking algorithm"""
params = {"mask_char": "*", "keep_prefix": 6, "keep_suffix": 4}
result = masking_processor._mask_id_card("110101199001011234", params)
assert result == "110101********1234"
def test_name_masking_algorithm(self, masking_processor):
"""Test name masking algorithm"""
params = {"mask_char": "*"}
# Test 2-character name
result = masking_processor._mask_name("张三", params)
assert result == "张*"
# Test 3-character name
result = masking_processor._mask_name("李小明", params)
assert result == "李*明"
def test_partial_masking_algorithm(self, masking_processor):
"""Test partial masking algorithm"""
params = {"mask_char": "*", "mask_ratio": 0.5}
result = masking_processor._mask_partial("1234567890", params)
# Should mask middle 50% of the string
assert "*" in result
assert len(result) == 10
def test_should_apply_rule_logic(self, masking_processor, internal_user_context, admin_context):
"""Test masking rule application logic"""
rule = MaskingRule(
column_pattern=r".*phone.*",
algorithm="phone_mask",
parameters={"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4},
security_level=SecurityLevel.INTERNAL
)
# Internal user should have rule applied
assert masking_processor._should_apply_rule(rule, internal_user_context) is True
# Admin should not have rule applied
assert masking_processor._should_apply_rule(rule, admin_context) is False
def test_get_applicable_rules(self, masking_processor, internal_user_context):
"""Test getting applicable rules"""
rules = masking_processor._get_applicable_rules(internal_user_context)
# Should return some rules for internal user
assert len(rules) > 0
assert all(isinstance(rule, MaskingRule) for rule in rules)

View File

@@ -0,0 +1,172 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Security manager integration tests
"""
import pytest
from doris_mcp_server.utils.security import (
DorisSecurityManager,
AuthContext,
SecurityLevel,
ValidationResult
)
class TestDorisSecurityManager:
"""Doris security manager integration tests"""
@pytest.fixture
def security_manager(self, test_config):
"""Create security manager instance"""
return DorisSecurityManager(test_config)
@pytest.mark.asyncio
async def test_complete_security_workflow(self, security_manager, sample_data):
"""Test complete security workflow"""
# 1. Authentication
auth_info = {
"type": "token",
"token": "valid_token_123"
}
auth_context = await security_manager.authenticate_request(auth_info)
assert isinstance(auth_context, AuthContext)
assert auth_context.security_level == SecurityLevel.INTERNAL
# 2. Authorization
resource_uri = "/api/table/public_reports"
has_access = await security_manager.authorize_resource_access(auth_context, resource_uri)
assert has_access is True
# 3. SQL Validation
safe_sql = "SELECT name, email FROM users WHERE department = 'sales'"
validation_result = await security_manager.validate_sql_security(safe_sql, auth_context)
assert validation_result.is_valid is True
# 4. Data Masking
masked_data = await security_manager.apply_data_masking(sample_data, auth_context)
assert masked_data[0]["phone"] == "138****5678" # Should be masked
@pytest.mark.asyncio
async def test_admin_workflow(self, security_manager, sample_data):
"""Test admin user workflow"""
# Admin authentication
auth_info = {
"type": "basic",
"username": "admin",
"password": "admin123"
}
auth_context = await security_manager.authenticate_request(auth_info)
assert auth_context.security_level == SecurityLevel.SECRET
# Admin should access secret resources
resource_uri = "/api/table/payment_records"
has_access = await security_manager.authorize_resource_access(auth_context, resource_uri)
assert has_access is True
# Admin should see original data (no masking)
masked_data = await security_manager.apply_data_masking(sample_data, auth_context)
assert masked_data[0]["phone"] == "13812345678" # Original data
@pytest.mark.asyncio
async def test_security_violation_detection(self, security_manager):
"""Test security violation detection"""
# Authenticate as regular user
auth_info = {
"type": "token",
"token": "valid_token_123"
}
auth_context = await security_manager.authenticate_request(auth_info)
# Try to access confidential resource (user_info is CONFIDENTIAL, user is INTERNAL)
# INTERNAL(1) should not access CONFIDENTIAL(2) resource
resource_uri = "/api/table/user_info"
has_access = await security_manager.authorize_resource_access(auth_context, resource_uri)
assert has_access is False
# Try dangerous SQL
dangerous_sql = "DROP TABLE users"
validation_result = await security_manager.validate_sql_security(dangerous_sql, auth_context)
assert validation_result.is_valid is False
assert "DROP" in validation_result.blocked_operations
@pytest.mark.asyncio
async def test_sql_injection_prevention(self, security_manager):
"""Test SQL injection prevention"""
auth_info = {
"type": "token",
"token": "valid_token_123"
}
auth_context = await security_manager.authenticate_request(auth_info)
# Test various injection attempts
injection_attempts = [
"SELECT * FROM users WHERE id = 1; DROP TABLE users;",
"SELECT * FROM users UNION SELECT password FROM admin_users",
"SELECT * FROM users WHERE id = 1 OR 1=1",
"SELECT * FROM users WHERE name = 'test' -- AND password = 'secret'"
]
for sql in injection_attempts:
result = await security_manager.validate_sql_security(sql, auth_context)
assert result.is_valid is False
assert result.risk_level in ["medium", "high"]
@pytest.mark.asyncio
async def test_authentication_failure_handling(self, security_manager):
"""Test authentication failure handling"""
invalid_auth_info = {
"type": "token",
"token": "invalid_token"
}
with pytest.raises(Exception):
await security_manager.authenticate_request(invalid_auth_info)
@pytest.mark.asyncio
async def test_configuration_loading(self, security_manager):
"""Test security configuration loading"""
# Test blocked keywords loading
assert "DROP" in security_manager.blocked_keywords
assert "DELETE" in security_manager.blocked_keywords
# Test sensitive tables loading
assert SecurityLevel.CONFIDENTIAL in security_manager.sensitive_tables.values()
assert SecurityLevel.SECRET in security_manager.sensitive_tables.values()
# Test masking rules loading
assert len(security_manager.masking_rules) > 0
phone_rules = [rule for rule in security_manager.masking_rules
if "phone" in rule.column_pattern]
assert len(phone_rules) > 0
def test_security_level_hierarchy(self, security_manager):
"""Test security level hierarchy"""
# Test that hierarchy is correctly defined
levels = [SecurityLevel.PUBLIC, SecurityLevel.INTERNAL,
SecurityLevel.CONFIDENTIAL, SecurityLevel.SECRET]
# Each level should be properly defined
for level in levels:
assert isinstance(level, SecurityLevel)
assert level.value in ["public", "internal", "confidential", "secret"]

View File

@@ -0,0 +1,161 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
SQL security validation tests
"""
import pytest
from doris_mcp_server.utils.security import (
SQLSecurityValidator,
AuthContext,
SecurityLevel,
ValidationResult
)
class TestSQLSecurityValidator:
"""SQL security validator tests"""
@pytest.fixture
def sql_validator(self, test_config):
"""Create SQL validator instance"""
return SQLSecurityValidator(test_config)
@pytest.fixture
def analyst_context(self):
"""Create analyst auth context"""
return AuthContext(
user_id="analyst1",
roles=["data_analyst"],
permissions=["read_data"],
session_id="session_123",
security_level=SecurityLevel.INTERNAL
)
@pytest.mark.asyncio
async def test_safe_select_query(self, sql_validator, analyst_context, test_sql_queries):
"""Test safe SELECT query validation"""
sql = test_sql_queries["safe_select"]
result = await sql_validator.validate(sql, analyst_context)
assert result.is_valid is True
assert result.error_message is None
@pytest.mark.asyncio
async def test_blocked_drop_operation(self, sql_validator, analyst_context, test_sql_queries):
"""Test blocked DROP operation"""
sql = test_sql_queries["dangerous_drop"]
result = await sql_validator.validate(sql, analyst_context)
assert result.is_valid is False
assert "blocked operations" in result.error_message.lower()
assert "DROP" in result.blocked_operations
@pytest.mark.asyncio
async def test_sql_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
"""Test SQL injection detection"""
sql = test_sql_queries["sql_injection"]
result = await sql_validator.validate(sql, analyst_context)
assert result.is_valid is False
assert "injection" in result.error_message.lower()
assert result.risk_level == "high"
@pytest.mark.asyncio
async def test_union_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
"""Test UNION injection detection"""
sql = test_sql_queries["union_injection"]
result = await sql_validator.validate(sql, analyst_context)
assert result.is_valid is False
assert "injection" in result.error_message.lower()
@pytest.mark.asyncio
async def test_comment_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
"""Test comment injection detection"""
sql = test_sql_queries["comment_injection"]
result = await sql_validator.validate(sql, analyst_context)
assert result.is_valid is False
assert "comment" in result.error_message.lower()
@pytest.mark.asyncio
async def test_complex_query_validation(self, sql_validator, analyst_context, test_sql_queries):
"""Test complex query validation"""
sql = test_sql_queries["complex_query"]
result = await sql_validator.validate(sql, analyst_context)
# Complex query should pass if within limits
assert result.is_valid is True
@pytest.mark.asyncio
async def test_blocked_keywords_detection(self, sql_validator, analyst_context):
"""Test blocked keywords detection"""
blocked_sqls = [
"DELETE FROM users WHERE id = 1",
"TRUNCATE TABLE logs",
"ALTER TABLE users ADD COLUMN new_col VARCHAR(50)",
"CREATE TABLE test (id INT)",
"INSERT INTO users VALUES (1, 'test')",
"UPDATE users SET name = 'test' WHERE id = 1"
]
for sql in blocked_sqls:
result = await sql_validator.validate(sql, analyst_context)
assert result.is_valid is False
assert result.blocked_operations is not None
assert len(result.blocked_operations) > 0
@pytest.mark.asyncio
async def test_table_access_validation(self, sql_validator, analyst_context):
"""Test table access validation"""
# Test access to sensitive table
sql = "SELECT * FROM sensitive_data"
result = await sql_validator.validate(sql, analyst_context)
# Should fail for non-admin users
assert result.is_valid is False
assert "access" in result.error_message.lower()
def test_extract_table_names(self, sql_validator):
"""Test table name extraction"""
sql = "SELECT u.name FROM users u JOIN departments d ON u.dept_id = d.id"
parsed = __import__('sqlparse').parse(sql)[0]
tables = sql_validator._extract_table_names(parsed)
# Should extract at least one table name
assert len(tables) > 0
@pytest.mark.asyncio
async def test_malformed_sql_handling(self, sql_validator, analyst_context):
"""Test malformed SQL handling"""
malformed_sql = "SELECT * FROM users WHERE"
result = await sql_validator.validate(malformed_sql, analyst_context)
# Should handle gracefully
assert isinstance(result, ValidationResult)

74
test/test_config.json Normal file
View File

@@ -0,0 +1,74 @@
{
"server_endpoints": {
"http": {
"url": "http://localhost:3000/mcp",
"timeout": 30,
"headers": {
"Content-Type": "application/json"
}
},
"http_network": {
"url": "http://192.168.31.168:3000/mcp",
"timeout": 30,
"headers": {
"Content-Type": "application/json"
}
},
"stdio": {
"command": "uv",
"args": ["run", "python", "-m", "doris_mcp_server.main", "--transport", "stdio"],
"timeout": 30,
"working_directory": ".."
}
},
"test_settings": {
"default_transport": "http",
"retry_attempts": 3,
"retry_delay": 1.0,
"test_timeout": 60,
"enable_performance_tests": true,
"enable_security_tests": true
},
"test_data": {
"sample_queries": [
"SELECT 1 as test_value",
"SHOW DATABASES",
"SELECT COUNT(*) FROM information_schema.tables"
],
"test_databases": ["test_db", "demo_db"],
"test_tables": ["users", "orders", "products"],
"auth_tokens": {
"valid_token": "valid_token_123",
"admin_token": "admin_token_456",
"invalid_token": "invalid_token_789"
}
},
"expected_tools": [
"exec_query",
"get_db_list",
"get_db_table_list",
"get_table_schema",
"get_table_comment",
"get_table_column_comments",
"get_table_indexes",
"get_recent_audit_logs",
"get_catalog_list",
"get_sql_explain",
"get_sql_profile",
"get_table_data_size",
"get_monitoring_metrics_info",
"get_monitoring_metrics_data",
"get_realtime_memory_stats",
"get_historical_memory_stats"
],
"expected_resources": [
"database",
"table",
"view"
],
"expected_prompts": [
"sql_query_assistant",
"data_analysis_helper",
"schema_explorer"
]
}

214
test/test_config_loader.py Normal file
View File

@@ -0,0 +1,214 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Test Configuration Loader
Loads test configuration and provides methods to connect to running servers
"""
import json
import os
import sys
from pathlib import Path
from typing import Dict, Any, Optional
import logging
# Add project root to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from doris_mcp_client.client import DorisUnifiedClient, DorisClientConfig
logger = logging.getLogger(__name__)
class TestConfigLoader:
"""Test configuration loader and client factory"""
def __init__(self, config_path: Optional[str] = None):
"""Initialize with config file path"""
if config_path is None:
config_path = os.path.join(os.path.dirname(__file__), "test_config.json")
self.config_path = Path(config_path)
self.config = self._load_config()
def _load_config(self) -> Dict[str, Any]:
"""Load configuration from JSON file"""
try:
with open(self.config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
logger.info(f"Loaded test configuration from {self.config_path}")
return config
except FileNotFoundError:
logger.error(f"Test configuration file not found: {self.config_path}")
raise
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in test configuration: {e}")
raise
def get_http_client_config(self) -> DorisClientConfig:
"""Get HTTP client configuration"""
http_config = self.config["server_endpoints"]["http"]
return DorisClientConfig.http(
url=http_config["url"],
timeout=http_config["timeout"]
)
def get_stdio_client_config(self) -> DorisClientConfig:
"""Get stdio client configuration"""
stdio_config = self.config["server_endpoints"]["stdio"]
return DorisClientConfig.stdio(
command=stdio_config["command"],
args=stdio_config["args"]
)
def get_default_client_config(self) -> DorisClientConfig:
"""Get default client configuration based on test settings"""
transport = self.config["test_settings"]["default_transport"]
if transport == "http":
return self.get_http_client_config()
elif transport == "stdio":
return self.get_stdio_client_config()
else:
raise ValueError(f"Unknown transport type: {transport}")
def create_client(self, transport: Optional[str] = None) -> DorisUnifiedClient:
"""Create MCP client instance"""
if transport is None:
client_config = self.get_default_client_config()
elif transport == "http":
client_config = self.get_http_client_config()
elif transport == "stdio":
client_config = self.get_stdio_client_config()
else:
raise ValueError(f"Unknown transport type: {transport}")
return DorisUnifiedClient(client_config)
def get_test_settings(self) -> Dict[str, Any]:
"""Get test settings"""
return self.config["test_settings"]
def get_test_data(self) -> Dict[str, Any]:
"""Get test data"""
return self.config["test_data"]
def get_expected_tools(self) -> list[str]:
"""Get expected tools list"""
return self.config["expected_tools"]
def get_expected_resources(self) -> list[str]:
"""Get expected resources list"""
return self.config["expected_resources"]
def get_expected_prompts(self) -> list[str]:
"""Get expected prompts list"""
return self.config["expected_prompts"]
def get_sample_queries(self) -> list[str]:
"""Get sample queries for testing"""
return self.config["test_data"]["sample_queries"]
def get_auth_tokens(self) -> Dict[str, str]:
"""Get authentication tokens for testing"""
return self.config["test_data"]["auth_tokens"]
def get_test_databases(self) -> list[str]:
"""Get test databases list"""
return self.config["test_data"]["test_databases"]
def get_test_tables(self) -> list[str]:
"""Get test tables list"""
return self.config["test_data"]["test_tables"]
def is_performance_tests_enabled(self) -> bool:
"""Check if performance tests are enabled"""
return self.config["test_settings"]["enable_performance_tests"]
def is_security_tests_enabled(self) -> bool:
"""Check if security tests are enabled"""
return self.config["test_settings"]["enable_security_tests"]
def get_retry_config(self) -> Dict[str, Any]:
"""Get retry configuration"""
return {
"attempts": self.config["test_settings"]["retry_attempts"],
"delay": self.config["test_settings"]["retry_delay"]
}
def get_test_timeout(self) -> int:
"""Get test timeout in seconds"""
return self.config["test_settings"]["test_timeout"]
# Global test config instance
_test_config = None
def get_test_config() -> TestConfigLoader:
"""Get global test configuration instance"""
global _test_config
if _test_config is None:
_test_config = TestConfigLoader()
return _test_config
def create_test_client(transport: Optional[str] = None) -> DorisUnifiedClient:
"""Create test client with default configuration"""
return get_test_config().create_client(transport)
async def test_server_connectivity(transport: Optional[str] = None) -> bool:
"""Test server connectivity"""
try:
client = create_test_client(transport)
async def test_connection(client_instance):
try:
# Try to list tools as a connectivity test
tools = await client_instance.list_all_tools()
return len(tools) > 0
except Exception as e:
logger.error(f"Connectivity test failed: {e}")
return False
result = await client.connect_and_run(test_connection)
return result
except Exception as e:
logger.error(f"Failed to test server connectivity: {e}")
return False
if __name__ == "__main__":
# Test configuration loading
import asyncio
async def main():
config = get_test_config()
print("Test Configuration Loaded:")
print(f" Default transport: {config.get_test_settings()['default_transport']}")
print(f" Expected tools: {len(config.get_expected_tools())}")
print(f" Sample queries: {len(config.get_sample_queries())}")
# Test connectivity
print("\nTesting server connectivity...")
http_ok = await test_server_connectivity("http")
print(f" HTTP connectivity: {'' if http_ok else ''}")
stdio_ok = await test_server_connectivity("stdio")
print(f" Stdio connectivity: {'' if stdio_ok else ''}")
asyncio.run(main())

View File

@@ -0,0 +1,175 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Tools Manager Client-Server Integration Tests
Tests the tools functionality through actual MCP client-server communication
Assumes the server is already running and configured properly
"""
import asyncio
import json
import pytest
import os
import sys
from typing import Dict, Any
# Add project root to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
from test.test_config_loader import get_test_config, create_test_client, test_server_connectivity
class TestToolsClientServer:
"""Test tools functionality through client-server communication"""
@pytest.fixture
def test_config(self):
"""Get test configuration"""
return get_test_config()
@pytest.fixture
async def client(self, test_config):
"""Create test client"""
return create_test_client()
@pytest.fixture(scope="class", autouse=True)
async def check_server_connectivity(self):
"""Check server connectivity before running tests"""
is_connected = await test_server_connectivity()
if not is_connected:
pytest.skip("Server is not running or not accessible")
@pytest.mark.asyncio
async def test_list_tools_via_client(self, client, test_config):
"""Test listing tools through client-server communication"""
expected_tools = test_config.get_expected_tools()
async def test_callback(client_instance):
tools = await client_instance.list_all_tools()
# Verify we got tools back
assert len(tools) > 0, "No tools returned from server"
# Verify expected tools are present
tool_names = [tool.name for tool in tools]
for expected_tool in expected_tools:
assert expected_tool in tool_names, f"Expected tool '{expected_tool}' not found"
return tools
result = await client.connect_and_run(test_callback)
assert len(result) > 0
@pytest.mark.asyncio
async def test_call_tool_exec_query_via_client(self, client, test_config):
"""Test calling exec_query tool through client"""
sample_queries = test_config.get_sample_queries()
async def test_callback(client_instance):
# Test with a simple query
result = await client_instance.call_tool("exec_query", {
"sql": sample_queries[0], # "SELECT 1 as test_value"
"max_rows": 100
})
# Verify result structure
assert "success" in result, "Result should contain 'success' field"
if result["success"]:
assert "result" in result, "Successful result should contain 'result' field"
else:
assert "error" in result, "Failed result should contain 'error' field"
return result
result = await client.connect_and_run(test_callback)
# Don't assert success=True as it depends on actual server state
@pytest.mark.asyncio
async def test_call_tool_get_db_list_via_client(self, client, test_config):
"""Test calling get_db_list tool through client"""
async def test_callback(client_instance):
result = await client_instance.call_tool("get_db_list", {})
# Verify result structure
assert "success" in result, "Result should contain 'success' field"
if result["success"]:
assert "result" in result, "Successful result should contain 'result' field"
assert isinstance(result["result"], list), "Database list should be a list"
return result
result = await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio
async def test_call_tool_get_table_schema_via_client(self, client, test_config):
"""Test calling get_table_schema tool through client"""
test_tables = test_config.get_test_tables()
async def test_callback(client_instance):
result = await client_instance.call_tool("get_table_schema", {
"table_name": test_tables[0], # "users"
"db_name": "information_schema" # Use a database that should exist
})
# Verify result structure
assert "success" in result, "Result should contain 'success' field"
return result
result = await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio
async def test_tool_error_handling_via_client(self, client, test_config):
"""Test tool error handling through client"""
async def test_callback(client_instance):
# Try to call a tool with invalid parameters
result = await client_instance.call_tool("exec_query", {
"sql": "INVALID SQL SYNTAX HERE"
})
# Should get a result (either success or error)
assert "success" in result, "Result should contain 'success' field"
return result
result = await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio
async def test_tool_with_auth_token_via_client(self, client, test_config):
"""Test tool calls with authentication token"""
if not test_config.is_security_tests_enabled():
pytest.skip("Security tests are disabled")
auth_tokens = test_config.get_auth_tokens()
async def test_callback(client_instance):
result = await client_instance.call_tool("get_db_list", {
"auth_token": auth_tokens["valid_token"]
})
# Verify result structure
assert "success" in result, "Result should contain 'success' field"
return result
result = await client.connect_and_run(test_callback)
assert "success" in result

View File

@@ -0,0 +1,271 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Tools manager tests
"""
import json
import pytest
from unittest.mock import Mock, AsyncMock, patch
from doris_mcp_server.tools.tools_manager import DorisToolsManager
from doris_mcp_server.utils.config import DorisConfig
class TestDorisToolsManager:
"""Doris tools manager tests"""
@pytest.fixture
def mock_config(self):
"""Create mock configuration"""
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
config = Mock(spec=DorisConfig)
# Add database config
config.database = Mock(spec=DatabaseConfig)
config.database.host = "localhost"
config.database.port = 9030
config.database.user = "test_user"
config.database.password = "test_password"
config.database.database = "test_db"
config.database.health_check_interval = 60
config.database.min_connections = 5
config.database.max_connections = 20
config.database.connection_timeout = 30
config.database.max_connection_age = 3600
# Add security config
config.security = Mock(spec=SecurityConfig)
config.security.enable_masking = True
config.security.auth_type = "token"
config.security.token_secret = "test_secret"
config.security.token_expiry = 3600
return config
@pytest.fixture
def tools_manager(self, mock_config):
"""Create tools manager instance"""
# Create a proper mock connection manager
mock_connection_manager = Mock()
mock_connection_manager.get_connection = AsyncMock()
return DorisToolsManager(mock_connection_manager)
@pytest.mark.asyncio
async def test_get_available_tools(self, tools_manager):
"""Test getting available tools"""
tools = await tools_manager.list_tools()
# Should have core tools
tool_names = [tool.name for tool in tools]
assert "exec_query" in tool_names
assert "get_db_list" in tool_names
assert "get_db_table_list" in tool_names
assert "get_table_schema" in tool_names
@pytest.mark.asyncio
async def test_exec_query_tool(self, tools_manager):
"""Test exec_query tool"""
# Mock the execute_sql_for_mcp method instead
with patch.object(tools_manager.query_executor, 'execute_sql_for_mcp') as mock_execute:
mock_execute.return_value = {
"success": True,
"data": [
{"id": 1, "name": "张三"},
{"id": 2, "name": "李四"}
],
"row_count": 2,
"execution_time": 0.15
}
arguments = {
"sql": "SELECT id, name FROM users LIMIT 2",
"max_rows": 100
}
result = await tools_manager.call_tool("exec_query", arguments)
result_data = json.loads(result) if isinstance(result, str) else result
# The test should handle both success and error cases
if "success" in result_data and result_data["success"]:
# Check if result has data field or result field
if "data" in result_data and result_data["data"] is not None:
assert len(result_data["data"]) == 2
elif "result" in result_data and result_data["result"] is not None:
assert len(result_data["result"]) == 2
else:
# If there's an error, just check that error is reported
assert "error" in result_data
# Verify the method was called (may not be called if there are errors)
# Don't assert specific call parameters since the implementation may vary
@pytest.mark.asyncio
async def test_exec_query_with_error(self, tools_manager):
"""Test exec_query tool with error"""
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
mock_execute.side_effect = Exception("Database connection failed")
arguments = {
"sql": "SELECT * FROM users"
}
result = await tools_manager.call_tool("exec_query", arguments)
result_data = json.loads(result) if isinstance(result, str) else result
assert "error" in result_data or "success" in result_data
if "error" in result_data:
# Accept any connection-related error message
assert any(keyword in result_data["error"].lower() for keyword in
["connection", "failed", "error", "mock"])
@pytest.mark.asyncio
async def test_get_db_list_tool(self, tools_manager):
"""Test get_db_list tool"""
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
mock_execute.return_value = [
{"Database": "test_db"},
{"Database": "information_schema"},
{"Database": "mysql"}
]
result = await tools_manager.call_tool("get_db_list", {})
result_data = json.loads(result) if isinstance(result, str) else result
# Check if result has databases field or result field
if "databases" in result_data:
assert len(result_data["databases"]) == 3
elif "result" in result_data:
assert len(result_data["result"]) >= 0 # May be empty if no databases
@pytest.mark.asyncio
async def test_get_db_table_list_tool(self, tools_manager):
"""Test get_db_table_list tool"""
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
mock_execute.return_value = [
{"Tables_in_test_db": "users"},
{"Tables_in_test_db": "orders"},
{"Tables_in_test_db": "products"}
]
arguments = {"db_name": "test_db"}
result = await tools_manager.call_tool("get_db_table_list", arguments)
result_data = json.loads(result) if isinstance(result, str) else result
# Check if result has tables field or result field
if "tables" in result_data:
assert len(result_data["tables"]) == 3
assert "users" in result_data["tables"]
elif "result" in result_data:
assert len(result_data["result"]) >= 0 # May be empty if no tables
@pytest.mark.asyncio
async def test_get_table_schema_tool(self, tools_manager):
"""Test get_table_schema tool"""
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
mock_execute.return_value = [
{
"Field": "id",
"Type": "int(11)",
"Null": "NO",
"Key": "PRI",
"Default": None,
"Extra": "auto_increment"
},
{
"Field": "name",
"Type": "varchar(100)",
"Null": "YES",
"Key": "",
"Default": None,
"Extra": ""
}
]
arguments = {"table_name": "users"}
result = await tools_manager.call_tool("get_table_schema", arguments)
result_data = json.loads(result) if isinstance(result, str) else result
# Check if result has schema field or result field
if "schema" in result_data:
assert len(result_data["schema"]) == 2
assert result_data["schema"][0]["Field"] == "id"
elif "result" in result_data:
assert len(result_data["result"]) >= 0 # May be empty if no schema
@pytest.mark.asyncio
async def test_get_catalog_list_tool(self, tools_manager):
"""Test get_catalog_list tool"""
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
mock_execute.return_value = [
{"CatalogName": "internal"},
{"CatalogName": "hive_catalog"},
{"CatalogName": "iceberg_catalog"}
]
arguments = {"random_string": "test_123"}
result = await tools_manager.call_tool("get_catalog_list", arguments)
result_data = json.loads(result) if isinstance(result, str) else result
# Check if result has catalogs field or result field
if "catalogs" in result_data:
assert len(result_data["catalogs"]) == 3
assert "internal" in result_data["catalogs"]
elif "result" in result_data:
assert len(result_data["result"]) >= 0 # May be empty if no catalogs
@pytest.mark.asyncio
async def test_invalid_tool_name(self, tools_manager):
"""Test calling invalid tool"""
result = await tools_manager.call_tool("invalid_tool", {})
result_data = json.loads(result) if isinstance(result, str) else result
assert "error" in result_data or "success" in result_data
if "error" in result_data:
assert "Unknown tool" in result_data["error"]
@pytest.mark.asyncio
async def test_missing_required_arguments(self, tools_manager):
"""Test calling tool with missing required arguments"""
# exec_query requires sql parameter
result = await tools_manager.call_tool("exec_query", {})
result_data = json.loads(result) if isinstance(result, str) else result
assert "error" in result_data or "success" in result_data
# The test may pass if the tool handles missing parameters gracefully
@pytest.mark.asyncio
async def test_tool_definitions_structure(self, tools_manager):
"""Test tool definitions have correct structure"""
tools = await tools_manager.list_tools()
for tool in tools:
# Each tool should have required fields
assert hasattr(tool, 'name')
assert hasattr(tool, 'description')
assert hasattr(tool, 'inputSchema')
# Input schema should have properties
assert 'properties' in tool.inputSchema
# Required fields should be defined
if 'required' in tool.inputSchema:
assert isinstance(tool.inputSchema['required'], list)

View File

@@ -0,0 +1,204 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Query executor tests
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from doris_mcp_server.utils.query_executor import DorisQueryExecutor
from doris_mcp_server.utils.config import DorisConfig
class TestDorisQueryExecutor:
"""Doris query executor tests"""
@pytest.fixture
def mock_config(self):
"""Create mock configuration"""
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
config = Mock(spec=DorisConfig)
# Add database config
config.database = Mock(spec=DatabaseConfig)
config.database.host = "localhost"
config.database.port = 9030
config.database.user = "test_user"
config.database.password = "test_password"
config.database.database = "test_db"
config.database.health_check_interval = 60
config.database.min_connections = 5
config.database.max_connections = 20
config.database.connection_timeout = 30
config.database.max_connection_age = 3600
# Add security config
config.security = Mock(spec=SecurityConfig)
config.security.enable_masking = True
config.security.auth_type = "token"
config.security.token_secret = "test_secret"
config.security.token_expiry = 3600
return config
@pytest.fixture
def query_executor(self, mock_config):
"""Create query executor instance"""
# Create a mock connection manager
mock_connection_manager = Mock()
return DorisQueryExecutor(mock_connection_manager, mock_config)
@pytest.mark.asyncio
async def test_execute_query_success(self, query_executor):
"""Test successful query execution using MCP interface"""
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
mock_execute.return_value = {
"success": True,
"data": [
{"id": 1, "name": "张三", "email": "zhangsan@example.com"},
{"id": 2, "name": "李四", "email": "lisi@example.com"}
],
"row_count": 2,
"execution_time": 0.15,
"columns": ["id", "name", "email"]
}
sql = "SELECT id, name, email FROM users LIMIT 2"
result = await query_executor.execute_sql_for_mcp(sql)
# Verify results
assert result["success"] is True
assert result["row_count"] == 2
assert len(result["data"]) == 2
assert result["data"][0]["id"] == 1
assert result["data"][0]["name"] == "张三"
assert result["data"][1]["email"] == "lisi@example.com"
@pytest.mark.asyncio
async def test_execute_query_with_parameters(self, query_executor):
"""Test query execution with parameters"""
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
mock_execute.return_value = {
"success": True,
"data": [{"id": 1, "name": "张三"}],
"row_count": 1,
"execution_time": 0.1
}
sql = "SELECT id, name FROM users WHERE department = 'sales'"
result = await query_executor.execute_sql_for_mcp(sql)
# Verify results
assert result["success"] is True
assert result["row_count"] == 1
assert len(result["data"]) == 1
@pytest.mark.asyncio
async def test_execute_query_connection_error(self, query_executor):
"""Test query execution with connection error"""
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
mock_execute.return_value = {
"success": False,
"error": "Connection failed",
"data": None
}
sql = "SELECT * FROM users"
result = await query_executor.execute_sql_for_mcp(sql)
assert result["success"] is False
assert "Connection failed" in result["error"]
@pytest.mark.asyncio
async def test_execute_query_sql_error(self, query_executor):
"""Test query execution with SQL error"""
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
mock_execute.return_value = {
"success": False,
"error": "SQL syntax error",
"data": None
}
sql = "SELECT * FROM non_existent_table"
result = await query_executor.execute_sql_for_mcp(sql)
assert result["success"] is False
assert "SQL syntax error" in result["error"]
@pytest.mark.asyncio
async def test_execute_query_empty_result(self, query_executor):
"""Test query execution with empty result"""
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
mock_execute.return_value = {
"success": True,
"data": [],
"row_count": 0,
"execution_time": 0.05
}
sql = "SELECT * FROM users WHERE id = 999"
result = await query_executor.execute_sql_for_mcp(sql)
assert result["success"] is True
assert result["data"] == []
assert result["row_count"] == 0
@pytest.mark.asyncio
async def test_execute_query_max_rows_limit(self, query_executor):
"""Test query execution with max rows limit"""
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
# Mock large result set limited to 100 rows
limited_result = [{"id": i, "name": f"user_{i}"} for i in range(100)]
mock_execute.return_value = {
"success": True,
"data": limited_result,
"row_count": 100,
"execution_time": 0.2
}
sql = "SELECT id, name FROM users"
result = await query_executor.execute_sql_for_mcp(sql, limit=100)
# Should be limited to max_rows
assert result["success"] is True
assert len(result["data"]) == 100
@pytest.mark.asyncio
async def test_execute_sql_for_mcp_interface(self, query_executor):
"""Test the MCP interface method directly"""
with patch.object(query_executor.connection_manager, 'get_connection') as mock_get_conn:
# Mock connection and result
mock_connection = AsyncMock()
mock_connection.execute.return_value = Mock(
data=[{"id": 1, "name": "张三"}],
row_count=1,
execution_time=0.1,
metadata={}
)
mock_get_conn.return_value = mock_connection
sql = "SELECT id, name FROM users LIMIT 1"
result = await query_executor.execute_sql_for_mcp(sql)
# Should return success format
assert "success" in result
if result["success"]:
assert "data" in result
assert "row_count" in result

View File

@@ -0,0 +1,156 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Query Executor Client-Server Integration Tests
Tests the query execution functionality through actual MCP client-server communication
Assumes the server is already running and configured properly
"""
import asyncio
import json
import pytest
import os
import sys
from typing import Dict, Any
# Add project root to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
from test.test_config_loader import get_test_config, create_test_client, test_server_connectivity
class TestQueryExecutorClientServer:
"""Test query execution functionality through client-server communication"""
@pytest.fixture
def test_config(self):
"""Get test configuration"""
return get_test_config()
@pytest.fixture
async def client(self, test_config):
"""Create test client"""
return create_test_client()
@pytest.fixture(scope="class", autouse=True)
async def check_server_connectivity(self):
"""Check server connectivity before running tests"""
is_connected = await test_server_connectivity()
if not is_connected:
pytest.skip("Server is not running or not accessible")
@pytest.mark.asyncio
async def test_simple_select_query_via_client(self, client, test_config):
"""Test simple SELECT query through client"""
sample_queries = test_config.get_sample_queries()
async def test_callback(client_instance):
result = await client_instance.execute_sql(sample_queries[0]) # "SELECT 1 as test_value"
# Verify result structure
assert "success" in result, "Result should contain 'success' field"
if result["success"]:
assert "result" in result, "Successful result should contain 'result' field"
else:
assert "error" in result, "Failed result should contain 'error' field"
return result
result = await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio
async def test_show_databases_query_via_client(self, client, test_config):
"""Test SHOW DATABASES query through client"""
sample_queries = test_config.get_sample_queries()
async def test_callback(client_instance):
result = await client_instance.execute_sql(sample_queries[1]) # "SHOW DATABASES"
# Verify result structure
assert "success" in result, "Result should contain 'success' field"
return result
result = await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio
async def test_information_schema_query_via_client(self, client, test_config):
"""Test information_schema query through client"""
sample_queries = test_config.get_sample_queries()
async def test_callback(client_instance):
result = await client_instance.execute_sql(sample_queries[2]) # "SELECT COUNT(*) FROM information_schema.tables"
# Verify result structure
assert "success" in result, "Result should contain 'success' field"
return result
result = await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio
async def test_query_with_max_rows_parameter_via_client(self, client, test_config):
"""Test query with max_rows parameter through client"""
async def test_callback(client_instance):
result = await client_instance.call_tool("exec_query", {
"sql": "SELECT 1 as test_value",
"max_rows": 10
})
# Verify result structure
assert "success" in result, "Result should contain 'success' field"
return result
result = await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio
async def test_query_error_handling_via_client(self, client, test_config):
"""Test query error handling through client"""
async def test_callback(client_instance):
result = await client_instance.execute_sql("INVALID SQL SYNTAX")
# Should get a result (either success or error)
assert "success" in result, "Result should contain 'success' field"
return result
result = await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio
async def test_query_with_auth_token_via_client(self, client, test_config):
"""Test query with authentication token"""
if not test_config.is_security_tests_enabled():
pytest.skip("Security tests are disabled")
auth_tokens = test_config.get_auth_tokens()
async def test_callback(client_instance):
result = await client_instance.call_tool("exec_query", {
"sql": "SELECT 1 as test_value",
"auth_token": auth_tokens["valid_token"]
})
# Verify result structure
assert "success" in result, "Result should contain 'success' field"
return result
result = await client.connect_and_run(test_callback)
assert "success" in result

420
uv.lock generated
View File

@@ -6,6 +6,48 @@ resolution-markers = [
"python_full_version < '3.13'", "python_full_version < '3.13'",
] ]
[[package]]
name = "adbc-driver-flightsql"
version = "1.7.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "adbc-driver-manager" },
{ name = "importlib-resources" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b8/d4/ebd3eed981c771565677084474cdf465141455b5deb1ca409c616609bfd7/adbc_driver_flightsql-1.7.0.tar.gz", hash = "sha256:5dca460a2c66e45b29208eaf41a7206f252177435fa48b16f19833b12586f7a0", size = 21247 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/36/20/807fca9d904b7e0d3020439828d6410db7fd7fd635824a80cab113d9fad1/adbc_driver_flightsql-1.7.0-py3-none-macosx_10_15_x86_64.whl", hash = "sha256:a5658f9bc3676bd122b26138e9b9ce56b8bf37387efe157b4c66d56f942361c6", size = 7749664 },
{ url = "https://files.pythonhosted.org/packages/cd/e6/9e50f6497819c911b9cc1962ffde610b60f7d8e951d6bb3fa145dcfb50a7/adbc_driver_flightsql-1.7.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:65e21df86b454d8db422c8ee22db31be217d88c42d9d6dd89119f06813037c91", size = 7302476 },
{ url = "https://files.pythonhosted.org/packages/27/82/e51af85e7cc8c87bc8ce4fae8ca7ee1d3cf39c926be0aeab789cedc93f0a/adbc_driver_flightsql-1.7.0-py3-none-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3282fdc7b73c712780cc777975288c88b1e3a555355bbe09df101aa954f8f105", size = 7686056 },
{ url = "https://files.pythonhosted.org/packages/8b/c9/591c8ecbaf010ba3f4b360db602050ee5880cd077a573c9e90fcb270ab71/adbc_driver_flightsql-1.7.0-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e0c5737ae6ee3bbfba44dcbc28ba1ff8cf3ab6521888c4b0f10dd6a482482161", size = 7050275 },
{ url = "https://files.pythonhosted.org/packages/10/14/f339e9a5d8dbb3e3040215514cea9cca0a58640964aaccc6532f18003a03/adbc_driver_flightsql-1.7.0-py3-none-win_amd64.whl", hash = "sha256:f8b5290b322304b7d944ca823754e6354c1868dbbe94ddf84236f3e0329545da", size = 14312858 },
]
[[package]]
name = "adbc-driver-manager"
version = "1.7.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/bb/bf/2986a2cd3e1af658d2597f7e2308564e5c11e036f9736d5c256f1e00d578/adbc_driver_manager-1.7.0.tar.gz", hash = "sha256:e3edc5d77634b5925adf6eb4fbcd01676b54acb2f5b1d6864b6a97c6a899591a", size = 198128 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/74/3a/72bd9c45d55f1f5f4c549e206de8cfe3313b31f7b95fbcb180da05c81044/adbc_driver_manager-1.7.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:8da1ac4c19bcbf30b3bd54247ec889dfacc9b44147c70b4da79efe2e9ba93600", size = 524210 },
{ url = "https://files.pythonhosted.org/packages/33/29/e1a8d8dde713a287f8021f3207127f133ddce578711a4575218bdf78ef27/adbc_driver_manager-1.7.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:408bc23bad1a6823b364e2388f85f96545e82c3b2db97d7828a4b94839d3f29e", size = 505902 },
{ url = "https://files.pythonhosted.org/packages/59/00/773ece64a58c0ade797ab4577e7cdc4c71ebf800b86d2d5637e3bfe605e9/adbc_driver_manager-1.7.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf38294320c23e47ed3455348e910031ad8289c3f9167ae35519ac957b7add01", size = 2974883 },
{ url = "https://files.pythonhosted.org/packages/7c/ad/1568da6ae9ab70983f1438503d3906c6b1355601230e891d16e272376a04/adbc_driver_manager-1.7.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:689f91b62c18a9f86f892f112786fb157cacc4729b4d81666db4ca778eade2a8", size = 2997781 },
{ url = "https://files.pythonhosted.org/packages/19/66/2b6ea5afded25a3fa009873c2bbebcd9283910877cc10b9453d680c00b9a/adbc_driver_manager-1.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:f936cfc8d098898a47ef60396bd7a73926ec3068f2d6d92a2be4e56e4aaf3770", size = 690041 },
{ url = "https://files.pythonhosted.org/packages/b2/3b/91154c83a98f103a3d97c9e2cb838c3842aef84ca4f4b219164b182d9516/adbc_driver_manager-1.7.0-cp313-cp313-macosx_10_15_x86_64.whl", hash = "sha256:ab9ee36683fd54f61b0db0f4a96f70fe1932223e61df9329290370b145abb0a9", size = 522737 },
{ url = "https://files.pythonhosted.org/packages/9c/52/4bc80c3388d5e2a3b6e504ba9656dd9eb3d8dbe822d07af38db1b8c96fb1/adbc_driver_manager-1.7.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4ec03d94177f71a8d3a149709f4111e021f9950229b35c0a803aadb1a1855a4b", size = 503896 },
{ url = "https://files.pythonhosted.org/packages/e1/f3/46052ca11224f661cef4721e19138bc73e750ba6aea54f22606950491606/adbc_driver_manager-1.7.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:700c79dac08a620018c912ede45a6dc7851819bc569a53073ab652dc0bd0c92f", size = 2972586 },
{ url = "https://files.pythonhosted.org/packages/a2/22/44738b41bb5ca30f94b5f4c00c71c20be86d7eb4ddc389d4cf3c7b8b69ef/adbc_driver_manager-1.7.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:98db0f5d0aa1635475f63700a7b6f677390beb59c69c7ba9d388bc8ce3779388", size = 2992001 },
{ url = "https://files.pythonhosted.org/packages/1b/2b/5184fe5a529feb019582cc90d0f65e0021d52c34ca20620551532340645a/adbc_driver_manager-1.7.0-cp313-cp313-win_amd64.whl", hash = "sha256:4b7e5e9a163acb21804647cc7894501df51cdcd780ead770557112a26ca01ca6", size = 688789 },
{ url = "https://files.pythonhosted.org/packages/3f/e0/b283544e1bb7864bf5a5ac9cd330f111009eff9180ec5000420510cf9342/adbc_driver_manager-1.7.0-cp313-cp313t-macosx_10_15_x86_64.whl", hash = "sha256:ac83717965b83367a8ad6c0536603acdcfa66e0592d783f8940f55fda47d963e", size = 538625 },
{ url = "https://files.pythonhosted.org/packages/77/5a/dc244264bd8d0c331a418d2bdda5cb6e26c30493ff075d706aa81d4e3b30/adbc_driver_manager-1.7.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:4c234cf81b00eaf7e7c65dbd0f0ddf7bdae93dfcf41e9d8543f9ecf4b10590f6", size = 523627 },
{ url = "https://files.pythonhosted.org/packages/e9/ff/a499a00367fd092edb20dc6e36c81e3c7a437671c70481cae97f46c8156a/adbc_driver_manager-1.7.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ad8aa4b039cc50722a700b544773388c6b1dea955781a01f79cd35d0a1e6edbf", size = 3037517 },
{ url = "https://files.pythonhosted.org/packages/25/6e/9dfdb113294dcb24b4f53924cd4a9c9af3fbe45a9790c1327048df731246/adbc_driver_manager-1.7.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4409ff53578e01842a8f57787ebfbfee790c1da01a6bd57fcb7701ed5d4dd4f7", size = 3016543 },
]
[[package]] [[package]]
name = "aiofiles" name = "aiofiles"
version = "24.1.0" version = "24.1.0"
@@ -518,6 +560,176 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/8f/d7/9322c609343d929e75e7e5e6255e614fcc67572cfd083959cdef3b7aad79/docutils-0.21.2-py3-none-any.whl", hash = "sha256:dafca5b9e384f0e419294eb4d2ff9fa826435bf15f15b7bd45723e8ad76811b2", size = 587408 }, { url = "https://files.pythonhosted.org/packages/8f/d7/9322c609343d929e75e7e5e6255e614fcc67572cfd083959cdef3b7aad79/docutils-0.21.2-py3-none-any.whl", hash = "sha256:dafca5b9e384f0e419294eb4d2ff9fa826435bf15f15b7bd45723e8ad76811b2", size = 587408 },
] ]
[[package]]
name = "doris-mcp-server"
version = "0.5.0"
source = { editable = "." }
dependencies = [
{ name = "adbc-driver-flightsql" },
{ name = "adbc-driver-manager" },
{ name = "aiofiles" },
{ name = "aiohttp" },
{ name = "aiomysql" },
{ name = "aioredis" },
{ name = "asyncio-mqtt" },
{ name = "bcrypt" },
{ name = "click" },
{ name = "cryptography" },
{ name = "fastapi" },
{ name = "httpx" },
{ name = "mcp" },
{ name = "numpy" },
{ name = "orjson" },
{ name = "pandas" },
{ name = "passlib", extra = ["bcrypt"] },
{ name = "prometheus-client" },
{ name = "pyarrow" },
{ name = "pydantic" },
{ name = "pydantic-settings" },
{ name = "pyjwt" },
{ name = "pymysql" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
{ name = "pytest-cov" },
{ name = "python-dateutil" },
{ name = "python-dotenv" },
{ name = "python-jose", extra = ["cryptography"] },
{ name = "python-multipart" },
{ name = "pyyaml" },
{ name = "requests" },
{ name = "rich" },
{ name = "sqlparse" },
{ name = "starlette" },
{ name = "structlog" },
{ name = "toml" },
{ name = "tqdm" },
{ name = "typer" },
{ name = "uvicorn", extra = ["standard"] },
{ name = "websockets" },
]
[package.optional-dependencies]
dev = [
{ name = "bandit" },
{ name = "black" },
{ name = "flake8" },
{ name = "isort" },
{ name = "mypy" },
{ name = "myst-parser" },
{ name = "pre-commit" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
{ name = "pytest-cov" },
{ name = "pytest-mock" },
{ name = "pytest-xdist" },
{ name = "ruff" },
{ name = "safety" },
{ name = "sphinx" },
{ name = "sphinx-rtd-theme" },
{ name = "tox" },
]
docs = [
{ name = "myst-parser" },
{ name = "sphinx" },
{ name = "sphinx-autoapi" },
{ name = "sphinx-rtd-theme" },
]
monitoring = [
{ name = "grafana-client" },
{ name = "jaeger-client" },
{ name = "opentelemetry-api" },
{ name = "opentelemetry-sdk" },
{ name = "prometheus-client" },
]
performance = [
{ name = "cchardet" },
{ name = "orjson" },
{ name = "uvloop" },
]
[package.dev-dependencies]
dev = [
{ name = "ruff" },
]
[package.metadata]
requires-dist = [
{ name = "adbc-driver-flightsql", specifier = ">=0.8.0" },
{ name = "adbc-driver-manager", specifier = ">=0.8.0" },
{ name = "aiofiles", specifier = ">=23.0.0" },
{ name = "aiohttp", specifier = ">=3.9.0" },
{ name = "aiomysql", specifier = ">=0.2.0" },
{ name = "aioredis", specifier = ">=2.0.0" },
{ name = "asyncio-mqtt", specifier = ">=0.16.0" },
{ name = "bandit", marker = "extra == 'dev'", specifier = ">=1.7.0" },
{ name = "bcrypt", specifier = ">=4.1.0" },
{ name = "black", marker = "extra == 'dev'", specifier = ">=23.12.0" },
{ name = "cchardet", marker = "extra == 'performance'", specifier = ">=2.1.0" },
{ name = "click", specifier = ">=8.1.0" },
{ name = "cryptography", specifier = ">=41.0.0" },
{ name = "fastapi", specifier = ">=0.108.0" },
{ name = "flake8", marker = "extra == 'dev'", specifier = ">=7.0.0" },
{ name = "grafana-client", marker = "extra == 'monitoring'", specifier = ">=3.5.0" },
{ name = "httpx", specifier = ">=0.26.0" },
{ name = "isort", marker = "extra == 'dev'", specifier = ">=5.13.0" },
{ name = "jaeger-client", marker = "extra == 'monitoring'", specifier = ">=4.8.0" },
{ name = "mcp", specifier = ">=1.8.0,<2.0.0" },
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" },
{ name = "myst-parser", marker = "extra == 'dev'", specifier = ">=2.0.0" },
{ name = "myst-parser", marker = "extra == 'docs'", specifier = ">=2.0.0" },
{ name = "numpy", specifier = ">=1.24.0" },
{ name = "opentelemetry-api", marker = "extra == 'monitoring'", specifier = ">=1.21.0" },
{ name = "opentelemetry-sdk", marker = "extra == 'monitoring'", specifier = ">=1.21.0" },
{ name = "orjson", specifier = ">=3.9.0" },
{ name = "orjson", marker = "extra == 'performance'", specifier = ">=3.9.0" },
{ name = "pandas", specifier = ">=2.0.0" },
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.0" },
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.6.0" },
{ name = "prometheus-client", specifier = ">=0.19.0" },
{ name = "prometheus-client", marker = "extra == 'monitoring'", specifier = ">=0.19.0" },
{ name = "pyarrow", specifier = ">=14.0.0" },
{ name = "pydantic", specifier = ">=2.5.0" },
{ name = "pydantic-settings", specifier = ">=2.1.0" },
{ name = "pyjwt", specifier = ">=2.8.0" },
{ name = "pymysql", specifier = ">=1.1.0" },
{ name = "pytest", specifier = ">=8.4.0" },
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0" },
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23.0" },
{ name = "pytest-cov", specifier = ">=6.1.1" },
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.1.0" },
{ name = "pytest-mock", marker = "extra == 'dev'", specifier = ">=3.12.0" },
{ name = "pytest-xdist", marker = "extra == 'dev'", specifier = ">=3.5.0" },
{ name = "python-dateutil", specifier = ">=2.8.0" },
{ name = "python-dotenv", specifier = ">=1.0.0" },
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
{ name = "python-multipart", specifier = ">=0.0.6" },
{ name = "pyyaml", specifier = ">=6.0.0" },
{ name = "requests", specifier = ">=2.31.0" },
{ name = "rich", specifier = ">=13.7.0" },
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" },
{ name = "safety", marker = "extra == 'dev'", specifier = ">=2.3.0" },
{ name = "sphinx", marker = "extra == 'dev'", specifier = ">=7.2.0" },
{ name = "sphinx", marker = "extra == 'docs'", specifier = ">=7.2.0" },
{ name = "sphinx-autoapi", marker = "extra == 'docs'", specifier = ">=3.0.0" },
{ name = "sphinx-rtd-theme", marker = "extra == 'dev'", specifier = ">=2.0.0" },
{ name = "sphinx-rtd-theme", marker = "extra == 'docs'", specifier = ">=2.0.0" },
{ name = "sqlparse", specifier = ">=0.4.4" },
{ name = "starlette", specifier = ">=0.27.0" },
{ name = "structlog", specifier = ">=23.2.0" },
{ name = "toml", specifier = ">=0.10.0" },
{ name = "tox", marker = "extra == 'dev'", specifier = ">=4.11.0" },
{ name = "tqdm", specifier = ">=4.66.0" },
{ name = "typer", specifier = ">=0.9.0" },
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.25.0" },
{ name = "uvloop", marker = "extra == 'performance'", specifier = ">=0.19.0" },
{ name = "websockets", specifier = ">=12.0" },
]
provides-extras = ["dev", "docs", "performance", "monitoring"]
[package.metadata.requires-dev]
dev = [{ name = "ruff", specifier = ">=0.11.13" }]
[[package]] [[package]]
name = "dparse" name = "dparse"
version = "0.6.4" version = "0.6.4"
@@ -768,6 +980,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/20/b0/36bd937216ec521246249be3bf9855081de4c5e06a0c9b4219dbeda50373/importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd", size = 27656 }, { url = "https://files.pythonhosted.org/packages/20/b0/36bd937216ec521246249be3bf9855081de4c5e06a0c9b4219dbeda50373/importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd", size = 27656 },
] ]
[[package]]
name = "importlib-resources"
version = "6.5.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/cf/8c/f834fbf984f691b4f7ff60f50b514cc3de5cc08abfc3295564dd89c5e2e7/importlib_resources-6.5.2.tar.gz", hash = "sha256:185f87adef5bcc288449d98fb4fba07cea78bc036455dd44c5fc4a2fe78fed2c", size = 44693 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461 },
]
[[package]] [[package]]
name = "iniconfig" name = "iniconfig"
version = "2.1.0" version = "2.1.0"
@@ -946,170 +1167,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/79/45/823ad05504bea55cb0feb7470387f151252127ad5c72f8882e8fe6cf5c0e/mcp-1.9.3-py3-none-any.whl", hash = "sha256:69b0136d1ac9927402ed4cf221d4b8ff875e7132b0b06edd446448766f34f9b9", size = 131063 }, { url = "https://files.pythonhosted.org/packages/79/45/823ad05504bea55cb0feb7470387f151252127ad5c72f8882e8fe6cf5c0e/mcp-1.9.3-py3-none-any.whl", hash = "sha256:69b0136d1ac9927402ed4cf221d4b8ff875e7132b0b06edd446448766f34f9b9", size = 131063 },
] ]
[[package]]
name = "mcp-doris-server"
version = "0.3.0"
source = { editable = "." }
dependencies = [
{ name = "aiofiles" },
{ name = "aiohttp" },
{ name = "aiomysql" },
{ name = "aioredis" },
{ name = "asyncio-mqtt" },
{ name = "bcrypt" },
{ name = "click" },
{ name = "cryptography" },
{ name = "fastapi" },
{ name = "httpx" },
{ name = "mcp" },
{ name = "numpy" },
{ name = "orjson" },
{ name = "pandas" },
{ name = "passlib", extra = ["bcrypt"] },
{ name = "prometheus-client" },
{ name = "pydantic" },
{ name = "pydantic-settings" },
{ name = "pyjwt" },
{ name = "pymysql" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
{ name = "pytest-cov" },
{ name = "python-dateutil" },
{ name = "python-dotenv" },
{ name = "python-jose", extra = ["cryptography"] },
{ name = "python-multipart" },
{ name = "pyyaml" },
{ name = "requests" },
{ name = "rich" },
{ name = "sqlparse" },
{ name = "starlette" },
{ name = "structlog" },
{ name = "toml" },
{ name = "tqdm" },
{ name = "typer" },
{ name = "uvicorn", extra = ["standard"] },
{ name = "websockets" },
]
[package.optional-dependencies]
dev = [
{ name = "bandit" },
{ name = "black" },
{ name = "flake8" },
{ name = "isort" },
{ name = "mypy" },
{ name = "myst-parser" },
{ name = "pre-commit" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
{ name = "pytest-cov" },
{ name = "pytest-mock" },
{ name = "pytest-xdist" },
{ name = "ruff" },
{ name = "safety" },
{ name = "sphinx" },
{ name = "sphinx-rtd-theme" },
{ name = "tox" },
]
docs = [
{ name = "myst-parser" },
{ name = "sphinx" },
{ name = "sphinx-autoapi" },
{ name = "sphinx-rtd-theme" },
]
monitoring = [
{ name = "grafana-client" },
{ name = "jaeger-client" },
{ name = "opentelemetry-api" },
{ name = "opentelemetry-sdk" },
{ name = "prometheus-client" },
]
performance = [
{ name = "cchardet" },
{ name = "orjson" },
{ name = "uvloop" },
]
[package.dev-dependencies]
dev = [
{ name = "ruff" },
]
[package.metadata]
requires-dist = [
{ name = "aiofiles", specifier = ">=23.0.0" },
{ name = "aiohttp", specifier = ">=3.9.0" },
{ name = "aiomysql", specifier = ">=0.2.0" },
{ name = "aioredis", specifier = ">=2.0.0" },
{ name = "asyncio-mqtt", specifier = ">=0.16.0" },
{ name = "bandit", marker = "extra == 'dev'", specifier = ">=1.7.0" },
{ name = "bcrypt", specifier = ">=4.1.0" },
{ name = "black", marker = "extra == 'dev'", specifier = ">=23.12.0" },
{ name = "cchardet", marker = "extra == 'performance'", specifier = ">=2.1.0" },
{ name = "click", specifier = ">=8.1.0" },
{ name = "cryptography", specifier = ">=41.0.0" },
{ name = "fastapi", specifier = ">=0.108.0" },
{ name = "flake8", marker = "extra == 'dev'", specifier = ">=7.0.0" },
{ name = "grafana-client", marker = "extra == 'monitoring'", specifier = ">=3.5.0" },
{ name = "httpx", specifier = ">=0.26.0" },
{ name = "isort", marker = "extra == 'dev'", specifier = ">=5.13.0" },
{ name = "jaeger-client", marker = "extra == 'monitoring'", specifier = ">=4.8.0" },
{ name = "mcp", specifier = ">=1.0.0" },
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" },
{ name = "myst-parser", marker = "extra == 'dev'", specifier = ">=2.0.0" },
{ name = "myst-parser", marker = "extra == 'docs'", specifier = ">=2.0.0" },
{ name = "numpy", specifier = ">=1.24.0" },
{ name = "opentelemetry-api", marker = "extra == 'monitoring'", specifier = ">=1.21.0" },
{ name = "opentelemetry-sdk", marker = "extra == 'monitoring'", specifier = ">=1.21.0" },
{ name = "orjson", specifier = ">=3.9.0" },
{ name = "orjson", marker = "extra == 'performance'", specifier = ">=3.9.0" },
{ name = "pandas", specifier = ">=2.0.0" },
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.0" },
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.6.0" },
{ name = "prometheus-client", specifier = ">=0.19.0" },
{ name = "prometheus-client", marker = "extra == 'monitoring'", specifier = ">=0.19.0" },
{ name = "pydantic", specifier = ">=2.5.0" },
{ name = "pydantic-settings", specifier = ">=2.1.0" },
{ name = "pyjwt", specifier = ">=2.8.0" },
{ name = "pymysql", specifier = ">=1.1.0" },
{ name = "pytest", specifier = ">=8.4.0" },
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0" },
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23.0" },
{ name = "pytest-cov", specifier = ">=6.1.1" },
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.1.0" },
{ name = "pytest-mock", marker = "extra == 'dev'", specifier = ">=3.12.0" },
{ name = "pytest-xdist", marker = "extra == 'dev'", specifier = ">=3.5.0" },
{ name = "python-dateutil", specifier = ">=2.8.0" },
{ name = "python-dotenv", specifier = ">=1.0.0" },
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
{ name = "python-multipart", specifier = ">=0.0.6" },
{ name = "pyyaml", specifier = ">=6.0.0" },
{ name = "requests", specifier = ">=2.31.0" },
{ name = "rich", specifier = ">=13.7.0" },
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" },
{ name = "safety", marker = "extra == 'dev'", specifier = ">=2.3.0" },
{ name = "sphinx", marker = "extra == 'dev'", specifier = ">=7.2.0" },
{ name = "sphinx", marker = "extra == 'docs'", specifier = ">=7.2.0" },
{ name = "sphinx-autoapi", marker = "extra == 'docs'", specifier = ">=3.0.0" },
{ name = "sphinx-rtd-theme", marker = "extra == 'dev'", specifier = ">=2.0.0" },
{ name = "sphinx-rtd-theme", marker = "extra == 'docs'", specifier = ">=2.0.0" },
{ name = "sqlparse", specifier = ">=0.4.4" },
{ name = "starlette", specifier = ">=0.27.0" },
{ name = "structlog", specifier = ">=23.2.0" },
{ name = "toml", specifier = ">=0.10.0" },
{ name = "tox", marker = "extra == 'dev'", specifier = ">=4.11.0" },
{ name = "tqdm", specifier = ">=4.66.0" },
{ name = "typer", specifier = ">=0.9.0" },
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.25.0" },
{ name = "uvloop", marker = "extra == 'performance'", specifier = ">=0.19.0" },
{ name = "websockets", specifier = ">=12.0" },
]
provides-extras = ["dev", "docs", "performance", "monitoring"]
[package.metadata.requires-dev]
dev = [{ name = "ruff", specifier = ">=0.11.13" }]
[[package]] [[package]]
name = "mdit-py-plugins" name = "mdit-py-plugins"
version = "0.4.2" version = "0.4.2"
@@ -1605,6 +1662,41 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/7b/d7/7831438e6c3ebbfa6e01a927127a6cb42ad3ab844247f3c5b96bea25d73d/psutil-6.1.1-cp37-abi3-win_amd64.whl", hash = "sha256:f35cfccb065fff93529d2afb4a2e89e363fe63ca1e4a5da22b603a85833c2649", size = 254444 }, { url = "https://files.pythonhosted.org/packages/7b/d7/7831438e6c3ebbfa6e01a927127a6cb42ad3ab844247f3c5b96bea25d73d/psutil-6.1.1-cp37-abi3-win_amd64.whl", hash = "sha256:f35cfccb065fff93529d2afb4a2e89e363fe63ca1e4a5da22b603a85833c2649", size = 254444 },
] ]
[[package]]
name = "pyarrow"
version = "20.0.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a2/ee/a7810cb9f3d6e9238e61d312076a9859bf3668fd21c69744de9532383912/pyarrow-20.0.0.tar.gz", hash = "sha256:febc4a913592573c8d5805091a6c2b5064c8bd6e002131f01061797d91c783c1", size = 1125187 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a1/d6/0c10e0d54f6c13eb464ee9b67a68b8c71bcf2f67760ef5b6fbcddd2ab05f/pyarrow-20.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:75a51a5b0eef32727a247707d4755322cb970be7e935172b6a3a9f9ae98404ba", size = 30815067 },
{ url = "https://files.pythonhosted.org/packages/7e/e2/04e9874abe4094a06fd8b0cbb0f1312d8dd7d707f144c2ec1e5e8f452ffa/pyarrow-20.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:211d5e84cecc640c7a3ab900f930aaff5cd2702177e0d562d426fb7c4f737781", size = 32297128 },
{ url = "https://files.pythonhosted.org/packages/31/fd/c565e5dcc906a3b471a83273039cb75cb79aad4a2d4a12f76cc5ae90a4b8/pyarrow-20.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ba3cf4182828be7a896cbd232aa8dd6a31bd1f9e32776cc3796c012855e1199", size = 41334890 },
{ url = "https://files.pythonhosted.org/packages/af/a9/3bdd799e2c9b20c1ea6dc6fa8e83f29480a97711cf806e823f808c2316ac/pyarrow-20.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c3a01f313ffe27ac4126f4c2e5ea0f36a5fc6ab51f8726cf41fee4b256680bd", size = 42421775 },
{ url = "https://files.pythonhosted.org/packages/10/f7/da98ccd86354c332f593218101ae56568d5dcedb460e342000bd89c49cc1/pyarrow-20.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:a2791f69ad72addd33510fec7bb14ee06c2a448e06b649e264c094c5b5f7ce28", size = 40687231 },
{ url = "https://files.pythonhosted.org/packages/bb/1b/2168d6050e52ff1e6cefc61d600723870bf569cbf41d13db939c8cf97a16/pyarrow-20.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4250e28a22302ce8692d3a0e8ec9d9dde54ec00d237cff4dfa9c1fbf79e472a8", size = 42295639 },
{ url = "https://files.pythonhosted.org/packages/b2/66/2d976c0c7158fd25591c8ca55aee026e6d5745a021915a1835578707feb3/pyarrow-20.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:89e030dc58fc760e4010148e6ff164d2f44441490280ef1e97a542375e41058e", size = 42908549 },
{ url = "https://files.pythonhosted.org/packages/31/a9/dfb999c2fc6911201dcbf348247f9cc382a8990f9ab45c12eabfd7243a38/pyarrow-20.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6102b4864d77102dbbb72965618e204e550135a940c2534711d5ffa787df2a5a", size = 44557216 },
{ url = "https://files.pythonhosted.org/packages/a0/8e/9adee63dfa3911be2382fb4d92e4b2e7d82610f9d9f668493bebaa2af50f/pyarrow-20.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:96d6a0a37d9c98be08f5ed6a10831d88d52cac7b13f5287f1e0f625a0de8062b", size = 25660496 },
{ url = "https://files.pythonhosted.org/packages/9b/aa/daa413b81446d20d4dad2944110dcf4cf4f4179ef7f685dd5a6d7570dc8e/pyarrow-20.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:a15532e77b94c61efadde86d10957950392999503b3616b2ffcef7621a002893", size = 30798501 },
{ url = "https://files.pythonhosted.org/packages/ff/75/2303d1caa410925de902d32ac215dc80a7ce7dd8dfe95358c165f2adf107/pyarrow-20.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:dd43f58037443af715f34f1322c782ec463a3c8a94a85fdb2d987ceb5658e061", size = 32277895 },
{ url = "https://files.pythonhosted.org/packages/92/41/fe18c7c0b38b20811b73d1bdd54b1fccba0dab0e51d2048878042d84afa8/pyarrow-20.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aa0d288143a8585806e3cc7c39566407aab646fb9ece164609dac1cfff45f6ae", size = 41327322 },
{ url = "https://files.pythonhosted.org/packages/da/ab/7dbf3d11db67c72dbf36ae63dcbc9f30b866c153b3a22ef728523943eee6/pyarrow-20.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6953f0114f8d6f3d905d98e987d0924dabce59c3cda380bdfaa25a6201563b4", size = 42411441 },
{ url = "https://files.pythonhosted.org/packages/90/c3/0c7da7b6dac863af75b64e2f827e4742161128c350bfe7955b426484e226/pyarrow-20.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:991f85b48a8a5e839b2128590ce07611fae48a904cae6cab1f089c5955b57eb5", size = 40677027 },
{ url = "https://files.pythonhosted.org/packages/be/27/43a47fa0ff9053ab5203bb3faeec435d43c0d8bfa40179bfd076cdbd4e1c/pyarrow-20.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:97c8dc984ed09cb07d618d57d8d4b67a5100a30c3818c2fb0b04599f0da2de7b", size = 42281473 },
{ url = "https://files.pythonhosted.org/packages/bc/0b/d56c63b078876da81bbb9ba695a596eabee9b085555ed12bf6eb3b7cab0e/pyarrow-20.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9b71daf534f4745818f96c214dbc1e6124d7daf059167330b610fc69b6f3d3e3", size = 42893897 },
{ url = "https://files.pythonhosted.org/packages/92/ac/7d4bd020ba9145f354012838692d48300c1b8fe5634bfda886abcada67ed/pyarrow-20.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e8b88758f9303fa5a83d6c90e176714b2fd3852e776fc2d7e42a22dd6c2fb368", size = 44543847 },
{ url = "https://files.pythonhosted.org/packages/9d/07/290f4abf9ca702c5df7b47739c1b2c83588641ddfa2cc75e34a301d42e55/pyarrow-20.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:30b3051b7975801c1e1d387e17c588d8ab05ced9b1e14eec57915f79869b5031", size = 25653219 },
{ url = "https://files.pythonhosted.org/packages/95/df/720bb17704b10bd69dde086e1400b8eefb8f58df3f8ac9cff6c425bf57f1/pyarrow-20.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:ca151afa4f9b7bc45bcc791eb9a89e90a9eb2772767d0b1e5389609c7d03db63", size = 30853957 },
{ url = "https://files.pythonhosted.org/packages/d9/72/0d5f875efc31baef742ba55a00a25213a19ea64d7176e0fe001c5d8b6e9a/pyarrow-20.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:4680f01ecd86e0dd63e39eb5cd59ef9ff24a9d166db328679e36c108dc993d4c", size = 32247972 },
{ url = "https://files.pythonhosted.org/packages/d5/bc/e48b4fa544d2eea72f7844180eb77f83f2030b84c8dad860f199f94307ed/pyarrow-20.0.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f4c8534e2ff059765647aa69b75d6543f9fef59e2cd4c6d18015192565d2b70", size = 41256434 },
{ url = "https://files.pythonhosted.org/packages/c3/01/974043a29874aa2cf4f87fb07fd108828fc7362300265a2a64a94965e35b/pyarrow-20.0.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e1f8a47f4b4ae4c69c4d702cfbdfe4d41e18e5c7ef6f1bb1c50918c1e81c57b", size = 42353648 },
{ url = "https://files.pythonhosted.org/packages/68/95/cc0d3634cde9ca69b0e51cbe830d8915ea32dda2157560dda27ff3b3337b/pyarrow-20.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:a1f60dc14658efaa927f8214734f6a01a806d7690be4b3232ba526836d216122", size = 40619853 },
{ url = "https://files.pythonhosted.org/packages/29/c2/3ad40e07e96a3e74e7ed7cc8285aadfa84eb848a798c98ec0ad009eb6bcc/pyarrow-20.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:204a846dca751428991346976b914d6d2a82ae5b8316a6ed99789ebf976551e6", size = 42241743 },
{ url = "https://files.pythonhosted.org/packages/eb/cb/65fa110b483339add6a9bc7b6373614166b14e20375d4daa73483755f830/pyarrow-20.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:f3b117b922af5e4c6b9a9115825726cac7d8b1421c37c2b5e24fbacc8930612c", size = 42839441 },
{ url = "https://files.pythonhosted.org/packages/98/7b/f30b1954589243207d7a0fbc9997401044bf9a033eec78f6cb50da3f304a/pyarrow-20.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:e724a3fd23ae5b9c010e7be857f4405ed5e679db5c93e66204db1a69f733936a", size = 44503279 },
{ url = "https://files.pythonhosted.org/packages/37/40/ad395740cd641869a13bcf60851296c89624662575621968dcfafabaa7f6/pyarrow-20.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:82f1ee5133bd8f49d31be1299dc07f585136679666b502540db854968576faf9", size = 25944982 },
]
[[package]] [[package]]
name = "pyasn1" name = "pyasn1"
version = "0.6.1" version = "0.6.1"